#include #include #include #include #include namespace torch { namespace jit { namespace { struct Foo : torch::jit::CustomClassHolder { int x, y; Foo() : x(0), y(0) {} Foo(int x_, int y_) : x(x_), y(y_) {} int64_t info() { return this->x * this->y; } int64_t add(int64_t z) { return (x + y) * z; } void increment(int64_t z) { this->x += z; this->y += z; } int64_t combine(c10::intrusive_ptr b) { return this->info() + b->info(); } ~Foo() { // std::cout<<"Destroying object with values: "< struct Stack : torch::jit::CustomClassHolder { std::vector stack_; Stack(std::vector init) : stack_(init.begin(), init.end()) {} void push(T x) { stack_.push_back(x); } T pop() { auto val = stack_.back(); stack_.pop_back(); return val; } c10::intrusive_ptr clone() const { return c10::make_intrusive(stack_); } void merge(const c10::intrusive_ptr& c) { for (auto& elem : c->stack_) { push(elem); } } std::tuple return_a_tuple() const { return std::make_tuple(1337.0f, 123); } }; struct PickleTester : torch::jit::CustomClassHolder { PickleTester(std::vector vals) : vals(std::move(vals)) {} std::vector vals; }; static auto test = torch::jit::class_("_TorchScriptTesting_Foo") .def(torch::jit::init()) // .def(torch::jit::init<>()) .def("info", &Foo::info) .def("increment", &Foo::increment) .def("add", &Foo::add) .def("combine", &Foo::combine); static auto testStack = torch::jit::class_>("_TorchScriptTesting_StackString") .def(torch::jit::init>()) .def("push", &Stack::push) .def("pop", &Stack::pop) .def("clone", &Stack::clone) .def("merge", &Stack::merge) .def_pickle( [](const c10::intrusive_ptr>& self) { return self->stack_; }, [](std::vector state) { // __setstate__ return c10::make_intrusive>( std::vector{"i", "was", "deserialized"}); }) .def("return_a_tuple", &Stack::return_a_tuple) .def( "top", [](const c10::intrusive_ptr>& self) -> std::string { return self->stack_.back(); }); // clang-format off // The following will fail with a static assert telling you you have to // take an intrusive_ptr as the first argument. // .def("foo", [](int64_t a) -> int64_t{ return 3;}); // clang-format on static auto testPickle = torch::jit::class_("_TorchScriptTesting_PickleTester") .def(torch::jit::init>()) .def_pickle( [](c10::intrusive_ptr self) { // __getstate__ return std::vector{1, 3, 3, 7}; }, [](std::vector state) { // __setstate__ return c10::make_intrusive(std::move(state)); }) .def( "top", [](const c10::intrusive_ptr& self) { return self->vals.back(); }) .def("pop", [](const c10::intrusive_ptr& self) { auto val = self->vals.back(); self->vals.pop_back(); return val; }); at::Tensor take_an_instance(const c10::intrusive_ptr& instance) { return torch::zeros({instance->vals.back(), 4}); } torch::RegisterOperators& register_take_instance() { static auto instance_registry = torch::RegisterOperators().op( torch::RegisterOperators::options() .schema( "_TorchScriptTesting::take_an_instance(__torch__.torch.classes._TorchScriptTesting_PickleTester x) -> Tensor Y") .catchAllKernel()); return instance_registry; } static auto& ensure_take_instance_registered = register_take_instance(); } // namespace } // namespace jit } // namespace torch