mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45249 Reland of https://github.com/pytorch/pytorch/pull/45055 and https://github.com/pytorch/pytorch/pull/45020 See https://github.com/pytorch/pytorch/pull/45018 for context. Test Plan: Imported from OSS Reviewed By: jamesr66a Differential Revision: D23892645 Pulled By: suo fbshipit-source-id: e7fe58d5e1a5a0c44f4e2aec9694145afabde0fd
49 lines
1.4 KiB
C++
49 lines
1.4 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <test/cpp/jit/test_custom_class_registrations.h>
|
|
#include <torch/custom_class.h>
|
|
#include <torch/script.h>
|
|
|
|
#include <iostream>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
TEST(CustomClassTest, TorchbindIValueAPI) {
|
|
script::Module m("m");
|
|
|
|
// test make_custom_class API
|
|
auto custom_class_obj = make_custom_class<MyStackClass<std::string>>(
|
|
std::vector<std::string>{"foo", "bar"});
|
|
m.define(R"(
|
|
def forward(self, s : __torch__.torch.classes._TorchScriptTesting._StackString):
|
|
return s.pop(), s
|
|
)");
|
|
|
|
auto test_with_obj = [&m](IValue obj, std::string expected) {
|
|
auto res = m.run_method("forward", obj);
|
|
auto tup = res.toTuple();
|
|
AT_ASSERT(tup->elements().size() == 2);
|
|
auto str = tup->elements()[0].toStringRef();
|
|
auto other_obj =
|
|
tup->elements()[1].toCustomClass<MyStackClass<std::string>>();
|
|
AT_ASSERT(str == expected);
|
|
auto ref_obj = obj.toCustomClass<MyStackClass<std::string>>();
|
|
AT_ASSERT(other_obj.get() == ref_obj.get());
|
|
};
|
|
|
|
test_with_obj(custom_class_obj, "bar");
|
|
|
|
// test IValue() API
|
|
auto my_new_stack = c10::make_intrusive<MyStackClass<std::string>>(
|
|
std::vector<std::string>{"baz", "boo"});
|
|
auto new_stack_ivalue = c10::IValue(my_new_stack);
|
|
|
|
test_with_obj(new_stack_ivalue, "boo");
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|