mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Creating an untyped generic list is deprecated, we always want type information to be present. This fixes test cases and removes one that used lists with ambigious types. Pull Request resolved: https://github.com/pytorch/pytorch/pull/23192 ghstack-source-id: 86972891 Differential Revision: D16431482 fbshipit-source-id: 4ca5cd142118a3f0a4dcb8cd77383127c54abb29
113 lines
3.3 KiB
C++
113 lines
3.3 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <torch/jit.h>
|
|
#include <torch/types.h>
|
|
|
|
#include <string>
|
|
|
|
TEST(TorchScriptTest, CanCompileMultipleFunctions) {
|
|
auto module = torch::jit::compile(R"JIT(
|
|
def test_mul(a, b):
|
|
return a * b
|
|
def test_relu(a, b):
|
|
return torch.relu(a + b)
|
|
def test_while(a, i):
|
|
while bool(i < 10):
|
|
a += a
|
|
i += 1
|
|
return a
|
|
def test_len(a : List[int]):
|
|
return len(a)
|
|
)JIT");
|
|
auto a = torch::ones(1);
|
|
auto b = torch::ones(1);
|
|
|
|
ASSERT_EQ(1, module->run_method("test_mul", a, b).toTensor().item<int64_t>());
|
|
|
|
ASSERT_EQ(2, module->run_method("test_relu", a, b).toTensor().item<int64_t>());
|
|
|
|
ASSERT_TRUE(
|
|
0x200 == module->run_method("test_while", a, b).toTensor().item<int64_t>());
|
|
|
|
at::IValue list = c10::List<int64_t>({3, 4});
|
|
ASSERT_EQ(2, module->run_method("test_len", list).toInt());
|
|
|
|
}
|
|
|
|
|
|
TEST(TorchScriptTest, TestNestedIValueModuleArgMatching) {
|
|
auto module = torch::jit::compile(R"JIT(
|
|
def nested_loop(a: List[List[Tensor]], b: int):
|
|
return torch.tensor(1.0) + b
|
|
)JIT");
|
|
|
|
auto b = 3;
|
|
|
|
torch::List<torch::Tensor> list({torch::rand({4, 4})});
|
|
|
|
torch::List<torch::List<torch::Tensor>> list_of_lists;
|
|
list_of_lists.push_back(list);
|
|
module->run_method("nested_loop", list_of_lists, b);
|
|
|
|
auto generic_list = c10::impl::GenericList(at::TensorType::get());
|
|
auto empty_generic_list = c10::impl::GenericList(at::ListType::create(at::TensorType::get()));
|
|
empty_generic_list.push_back(generic_list);
|
|
module->run_method("nested_loop", empty_generic_list, b);
|
|
|
|
auto too_many_lists = c10::impl::GenericList(at::ListType::create(at::ListType::create(at::TensorType::get())));
|
|
too_many_lists.push_back(empty_generic_list);
|
|
try {
|
|
module->run_method("nested_loop", too_many_lists, b);
|
|
AT_ASSERT(false);
|
|
} catch (const c10::Error& error) {
|
|
AT_ASSERT(
|
|
std::string(error.what_without_backtrace())
|
|
.find("nested_loop() Expected a value of type 'List[List[Tensor]]'"
|
|
" for argument 'a' but instead found type "
|
|
"'List[List[List[t]]]'") == 0);
|
|
};
|
|
}
|
|
|
|
|
|
TEST(TorchScriptTest, TestDictArgMatching) {
|
|
auto module = torch::jit::compile(R"JIT(
|
|
def dict_op(a: Dict[str, Tensor], b: str):
|
|
return a[b]
|
|
)JIT");
|
|
c10::Dict<std::string, at::Tensor> dict;
|
|
dict.insert("hello", torch::ones({2}));
|
|
auto output = module->run_method("dict_op", dict, std::string("hello"));
|
|
ASSERT_EQ(1, output.toTensor()[0].item<int64_t>());
|
|
}
|
|
|
|
TEST(TorchScriptTest, TestTupleArgMatching) {
|
|
auto module = torch::jit::compile(R"JIT(
|
|
def tuple_op(a: Tuple[List[int]]):
|
|
return a
|
|
)JIT");
|
|
|
|
c10::List<int64_t> int_list({1});
|
|
auto tuple_generic_list = c10::ivalue::Tuple::create({ int_list });
|
|
|
|
// doesn't fail on arg matching
|
|
module->run_method("tuple_op", tuple_generic_list);
|
|
|
|
}
|
|
|
|
TEST(TorchScriptTest, TestOptionalArgMatching) {
|
|
auto module = torch::jit::compile(R"JIT(
|
|
def optional_tuple_op(a: Optional[Tuple[int, str]]):
|
|
if a is None:
|
|
return 0
|
|
else:
|
|
return a[0]
|
|
)JIT");
|
|
|
|
auto optional_tuple = c10::ivalue::Tuple::create({2, std::string("hi")});
|
|
|
|
ASSERT_EQ(2, module->run_method("optional_tuple_op", optional_tuple).toInt());
|
|
ASSERT_EQ(
|
|
0, module->run_method("optional_tuple_op", torch::jit::IValue()).toInt());
|
|
|
|
}
|