#include #include #include #include #include #include #include #include #include #include #include #include #include "caffe2/serialize/istream_adapter.h" namespace torch { namespace jit { namespace { Module roundtripThroughMobile(const Module& m) { ExtraFilesMap files; std::vector constants; jitModuleToPythonCodeAndConstants(m, &files, &constants); CompilationOptions options; mobile::Module mobilem = jitModuleToMobile(m, options); return jitModuleFromSourceAndConstants( mobilem._ivalue(), files, constants, 8); } template inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) { try { std::forward(functor)(); } catch (const Error& e) { EXPECT_STREQ(e.what_without_backtrace(), expectedMessage); return; } ADD_FAILURE() << "Expected to throw exception with message \"" << expectedMessage << "\" but didn't throw"; } } // namespace TEST(SerializationTest, ExtraFilesHookPreference) { // Tests that an extra file written explicitly has precedence over // extra files written by a hook // TODO: test for the warning, too const auto script = R"JIT( def forward(self): x = torch.rand(5, 5) x = x.mm(x) return x )JIT"; auto module = std::make_shared("Module", std::make_shared()); module->define(script); std::ostringstream oss; std::unordered_map extra_files; extra_files["metadata.json"] = "abc"; SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap { return {{"metadata.json", "def"}}; }); module->save(oss, extra_files); SetExportModuleExtraFilesHook(nullptr); std::istringstream iss(oss.str()); caffe2::serialize::IStreamAdapter adapter{&iss}; std::unordered_map loaded_extra_files; loaded_extra_files["metadata.json"] = ""; auto loaded_module = torch::jit::load(iss, torch::kCPU, loaded_extra_files); ASSERT_EQ(loaded_extra_files["metadata.json"], "abc"); } TEST(SerializationTest, ExtraFileHooksNoSecret) { // no secrets std::stringstream ss; { Module m("__torch__.m"); ExtraFilesMap extra; extra["metadata.json"] = "abc"; m.save(ss, extra); } ss.seekg(0); { ExtraFilesMap extra; extra["metadata.json"] = ""; extra["secret.json"] = ""; jit::load(ss, c10::nullopt, extra); ASSERT_EQ(extra["metadata.json"], "abc"); ASSERT_EQ(extra["secret.json"], ""); } } TEST(SerializationTest, ExtraFileHooksWithSecret) { std::stringstream ss; { SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap { return {{"secret.json", "topsecret"}}; }); Module m("__torch__.m"); ExtraFilesMap extra; extra["metadata.json"] = "abc"; m.save(ss, extra); SetExportModuleExtraFilesHook(nullptr); } ss.seekg(0); { ExtraFilesMap extra; extra["metadata.json"] = ""; extra["secret.json"] = ""; jit::load(ss, c10::nullopt, extra); ASSERT_EQ(extra["metadata.json"], "abc"); ASSERT_EQ(extra["secret.json"], "topsecret"); } } TEST(SerializationTest, TypeTags) { auto list = c10::List>(); list.push_back(c10::List({1, 2, 3})); list.push_back(c10::List({4, 5, 6})); auto dict = c10::Dict(); dict.insert("Hello", torch::ones({2, 2})); auto dict_list = c10::List>(); for (size_t i = 0; i < 5; i++) { auto another_dict = c10::Dict(); another_dict.insert("Hello" + std::to_string(i), torch::ones({2, 2})); dict_list.push_back(another_dict); } auto tuple = std::tuple(2, "hi"); struct TestItem { IValue value; TypePtr expected_type; }; std::vector items = { {list, ListType::create(ListType::create(IntType::get()))}, {2, IntType::get()}, {dict, DictType::create(StringType::get(), TensorType::get())}, {dict_list, ListType::create( DictType::create(StringType::get(), TensorType::get()))}, {tuple, TupleType::create({IntType::get(), StringType::get()})}}; // NOLINTNEXTLINE(performance-for-range-copy) for (auto item : items) { auto bytes = torch::pickle_save(item.value); auto loaded = torch::pickle_load(bytes); ASSERT_TRUE(loaded.type()->isSubtypeOf(*item.expected_type)); ASSERT_TRUE(item.expected_type->isSubtypeOf(*loaded.type())); } } TEST(SerializationTest, SaveStateDict) { // Requires the state_dict that should have been written in tests_setup.py // Refer: SaveStateDict in test/cpp/jit/tests_setup.py std::ifstream file("state_dict.pt", std::ios::binary); std::vector data( (std::istreambuf_iterator(file)), std::istreambuf_iterator()); auto dict = torch::pickle_load(data).toGenericDict(); for (auto& el : dict) { auto key = el.key().toStringRef(); auto ten = el.value().toTensor(); if (key == "weight") { ASSERT_TRUE(ten.eq(2.0).all().item().toBool()); } else if (key == "bias") { ASSERT_TRUE(ten.eq(3.0).all().item().toBool()); } else { ASSERT_TRUE(false); } } } TEST(SerializationTest, TestJitStream_CUDA) { torch::jit::Module model; std::vector inputs; // Deserialize the ScriptModule from a file using torch::jit::load(). // Load the scripted model. This should have been generated by tests_setup.py // Refer: TorchSaveJitStream_CUDA in test/cpp/jit/tests_setup.py model = torch::jit::load("saved_stream_model.pt"); auto output = model.forward(inputs); const auto& list_of_elements = output.toTupleRef().elements(); auto is_stream_s = list_of_elements[0].toBool(); // a,b: These are the two input tensors // c: This is output tensor generated by the operation torch.cat(a,b) auto a = list_of_elements[1].toTensor(); auto b = list_of_elements[2].toTensor(); auto c = list_of_elements[3].toTensor(); // op: this is used to verify if the cat operation produced the same results // as that on the GPU with torch.cat auto op = at::cat({a, b}, 0); // Check if the stream is set ASSERT_TRUE(is_stream_s); // Check if the sizes of the outputs (op and c) is same on the GPU and CPU ASSERT_EQ(op.sizes(), c.sizes()); // Check if both the output tensors are equal ASSERT_TRUE(op.equal(c)); } TEST(TestSourceRoundTrip, UpsampleNearest2d) { Module m("m"); m.define(R"( def forward(self, input: Tensor, scale:float): return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale)) )"); std::vector inputs; inputs.emplace_back(torch::rand({1, 3, 128, 128})); inputs.emplace_back(at::Scalar(2.0)); auto ref = m.forward(inputs); Module m2 = roundtripThroughMobile(m); auto res = m2.forward(inputs); auto resd = res.toTensor(); auto refd = ref.toTensor(); ASSERT_TRUE(resd.equal(refd)); } TEST(TestSourceRoundTrip, CheckAttrAccess) { Module m("m"); m.register_attribute("mobile_optimized", BoolType::get(), true); Module m2 = roundtripThroughMobile(m); bool mobile_optimized = m2.attr("mobile_optimized", false).toBool(); AT_ASSERT(mobile_optimized); } TEST(TestSourceRoundTrip, MethodInvocation) { // NOLINT (use =delete in gtest) const std::vector test_programs{ // test invoking a method with default parameter R"( def test_func(self, x, b : int = 4): return self.foo + x + b )", // inner method call with default parameter (gets inlined) R"( def add_with_default_arg(self, x, b : int = 4): return self.foo + x + b def test_func(self, x): return self.add_with_default_arg(x) # invoke method w/ default arg )", // simple method call R"( def test_func(self, x): b = 4 return self.foo + x + b )", }; for (const auto& test_program : test_programs) { Module m("m"); m.register_parameter("foo", torch::ones({}), false); m.define(test_program); const int fortyTwo = 42; // (keep linter happy) auto minput = fortyTwo * torch::ones({}); auto ref = m.run_method("test_func", minput); Module m2 = roundtripThroughMobile(m); const auto& test_func = m2.get_method("test_func"); IValue res; for (int i = 0; i < 3; ++i) { res = test_func({minput}); } auto resd = res.toTensor().item(); auto refd = ref.toTensor().item(); AT_ASSERT(resd == refd); } } TEST(SerializationTest, ParentDirNotExist) { expectThrowsEq( []() { auto t = torch::nn::Linear(5, 5); torch::save(t, "./doesnotexist/file.pt"); }, "Parent directory ./doesnotexist does not exist."); } TEST(SerializationTest, CalculateNecessaryArgsTest) { auto schema = torch::schema( "sync_stream(int stream_id = -1) -> ()", c10::AliasAnalysisKind::CONSERVATIVE); auto graph = std::make_shared(); auto one_val = graph->insertConstant(-1); auto necessary = CalculateNecessaryArgs(schema.arguments(), {one_val}, true); EXPECT_EQ(0, necessary.first); EXPECT_EQ(0, necessary.second); } TEST(TestSaveLoad, LoadWithoutDebugInfo) { // NOLINT (use =delete in gtest) Module m("m"); m.register_parameter("foo", torch::ones({}), false); m.define( R"( def test_func(self, x): b = 4 return self.foo + x + b )"); m.define( R"( def exception(self): assert False, "message" )"); std::stringstream ss; m.save(ss); ss.seekg(0); caffe2::serialize::PyTorchStreamReader reader(&ss); reader.setShouldLoadDebugSymbol(true); EXPECT_TRUE(reader.hasRecord("code/__torch__.py.debug_pkl")); reader.setShouldLoadDebugSymbol(false); EXPECT_FALSE(reader.hasRecord("code/__torch__.py.debug_pkl")); ss.seekg(0); Module m2 = torch::jit::load(ss); std::string error_msg = R"( def exception(self): assert False, "message" ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE)"; ASSERT_THROWS_WITH_MESSAGE(m2.run_method("exception"), error_msg); ss.seekg(0); // NO DEBUG trace so error message points to torchscript generated // source instead of original python source. std::string error2 = R"( def exception(self: __torch__.m) -> NoneType: _0 = uninitialized(NoneType) ops.prim.RaiseException("AssertionError: message") ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE return _0 )"; Module m3 = torch::jit::load(ss, c10::nullopt, false); ASSERT_THROWS_WITH_MESSAGE(m3.run_method("exception"), error2); } } // namespace jit } // namespace torch