#include #include #include #include #include namespace torch { namespace jit { namespace { c10::OperatorOptions aliasAnalysisInternalSpecialCase() { c10::OperatorOptions options; options.setAliasAnalysis(AliasAnalysisKind::INTERNAL_SPECIAL_CASE); return options; } } // namespace Value* insertConstant( Graph& g, const IValue& val, const c10::TypePtr& result_type, c10::optional loc, c10::optional scope) { auto value = tryInsertConstant(g, val, result_type, loc, scope); if (value) { return *value; } throw constant_not_supported_error( "Unsupported value kind: " + val.tagKind()); } // IValue -> Constant node c10::optional tryInsertConstant( Graph& g, const IValue& val, const c10::TypePtr& result_type, c10::optional loc, c10::optional scope) { Node* n = g.create(prim::Constant); if (val.isTensor()) { at::Tensor ref = val.toTensor(); if (!ref.defined()) { n->destroy(); return g.insertNode(g.createNone(TensorType::get()))->output(); } // TODO: fix all cases where we are not passing in a variable, // and then change this to an AT_ASSERT if (!ref.is_variable()) { ref = autograd::make_variable(ref, /*requires_grad=*/false); } else { AT_ASSERT(!ref.requires_grad()); } n->output()->inferTypeFrom( ref); // note: before t_ because of std::move(ref) n->t_(attr::value, std::move(ref)); } else if (val.isInt()) { n->i_(attr::value, val.toInt()); n->output()->setType(IntType::get()); } else if (val.isDouble()) { n->f_(attr::value, val.toDouble()); n->output()->setType(FloatType::get()); } else if (val.isBool()) { n->i_(attr::value, val.toBool()); n->output()->setType(BoolType::get()); } else if (val.isBoolList()) { auto bool_list = val.toBoolList(); n->is_( attr::value, std::vector(bool_list.begin(), bool_list.end())); n->output()->setType(ListType::ofBools()); } else if (val.isIntList()) { n->is_(attr::value, val.toIntListRef().vec()); n->output()->setType(ListType::ofInts()); } else if (val.isTensorList()) { n->ts_( attr::value, fmap(val.toTensorListRef(), [](const at::Tensor& t) { AT_ASSERT(t.is_variable() && !t.requires_grad()); return t; })); n->output()->setType(ListType::ofTensors()); } else if (val.isString()) { n->s_(attr::value, val.toString()->string()); n->output()->setType(StringType::get()); } else if (val.isDevice()) { std::stringstream ss; ss << val.toDevice(); n->s_(attr::value, ss.str()); n->output()->setType(DeviceObjType::get()); } else if (val.isNone()) { n->output()->setType(NoneType::get()); } else { n->destroy(); return c10::nullopt; } if (loc) n->setSourceRange(*loc); if (scope) n->setScope(*scope); if (result_type) { auto inferred_type = n->output()->type(); if (inferred_type->isSubtypeOf(NoneType::get()) && !inferred_type->isSubtypeOf(result_type)) { // None doesn't subtype Optional, but an Optional can be None, so handle // that here TORCH_CHECK( result_type->kind() == TypeKind::OptionalType, "Expected OptionalType or NoneType, got ", result_type->python_str()); n->output()->setType(result_type); } else if (!(inferred_type->isSubtypeOf(TensorType::get()) && result_type->isSubtypeOf(inferred_type))) { // Retain more type information in case of tensor constant n->output()->setType(result_type); } } return g.insertNode(n)->output(); } RegisterOperators reg({ Operator( FunctionSchema( prim::Constant, "", {}, {}, /*is_vararg=*/false, /*is_varret=*/true), [](const Node* node) -> Operation { TypePtr type = node->output()->type(); if (type->isSubtypeOf(TensorType::get())) { auto t = node->t(attr::value); return [t](Stack& stack) { push(stack, t); return 0; }; } else if (type->isSubtypeOf(BoolType::get())) { bool b = node->i(attr::value); return [b](Stack& stack) { push(stack, b); return 0; }; } else if ( type->isSubtypeOf(NumberType::get()) && node->kindOf(attr::value) == AttributeKind::i) { auto i = node->i(attr::value); return [i](Stack& stack) { push(stack, i); return 0; }; } else if ( type->isSubtypeOf(NumberType::get()) && node->kindOf(attr::value) == AttributeKind::f) { auto f = node->f(attr::value); return [f](Stack& stack) { push(stack, f); return 0; }; } else if (type->isSubtypeOf(ListType::ofInts())) { const auto& is = node->is(attr::value); return [is](Stack& stack) { push(stack, is); return 0; }; } else if (type->isSubtypeOf(ListType::ofBools())) { const auto bs = fmap(node->is(attr::value)); return [bs](Stack& stack) { push(stack, bs); return 0; }; } else if (type->isSubtypeOf(ListType::ofTensors())) { const auto& ts = node->ts(attr::value); return [ts](Stack& stack) { push(stack, ts); return 0; }; } else if (type == StringType::get()) { const auto& s = node->s(attr::value); return [s](Stack& stack) { push(stack, s); return 0; }; } else if (type == DeviceObjType::get()) { auto d = c10::Device(node->s(attr::value)); return [d](Stack& stack) { push(stack, d); return 0; }; } else if (node->mustBeNone()) { return [](Stack& stack) { push(stack, IValue()); return 0; }; } else { std::stringstream ss; ss << "constant literal not supported for: " << type->str(); throw std::runtime_error(ss.str()); } }, aliasAnalysisInternalSpecialCase()), }); c10::optional toIValue(const Value* v) { if (v->node()->kind() != prim::Constant) { return c10::nullopt; } // use implemenation of prim::Constant to compute the output IValue auto op = getOperation(v->node()); Stack stack; op(stack); return stack.back(); } } // namespace jit } // namespace torch