mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Extend jit::load to work on flatbuffer file; Take 2 (#75256)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/75256 ghstack-source-id: 153138970 Test Plan: CI Reviewed By: iseeyuan Differential Revision: D35399581 fbshipit-source-id: dafe9d301009d3f70986ed92bfe06d160ab90ba0 (cherry picked from commit ccc860fd07946de5aae12bc179a0b8bbba83b997)
This commit is contained in:
parent
ff7051781f
commit
f984e50f39
|
|
@ -355,6 +355,7 @@ TEST(SerializeTest, ErrorOnMissingKey) {
|
|||
// We want the errors to contain hierarchy information, too.
|
||||
ASSERT_THROWS_WITH(
|
||||
torch::load(model2, stream), "No such serialized tensor 'a.b.x'");
|
||||
stream.seekg(0, stream.beg);
|
||||
ASSERT_THROWS_WITH(
|
||||
torch::load(model3, stream), "No such serialized submodule: 'a.x'");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -276,6 +276,7 @@ TEST(BackendTest, TestConsistencyOfCompositeWithSetStates) {
|
|||
c._save_for_mobile(ss);
|
||||
auto mc = _load_for_mobile(ss);
|
||||
auto res_mobile = mc.forward(inputs);
|
||||
ss.seekg(0, ss.beg);
|
||||
|
||||
// check if the methods names are always the same
|
||||
// by reloading the script module and saving it back as mobile
|
||||
|
|
|
|||
|
|
@ -1177,6 +1177,7 @@ Module jitModuleFromBuffer(void* data) {
|
|||
mobilem._ivalue(), files, constants, 8);
|
||||
}
|
||||
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
TEST(TestSourceFlatbuffer, UpsampleNearest2d) {
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
|
|
@ -1189,20 +1190,21 @@ TEST(TestSourceFlatbuffer, UpsampleNearest2d) {
|
|||
inputs.emplace_back(at::Scalar(2.0));
|
||||
auto ref = m.forward(inputs);
|
||||
|
||||
auto data = save_jit_module_to_bytes(m);
|
||||
Module m2 = jitModuleFromBuffer(data.data());
|
||||
std::stringstream ss;
|
||||
m._save_for_mobile(ss, {}, false, /*use_fatbuffer=*/true);
|
||||
auto mm = _load_for_mobile(ss);
|
||||
auto m2 = load(ss);
|
||||
|
||||
auto res = m2.forward(inputs);
|
||||
auto resm = mm.forward(inputs);
|
||||
|
||||
auto resd = res.toTensor();
|
||||
auto refd = ref.toTensor();
|
||||
auto resmd = resm.toTensor();
|
||||
ASSERT_TRUE(resd.equal(refd));
|
||||
|
||||
mobile::Module m3 = parse_mobile_module(data.data(), data.size());
|
||||
res = m3.forward(inputs);
|
||||
resd = res.toTensor();
|
||||
refd = ref.toTensor();
|
||||
ASSERT_TRUE(resd.equal(refd));
|
||||
ASSERT_TRUE(resmd.equal(refd));
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(TestSourceFlatbuffer, CheckAttrAccess) {
|
||||
Module m("m");
|
||||
|
|
|
|||
|
|
@ -1,20 +1,22 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from typing import NamedTuple, Optional
|
||||
import io
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import unittest
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.testing._internal.common_utils import TemporaryFileName
|
||||
import torch
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import (JitTestCase,
|
||||
clear_class_registry)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, clear_class_registry
|
||||
|
||||
ENABLE_FLATBUFFER = os.environ.get("ENABLE_FLATBUFFER", "0") == "1"
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
|
|
@ -23,12 +25,14 @@ if __name__ == "__main__":
|
|||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestSaveLoad(JitTestCase):
|
||||
def test_different_modules(self):
|
||||
"""
|
||||
Exercise the situation where we have the same qualified name
|
||||
in two different CompilationUnits on save/load.
|
||||
"""
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Foo, self).__init__()
|
||||
|
|
@ -64,7 +68,8 @@ class TestSaveLoad(JitTestCase):
|
|||
clear_class_registry()
|
||||
|
||||
self.assertEqual(
|
||||
first_script_module._c.qualified_name, second_script_module._c.qualified_name
|
||||
first_script_module._c.qualified_name,
|
||||
second_script_module._c.qualified_name,
|
||||
)
|
||||
|
||||
class ContainsBoth(torch.nn.Module):
|
||||
|
|
@ -89,6 +94,7 @@ class TestSaveLoad(JitTestCase):
|
|||
Exercise the situation where we have the same qualified name
|
||||
in two different CompilationUnits on save/load.
|
||||
"""
|
||||
|
||||
def lol(x):
|
||||
return x
|
||||
|
||||
|
|
@ -118,7 +124,8 @@ class TestSaveLoad(JitTestCase):
|
|||
clear_class_registry()
|
||||
|
||||
self.assertEqual(
|
||||
first_script_module._c.qualified_name, second_script_module._c.qualified_name
|
||||
first_script_module._c.qualified_name,
|
||||
second_script_module._c.qualified_name,
|
||||
)
|
||||
|
||||
class ContainsBoth(torch.nn.Module):
|
||||
|
|
@ -143,6 +150,7 @@ class TestSaveLoad(JitTestCase):
|
|||
Exercise the situation where we have the same qualified name
|
||||
in two different CompilationUnits on save/load.
|
||||
"""
|
||||
|
||||
@torch.jit.interface
|
||||
class MyInterface(object):
|
||||
def bar(self, x: Tensor) -> Tensor:
|
||||
|
|
@ -204,7 +212,8 @@ class TestSaveLoad(JitTestCase):
|
|||
clear_class_registry()
|
||||
|
||||
self.assertEqual(
|
||||
first_script_module._c.qualified_name, second_script_module._c.qualified_name
|
||||
first_script_module._c.qualified_name,
|
||||
second_script_module._c.qualified_name,
|
||||
)
|
||||
|
||||
class ContainsBoth(torch.nn.Module):
|
||||
|
|
@ -261,7 +270,6 @@ class TestSaveLoad(JitTestCase):
|
|||
|
||||
return x, MyCoolNamedTuple(a=5)
|
||||
|
||||
|
||||
first_script_module = torch.jit.script(Foo())
|
||||
first_saved_module = io.BytesIO()
|
||||
torch.jit.save(first_script_module, first_saved_module)
|
||||
|
|
@ -310,7 +318,8 @@ class TestSaveLoad(JitTestCase):
|
|||
clear_class_registry()
|
||||
|
||||
self.assertEqual(
|
||||
first_script_module._c.qualified_name, second_script_module._c.qualified_name
|
||||
first_script_module._c.qualified_name,
|
||||
second_script_module._c.qualified_name,
|
||||
)
|
||||
|
||||
class ContainsBoth(torch.nn.Module):
|
||||
|
|
@ -340,44 +349,44 @@ class TestSaveLoad(JitTestCase):
|
|||
value = b"bar\x00\xffbaz"
|
||||
|
||||
expected_extra_files = {}
|
||||
expected_extra_files['foo'] = value
|
||||
expected_extra_files["foo"] = value
|
||||
# verify that str to bytes conversion also works
|
||||
expected_extra_files['foo2'] = "bar"
|
||||
expected_extra_files["foo2"] = "bar"
|
||||
m = MyMod()
|
||||
|
||||
# Save to file.
|
||||
with TemporaryFileName() as fname:
|
||||
m.save(fname, _extra_files=expected_extra_files)
|
||||
# values don't matter
|
||||
extra_files = {'foo': '', 'foo2': None}
|
||||
extra_files = {"foo": "", "foo2": None}
|
||||
torch.jit.load(fname, _extra_files=extra_files)
|
||||
self.assertEqual(value, extra_files['foo'])
|
||||
self.assertEqual(value, extra_files["foo"])
|
||||
# results come back always as bytes
|
||||
self.assertEqual(b"bar", extra_files['foo2'])
|
||||
self.assertEqual(b"bar", extra_files["foo2"])
|
||||
|
||||
# Use torch.jit API
|
||||
torch.jit.save(m, fname, _extra_files=expected_extra_files)
|
||||
extra_files['foo'] = ''
|
||||
extra_files["foo"] = ""
|
||||
torch.jit.load(fname, _extra_files=extra_files)
|
||||
self.assertEqual(value, extra_files['foo'])
|
||||
self.assertEqual(value, extra_files["foo"])
|
||||
|
||||
# Save to buffer.
|
||||
buffer = io.BytesIO(m.save_to_buffer(_extra_files=expected_extra_files))
|
||||
extra_files = {'foo': ''}
|
||||
extra_files = {"foo": ""}
|
||||
torch.jit.load(buffer, _extra_files=extra_files)
|
||||
self.assertEqual(value, extra_files['foo'])
|
||||
self.assertEqual(value, extra_files["foo"])
|
||||
|
||||
# Use torch.jit API
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(m, buffer, _extra_files=expected_extra_files)
|
||||
buffer.seek(0)
|
||||
extra_files = {'foo': ''}
|
||||
extra_files = {"foo": ""}
|
||||
torch.jit.load(buffer, _extra_files=extra_files)
|
||||
self.assertEqual(value, extra_files['foo'])
|
||||
self.assertEqual(value, extra_files["foo"])
|
||||
|
||||
# Non-existent file 'bar'
|
||||
with self.assertRaises(RuntimeError):
|
||||
extra_files['bar'] = ''
|
||||
extra_files["bar"] = ""
|
||||
torch.jit.load(buffer, _extra_files=extra_files)
|
||||
|
||||
def test_save_load_using_pathlib(self):
|
||||
|
|
@ -394,7 +403,7 @@ class TestSaveLoad(JitTestCase):
|
|||
m.save(path)
|
||||
m2 = torch.jit.load(path)
|
||||
|
||||
x = torch.tensor([1., 2., 3., 4.])
|
||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
self.assertTrue(torch.equal(m(x), m2(x)))
|
||||
|
||||
def test_save_nonexit_file(self):
|
||||
|
|
@ -455,7 +464,9 @@ class TestSaveLoad(JitTestCase):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_module("submodule_a", Submodule())
|
||||
self.register_parameter("parameter_a", torch.nn.Parameter(torch.randn(4)))
|
||||
self.register_parameter(
|
||||
"parameter_a", torch.nn.Parameter(torch.randn(4))
|
||||
)
|
||||
self.register_buffer("buffer", torch.randn(4))
|
||||
self.t = torch.rand(4) # not buffer
|
||||
|
||||
|
|
@ -466,7 +477,9 @@ class TestSaveLoad(JitTestCase):
|
|||
m_loaded = self.getExportImportCopy(torch.jit.script(m))
|
||||
|
||||
# Check submodules.
|
||||
self.assertEqual(len(list(m.named_modules())), len(list(m_loaded.named_modules())))
|
||||
self.assertEqual(
|
||||
len(list(m.named_modules())), len(list(m_loaded.named_modules()))
|
||||
)
|
||||
for m_s, loaded_s in zip(m.named_modules(), m_loaded.named_modules()):
|
||||
m_name, _ = m_s
|
||||
loaded_name, _ = loaded_s
|
||||
|
|
@ -478,7 +491,9 @@ class TestSaveLoad(JitTestCase):
|
|||
self.assertEqual(m_p, loaded_p)
|
||||
|
||||
# Check buffers.
|
||||
self.assertEqual(len(list(m.named_buffers())), len(list(m_loaded.named_buffers())))
|
||||
self.assertEqual(
|
||||
len(list(m.named_buffers())), len(list(m_loaded.named_buffers()))
|
||||
)
|
||||
for m_b, loaded_b in zip(m.named_buffers(), m_loaded.named_buffers()):
|
||||
m_name, m_buffer = m_b
|
||||
loaded_name, loaded_buffer = loaded_b
|
||||
|
|
@ -490,6 +505,7 @@ class TestSaveLoad(JitTestCase):
|
|||
Check that parameters, buffers, and submodules are the same after loading
|
||||
for a module with parameters and buffers that are meta tensors
|
||||
"""
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Foo, self).__init__()
|
||||
|
|
@ -505,8 +521,13 @@ class TestSaveLoad(JitTestCase):
|
|||
m = Foo()
|
||||
m_loaded = self.getExportImportCopy(torch.jit.script(m))
|
||||
# Check submodules.
|
||||
self.assertEqual(len(list(m.named_modules())), len(list(m_loaded.named_modules())))
|
||||
self.assertEqual(set(name for name, _ in m.named_modules()), set(name for name, _ in m_loaded.named_modules()))
|
||||
self.assertEqual(
|
||||
len(list(m.named_modules())), len(list(m_loaded.named_modules()))
|
||||
)
|
||||
self.assertEqual(
|
||||
set(name for name, _ in m.named_modules()),
|
||||
set(name for name, _ in m_loaded.named_modules()),
|
||||
)
|
||||
# Check parameters.
|
||||
m_params = dict(m.named_parameters())
|
||||
m_loaded_params = dict(m_loaded.named_parameters())
|
||||
|
|
@ -518,24 +539,36 @@ class TestSaveLoad(JitTestCase):
|
|||
self.assertEqual(len(m_buffers), len(m_loaded_buffers))
|
||||
self.assertEqual(m_buffers, m_loaded_buffers)
|
||||
# Check params and buffers that are/are not meta tensors
|
||||
self.assertTrue(m_params['foo.weight'].is_meta)
|
||||
self.assertTrue(m_loaded_params['foo.weight'].is_meta)
|
||||
self.assertTrue(m_params['foo.bias'].is_meta)
|
||||
self.assertTrue(m_loaded_params['foo.bias'].is_meta)
|
||||
self.assertFalse(m_params['bar.weight'].is_meta)
|
||||
self.assertFalse(m_loaded_params['bar.weight'].is_meta)
|
||||
self.assertFalse(m_params['bar.bias'].is_meta)
|
||||
self.assertFalse(m_loaded_params['bar.bias'].is_meta)
|
||||
self.assertTrue(m_buffers['buffer'].is_meta)
|
||||
self.assertTrue(m_loaded_buffers['buffer'].is_meta)
|
||||
self.assertTrue(m_params["foo.weight"].is_meta)
|
||||
self.assertTrue(m_loaded_params["foo.weight"].is_meta)
|
||||
self.assertTrue(m_params["foo.bias"].is_meta)
|
||||
self.assertTrue(m_loaded_params["foo.bias"].is_meta)
|
||||
self.assertFalse(m_params["bar.weight"].is_meta)
|
||||
self.assertFalse(m_loaded_params["bar.weight"].is_meta)
|
||||
self.assertFalse(m_params["bar.bias"].is_meta)
|
||||
self.assertFalse(m_loaded_params["bar.bias"].is_meta)
|
||||
self.assertTrue(m_buffers["buffer"].is_meta)
|
||||
self.assertTrue(m_loaded_buffers["buffer"].is_meta)
|
||||
|
||||
|
||||
def script_module_to_buffer(script_module):
|
||||
module_buffer = io.BytesIO(
|
||||
script_module._save_to_buffer_for_lite_interpreter(_use_flatbuffer=True)
|
||||
)
|
||||
module_buffer.seek(0)
|
||||
return module_buffer
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not ENABLE_FLATBUFFER, "Need to enable flatbuffer to run the below tests"
|
||||
)
|
||||
class TestSaveLoadFlatbuffer(JitTestCase):
|
||||
def test_different_modules(self):
|
||||
"""
|
||||
Exercise the situation where we have the same qualified name
|
||||
in two different CompilationUnits on save/load.
|
||||
"""
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Foo, self).__init__()
|
||||
|
|
@ -548,9 +581,7 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
|||
return x
|
||||
|
||||
first_script_module = torch.jit.script(Foo())
|
||||
first_saved_module = io.BytesIO()
|
||||
torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module)
|
||||
first_saved_module.seek(0)
|
||||
first_saved_module = script_module_to_buffer(first_script_module)
|
||||
|
||||
clear_class_registry()
|
||||
|
||||
|
|
@ -564,21 +595,24 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
|||
return x
|
||||
|
||||
second_script_module = torch.jit.script(Foo())
|
||||
second_saved_module = io.BytesIO()
|
||||
torch.jit.save_jit_module_to_flatbuffer(torch.jit.script(Foo()), second_saved_module)
|
||||
second_saved_module.seek(0)
|
||||
second_saved_module = script_module_to_buffer(second_script_module)
|
||||
|
||||
clear_class_registry()
|
||||
|
||||
self.assertEqual(
|
||||
first_script_module._c.qualified_name, second_script_module._c.qualified_name
|
||||
first_script_module._c.qualified_name,
|
||||
second_script_module._c.qualified_name,
|
||||
)
|
||||
|
||||
class ContainsBoth(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_module("second", torch.jit.jit_module_from_flatbuffer(second_saved_module))
|
||||
self.add_module("first", torch.jit.jit_module_from_flatbuffer(first_saved_module))
|
||||
self.add_module(
|
||||
"second", torch.jit.load(second_saved_module)
|
||||
)
|
||||
self.add_module(
|
||||
"first", torch.jit.load(first_saved_module)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.first(x)
|
||||
|
|
@ -586,16 +620,15 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
|||
return x
|
||||
|
||||
sm = torch.jit.script(ContainsBoth())
|
||||
contains_both = io.BytesIO()
|
||||
torch.jit.save_jit_module_to_flatbuffer(sm, contains_both)
|
||||
contains_both.seek(0)
|
||||
sm = torch.jit.jit_module_from_flatbuffer(contains_both)
|
||||
contains_both = script_module_to_buffer(sm)
|
||||
sm = torch.jit.load(contains_both)
|
||||
|
||||
def test_different_functions(self):
|
||||
"""
|
||||
Exercise the situation where we have the same qualified name
|
||||
in two different CompilationUnits on save/load.
|
||||
"""
|
||||
|
||||
def lol(x):
|
||||
return x
|
||||
|
||||
|
|
@ -604,10 +637,7 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
|||
return lol(x)
|
||||
|
||||
first_script_module = torch.jit.script(Foo())
|
||||
first_saved_module = io.BytesIO()
|
||||
torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module)
|
||||
first_saved_module.seek(0)
|
||||
|
||||
first_saved_module = script_module_to_buffer(first_script_module)
|
||||
clear_class_registry()
|
||||
|
||||
def lol(x): # noqa: F811
|
||||
|
|
@ -618,21 +648,24 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
|||
return lol(x)
|
||||
|
||||
second_script_module = torch.jit.script(Foo())
|
||||
second_saved_module = io.BytesIO()
|
||||
torch.jit.save_jit_module_to_flatbuffer(torch.jit.script(Foo()), second_saved_module)
|
||||
second_saved_module.seek(0)
|
||||
second_saved_module = script_module_to_buffer(second_script_module)
|
||||
|
||||
clear_class_registry()
|
||||
|
||||
self.assertEqual(
|
||||
first_script_module._c.qualified_name, second_script_module._c.qualified_name
|
||||
first_script_module._c.qualified_name,
|
||||
second_script_module._c.qualified_name,
|
||||
)
|
||||
|
||||
class ContainsBoth(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_module("second", torch.jit.jit_module_from_flatbuffer(second_saved_module))
|
||||
self.add_module("first", torch.jit.jit_module_from_flatbuffer(first_saved_module))
|
||||
self.add_module(
|
||||
"second", torch.jit.load(second_saved_module)
|
||||
)
|
||||
self.add_module(
|
||||
"first", torch.jit.load(first_saved_module)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.first(x)
|
||||
|
|
@ -640,16 +673,15 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
|||
return x
|
||||
|
||||
sm = torch.jit.script(ContainsBoth())
|
||||
contains_both = io.BytesIO()
|
||||
torch.jit.save_jit_module_to_flatbuffer(sm, contains_both)
|
||||
contains_both.seek(0)
|
||||
sm = torch.jit.jit_module_from_flatbuffer(contains_both)
|
||||
contains_both = script_module_to_buffer(sm)
|
||||
sm = torch.jit.load(contains_both)
|
||||
|
||||
def test_different_interfaces(self):
|
||||
"""
|
||||
Exercise the situation where we have the same qualified name
|
||||
in two different CompilationUnits on save/load.
|
||||
"""
|
||||
|
||||
@torch.jit.interface
|
||||
class MyInterface(object):
|
||||
def bar(self, x: Tensor) -> Tensor:
|
||||
|
|
@ -674,10 +706,7 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
|||
return self.interface.bar(x)
|
||||
|
||||
first_script_module = torch.jit.script(Foo())
|
||||
first_saved_module = io.BytesIO()
|
||||
torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module)
|
||||
first_saved_module.seek(0)
|
||||
|
||||
first_saved_module = script_module_to_buffer(first_script_module)
|
||||
clear_class_registry()
|
||||
|
||||
@torch.jit.interface
|
||||
|
|
@ -704,21 +733,24 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
|||
return self.interface.not_bar(x)
|
||||
|
||||
second_script_module = torch.jit.script(Foo())
|
||||
second_saved_module = io.BytesIO()
|
||||
torch.jit.save_jit_module_to_flatbuffer(torch.jit.script(Foo()), second_saved_module)
|
||||
second_saved_module.seek(0)
|
||||
second_saved_module = script_module_to_buffer(second_script_module)
|
||||
|
||||
clear_class_registry()
|
||||
|
||||
self.assertEqual(
|
||||
first_script_module._c.qualified_name, second_script_module._c.qualified_name
|
||||
first_script_module._c.qualified_name,
|
||||
second_script_module._c.qualified_name,
|
||||
)
|
||||
|
||||
class ContainsBoth(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_module("second", torch.jit.jit_module_from_flatbuffer(second_saved_module))
|
||||
self.add_module("first", torch.jit.jit_module_from_flatbuffer(first_saved_module))
|
||||
self.add_module(
|
||||
"second", torch.jit.load(second_saved_module)
|
||||
)
|
||||
self.add_module(
|
||||
"first", torch.jit.load(first_saved_module)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.first(x)
|
||||
|
|
@ -726,10 +758,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
|||
return x
|
||||
|
||||
sm = torch.jit.script(ContainsBoth())
|
||||
contains_both = io.BytesIO()
|
||||
torch.jit.save_jit_module_to_flatbuffer(sm, contains_both)
|
||||
contains_both.seek(0)
|
||||
sm = torch.jit.jit_module_from_flatbuffer(contains_both)
|
||||
contains_both = script_module_to_buffer(sm)
|
||||
sm = torch.jit.load(contains_both)
|
||||
|
||||
def test_many_collisions(self):
|
||||
class MyCoolNamedTuple(NamedTuple):
|
||||
|
|
@ -768,11 +798,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
|||
|
||||
return x, MyCoolNamedTuple(a=5)
|
||||
|
||||
|
||||
first_script_module = torch.jit.script(Foo())
|
||||
first_saved_module = io.BytesIO()
|
||||
torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module)
|
||||
first_saved_module.seek(0)
|
||||
first_saved_module = script_module_to_buffer(first_script_module)
|
||||
|
||||
clear_class_registry()
|
||||
|
||||
|
|
@ -810,21 +837,24 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
|||
return x, MyCoolNamedTuple(a="hello")
|
||||
|
||||
second_script_module = torch.jit.script(Foo())
|
||||
second_saved_module = io.BytesIO()
|
||||
torch.jit.save_jit_module_to_flatbuffer(second_script_module, second_saved_module)
|
||||
second_saved_module.seek(0)
|
||||
second_saved_module = script_module_to_buffer(second_script_module)
|
||||
|
||||
clear_class_registry()
|
||||
|
||||
self.assertEqual(
|
||||
first_script_module._c.qualified_name, second_script_module._c.qualified_name
|
||||
first_script_module._c.qualified_name,
|
||||
second_script_module._c.qualified_name,
|
||||
)
|
||||
|
||||
class ContainsBoth(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_module("second", torch.jit.jit_module_from_flatbuffer(second_saved_module))
|
||||
self.add_module("first", torch.jit.jit_module_from_flatbuffer(first_saved_module))
|
||||
self.add_module(
|
||||
"second", torch.jit.load(second_saved_module)
|
||||
)
|
||||
self.add_module(
|
||||
"first", torch.jit.load(first_saved_module)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x, named_tuple_1 = self.first(x)
|
||||
|
|
@ -832,10 +862,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
|||
return len(x + named_tuple_2.a) + named_tuple_1.a
|
||||
|
||||
sm = torch.jit.script(ContainsBoth())
|
||||
contains_both = io.BytesIO()
|
||||
torch.jit.save_jit_module_to_flatbuffer(sm, contains_both)
|
||||
contains_both.seek(0)
|
||||
sm = torch.jit.jit_module_from_flatbuffer(contains_both)
|
||||
contains_both = script_module_to_buffer(sm)
|
||||
sm = torch.jit.load(contains_both)
|
||||
|
||||
def test_save_load_using_pathlib(self):
|
||||
class MyMod(torch.jit.ScriptModule):
|
||||
|
|
@ -849,9 +877,9 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
|||
with TemporaryFileName() as fname:
|
||||
path = pathlib.Path(fname)
|
||||
torch.jit.save_jit_module_to_flatbuffer(m, path)
|
||||
m2 = torch.jit.jit_module_from_flatbuffer(path)
|
||||
m2 = torch.jit.load(path)
|
||||
|
||||
x = torch.tensor([1., 2., 3., 4.])
|
||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
self.assertTrue(torch.equal(m(x), m2(x)))
|
||||
|
||||
def test_save_namedtuple_input_only(self):
|
||||
|
|
@ -903,7 +931,9 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_module("submodule_a", Submodule())
|
||||
self.register_parameter("parameter_a", torch.nn.Parameter(torch.randn(4)))
|
||||
self.register_parameter(
|
||||
"parameter_a", torch.nn.Parameter(torch.randn(4))
|
||||
)
|
||||
self.register_buffer("buffer", torch.randn(4))
|
||||
self.t = torch.rand(4) # not buffer
|
||||
|
||||
|
|
@ -914,7 +944,9 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
|||
m_loaded = self.getExportImportCopy(torch.jit.script(m))
|
||||
|
||||
# Check submodules.
|
||||
self.assertEqual(len(list(m.named_modules())), len(list(m_loaded.named_modules())))
|
||||
self.assertEqual(
|
||||
len(list(m.named_modules())), len(list(m_loaded.named_modules()))
|
||||
)
|
||||
for m_s, loaded_s in zip(m.named_modules(), m_loaded.named_modules()):
|
||||
m_name, _ = m_s
|
||||
loaded_name, _ = loaded_s
|
||||
|
|
@ -926,7 +958,9 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
|||
self.assertEqual(m_p, loaded_p)
|
||||
|
||||
# Check buffers.
|
||||
self.assertEqual(len(list(m.named_buffers())), len(list(m_loaded.named_buffers())))
|
||||
self.assertEqual(
|
||||
len(list(m.named_buffers())), len(list(m_loaded.named_buffers()))
|
||||
)
|
||||
for m_b, loaded_b in zip(m.named_buffers(), m_loaded.named_buffers()):
|
||||
m_name, m_buffer = m_b
|
||||
loaded_name, loaded_buffer = loaded_b
|
||||
|
|
|
|||
|
|
@ -540,6 +540,7 @@ mobile::Module _load_for_mobile(
|
|||
std::istream& in,
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
in.seekg(0, in.beg);
|
||||
auto format = getFileFormat(in);
|
||||
switch (format) {
|
||||
case FileFormat::ZipFileFormat: {
|
||||
|
|
|
|||
|
|
@ -797,9 +797,13 @@ void save_mobile_module_to(
|
|||
const ExtraFilesMap& extra_files,
|
||||
bool save_mobile_debug_info,
|
||||
const std::function<size_t(const void*, size_t)>& writer_func) {
|
||||
ExtraFilesMap jitFiles;
|
||||
CompilationOptions options = getOptionsFromGlobal();
|
||||
std::vector<IValue> constants;
|
||||
jitModuleToPythonCodeAndConstants(module, &jitFiles, &constants);
|
||||
mobile::Module mod = jitModuleToMobile(module, options);
|
||||
auto buffer = save_mobile_module_to_bytes(mod, extra_files);
|
||||
auto buffer =
|
||||
save_mobile_module_to_bytes(mod, extra_files, jitFiles, constants);
|
||||
writer_func(reinterpret_cast<void*>(buffer.data()), buffer.size());
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -789,6 +789,14 @@ Module load_jit_module_from_file(
|
|||
std::move(std::get<0>(data)), std::get<1>(data), device);
|
||||
}
|
||||
|
||||
Module load_jit_module_from_stream(
|
||||
std::istream& in,
|
||||
c10::optional<at::Device> device) {
|
||||
auto data = get_stream_content(in);
|
||||
return parse_and_initialize_jit_module(
|
||||
std::move(std::get<0>(data)), std::get<1>(data), device);
|
||||
}
|
||||
|
||||
void save_jit_module(
|
||||
const Module& module,
|
||||
const std::string& filename,
|
||||
|
|
|
|||
|
|
@ -47,5 +47,9 @@ TORCH_API Module load_jit_module_from_file(
|
|||
const std::string& filename,
|
||||
c10::optional<at::Device> device = c10::nullopt);
|
||||
|
||||
TORCH_API Module load_jit_module_from_stream(
|
||||
std::istream& in,
|
||||
c10::optional<at::Device> device = c10::nullopt);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#endif
|
||||
#include <torch/csrc/jit/frontend/script_type_parser.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/mobile/file_format.h>
|
||||
#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
|
||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
#include <torch/csrc/jit/serialization/import_read.h>
|
||||
|
|
@ -18,6 +19,10 @@
|
|||
#include <torch/csrc/jit/serialization/source_range_serialization.h>
|
||||
#include <torch/csrc/jit/serialization/unpickler.h>
|
||||
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
||||
#endif
|
||||
|
||||
#include <caffe2/serialize/file_adapter.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <caffe2/serialize/istream_adapter.h>
|
||||
|
|
@ -290,9 +295,26 @@ Module import_ir_module(
|
|||
std::istream& in,
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
auto reader = torch::make_unique<PyTorchStreamReader>(&in);
|
||||
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
|
||||
return deserializer.deserialize(device, extra_files);
|
||||
in.seekg(0, in.beg);
|
||||
auto format = getFileFormat(in);
|
||||
switch (format) {
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
return load_jit_module_from_stream(in, device);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "Flatbuffer input file but the build hasn't enable flatbuffer")
|
||||
#endif
|
||||
}
|
||||
case FileFormat::ZipFileFormat: {
|
||||
auto reader = torch::make_unique<PyTorchStreamReader>(&in);
|
||||
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
|
||||
return deserializer.deserialize(device, extra_files);
|
||||
}
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false, "Unrecognized data format");
|
||||
}
|
||||
}
|
||||
|
||||
// For reading unified serialization format from torch.Package.
|
||||
|
|
@ -325,9 +347,25 @@ Module import_ir_module(
|
|||
const std::string& filename,
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
auto reader = torch::make_unique<PyTorchStreamReader>(filename);
|
||||
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
|
||||
return deserializer.deserialize(device, extra_files);
|
||||
auto format = getFileFormat(filename);
|
||||
switch (format) {
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
return load_jit_module_from_file(filename, device);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "Flatbuffer input file but the build hasn't enable flatbuffer")
|
||||
#endif
|
||||
}
|
||||
case FileFormat::ZipFileFormat: {
|
||||
auto reader = torch::make_unique<PyTorchStreamReader>(filename);
|
||||
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
|
||||
return deserializer.deserialize(device, extra_files);
|
||||
}
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false, "Unrecognized data format");
|
||||
}
|
||||
}
|
||||
|
||||
Module import_ir_module(
|
||||
|
|
@ -357,9 +395,27 @@ Module load(
|
|||
std::istream& in,
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
|
||||
auto module = load(std::move(rai), device, extra_files);
|
||||
return module;
|
||||
in.seekg(0, in.beg);
|
||||
auto format = getFileFormat(in);
|
||||
switch (format) {
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
return load_jit_module_from_stream(in, device);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "Flatbuffer input file but the build hasn't enable flatbuffer")
|
||||
#endif
|
||||
}
|
||||
case FileFormat::ZipFileFormat: {
|
||||
std::unique_ptr<IStreamAdapter> rai =
|
||||
std::make_unique<IStreamAdapter>(&in);
|
||||
auto module = load(std::move(rai), device, extra_files);
|
||||
return module;
|
||||
}
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false, "Unrecognized data format");
|
||||
}
|
||||
}
|
||||
|
||||
Module load(const std::string& filename, c10::optional<at::Device> device) {
|
||||
|
|
@ -371,9 +427,27 @@ Module load(
|
|||
const std::string& filename,
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
|
||||
auto module = load(std::move(rai), device, extra_files);
|
||||
return module;
|
||||
auto format = getFileFormat(filename);
|
||||
switch (format) {
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
return load_jit_module_from_file(filename, device);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "Flatbuffer input file but the build hasn't enable flatbuffer")
|
||||
#endif
|
||||
|
||||
case FileFormat::ZipFileFormat: {
|
||||
std::unique_ptr<FileAdapter> rai =
|
||||
std::make_unique<FileAdapter>(filename);
|
||||
auto module = load(std::move(rai), device, extra_files);
|
||||
return module;
|
||||
}
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false, "Unrecognized data format");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Module load(
|
||||
|
|
@ -387,8 +461,8 @@ Module load(
|
|||
std::shared_ptr<ReadAdapterInterface> rai,
|
||||
c10::optional<c10::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
// Verify that we're loading a zip archive and not a torch.save pickle archive
|
||||
// (marked by the 0x80 0x02 bytes at the start)
|
||||
// Verify that we're loading a zip archive and not a torch.save pickle
|
||||
// archive (marked by the 0x80 0x02 bytes at the start)
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
TORCH_CHECK(
|
||||
check_zip_file(rai),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user