#include #include #include #include #include #include #include namespace torch { namespace jit { using namespace script; void testSaveExtraFilesHook() { // no secrets { std::stringstream ss; { Module m("__torch__.m"); ExtraFilesMap extra; extra["metadata.json"] = "abc"; m.save(ss, extra); } ss.seekg(0); { ExtraFilesMap extra; extra["metadata.json"] = ""; extra["secret.json"] = ""; jit::load(ss, c10::nullopt, extra); ASSERT_EQ(extra["metadata.json"], "abc"); ASSERT_EQ(extra["secret.json"], ""); } } // some secret { std::stringstream ss; { SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap { return {{"secret.json", "topsecret"}}; }); Module m("__torch__.m"); ExtraFilesMap extra; extra["metadata.json"] = "abc"; m.save(ss, extra); SetExportModuleExtraFilesHook(nullptr); } ss.seekg(0); { ExtraFilesMap extra; extra["metadata.json"] = ""; extra["secret.json"] = ""; jit::load(ss, c10::nullopt, extra); ASSERT_EQ(extra["metadata.json"], "abc"); ASSERT_EQ(extra["secret.json"], "topsecret"); } } } // TODO: Re-enable when add_type_tags is true void testTypeTags() { // auto list = c10::List>(); // list.push_back(c10::List({1, 2, 3})); // list.push_back(c10::List({4, 5, 6})); // // auto dict = c10::Dict(); // dict.insert("Hello", torch::ones({2, 2})); // // auto dict_list = c10::List>(); // for (size_t i = 0; i < 5; i++) { // auto another_dict = c10::Dict(); // another_dict.insert("Hello" + std::to_string(i), torch::ones({2, 2})); // dict_list.push_back(another_dict); // } // // auto tuple = std::tuple(2, "hi"); // // struct TestItem { // IValue value; // TypePtr expected_type; // }; // std::vector items = { // {list, ListType::create(ListType::create(IntType::get()))}, // {2, IntType::get()}, // {dict, DictType::create(StringType::get(), TensorType::get())}, // {dict_list, ListType::create(DictType::create(StringType::get(), TensorType::get()))}, // {tuple, TupleType::create({IntType::get(), StringType::get()})} // }; // // for (auto item : items) { // auto bytes = torch::pickle_save(item.value); // auto loaded = torch::pickle_load(bytes); // ASSERT_TRUE(loaded.type()->isSubtypeOf(item.expected_type)); // ASSERT_TRUE(item.expected_type->isSubtypeOf(loaded.type())); // } } } // namespace jit } // namespace torch