#include #include #include #include #include #include using namespace torch::jit; namespace { struct Foo : torch::CustomClassHolder { int x, y; Foo() : x(0), y(0) {} Foo(int x_, int y_) : x(x_), y(y_) {} int64_t info() { return this->x * this->y; } int64_t add(int64_t z) { return (x + y) * z; } void increment(int64_t z) { this->x += z; this->y += z; } int64_t combine(c10::intrusive_ptr b) { return this->info() + b->info(); } ~Foo() { // std::cout<<"Destroying object with values: "<x - this->y; } }; struct NoInit : torch::CustomClassHolder { int64_t x; }; struct PickleTester : torch::CustomClassHolder { PickleTester(std::vector vals) : vals(std::move(vals)) {} std::vector vals; }; at::Tensor take_an_instance(const c10::intrusive_ptr& instance) { return torch::zeros({instance->vals.back(), 4}); } struct ElementwiseInterpreter : torch::CustomClassHolder { using InstructionType = std::tuple< std::string /*op*/, std::vector /*inputs*/, std::string /*output*/>; ElementwiseInterpreter() {} // Load a list of instructions into the interpreter. As specified above, // instructions specify the operation (currently support "add" and "mul"), // the names of the input values, and the name of the single output value // from this instruction void setInstructions(std::vector instructions) { instructions_ = std::move(instructions); } // Add a constant. The interpreter maintains a set of constants across // calls. They are keyed by name, and constants can be referenced in // Instructions by the name specified void addConstant(const std::string& name, at::Tensor value) { constants_.insert_or_assign(name, std::move(value)); } // Set the string names for the positional inputs to the function this // interpreter represents. When invoked, the interpreter will assign // the positional inputs to the names in the corresponding position in // input_names. void setInputNames(std::vector input_names) { input_names_ = std::move(input_names); } // Specify the output name for the function this interpreter represents. This // should match the "output" field of one of the instructions in the // instruction list, typically the last instruction. void setOutputName(std::string output_name) { output_name_ = std::move(output_name); } // Invoke this interpreter. This takes a list of positional inputs and returns // a single output. Currently, inputs and outputs must all be Tensors. at::Tensor __call__(std::vector inputs) { // Environment to hold local variables std::unordered_map environment; // Load inputs according to the specified names if (inputs.size() != input_names_.size()) { std::stringstream err; err << "Expected " << input_names_.size() << " inputs, but got " << inputs.size() << "!"; throw std::runtime_error(err.str()); } for (size_t i = 0; i < inputs.size(); ++i) { environment[input_names_[i]] = inputs[i]; } for (InstructionType& instr : instructions_) { // Retrieve all input values for this op std::vector inputs; for (const auto& input_name : std::get<1>(instr)) { // Operator output values shadow constants. // Imagine all constants are defined in statements at the beginning // of a function (a la K&R C). Any definition of an output value must // necessarily come after constant definition in textual order. Thus, // We look up values in the environment first then the constant table // second to implement this shadowing behavior if (environment.find(input_name) != environment.end()) { inputs.push_back(environment.at(input_name)); } else if (constants_.find(input_name) != constants_.end()) { inputs.push_back(constants_.at(input_name)); } else { std::stringstream err; err << "Instruction referenced unknown value " << input_name << "!"; throw std::runtime_error(err.str()); } } // Run the specified operation at::Tensor result; const auto& op = std::get<0>(instr); if (op == "add") { if (inputs.size() != 2) { throw std::runtime_error("Unexpected number of inputs for add op!"); } result = inputs[0] + inputs[1]; } else if (op == "mul") { if (inputs.size() != 2) { throw std::runtime_error("Unexpected number of inputs for mul op!"); } result = inputs[0] * inputs[1]; } else { std::stringstream err; err << "Unknown operator " << op << "!"; throw std::runtime_error(err.str()); } // Write back result into environment const auto& output_name = std::get<2>(instr); environment[output_name] = std::move(result); } if (!output_name_) { throw std::runtime_error("Output name not specififed!"); } return environment.at(*output_name_); } // Ser/De infrastructure. See // https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html#defining-serialization-deserialization-methods-for-custom-c-classes // for more info. // This is the type we will use to marshall information on disk during // ser/de. It is a simple tuple composed of primitive types and simple // collection types like vector, optional, and dict. using SerializationType = std::tuple< std::vector /*input_names_*/, c10::optional /*output_name_*/, c10::Dict /*constants_*/, std::vector /*instructions_*/ >; // This function yields the SerializationType instance for `this`. SerializationType __getstate__() const { return SerializationType{ input_names_, output_name_, constants_, instructions_}; } // This function will create an instance of `ElementwiseInterpreter` given // an instance of `SerializationType`. static c10::intrusive_ptr __setstate__( SerializationType state) { auto instance = c10::make_intrusive(); std::tie( instance->input_names_, instance->output_name_, instance->constants_, instance->instructions_) = std::move(state); return instance; } // Class members std::vector input_names_; c10::optional output_name_; c10::Dict constants_; std::vector instructions_; }; TORCH_LIBRARY(_TorchScriptTesting, m) { m.class_("_Foo") .def(torch::init()) // .def(torch::init<>()) .def("info", &Foo::info) .def("increment", &Foo::increment) .def("add", &Foo::add) .def("combine", &Foo::combine); m.class_("_LambdaInit") .def(torch::init([](int64_t x, int64_t y, bool swap) { if (swap) { return c10::make_intrusive(y, x); } else { return c10::make_intrusive(x, y); } })) .def("diff", &LambdaInit::diff); m.class_("_NoInit").def( "get_x", [](const c10::intrusive_ptr& self) { return self->x; }); m.class_>("_StackString") .def(torch::init>()) .def("push", &MyStackClass::push) .def("pop", &MyStackClass::pop) .def("clone", &MyStackClass::clone) .def("merge", &MyStackClass::merge) .def_pickle( [](const c10::intrusive_ptr>& self) { return self->stack_; }, [](std::vector state) { // __setstate__ return c10::make_intrusive>( std::vector{"i", "was", "deserialized"}); }) .def("return_a_tuple", &MyStackClass::return_a_tuple) .def( "top", [](const c10::intrusive_ptr>& self) -> std::string { return self->stack_.back(); }) .def( "__str__", [](const c10::intrusive_ptr>& self) { std::stringstream ss; ss << "["; for (size_t i = 0; i < self->stack_.size(); ++i) { ss << self->stack_[i]; if (i != self->stack_.size() - 1) { ss << ", "; } } ss << "]"; return ss.str(); }); // clang-format off // The following will fail with a static assert telling you you have to // take an intrusive_ptr as the first argument. // .def("foo", [](int64_t a) -> int64_t{ return 3;}); // clang-format on m.class_("_PickleTester") .def(torch::init>()) .def_pickle( [](c10::intrusive_ptr self) { // __getstate__ return std::vector{1, 3, 3, 7}; }, [](std::vector state) { // __setstate__ return c10::make_intrusive(std::move(state)); }) .def( "top", [](const c10::intrusive_ptr& self) { return self->vals.back(); }) .def("pop", [](const c10::intrusive_ptr& self) { auto val = self->vals.back(); self->vals.pop_back(); return val; }); m.def( "take_an_instance(__torch__.torch.classes._TorchScriptTesting._PickleTester x) -> Tensor Y", take_an_instance); // test that schema inference is ok too m.def("take_an_instance_inferred", take_an_instance); m.class_("_ElementwiseInterpreter") .def(torch::init<>()) .def("set_instructions", &ElementwiseInterpreter::setInstructions) .def("add_constant", &ElementwiseInterpreter::addConstant) .def("set_input_names", &ElementwiseInterpreter::setInputNames) .def("set_output_name", &ElementwiseInterpreter::setOutputName) .def("__call__", &ElementwiseInterpreter::__call__) .def_pickle( /* __getstate__ */ [](const c10::intrusive_ptr& self) { return self->__getstate__(); }, /* __setstate__ */ [](ElementwiseInterpreter::SerializationType state) { return ElementwiseInterpreter::__setstate__(std::move(state)); }); } } // namespace