#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace jit { template detail::types init() { return detail::types{}; } // To bind custom classes into Torchscript, use an API very similar to Pybind's. // Currently exposes one class `torch::jit::class_` and 2 methods. // - Constructing `torch::jit::class_` registers `Foo` in Python and // Torchscript, and puts it under `torch.classes.Foo` in Python. // - torch::jit::class_.def("method1", &Foo::method1) does some template // metaprogramming to introspect the function types and register the operator // for use in Torchscript. // - torch::jit::class_.def(torch::jit::init()) registers // the Foo(int, int) constructor. // see test/custom_operator/classes.cpp and // test/custom_operator/test_custom_classes.py for example usages template class class_ { static_assert(std::is_base_of::value, "torch::jit::class_ requires T to inherit from CustomClassHolder"); std::string className; std::string qualClassName; ClassTypePtr classTypePtr; const std::string parentModule = "classes"; const std::string topModule = "__torch__.torch"; public: class_(std::string className_) : className(std::move(className_)) { qualClassName = topModule + "." + parentModule + "." + className; // We currently represent custom classes as torchscript classes with a // capsule attribute classTypePtr = ClassType::create(c10::QualifiedName(qualClassName), classCU()); classTypePtr->addAttribute("capsule", CapsuleType::get()); c10::getCustomClassTypeMap().insert({typeid(c10::intrusive_ptr).name(), c10::StrongTypePtr(classCU(), classTypePtr)}); c10::getCustomClassTypeMap().insert({typeid(c10::tagged_capsule).name(), c10::StrongTypePtr(classCU(), classTypePtr)}); classCU()->register_type(classTypePtr); } template class_& def(detail::types) { // Used in combination with // torch::jit::init<...>() auto func = [](c10::tagged_capsule self, Types... args) { auto classObj = c10::make_intrusive(args...); auto genericPtr = c10::static_intrusive_pointer_cast(std::move(classObj)); auto capsule = IValue(std::move(genericPtr)); auto object = std::move(self.ivalue).toObject(); object->setSlot(0, std::move(capsule)); }; defineMethod("__init__", std::move(func)); return *this; } template class_& def(std::string name, Func f) { auto wrapped_f = detail::wrap_func(std::move(f)); defineMethod(std::move(name), std::move(wrapped_f)); return *this; } // Pickle template class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) { static_assert( c10::guts::is_stateless_lambda>::value && c10::guts::is_stateless_lambda>::value, "torch::jit::pickle_ currently only supports lambdas as " "__getstate__ and __setstate__ arguments."); def("__getstate__", std::forward(get_state)); // __setstate__ needs to be registered with some custom handling: // We need to wrap the invocation of of the user-provided function // such that we take the return value (i.e. c10::intrusive_ptr) // and assign it to the `capsule` attribute. using SetStateTraits = c10::guts::infer_function_traits_t>; using SetStateArg = typename c10::guts::typelist::head_t< typename SetStateTraits::parameter_types>; auto setstate_wrapper = [set_state = std::move(set_state)]( c10::tagged_capsule self, SetStateArg&& arg) { c10::intrusive_ptr classObj = at::guts::invoke(set_state, std::forward(arg)); auto genericPtr = c10::static_intrusive_pointer_cast( classObj); auto capsule = IValue(genericPtr); auto object = self.ivalue.toObject(); object->setSlot(0, capsule); }; defineMethod( "__setstate__", detail::wrap_func( std::move(setstate_wrapper))); // type validation auto getstate_schema = classTypePtr->getMethod("__getstate__")->getSchema(); auto format_getstate_schema = [&getstate_schema]() { std::stringstream ss; ss << getstate_schema; return ss.str(); }; TORCH_CHECK( getstate_schema.arguments().size() == 1, "__getstate__ should take exactly one argument: self. Got: ", format_getstate_schema()); auto first_arg_type = getstate_schema.arguments().at(0).type(); TORCH_CHECK( *first_arg_type == *classTypePtr, "self argument of __getstate__ must be the custom class type. Got ", first_arg_type->python_str()); TORCH_CHECK( getstate_schema.returns().size() == 1, "__getstate__ should return exactly one value for serialization. Got: ", format_getstate_schema()); auto ser_type = getstate_schema.returns().at(0).type(); auto setstate_schema = classTypePtr->getMethod("__setstate__")->getSchema(); auto arg_type = setstate_schema.arguments().at(1).type(); TORCH_CHECK( (*arg_type == *ser_type), "__setstate__'s argument should be the same type as the " "return value of __getstate__. Got ", arg_type->python_str(), " but expected ", ser_type->python_str()); return *this; } private: template void defineMethod(std::string name, Func func) { auto graph = std::make_shared(); auto qualFuncName = className + "::" + name; ensure_c10_registerer_defined(); registeredOps().push_back( torch::RegisterOperators().op(qualFuncName, std::move(func))); auto func_symbol = c10::Symbol::fromQualString(qualFuncName); auto ops = torch::jit::getAllOperatorsFor(func_symbol); TORCH_CHECK(ops.size() == 1); auto &schema = ops[0]->schema(); for (const auto& arg : schema.arguments()) { graph->addInput()->setType(arg.type()); } auto opCall = graph->insertNode(graph->create( func_symbol, graph->inputs(), schema.returns().size())); Value* res; if (schema.returns().size() > 1) { const auto& returns = schema.returns(); size_t op_invocation_idx = 0; for (const auto& ret : returns) { opCall->output(op_invocation_idx++)->setType(ret.type()); } res = graph->insertNode(graph->createTuple(opCall->outputs()))->output(); } else if (schema.returns().size() == 1) { const auto& returns = schema.returns(); res = opCall->output()->setType(returns[0].type()); } else { res = graph->insertConstant(IValue())->setType(NoneType::get()); } graph->registerOutput(res); auto method = classCU()->create_function(qualClassName + "." + name, graph); classTypePtr->addMethod(method); } }; } // namespace jit } // namespace torch