#include #include #include #include #include namespace torch { namespace jit { namespace { struct Foo : torch::CustomClassHolder { int x, y; Foo() : x(0), y(0) {} Foo(int x_, int y_) : x(x_), y(y_) {} int64_t info() { return this->x * this->y; } int64_t add(int64_t z) { return (x + y) * z; } void increment(int64_t z) { this->x += z; this->y += z; } int64_t combine(c10::intrusive_ptr b) { return this->info() + b->info(); } ~Foo() { // std::cout<<"Destroying object with values: "< struct MyStackClass : torch::CustomClassHolder { std::vector stack_; MyStackClass(std::vector init) : stack_(init.begin(), init.end()) {} void push(T x) { stack_.push_back(x); } T pop() { auto val = stack_.back(); stack_.pop_back(); return val; } c10::intrusive_ptr clone() const { return c10::make_intrusive(stack_); } void merge(const c10::intrusive_ptr& c) { for (auto& elem : c->stack_) { push(elem); } } std::tuple return_a_tuple() const { return std::make_tuple(1337.0f, 123); } }; struct PickleTester : torch::CustomClassHolder { PickleTester(std::vector vals) : vals(std::move(vals)) {} std::vector vals; }; static auto test = torch::class_("_TorchScriptTesting", "_Foo") .def(torch::init()) // .def(torch::init<>()) .def("info", &Foo::info) .def("increment", &Foo::increment) .def("add", &Foo::add) .def("combine", &Foo::combine); static auto testStack = torch::class_>( "_TorchScriptTesting", "_StackString") .def(torch::init>()) .def("push", &MyStackClass::push) .def("pop", &MyStackClass::pop) .def("clone", &MyStackClass::clone) .def("merge", &MyStackClass::merge) .def_pickle( [](const c10::intrusive_ptr>& self) { return self->stack_; }, [](std::vector state) { // __setstate__ return c10::make_intrusive>( std::vector{"i", "was", "deserialized"}); }) .def("return_a_tuple", &MyStackClass::return_a_tuple) .def( "top", [](const c10::intrusive_ptr>& self) -> std::string { return self->stack_.back(); }); // clang-format off // The following will fail with a static assert telling you you have to // take an intrusive_ptr as the first argument. // .def("foo", [](int64_t a) -> int64_t{ return 3;}); // clang-format on static auto testPickle = torch::class_("_TorchScriptTesting", "_PickleTester") .def(torch::init>()) .def_pickle( [](c10::intrusive_ptr self) { // __getstate__ return std::vector{1, 3, 3, 7}; }, [](std::vector state) { // __setstate__ return c10::make_intrusive(std::move(state)); }) .def( "top", [](const c10::intrusive_ptr& self) { return self->vals.back(); }) .def("pop", [](const c10::intrusive_ptr& self) { auto val = self->vals.back(); self->vals.pop_back(); return val; }); at::Tensor take_an_instance(const c10::intrusive_ptr& instance) { return torch::zeros({instance->vals.back(), 4}); } torch::RegisterOperators& register_take_instance() { static auto instance_registry = torch::RegisterOperators().op( torch::RegisterOperators::options() .schema( "_TorchScriptTesting::take_an_instance(__torch__.torch.classes._TorchScriptTesting._PickleTester x) -> Tensor Y") .catchAllKernel()); return instance_registry; } static auto& ensure_take_instance_registered = register_take_instance(); } // namespace void testTorchbindIValueAPI() { script::Module m("m"); // test make_custom_class API auto custom_class_obj = make_custom_class>( std::vector{"foo", "bar"}); m.define(R"( def forward(self, s : __torch__.torch.classes._TorchScriptTesting._StackString): return s.pop(), s )"); auto test_with_obj = [&m](IValue obj, std::string expected) { auto res = m.run_method("forward", obj); auto tup = res.toTuple(); AT_ASSERT(tup->elements().size() == 2); auto str = tup->elements()[0].toStringRef(); auto other_obj = tup->elements()[1].toCustomClass>(); AT_ASSERT(str == expected); auto ref_obj = obj.toCustomClass>(); AT_ASSERT(other_obj.get() == ref_obj.get()); }; test_with_obj(custom_class_obj, "bar"); // test IValue() API auto my_new_stack = c10::make_intrusive>( std::vector{"baz", "boo"}); auto new_stack_ivalue = c10::IValue(my_new_stack); test_with_obj(new_stack_ivalue, "boo"); } } // namespace jit } // namespace torch