Extend jit::load to work on flatbuffer file; Take 2 (#75256)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75256

ghstack-source-id: 153138970

Test Plan: CI

Reviewed By: iseeyuan

Differential Revision: D35399581

fbshipit-source-id: dafe9d301009d3f70986ed92bfe06d160ab90ba0
(cherry picked from commit ccc860fd07946de5aae12bc179a0b8bbba83b997)
This commit is contained in:
Pavithran Ramachandran 2022-04-06 10:41:25 -07:00 committed by PyTorch MergeBot
parent ff7051781f
commit f984e50f39
9 changed files with 249 additions and 120 deletions

View File

@ -355,6 +355,7 @@ TEST(SerializeTest, ErrorOnMissingKey) {
// We want the errors to contain hierarchy information, too.
ASSERT_THROWS_WITH(
torch::load(model2, stream), "No such serialized tensor 'a.b.x'");
stream.seekg(0, stream.beg);
ASSERT_THROWS_WITH(
torch::load(model3, stream), "No such serialized submodule: 'a.x'");
}

View File

@ -276,6 +276,7 @@ TEST(BackendTest, TestConsistencyOfCompositeWithSetStates) {
c._save_for_mobile(ss);
auto mc = _load_for_mobile(ss);
auto res_mobile = mc.forward(inputs);
ss.seekg(0, ss.beg);
// check if the methods names are always the same
// by reloading the script module and saving it back as mobile

View File

@ -1177,6 +1177,7 @@ Module jitModuleFromBuffer(void* data) {
mobilem._ivalue(), files, constants, 8);
}
#if defined(ENABLE_FLATBUFFER)
TEST(TestSourceFlatbuffer, UpsampleNearest2d) {
Module m("m");
m.define(R"(
@ -1189,20 +1190,21 @@ TEST(TestSourceFlatbuffer, UpsampleNearest2d) {
inputs.emplace_back(at::Scalar(2.0));
auto ref = m.forward(inputs);
auto data = save_jit_module_to_bytes(m);
Module m2 = jitModuleFromBuffer(data.data());
std::stringstream ss;
m._save_for_mobile(ss, {}, false, /*use_fatbuffer=*/true);
auto mm = _load_for_mobile(ss);
auto m2 = load(ss);
auto res = m2.forward(inputs);
auto resm = mm.forward(inputs);
auto resd = res.toTensor();
auto refd = ref.toTensor();
auto resmd = resm.toTensor();
ASSERT_TRUE(resd.equal(refd));
mobile::Module m3 = parse_mobile_module(data.data(), data.size());
res = m3.forward(inputs);
resd = res.toTensor();
refd = ref.toTensor();
ASSERT_TRUE(resd.equal(refd));
ASSERT_TRUE(resmd.equal(refd));
}
#endif
TEST(TestSourceFlatbuffer, CheckAttrAccess) {
Module m("m");

View File

@ -1,20 +1,22 @@
# Owner(s): ["oncall: jit"]
from typing import NamedTuple, Optional
import io
import os
import pathlib
import sys
import unittest
from typing import NamedTuple, Optional
import torch
from torch import Tensor
from torch.testing._internal.common_utils import TemporaryFileName
import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import (JitTestCase,
clear_class_registry)
from torch.testing._internal.jit_utils import JitTestCase, clear_class_registry
ENABLE_FLATBUFFER = os.environ.get("ENABLE_FLATBUFFER", "0") == "1"
if __name__ == "__main__":
raise RuntimeError(
@ -23,12 +25,14 @@ if __name__ == "__main__":
"instead."
)
class TestSaveLoad(JitTestCase):
def test_different_modules(self):
"""
Exercise the situation where we have the same qualified name
in two different CompilationUnits on save/load.
"""
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
@ -64,7 +68,8 @@ class TestSaveLoad(JitTestCase):
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name, second_script_module._c.qualified_name
first_script_module._c.qualified_name,
second_script_module._c.qualified_name,
)
class ContainsBoth(torch.nn.Module):
@ -89,6 +94,7 @@ class TestSaveLoad(JitTestCase):
Exercise the situation where we have the same qualified name
in two different CompilationUnits on save/load.
"""
def lol(x):
return x
@ -118,7 +124,8 @@ class TestSaveLoad(JitTestCase):
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name, second_script_module._c.qualified_name
first_script_module._c.qualified_name,
second_script_module._c.qualified_name,
)
class ContainsBoth(torch.nn.Module):
@ -143,6 +150,7 @@ class TestSaveLoad(JitTestCase):
Exercise the situation where we have the same qualified name
in two different CompilationUnits on save/load.
"""
@torch.jit.interface
class MyInterface(object):
def bar(self, x: Tensor) -> Tensor:
@ -204,7 +212,8 @@ class TestSaveLoad(JitTestCase):
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name, second_script_module._c.qualified_name
first_script_module._c.qualified_name,
second_script_module._c.qualified_name,
)
class ContainsBoth(torch.nn.Module):
@ -261,7 +270,6 @@ class TestSaveLoad(JitTestCase):
return x, MyCoolNamedTuple(a=5)
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save(first_script_module, first_saved_module)
@ -310,7 +318,8 @@ class TestSaveLoad(JitTestCase):
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name, second_script_module._c.qualified_name
first_script_module._c.qualified_name,
second_script_module._c.qualified_name,
)
class ContainsBoth(torch.nn.Module):
@ -340,44 +349,44 @@ class TestSaveLoad(JitTestCase):
value = b"bar\x00\xffbaz"
expected_extra_files = {}
expected_extra_files['foo'] = value
expected_extra_files["foo"] = value
# verify that str to bytes conversion also works
expected_extra_files['foo2'] = "bar"
expected_extra_files["foo2"] = "bar"
m = MyMod()
# Save to file.
with TemporaryFileName() as fname:
m.save(fname, _extra_files=expected_extra_files)
# values don't matter
extra_files = {'foo': '', 'foo2': None}
extra_files = {"foo": "", "foo2": None}
torch.jit.load(fname, _extra_files=extra_files)
self.assertEqual(value, extra_files['foo'])
self.assertEqual(value, extra_files["foo"])
# results come back always as bytes
self.assertEqual(b"bar", extra_files['foo2'])
self.assertEqual(b"bar", extra_files["foo2"])
# Use torch.jit API
torch.jit.save(m, fname, _extra_files=expected_extra_files)
extra_files['foo'] = ''
extra_files["foo"] = ""
torch.jit.load(fname, _extra_files=extra_files)
self.assertEqual(value, extra_files['foo'])
self.assertEqual(value, extra_files["foo"])
# Save to buffer.
buffer = io.BytesIO(m.save_to_buffer(_extra_files=expected_extra_files))
extra_files = {'foo': ''}
extra_files = {"foo": ""}
torch.jit.load(buffer, _extra_files=extra_files)
self.assertEqual(value, extra_files['foo'])
self.assertEqual(value, extra_files["foo"])
# Use torch.jit API
buffer = io.BytesIO()
torch.jit.save(m, buffer, _extra_files=expected_extra_files)
buffer.seek(0)
extra_files = {'foo': ''}
extra_files = {"foo": ""}
torch.jit.load(buffer, _extra_files=extra_files)
self.assertEqual(value, extra_files['foo'])
self.assertEqual(value, extra_files["foo"])
# Non-existent file 'bar'
with self.assertRaises(RuntimeError):
extra_files['bar'] = ''
extra_files["bar"] = ""
torch.jit.load(buffer, _extra_files=extra_files)
def test_save_load_using_pathlib(self):
@ -394,7 +403,7 @@ class TestSaveLoad(JitTestCase):
m.save(path)
m2 = torch.jit.load(path)
x = torch.tensor([1., 2., 3., 4.])
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
self.assertTrue(torch.equal(m(x), m2(x)))
def test_save_nonexit_file(self):
@ -455,7 +464,9 @@ class TestSaveLoad(JitTestCase):
def __init__(self):
super().__init__()
self.add_module("submodule_a", Submodule())
self.register_parameter("parameter_a", torch.nn.Parameter(torch.randn(4)))
self.register_parameter(
"parameter_a", torch.nn.Parameter(torch.randn(4))
)
self.register_buffer("buffer", torch.randn(4))
self.t = torch.rand(4) # not buffer
@ -466,7 +477,9 @@ class TestSaveLoad(JitTestCase):
m_loaded = self.getExportImportCopy(torch.jit.script(m))
# Check submodules.
self.assertEqual(len(list(m.named_modules())), len(list(m_loaded.named_modules())))
self.assertEqual(
len(list(m.named_modules())), len(list(m_loaded.named_modules()))
)
for m_s, loaded_s in zip(m.named_modules(), m_loaded.named_modules()):
m_name, _ = m_s
loaded_name, _ = loaded_s
@ -478,7 +491,9 @@ class TestSaveLoad(JitTestCase):
self.assertEqual(m_p, loaded_p)
# Check buffers.
self.assertEqual(len(list(m.named_buffers())), len(list(m_loaded.named_buffers())))
self.assertEqual(
len(list(m.named_buffers())), len(list(m_loaded.named_buffers()))
)
for m_b, loaded_b in zip(m.named_buffers(), m_loaded.named_buffers()):
m_name, m_buffer = m_b
loaded_name, loaded_buffer = loaded_b
@ -490,6 +505,7 @@ class TestSaveLoad(JitTestCase):
Check that parameters, buffers, and submodules are the same after loading
for a module with parameters and buffers that are meta tensors
"""
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
@ -505,8 +521,13 @@ class TestSaveLoad(JitTestCase):
m = Foo()
m_loaded = self.getExportImportCopy(torch.jit.script(m))
# Check submodules.
self.assertEqual(len(list(m.named_modules())), len(list(m_loaded.named_modules())))
self.assertEqual(set(name for name, _ in m.named_modules()), set(name for name, _ in m_loaded.named_modules()))
self.assertEqual(
len(list(m.named_modules())), len(list(m_loaded.named_modules()))
)
self.assertEqual(
set(name for name, _ in m.named_modules()),
set(name for name, _ in m_loaded.named_modules()),
)
# Check parameters.
m_params = dict(m.named_parameters())
m_loaded_params = dict(m_loaded.named_parameters())
@ -518,24 +539,36 @@ class TestSaveLoad(JitTestCase):
self.assertEqual(len(m_buffers), len(m_loaded_buffers))
self.assertEqual(m_buffers, m_loaded_buffers)
# Check params and buffers that are/are not meta tensors
self.assertTrue(m_params['foo.weight'].is_meta)
self.assertTrue(m_loaded_params['foo.weight'].is_meta)
self.assertTrue(m_params['foo.bias'].is_meta)
self.assertTrue(m_loaded_params['foo.bias'].is_meta)
self.assertFalse(m_params['bar.weight'].is_meta)
self.assertFalse(m_loaded_params['bar.weight'].is_meta)
self.assertFalse(m_params['bar.bias'].is_meta)
self.assertFalse(m_loaded_params['bar.bias'].is_meta)
self.assertTrue(m_buffers['buffer'].is_meta)
self.assertTrue(m_loaded_buffers['buffer'].is_meta)
self.assertTrue(m_params["foo.weight"].is_meta)
self.assertTrue(m_loaded_params["foo.weight"].is_meta)
self.assertTrue(m_params["foo.bias"].is_meta)
self.assertTrue(m_loaded_params["foo.bias"].is_meta)
self.assertFalse(m_params["bar.weight"].is_meta)
self.assertFalse(m_loaded_params["bar.weight"].is_meta)
self.assertFalse(m_params["bar.bias"].is_meta)
self.assertFalse(m_loaded_params["bar.bias"].is_meta)
self.assertTrue(m_buffers["buffer"].is_meta)
self.assertTrue(m_loaded_buffers["buffer"].is_meta)
def script_module_to_buffer(script_module):
module_buffer = io.BytesIO(
script_module._save_to_buffer_for_lite_interpreter(_use_flatbuffer=True)
)
module_buffer.seek(0)
return module_buffer
@unittest.skipIf(
not ENABLE_FLATBUFFER, "Need to enable flatbuffer to run the below tests"
)
class TestSaveLoadFlatbuffer(JitTestCase):
def test_different_modules(self):
"""
Exercise the situation where we have the same qualified name
in two different CompilationUnits on save/load.
"""
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
@ -548,9 +581,7 @@ class TestSaveLoadFlatbuffer(JitTestCase):
return x
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module)
first_saved_module.seek(0)
first_saved_module = script_module_to_buffer(first_script_module)
clear_class_registry()
@ -564,21 +595,24 @@ class TestSaveLoadFlatbuffer(JitTestCase):
return x
second_script_module = torch.jit.script(Foo())
second_saved_module = io.BytesIO()
torch.jit.save_jit_module_to_flatbuffer(torch.jit.script(Foo()), second_saved_module)
second_saved_module.seek(0)
second_saved_module = script_module_to_buffer(second_script_module)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name, second_script_module._c.qualified_name
first_script_module._c.qualified_name,
second_script_module._c.qualified_name,
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("second", torch.jit.jit_module_from_flatbuffer(second_saved_module))
self.add_module("first", torch.jit.jit_module_from_flatbuffer(first_saved_module))
self.add_module(
"second", torch.jit.load(second_saved_module)
)
self.add_module(
"first", torch.jit.load(first_saved_module)
)
def forward(self, x):
x = self.first(x)
@ -586,16 +620,15 @@ class TestSaveLoadFlatbuffer(JitTestCase):
return x
sm = torch.jit.script(ContainsBoth())
contains_both = io.BytesIO()
torch.jit.save_jit_module_to_flatbuffer(sm, contains_both)
contains_both.seek(0)
sm = torch.jit.jit_module_from_flatbuffer(contains_both)
contains_both = script_module_to_buffer(sm)
sm = torch.jit.load(contains_both)
def test_different_functions(self):
"""
Exercise the situation where we have the same qualified name
in two different CompilationUnits on save/load.
"""
def lol(x):
return x
@ -604,10 +637,7 @@ class TestSaveLoadFlatbuffer(JitTestCase):
return lol(x)
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module)
first_saved_module.seek(0)
first_saved_module = script_module_to_buffer(first_script_module)
clear_class_registry()
def lol(x): # noqa: F811
@ -618,21 +648,24 @@ class TestSaveLoadFlatbuffer(JitTestCase):
return lol(x)
second_script_module = torch.jit.script(Foo())
second_saved_module = io.BytesIO()
torch.jit.save_jit_module_to_flatbuffer(torch.jit.script(Foo()), second_saved_module)
second_saved_module.seek(0)
second_saved_module = script_module_to_buffer(second_script_module)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name, second_script_module._c.qualified_name
first_script_module._c.qualified_name,
second_script_module._c.qualified_name,
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("second", torch.jit.jit_module_from_flatbuffer(second_saved_module))
self.add_module("first", torch.jit.jit_module_from_flatbuffer(first_saved_module))
self.add_module(
"second", torch.jit.load(second_saved_module)
)
self.add_module(
"first", torch.jit.load(first_saved_module)
)
def forward(self, x):
x = self.first(x)
@ -640,16 +673,15 @@ class TestSaveLoadFlatbuffer(JitTestCase):
return x
sm = torch.jit.script(ContainsBoth())
contains_both = io.BytesIO()
torch.jit.save_jit_module_to_flatbuffer(sm, contains_both)
contains_both.seek(0)
sm = torch.jit.jit_module_from_flatbuffer(contains_both)
contains_both = script_module_to_buffer(sm)
sm = torch.jit.load(contains_both)
def test_different_interfaces(self):
"""
Exercise the situation where we have the same qualified name
in two different CompilationUnits on save/load.
"""
@torch.jit.interface
class MyInterface(object):
def bar(self, x: Tensor) -> Tensor:
@ -674,10 +706,7 @@ class TestSaveLoadFlatbuffer(JitTestCase):
return self.interface.bar(x)
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module)
first_saved_module.seek(0)
first_saved_module = script_module_to_buffer(first_script_module)
clear_class_registry()
@torch.jit.interface
@ -704,21 +733,24 @@ class TestSaveLoadFlatbuffer(JitTestCase):
return self.interface.not_bar(x)
second_script_module = torch.jit.script(Foo())
second_saved_module = io.BytesIO()
torch.jit.save_jit_module_to_flatbuffer(torch.jit.script(Foo()), second_saved_module)
second_saved_module.seek(0)
second_saved_module = script_module_to_buffer(second_script_module)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name, second_script_module._c.qualified_name
first_script_module._c.qualified_name,
second_script_module._c.qualified_name,
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("second", torch.jit.jit_module_from_flatbuffer(second_saved_module))
self.add_module("first", torch.jit.jit_module_from_flatbuffer(first_saved_module))
self.add_module(
"second", torch.jit.load(second_saved_module)
)
self.add_module(
"first", torch.jit.load(first_saved_module)
)
def forward(self, x):
x = self.first(x)
@ -726,10 +758,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
return x
sm = torch.jit.script(ContainsBoth())
contains_both = io.BytesIO()
torch.jit.save_jit_module_to_flatbuffer(sm, contains_both)
contains_both.seek(0)
sm = torch.jit.jit_module_from_flatbuffer(contains_both)
contains_both = script_module_to_buffer(sm)
sm = torch.jit.load(contains_both)
def test_many_collisions(self):
class MyCoolNamedTuple(NamedTuple):
@ -768,11 +798,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
return x, MyCoolNamedTuple(a=5)
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module)
first_saved_module.seek(0)
first_saved_module = script_module_to_buffer(first_script_module)
clear_class_registry()
@ -810,21 +837,24 @@ class TestSaveLoadFlatbuffer(JitTestCase):
return x, MyCoolNamedTuple(a="hello")
second_script_module = torch.jit.script(Foo())
second_saved_module = io.BytesIO()
torch.jit.save_jit_module_to_flatbuffer(second_script_module, second_saved_module)
second_saved_module.seek(0)
second_saved_module = script_module_to_buffer(second_script_module)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name, second_script_module._c.qualified_name
first_script_module._c.qualified_name,
second_script_module._c.qualified_name,
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("second", torch.jit.jit_module_from_flatbuffer(second_saved_module))
self.add_module("first", torch.jit.jit_module_from_flatbuffer(first_saved_module))
self.add_module(
"second", torch.jit.load(second_saved_module)
)
self.add_module(
"first", torch.jit.load(first_saved_module)
)
def forward(self, x):
x, named_tuple_1 = self.first(x)
@ -832,10 +862,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
return len(x + named_tuple_2.a) + named_tuple_1.a
sm = torch.jit.script(ContainsBoth())
contains_both = io.BytesIO()
torch.jit.save_jit_module_to_flatbuffer(sm, contains_both)
contains_both.seek(0)
sm = torch.jit.jit_module_from_flatbuffer(contains_both)
contains_both = script_module_to_buffer(sm)
sm = torch.jit.load(contains_both)
def test_save_load_using_pathlib(self):
class MyMod(torch.jit.ScriptModule):
@ -849,9 +877,9 @@ class TestSaveLoadFlatbuffer(JitTestCase):
with TemporaryFileName() as fname:
path = pathlib.Path(fname)
torch.jit.save_jit_module_to_flatbuffer(m, path)
m2 = torch.jit.jit_module_from_flatbuffer(path)
m2 = torch.jit.load(path)
x = torch.tensor([1., 2., 3., 4.])
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
self.assertTrue(torch.equal(m(x), m2(x)))
def test_save_namedtuple_input_only(self):
@ -903,7 +931,9 @@ class TestSaveLoadFlatbuffer(JitTestCase):
def __init__(self):
super().__init__()
self.add_module("submodule_a", Submodule())
self.register_parameter("parameter_a", torch.nn.Parameter(torch.randn(4)))
self.register_parameter(
"parameter_a", torch.nn.Parameter(torch.randn(4))
)
self.register_buffer("buffer", torch.randn(4))
self.t = torch.rand(4) # not buffer
@ -914,7 +944,9 @@ class TestSaveLoadFlatbuffer(JitTestCase):
m_loaded = self.getExportImportCopy(torch.jit.script(m))
# Check submodules.
self.assertEqual(len(list(m.named_modules())), len(list(m_loaded.named_modules())))
self.assertEqual(
len(list(m.named_modules())), len(list(m_loaded.named_modules()))
)
for m_s, loaded_s in zip(m.named_modules(), m_loaded.named_modules()):
m_name, _ = m_s
loaded_name, _ = loaded_s
@ -926,7 +958,9 @@ class TestSaveLoadFlatbuffer(JitTestCase):
self.assertEqual(m_p, loaded_p)
# Check buffers.
self.assertEqual(len(list(m.named_buffers())), len(list(m_loaded.named_buffers())))
self.assertEqual(
len(list(m.named_buffers())), len(list(m_loaded.named_buffers()))
)
for m_b, loaded_b in zip(m.named_buffers(), m_loaded.named_buffers()):
m_name, m_buffer = m_b
loaded_name, loaded_buffer = loaded_b

View File

@ -540,6 +540,7 @@ mobile::Module _load_for_mobile(
std::istream& in,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
in.seekg(0, in.beg);
auto format = getFileFormat(in);
switch (format) {
case FileFormat::ZipFileFormat: {

View File

@ -797,9 +797,13 @@ void save_mobile_module_to(
const ExtraFilesMap& extra_files,
bool save_mobile_debug_info,
const std::function<size_t(const void*, size_t)>& writer_func) {
ExtraFilesMap jitFiles;
CompilationOptions options = getOptionsFromGlobal();
std::vector<IValue> constants;
jitModuleToPythonCodeAndConstants(module, &jitFiles, &constants);
mobile::Module mod = jitModuleToMobile(module, options);
auto buffer = save_mobile_module_to_bytes(mod, extra_files);
auto buffer =
save_mobile_module_to_bytes(mod, extra_files, jitFiles, constants);
writer_func(reinterpret_cast<void*>(buffer.data()), buffer.size());
}
#endif

View File

@ -789,6 +789,14 @@ Module load_jit_module_from_file(
std::move(std::get<0>(data)), std::get<1>(data), device);
}
Module load_jit_module_from_stream(
std::istream& in,
c10::optional<at::Device> device) {
auto data = get_stream_content(in);
return parse_and_initialize_jit_module(
std::move(std::get<0>(data)), std::get<1>(data), device);
}
void save_jit_module(
const Module& module,
const std::string& filename,

View File

@ -47,5 +47,9 @@ TORCH_API Module load_jit_module_from_file(
const std::string& filename,
c10::optional<at::Device> device = c10::nullopt);
TORCH_API Module load_jit_module_from_stream(
std::istream& in,
c10::optional<at::Device> device = c10::nullopt);
} // namespace jit
} // namespace torch

View File

@ -10,6 +10,7 @@
#endif
#include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/mobile/file_format.h>
#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/serialization/import_read.h>
@ -18,6 +19,10 @@
#include <torch/csrc/jit/serialization/source_range_serialization.h>
#include <torch/csrc/jit/serialization/unpickler.h>
#if defined(ENABLE_FLATBUFFER)
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
#endif
#include <caffe2/serialize/file_adapter.h>
#include <caffe2/serialize/inline_container.h>
#include <caffe2/serialize/istream_adapter.h>
@ -290,9 +295,26 @@ Module import_ir_module(
std::istream& in,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
auto reader = torch::make_unique<PyTorchStreamReader>(&in);
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files);
in.seekg(0, in.beg);
auto format = getFileFormat(in);
switch (format) {
case FileFormat::FlatbufferFileFormat: {
#if defined(ENABLE_FLATBUFFER)
return load_jit_module_from_stream(in, device);
#else
TORCH_CHECK(
false, "Flatbuffer input file but the build hasn't enable flatbuffer")
#endif
}
case FileFormat::ZipFileFormat: {
auto reader = torch::make_unique<PyTorchStreamReader>(&in);
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files);
}
default:
TORCH_CHECK(false, "Unrecognized data format");
}
}
// For reading unified serialization format from torch.Package.
@ -325,9 +347,25 @@ Module import_ir_module(
const std::string& filename,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
auto reader = torch::make_unique<PyTorchStreamReader>(filename);
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files);
auto format = getFileFormat(filename);
switch (format) {
case FileFormat::FlatbufferFileFormat: {
#if defined(ENABLE_FLATBUFFER)
return load_jit_module_from_file(filename, device);
#else
TORCH_CHECK(
false, "Flatbuffer input file but the build hasn't enable flatbuffer")
#endif
}
case FileFormat::ZipFileFormat: {
auto reader = torch::make_unique<PyTorchStreamReader>(filename);
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files);
}
default:
TORCH_CHECK(false, "Unrecognized data format");
}
}
Module import_ir_module(
@ -357,9 +395,27 @@ Module load(
std::istream& in,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
auto module = load(std::move(rai), device, extra_files);
return module;
in.seekg(0, in.beg);
auto format = getFileFormat(in);
switch (format) {
case FileFormat::FlatbufferFileFormat: {
#if defined(ENABLE_FLATBUFFER)
return load_jit_module_from_stream(in, device);
#else
TORCH_CHECK(
false, "Flatbuffer input file but the build hasn't enable flatbuffer")
#endif
}
case FileFormat::ZipFileFormat: {
std::unique_ptr<IStreamAdapter> rai =
std::make_unique<IStreamAdapter>(&in);
auto module = load(std::move(rai), device, extra_files);
return module;
}
default:
TORCH_CHECK(false, "Unrecognized data format");
}
}
Module load(const std::string& filename, c10::optional<at::Device> device) {
@ -371,9 +427,27 @@ Module load(
const std::string& filename,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
auto module = load(std::move(rai), device, extra_files);
return module;
auto format = getFileFormat(filename);
switch (format) {
case FileFormat::FlatbufferFileFormat: {
#if defined(ENABLE_FLATBUFFER)
return load_jit_module_from_file(filename, device);
#else
TORCH_CHECK(
false, "Flatbuffer input file but the build hasn't enable flatbuffer")
#endif
case FileFormat::ZipFileFormat: {
std::unique_ptr<FileAdapter> rai =
std::make_unique<FileAdapter>(filename);
auto module = load(std::move(rai), device, extra_files);
return module;
}
default:
TORCH_CHECK(false, "Unrecognized data format");
}
}
}
Module load(
@ -387,8 +461,8 @@ Module load(
std::shared_ptr<ReadAdapterInterface> rai,
c10::optional<c10::Device> device,
ExtraFilesMap& extra_files) {
// Verify that we're loading a zip archive and not a torch.save pickle archive
// (marked by the 0x80 0x02 bytes at the start)
// Verify that we're loading a zip archive and not a torch.save pickle
// archive (marked by the 0x80 0x02 bytes at the start)
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
TORCH_CHECK(
check_zip_file(rai),