#include #include #include #include #include #include #include namespace torch { namespace jit { TEST(CustomClassTest, TorchbindIValueAPI) { script::Module m("m"); // test make_custom_class API auto custom_class_obj = make_custom_class>( std::vector{"foo", "bar"}); m.define(R"( def forward(self, s : __torch__.torch.classes._TorchScriptTesting._StackString): return s.pop(), s )"); auto test_with_obj = [&m](IValue obj, std::string expected) { auto res = m.run_method("forward", obj); auto tup = res.toTuple(); AT_ASSERT(tup->elements().size() == 2); auto str = tup->elements()[0].toStringRef(); auto other_obj = tup->elements()[1].toCustomClass>(); AT_ASSERT(str == expected); auto ref_obj = obj.toCustomClass>(); AT_ASSERT(other_obj.get() == ref_obj.get()); }; test_with_obj(custom_class_obj, "bar"); // test IValue() API auto my_new_stack = c10::make_intrusive>( std::vector{"baz", "boo"}); auto new_stack_ivalue = c10::IValue(my_new_stack); test_with_obj(new_stack_ivalue, "boo"); } } // namespace jit } // namespace torch