diff --git a/test/cpp/api/serialize.cpp b/test/cpp/api/serialize.cpp index f1dc43c0b46..ecad2348674 100644 --- a/test/cpp/api/serialize.cpp +++ b/test/cpp/api/serialize.cpp @@ -355,6 +355,7 @@ TEST(SerializeTest, ErrorOnMissingKey) { // We want the errors to contain hierarchy information, too. ASSERT_THROWS_WITH( torch::load(model2, stream), "No such serialized tensor 'a.b.x'"); + stream.seekg(0, stream.beg); ASSERT_THROWS_WITH( torch::load(model3, stream), "No such serialized submodule: 'a.x'"); } diff --git a/test/cpp/jit/test_backend.cpp b/test/cpp/jit/test_backend.cpp index a6961a2e403..978daa08d94 100644 --- a/test/cpp/jit/test_backend.cpp +++ b/test/cpp/jit/test_backend.cpp @@ -276,6 +276,7 @@ TEST(BackendTest, TestConsistencyOfCompositeWithSetStates) { c._save_for_mobile(ss); auto mc = _load_for_mobile(ss); auto res_mobile = mc.forward(inputs); + ss.seekg(0, ss.beg); // check if the methods names are always the same // by reloading the script module and saving it back as mobile diff --git a/test/cpp/jit/test_flatbuffer.cpp b/test/cpp/jit/test_flatbuffer.cpp index 522ded9ab4d..e9803dbe950 100644 --- a/test/cpp/jit/test_flatbuffer.cpp +++ b/test/cpp/jit/test_flatbuffer.cpp @@ -1177,6 +1177,7 @@ Module jitModuleFromBuffer(void* data) { mobilem._ivalue(), files, constants, 8); } +#if defined(ENABLE_FLATBUFFER) TEST(TestSourceFlatbuffer, UpsampleNearest2d) { Module m("m"); m.define(R"( @@ -1189,20 +1190,21 @@ TEST(TestSourceFlatbuffer, UpsampleNearest2d) { inputs.emplace_back(at::Scalar(2.0)); auto ref = m.forward(inputs); - auto data = save_jit_module_to_bytes(m); - Module m2 = jitModuleFromBuffer(data.data()); + std::stringstream ss; + m._save_for_mobile(ss, {}, false, /*use_fatbuffer=*/true); + auto mm = _load_for_mobile(ss); + auto m2 = load(ss); + auto res = m2.forward(inputs); + auto resm = mm.forward(inputs); auto resd = res.toTensor(); auto refd = ref.toTensor(); + auto resmd = resm.toTensor(); ASSERT_TRUE(resd.equal(refd)); - - mobile::Module m3 = parse_mobile_module(data.data(), data.size()); - res = m3.forward(inputs); - resd = res.toTensor(); - refd = ref.toTensor(); - ASSERT_TRUE(resd.equal(refd)); + ASSERT_TRUE(resmd.equal(refd)); } +#endif TEST(TestSourceFlatbuffer, CheckAttrAccess) { Module m("m"); diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index 45e65e5e818..bbe7e0a7016 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -1,20 +1,22 @@ # Owner(s): ["oncall: jit"] -from typing import NamedTuple, Optional import io import os import pathlib import sys +import unittest +from typing import NamedTuple, Optional +import torch from torch import Tensor from torch.testing._internal.common_utils import TemporaryFileName -import torch # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from torch.testing._internal.jit_utils import (JitTestCase, - clear_class_registry) +from torch.testing._internal.jit_utils import JitTestCase, clear_class_registry + +ENABLE_FLATBUFFER = os.environ.get("ENABLE_FLATBUFFER", "0") == "1" if __name__ == "__main__": raise RuntimeError( @@ -23,12 +25,14 @@ if __name__ == "__main__": "instead." ) + class TestSaveLoad(JitTestCase): def test_different_modules(self): """ Exercise the situation where we have the same qualified name in two different CompilationUnits on save/load. """ + class Foo(torch.nn.Module): def __init__(self): super(Foo, self).__init__() @@ -64,7 +68,8 @@ class TestSaveLoad(JitTestCase): clear_class_registry() self.assertEqual( - first_script_module._c.qualified_name, second_script_module._c.qualified_name + first_script_module._c.qualified_name, + second_script_module._c.qualified_name, ) class ContainsBoth(torch.nn.Module): @@ -89,6 +94,7 @@ class TestSaveLoad(JitTestCase): Exercise the situation where we have the same qualified name in two different CompilationUnits on save/load. """ + def lol(x): return x @@ -118,7 +124,8 @@ class TestSaveLoad(JitTestCase): clear_class_registry() self.assertEqual( - first_script_module._c.qualified_name, second_script_module._c.qualified_name + first_script_module._c.qualified_name, + second_script_module._c.qualified_name, ) class ContainsBoth(torch.nn.Module): @@ -143,6 +150,7 @@ class TestSaveLoad(JitTestCase): Exercise the situation where we have the same qualified name in two different CompilationUnits on save/load. """ + @torch.jit.interface class MyInterface(object): def bar(self, x: Tensor) -> Tensor: @@ -204,7 +212,8 @@ class TestSaveLoad(JitTestCase): clear_class_registry() self.assertEqual( - first_script_module._c.qualified_name, second_script_module._c.qualified_name + first_script_module._c.qualified_name, + second_script_module._c.qualified_name, ) class ContainsBoth(torch.nn.Module): @@ -261,7 +270,6 @@ class TestSaveLoad(JitTestCase): return x, MyCoolNamedTuple(a=5) - first_script_module = torch.jit.script(Foo()) first_saved_module = io.BytesIO() torch.jit.save(first_script_module, first_saved_module) @@ -310,7 +318,8 @@ class TestSaveLoad(JitTestCase): clear_class_registry() self.assertEqual( - first_script_module._c.qualified_name, second_script_module._c.qualified_name + first_script_module._c.qualified_name, + second_script_module._c.qualified_name, ) class ContainsBoth(torch.nn.Module): @@ -340,44 +349,44 @@ class TestSaveLoad(JitTestCase): value = b"bar\x00\xffbaz" expected_extra_files = {} - expected_extra_files['foo'] = value + expected_extra_files["foo"] = value # verify that str to bytes conversion also works - expected_extra_files['foo2'] = "bar" + expected_extra_files["foo2"] = "bar" m = MyMod() # Save to file. with TemporaryFileName() as fname: m.save(fname, _extra_files=expected_extra_files) # values don't matter - extra_files = {'foo': '', 'foo2': None} + extra_files = {"foo": "", "foo2": None} torch.jit.load(fname, _extra_files=extra_files) - self.assertEqual(value, extra_files['foo']) + self.assertEqual(value, extra_files["foo"]) # results come back always as bytes - self.assertEqual(b"bar", extra_files['foo2']) + self.assertEqual(b"bar", extra_files["foo2"]) # Use torch.jit API torch.jit.save(m, fname, _extra_files=expected_extra_files) - extra_files['foo'] = '' + extra_files["foo"] = "" torch.jit.load(fname, _extra_files=extra_files) - self.assertEqual(value, extra_files['foo']) + self.assertEqual(value, extra_files["foo"]) # Save to buffer. buffer = io.BytesIO(m.save_to_buffer(_extra_files=expected_extra_files)) - extra_files = {'foo': ''} + extra_files = {"foo": ""} torch.jit.load(buffer, _extra_files=extra_files) - self.assertEqual(value, extra_files['foo']) + self.assertEqual(value, extra_files["foo"]) # Use torch.jit API buffer = io.BytesIO() torch.jit.save(m, buffer, _extra_files=expected_extra_files) buffer.seek(0) - extra_files = {'foo': ''} + extra_files = {"foo": ""} torch.jit.load(buffer, _extra_files=extra_files) - self.assertEqual(value, extra_files['foo']) + self.assertEqual(value, extra_files["foo"]) # Non-existent file 'bar' with self.assertRaises(RuntimeError): - extra_files['bar'] = '' + extra_files["bar"] = "" torch.jit.load(buffer, _extra_files=extra_files) def test_save_load_using_pathlib(self): @@ -394,7 +403,7 @@ class TestSaveLoad(JitTestCase): m.save(path) m2 = torch.jit.load(path) - x = torch.tensor([1., 2., 3., 4.]) + x = torch.tensor([1.0, 2.0, 3.0, 4.0]) self.assertTrue(torch.equal(m(x), m2(x))) def test_save_nonexit_file(self): @@ -455,7 +464,9 @@ class TestSaveLoad(JitTestCase): def __init__(self): super().__init__() self.add_module("submodule_a", Submodule()) - self.register_parameter("parameter_a", torch.nn.Parameter(torch.randn(4))) + self.register_parameter( + "parameter_a", torch.nn.Parameter(torch.randn(4)) + ) self.register_buffer("buffer", torch.randn(4)) self.t = torch.rand(4) # not buffer @@ -466,7 +477,9 @@ class TestSaveLoad(JitTestCase): m_loaded = self.getExportImportCopy(torch.jit.script(m)) # Check submodules. - self.assertEqual(len(list(m.named_modules())), len(list(m_loaded.named_modules()))) + self.assertEqual( + len(list(m.named_modules())), len(list(m_loaded.named_modules())) + ) for m_s, loaded_s in zip(m.named_modules(), m_loaded.named_modules()): m_name, _ = m_s loaded_name, _ = loaded_s @@ -478,7 +491,9 @@ class TestSaveLoad(JitTestCase): self.assertEqual(m_p, loaded_p) # Check buffers. - self.assertEqual(len(list(m.named_buffers())), len(list(m_loaded.named_buffers()))) + self.assertEqual( + len(list(m.named_buffers())), len(list(m_loaded.named_buffers())) + ) for m_b, loaded_b in zip(m.named_buffers(), m_loaded.named_buffers()): m_name, m_buffer = m_b loaded_name, loaded_buffer = loaded_b @@ -490,6 +505,7 @@ class TestSaveLoad(JitTestCase): Check that parameters, buffers, and submodules are the same after loading for a module with parameters and buffers that are meta tensors """ + class Foo(torch.nn.Module): def __init__(self): super(Foo, self).__init__() @@ -505,8 +521,13 @@ class TestSaveLoad(JitTestCase): m = Foo() m_loaded = self.getExportImportCopy(torch.jit.script(m)) # Check submodules. - self.assertEqual(len(list(m.named_modules())), len(list(m_loaded.named_modules()))) - self.assertEqual(set(name for name, _ in m.named_modules()), set(name for name, _ in m_loaded.named_modules())) + self.assertEqual( + len(list(m.named_modules())), len(list(m_loaded.named_modules())) + ) + self.assertEqual( + set(name for name, _ in m.named_modules()), + set(name for name, _ in m_loaded.named_modules()), + ) # Check parameters. m_params = dict(m.named_parameters()) m_loaded_params = dict(m_loaded.named_parameters()) @@ -518,24 +539,36 @@ class TestSaveLoad(JitTestCase): self.assertEqual(len(m_buffers), len(m_loaded_buffers)) self.assertEqual(m_buffers, m_loaded_buffers) # Check params and buffers that are/are not meta tensors - self.assertTrue(m_params['foo.weight'].is_meta) - self.assertTrue(m_loaded_params['foo.weight'].is_meta) - self.assertTrue(m_params['foo.bias'].is_meta) - self.assertTrue(m_loaded_params['foo.bias'].is_meta) - self.assertFalse(m_params['bar.weight'].is_meta) - self.assertFalse(m_loaded_params['bar.weight'].is_meta) - self.assertFalse(m_params['bar.bias'].is_meta) - self.assertFalse(m_loaded_params['bar.bias'].is_meta) - self.assertTrue(m_buffers['buffer'].is_meta) - self.assertTrue(m_loaded_buffers['buffer'].is_meta) + self.assertTrue(m_params["foo.weight"].is_meta) + self.assertTrue(m_loaded_params["foo.weight"].is_meta) + self.assertTrue(m_params["foo.bias"].is_meta) + self.assertTrue(m_loaded_params["foo.bias"].is_meta) + self.assertFalse(m_params["bar.weight"].is_meta) + self.assertFalse(m_loaded_params["bar.weight"].is_meta) + self.assertFalse(m_params["bar.bias"].is_meta) + self.assertFalse(m_loaded_params["bar.bias"].is_meta) + self.assertTrue(m_buffers["buffer"].is_meta) + self.assertTrue(m_loaded_buffers["buffer"].is_meta) +def script_module_to_buffer(script_module): + module_buffer = io.BytesIO( + script_module._save_to_buffer_for_lite_interpreter(_use_flatbuffer=True) + ) + module_buffer.seek(0) + return module_buffer + + +@unittest.skipIf( + not ENABLE_FLATBUFFER, "Need to enable flatbuffer to run the below tests" +) class TestSaveLoadFlatbuffer(JitTestCase): def test_different_modules(self): """ Exercise the situation where we have the same qualified name in two different CompilationUnits on save/load. """ + class Foo(torch.nn.Module): def __init__(self): super(Foo, self).__init__() @@ -548,9 +581,7 @@ class TestSaveLoadFlatbuffer(JitTestCase): return x first_script_module = torch.jit.script(Foo()) - first_saved_module = io.BytesIO() - torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module) - first_saved_module.seek(0) + first_saved_module = script_module_to_buffer(first_script_module) clear_class_registry() @@ -564,21 +595,24 @@ class TestSaveLoadFlatbuffer(JitTestCase): return x second_script_module = torch.jit.script(Foo()) - second_saved_module = io.BytesIO() - torch.jit.save_jit_module_to_flatbuffer(torch.jit.script(Foo()), second_saved_module) - second_saved_module.seek(0) + second_saved_module = script_module_to_buffer(second_script_module) clear_class_registry() self.assertEqual( - first_script_module._c.qualified_name, second_script_module._c.qualified_name + first_script_module._c.qualified_name, + second_script_module._c.qualified_name, ) class ContainsBoth(torch.nn.Module): def __init__(self): super().__init__() - self.add_module("second", torch.jit.jit_module_from_flatbuffer(second_saved_module)) - self.add_module("first", torch.jit.jit_module_from_flatbuffer(first_saved_module)) + self.add_module( + "second", torch.jit.load(second_saved_module) + ) + self.add_module( + "first", torch.jit.load(first_saved_module) + ) def forward(self, x): x = self.first(x) @@ -586,16 +620,15 @@ class TestSaveLoadFlatbuffer(JitTestCase): return x sm = torch.jit.script(ContainsBoth()) - contains_both = io.BytesIO() - torch.jit.save_jit_module_to_flatbuffer(sm, contains_both) - contains_both.seek(0) - sm = torch.jit.jit_module_from_flatbuffer(contains_both) + contains_both = script_module_to_buffer(sm) + sm = torch.jit.load(contains_both) def test_different_functions(self): """ Exercise the situation where we have the same qualified name in two different CompilationUnits on save/load. """ + def lol(x): return x @@ -604,10 +637,7 @@ class TestSaveLoadFlatbuffer(JitTestCase): return lol(x) first_script_module = torch.jit.script(Foo()) - first_saved_module = io.BytesIO() - torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module) - first_saved_module.seek(0) - + first_saved_module = script_module_to_buffer(first_script_module) clear_class_registry() def lol(x): # noqa: F811 @@ -618,21 +648,24 @@ class TestSaveLoadFlatbuffer(JitTestCase): return lol(x) second_script_module = torch.jit.script(Foo()) - second_saved_module = io.BytesIO() - torch.jit.save_jit_module_to_flatbuffer(torch.jit.script(Foo()), second_saved_module) - second_saved_module.seek(0) + second_saved_module = script_module_to_buffer(second_script_module) clear_class_registry() self.assertEqual( - first_script_module._c.qualified_name, second_script_module._c.qualified_name + first_script_module._c.qualified_name, + second_script_module._c.qualified_name, ) class ContainsBoth(torch.nn.Module): def __init__(self): super().__init__() - self.add_module("second", torch.jit.jit_module_from_flatbuffer(second_saved_module)) - self.add_module("first", torch.jit.jit_module_from_flatbuffer(first_saved_module)) + self.add_module( + "second", torch.jit.load(second_saved_module) + ) + self.add_module( + "first", torch.jit.load(first_saved_module) + ) def forward(self, x): x = self.first(x) @@ -640,16 +673,15 @@ class TestSaveLoadFlatbuffer(JitTestCase): return x sm = torch.jit.script(ContainsBoth()) - contains_both = io.BytesIO() - torch.jit.save_jit_module_to_flatbuffer(sm, contains_both) - contains_both.seek(0) - sm = torch.jit.jit_module_from_flatbuffer(contains_both) + contains_both = script_module_to_buffer(sm) + sm = torch.jit.load(contains_both) def test_different_interfaces(self): """ Exercise the situation where we have the same qualified name in two different CompilationUnits on save/load. """ + @torch.jit.interface class MyInterface(object): def bar(self, x: Tensor) -> Tensor: @@ -674,10 +706,7 @@ class TestSaveLoadFlatbuffer(JitTestCase): return self.interface.bar(x) first_script_module = torch.jit.script(Foo()) - first_saved_module = io.BytesIO() - torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module) - first_saved_module.seek(0) - + first_saved_module = script_module_to_buffer(first_script_module) clear_class_registry() @torch.jit.interface @@ -704,21 +733,24 @@ class TestSaveLoadFlatbuffer(JitTestCase): return self.interface.not_bar(x) second_script_module = torch.jit.script(Foo()) - second_saved_module = io.BytesIO() - torch.jit.save_jit_module_to_flatbuffer(torch.jit.script(Foo()), second_saved_module) - second_saved_module.seek(0) + second_saved_module = script_module_to_buffer(second_script_module) clear_class_registry() self.assertEqual( - first_script_module._c.qualified_name, second_script_module._c.qualified_name + first_script_module._c.qualified_name, + second_script_module._c.qualified_name, ) class ContainsBoth(torch.nn.Module): def __init__(self): super().__init__() - self.add_module("second", torch.jit.jit_module_from_flatbuffer(second_saved_module)) - self.add_module("first", torch.jit.jit_module_from_flatbuffer(first_saved_module)) + self.add_module( + "second", torch.jit.load(second_saved_module) + ) + self.add_module( + "first", torch.jit.load(first_saved_module) + ) def forward(self, x): x = self.first(x) @@ -726,10 +758,8 @@ class TestSaveLoadFlatbuffer(JitTestCase): return x sm = torch.jit.script(ContainsBoth()) - contains_both = io.BytesIO() - torch.jit.save_jit_module_to_flatbuffer(sm, contains_both) - contains_both.seek(0) - sm = torch.jit.jit_module_from_flatbuffer(contains_both) + contains_both = script_module_to_buffer(sm) + sm = torch.jit.load(contains_both) def test_many_collisions(self): class MyCoolNamedTuple(NamedTuple): @@ -768,11 +798,8 @@ class TestSaveLoadFlatbuffer(JitTestCase): return x, MyCoolNamedTuple(a=5) - first_script_module = torch.jit.script(Foo()) - first_saved_module = io.BytesIO() - torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module) - first_saved_module.seek(0) + first_saved_module = script_module_to_buffer(first_script_module) clear_class_registry() @@ -810,21 +837,24 @@ class TestSaveLoadFlatbuffer(JitTestCase): return x, MyCoolNamedTuple(a="hello") second_script_module = torch.jit.script(Foo()) - second_saved_module = io.BytesIO() - torch.jit.save_jit_module_to_flatbuffer(second_script_module, second_saved_module) - second_saved_module.seek(0) + second_saved_module = script_module_to_buffer(second_script_module) clear_class_registry() self.assertEqual( - first_script_module._c.qualified_name, second_script_module._c.qualified_name + first_script_module._c.qualified_name, + second_script_module._c.qualified_name, ) class ContainsBoth(torch.nn.Module): def __init__(self): super().__init__() - self.add_module("second", torch.jit.jit_module_from_flatbuffer(second_saved_module)) - self.add_module("first", torch.jit.jit_module_from_flatbuffer(first_saved_module)) + self.add_module( + "second", torch.jit.load(second_saved_module) + ) + self.add_module( + "first", torch.jit.load(first_saved_module) + ) def forward(self, x): x, named_tuple_1 = self.first(x) @@ -832,10 +862,8 @@ class TestSaveLoadFlatbuffer(JitTestCase): return len(x + named_tuple_2.a) + named_tuple_1.a sm = torch.jit.script(ContainsBoth()) - contains_both = io.BytesIO() - torch.jit.save_jit_module_to_flatbuffer(sm, contains_both) - contains_both.seek(0) - sm = torch.jit.jit_module_from_flatbuffer(contains_both) + contains_both = script_module_to_buffer(sm) + sm = torch.jit.load(contains_both) def test_save_load_using_pathlib(self): class MyMod(torch.jit.ScriptModule): @@ -849,9 +877,9 @@ class TestSaveLoadFlatbuffer(JitTestCase): with TemporaryFileName() as fname: path = pathlib.Path(fname) torch.jit.save_jit_module_to_flatbuffer(m, path) - m2 = torch.jit.jit_module_from_flatbuffer(path) + m2 = torch.jit.load(path) - x = torch.tensor([1., 2., 3., 4.]) + x = torch.tensor([1.0, 2.0, 3.0, 4.0]) self.assertTrue(torch.equal(m(x), m2(x))) def test_save_namedtuple_input_only(self): @@ -903,7 +931,9 @@ class TestSaveLoadFlatbuffer(JitTestCase): def __init__(self): super().__init__() self.add_module("submodule_a", Submodule()) - self.register_parameter("parameter_a", torch.nn.Parameter(torch.randn(4))) + self.register_parameter( + "parameter_a", torch.nn.Parameter(torch.randn(4)) + ) self.register_buffer("buffer", torch.randn(4)) self.t = torch.rand(4) # not buffer @@ -914,7 +944,9 @@ class TestSaveLoadFlatbuffer(JitTestCase): m_loaded = self.getExportImportCopy(torch.jit.script(m)) # Check submodules. - self.assertEqual(len(list(m.named_modules())), len(list(m_loaded.named_modules()))) + self.assertEqual( + len(list(m.named_modules())), len(list(m_loaded.named_modules())) + ) for m_s, loaded_s in zip(m.named_modules(), m_loaded.named_modules()): m_name, _ = m_s loaded_name, _ = loaded_s @@ -926,7 +958,9 @@ class TestSaveLoadFlatbuffer(JitTestCase): self.assertEqual(m_p, loaded_p) # Check buffers. - self.assertEqual(len(list(m.named_buffers())), len(list(m_loaded.named_buffers()))) + self.assertEqual( + len(list(m.named_buffers())), len(list(m_loaded.named_buffers())) + ) for m_b, loaded_b in zip(m.named_buffers(), m_loaded.named_buffers()): m_name, m_buffer = m_b loaded_name, loaded_buffer = loaded_b diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index 3746406a0bd..c3b8d291f02 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -540,6 +540,7 @@ mobile::Module _load_for_mobile( std::istream& in, c10::optional device, ExtraFilesMap& extra_files) { + in.seekg(0, in.beg); auto format = getFileFormat(in); switch (format) { case FileFormat::ZipFileFormat: { diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index 25e5299b420..f9cef185496 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -797,9 +797,13 @@ void save_mobile_module_to( const ExtraFilesMap& extra_files, bool save_mobile_debug_info, const std::function& writer_func) { + ExtraFilesMap jitFiles; CompilationOptions options = getOptionsFromGlobal(); + std::vector constants; + jitModuleToPythonCodeAndConstants(module, &jitFiles, &constants); mobile::Module mod = jitModuleToMobile(module, options); - auto buffer = save_mobile_module_to_bytes(mod, extra_files); + auto buffer = + save_mobile_module_to_bytes(mod, extra_files, jitFiles, constants); writer_func(reinterpret_cast(buffer.data()), buffer.size()); } #endif diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp index 20942c4cb7b..08a7beffa1e 100644 --- a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp +++ b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp @@ -789,6 +789,14 @@ Module load_jit_module_from_file( std::move(std::get<0>(data)), std::get<1>(data), device); } +Module load_jit_module_from_stream( + std::istream& in, + c10::optional device) { + auto data = get_stream_content(in); + return parse_and_initialize_jit_module( + std::move(std::get<0>(data)), std::get<1>(data), device); +} + void save_jit_module( const Module& module, const std::string& filename, diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer.h b/torch/csrc/jit/serialization/flatbuffer_serializer.h index 83a38fe9424..3f1d7ec0d4c 100644 --- a/torch/csrc/jit/serialization/flatbuffer_serializer.h +++ b/torch/csrc/jit/serialization/flatbuffer_serializer.h @@ -47,5 +47,9 @@ TORCH_API Module load_jit_module_from_file( const std::string& filename, c10::optional device = c10::nullopt); +TORCH_API Module load_jit_module_from_stream( + std::istream& in, + c10::optional device = c10::nullopt); + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/serialization/import.cpp b/torch/csrc/jit/serialization/import.cpp index 65f69a9844c..d6309fbe748 100644 --- a/torch/csrc/jit/serialization/import.cpp +++ b/torch/csrc/jit/serialization/import.cpp @@ -10,6 +10,7 @@ #endif #include #include +#include #include #include #include @@ -18,6 +19,10 @@ #include #include +#if defined(ENABLE_FLATBUFFER) +#include +#endif + #include #include #include @@ -290,9 +295,26 @@ Module import_ir_module( std::istream& in, c10::optional device, ExtraFilesMap& extra_files) { - auto reader = torch::make_unique(&in); - ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader)); - return deserializer.deserialize(device, extra_files); + in.seekg(0, in.beg); + auto format = getFileFormat(in); + switch (format) { + case FileFormat::FlatbufferFileFormat: { +#if defined(ENABLE_FLATBUFFER) + return load_jit_module_from_stream(in, device); +#else + TORCH_CHECK( + false, "Flatbuffer input file but the build hasn't enable flatbuffer") +#endif + } + case FileFormat::ZipFileFormat: { + auto reader = torch::make_unique(&in); + ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader)); + return deserializer.deserialize(device, extra_files); + } + + default: + TORCH_CHECK(false, "Unrecognized data format"); + } } // For reading unified serialization format from torch.Package. @@ -325,9 +347,25 @@ Module import_ir_module( const std::string& filename, c10::optional device, ExtraFilesMap& extra_files) { - auto reader = torch::make_unique(filename); - ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader)); - return deserializer.deserialize(device, extra_files); + auto format = getFileFormat(filename); + switch (format) { + case FileFormat::FlatbufferFileFormat: { +#if defined(ENABLE_FLATBUFFER) + return load_jit_module_from_file(filename, device); +#else + TORCH_CHECK( + false, "Flatbuffer input file but the build hasn't enable flatbuffer") +#endif + } + case FileFormat::ZipFileFormat: { + auto reader = torch::make_unique(filename); + ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader)); + return deserializer.deserialize(device, extra_files); + } + + default: + TORCH_CHECK(false, "Unrecognized data format"); + } } Module import_ir_module( @@ -357,9 +395,27 @@ Module load( std::istream& in, c10::optional device, ExtraFilesMap& extra_files) { - std::unique_ptr rai = std::make_unique(&in); - auto module = load(std::move(rai), device, extra_files); - return module; + in.seekg(0, in.beg); + auto format = getFileFormat(in); + switch (format) { + case FileFormat::FlatbufferFileFormat: { +#if defined(ENABLE_FLATBUFFER) + return load_jit_module_from_stream(in, device); +#else + TORCH_CHECK( + false, "Flatbuffer input file but the build hasn't enable flatbuffer") +#endif + } + case FileFormat::ZipFileFormat: { + std::unique_ptr rai = + std::make_unique(&in); + auto module = load(std::move(rai), device, extra_files); + return module; + } + + default: + TORCH_CHECK(false, "Unrecognized data format"); + } } Module load(const std::string& filename, c10::optional device) { @@ -371,9 +427,27 @@ Module load( const std::string& filename, c10::optional device, ExtraFilesMap& extra_files) { - std::unique_ptr rai = std::make_unique(filename); - auto module = load(std::move(rai), device, extra_files); - return module; + auto format = getFileFormat(filename); + switch (format) { + case FileFormat::FlatbufferFileFormat: { +#if defined(ENABLE_FLATBUFFER) + return load_jit_module_from_file(filename, device); +#else + TORCH_CHECK( + false, "Flatbuffer input file but the build hasn't enable flatbuffer") +#endif + + case FileFormat::ZipFileFormat: { + std::unique_ptr rai = + std::make_unique(filename); + auto module = load(std::move(rai), device, extra_files); + return module; + } + + default: + TORCH_CHECK(false, "Unrecognized data format"); + } + } } Module load( @@ -387,8 +461,8 @@ Module load( std::shared_ptr rai, c10::optional device, ExtraFilesMap& extra_files) { - // Verify that we're loading a zip archive and not a torch.save pickle archive - // (marked by the 0x80 0x02 bytes at the start) + // Verify that we're loading a zip archive and not a torch.save pickle + // archive (marked by the 0x80 0x02 bytes at the start) // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) TORCH_CHECK( check_zip_file(rai),