mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62030 Remove dtype tracking from Python Storage interface, remove all the different `<type>Storage` classes except for `ByteStorage`, and update serialization accordingly, while maintaining as much FC/BC as possible Fixes https://github.com/pytorch/pytorch/issues/47442 * **THE SERIALIZATION FORMAT IS FULLY FC/BC.** We worked very hard to make sure this is the case. We will probably want to break FC at some point to make the serialization structure of tensors make more sense, but not today. * There is now only a single torch.ByteStorage class. Methods like `Tensor.set_` no longer check that the dtype of storage is appropriate. * As we no longer know what dtype of a storage is, we've **removed** the size method from Storage, replacing it with nbytes. This is to help catch otherwise silent errors where you confuse number of elements with number of bytes. * `Storage._new_shared` takes a `nbytes` kwarg and will reject previous positional only calls. `Storage._new_with_file` and `_set_from_file` require explicit element size arguments. * It's no longer possible to convert storages to different types using the float/double/etc methods. Instead, do the conversion using a tensor. * It's no longer possible to allocate a typed storage directly using FloatStorage/DoubleStorage/etc constructors. Instead, construct a tensor and extract its storage. The classes still exist but they are used purely for unpickling. * The preexisting serialization format stores dtype with storage, and in fact this dtype is used to determine the dtype of the tensor overall. To accommodate this case, we introduce a new TypedStorage concept that exists only during unpickling time which is used to temporarily store the dtype so we can construct a tensor. **If you overrode the handling of pickling/unpickling, you MUST add handling for TypedStorage** or your serialization code will degrade to standard file-based serialization. Original pull request: https://github.com/pytorch/pytorch/pull/59671 Reviewed By: soulitzer, ngimel Differential Revision: D29466819 Pulled By: ezyang fbshipit-source-id: 4a14e5d3c2b08e06e558683d97f7378a3180b00e
769 lines
27 KiB
Python
769 lines
27 KiB
Python
from io import BytesIO
|
|
from textwrap import dedent
|
|
from unittest import skipIf
|
|
|
|
import torch
|
|
from torch.package import PackageExporter, PackageImporter
|
|
from torch.testing._internal.common_utils import (
|
|
IS_FBCODE,
|
|
IS_SANDCASTLE,
|
|
run_tests,
|
|
)
|
|
|
|
try:
|
|
from .common import PackageTestCase
|
|
except ImportError:
|
|
# Support the case where we run this file directly.
|
|
from common import PackageTestCase
|
|
|
|
try:
|
|
from torchvision.models import resnet18
|
|
|
|
HAS_TORCHVISION = True
|
|
except ImportError:
|
|
HAS_TORCHVISION = False
|
|
skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision")
|
|
|
|
|
|
class TestPackageScript(PackageTestCase):
|
|
"""Tests for compatibility with TorchScript."""
|
|
|
|
def test_package_interface(self):
|
|
"""Packaging an interface class should work correctly."""
|
|
|
|
import package_a.fake_interface as fake
|
|
|
|
uses_interface = fake.UsesInterface()
|
|
scripted = torch.jit.script(uses_interface)
|
|
scripted.proxy_mod = torch.jit.script(fake.NewModule())
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as pe:
|
|
pe.intern("**")
|
|
pe.save_pickle("model", "model.pkl", uses_interface)
|
|
buffer.seek(0)
|
|
|
|
package_importer = PackageImporter(buffer)
|
|
loaded = package_importer.load_pickle("model", "model.pkl")
|
|
|
|
scripted_loaded = torch.jit.script(loaded)
|
|
scripted_loaded.proxy_mod = torch.jit.script(fake.NewModule())
|
|
|
|
input = torch.tensor(1)
|
|
|
|
self.assertEqual(scripted(input), scripted_loaded(input))
|
|
|
|
def test_different_package_interface(self):
|
|
"""Test a case where the interface defined in the package is
|
|
different than the one defined in the loading environment, to make
|
|
sure TorchScript can distinguish between the two.
|
|
"""
|
|
# Import one version of the interface
|
|
import package_a.fake_interface as fake
|
|
|
|
# Simulate a package that contains a different version of the
|
|
# interface, with the exact same name.
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as pe:
|
|
pe.save_source_string(
|
|
fake.__name__,
|
|
dedent(
|
|
"""\
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
@torch.jit.interface
|
|
class ModuleInterface(torch.nn.Module):
|
|
def one(self, inp1: Tensor) -> Tensor:
|
|
pass
|
|
|
|
class ImplementsInterface(torch.nn.Module):
|
|
def one(self, inp1: Tensor) -> Tensor:
|
|
return inp1 + 1
|
|
|
|
class UsesInterface(torch.nn.Module):
|
|
proxy_mod: ModuleInterface
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.proxy_mod = ImplementsInterface()
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return self.proxy_mod.one(input)
|
|
"""
|
|
),
|
|
)
|
|
buffer.seek(0)
|
|
|
|
package_importer = PackageImporter(buffer)
|
|
diff_fake = package_importer.import_module(fake.__name__)
|
|
# We should be able to script successfully.
|
|
torch.jit.script(diff_fake.UsesInterface())
|
|
|
|
def test_package_script_class(self):
|
|
import package_a.fake_script_class as fake
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as pe:
|
|
pe.save_module(fake.__name__)
|
|
buffer.seek(0)
|
|
|
|
package_importer = PackageImporter(buffer)
|
|
loaded = package_importer.import_module(fake.__name__)
|
|
|
|
input = torch.tensor(1)
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
fake.uses_script_class(input), loaded.uses_script_class(input)
|
|
)
|
|
)
|
|
|
|
def test_different_package_script_class(self):
|
|
"""Test a case where the script class defined in the package is
|
|
different than the one defined in the loading environment, to make
|
|
sure TorchScript can distinguish between the two.
|
|
"""
|
|
import package_a.fake_script_class as fake
|
|
|
|
# Simulate a package that contains a different version of the
|
|
# script class ,with the attribute `bar` instead of `foo`
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as pe2:
|
|
pe2.save_source_string(
|
|
fake.__name__,
|
|
dedent(
|
|
"""\
|
|
import torch
|
|
|
|
@torch.jit.script
|
|
class MyScriptClass:
|
|
def __init__(self, x):
|
|
self.bar = x
|
|
"""
|
|
),
|
|
)
|
|
buffer.seek(0)
|
|
|
|
package_importer = PackageImporter(buffer)
|
|
diff_fake = package_importer.import_module(fake.__name__)
|
|
input = torch.rand(2, 3)
|
|
loaded_script_class = diff_fake.MyScriptClass(input)
|
|
orig_script_class = fake.MyScriptClass(input)
|
|
self.assertEqual(loaded_script_class.bar, orig_script_class.foo)
|
|
|
|
def test_save_scriptmodule(self):
|
|
"""
|
|
Test basic saving of ScriptModule.
|
|
"""
|
|
from package_a.test_module import ModWithTensor
|
|
|
|
scripted_mod = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as e:
|
|
e.save_pickle("res", "mod.pkl", scripted_mod)
|
|
|
|
buffer.seek(0)
|
|
importer = PackageImporter(buffer)
|
|
loaded_mod = importer.load_pickle("res", "mod.pkl", map_location="cpu")
|
|
input = torch.rand(1, 2, 3)
|
|
self.assertEqual(loaded_mod(input), scripted_mod(input))
|
|
|
|
@skipIf(
|
|
IS_FBCODE or IS_SANDCASTLE,
|
|
"Tests that use temporary files are disabled in fbcode",
|
|
)
|
|
def test_save_scriptmodule_file(self):
|
|
"""
|
|
Test basic saving of ScriptModule in file.
|
|
"""
|
|
from package_a.test_module import ModWithTensor
|
|
|
|
scripted_mod = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
|
|
|
|
filename = self.temp()
|
|
with PackageExporter(filename) as e:
|
|
e.save_pickle("res", "mod.pkl", scripted_mod)
|
|
|
|
importer = PackageImporter(filename)
|
|
loaded_mod = importer.load_pickle("res", "mod.pkl")
|
|
input = torch.rand(1, 2, 3)
|
|
self.assertEqual(loaded_mod(input), scripted_mod(input))
|
|
|
|
def test_save_scriptmodule_with_submods(self):
|
|
"""
|
|
Test basic saving of ScriptModule with submodule.
|
|
"""
|
|
from package_a.test_module import ModWithSubmod, ModWithTensor
|
|
|
|
scripted_mod = torch.jit.script(
|
|
ModWithSubmod(ModWithTensor(torch.rand(1, 2, 3)))
|
|
)
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as e:
|
|
e.save_pickle("res", "mod.pkl", scripted_mod)
|
|
|
|
buffer.seek(0)
|
|
importer = PackageImporter(buffer)
|
|
loaded_mod = importer.load_pickle("res", "mod.pkl", map_location="cpu")
|
|
input = torch.rand(1, 2, 3)
|
|
self.assertEqual(loaded_mod(input), scripted_mod(input))
|
|
|
|
def test_save_scriptmodules_submod_redefinition(self):
|
|
"""
|
|
Test to verify saving multiple ScriptModules with same top module
|
|
but different submodules works. Submodule is redefined to between
|
|
the defintion of the top module to check that the different concrete
|
|
types of the modules are thoroughly recognized by serializaiton code.
|
|
"""
|
|
|
|
class Submod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, input: str):
|
|
input = input + "_submod"
|
|
return input
|
|
|
|
class TopMod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.modB = Submod()
|
|
|
|
def forward(self, input: str):
|
|
return self.modB(input)
|
|
|
|
scripted_mod_0 = torch.jit.script(TopMod())
|
|
|
|
# redefinition is intentional, change single inner string
|
|
# string attribute, should trigger new module type
|
|
class Submod(torch.nn.Module): # noqa: F811
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, input: str):
|
|
input = input + "_submod(changed)"
|
|
return input
|
|
|
|
scripted_mod_1 = torch.jit.script(TopMod())
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as e:
|
|
e.save_pickle("res", "mod1.pkl", scripted_mod_0)
|
|
e.save_pickle("res", "mod2.pkl", scripted_mod_1)
|
|
|
|
buffer.seek(0)
|
|
importer = PackageImporter(buffer)
|
|
loaded_mod_0 = importer.load_pickle("res", "mod1.pkl")
|
|
loaded_mod_1 = importer.load_pickle("res", "mod2.pkl")
|
|
self.assertEqual(loaded_mod_0("input"), scripted_mod_0("input"))
|
|
self.assertEqual(loaded_mod_1("input"), scripted_mod_1("input"))
|
|
self.assertNotEqual(loaded_mod_0("input"), loaded_mod_1("input"))
|
|
|
|
def test_save_independent_scriptmodules(self):
|
|
"""
|
|
Test to verify saving multiple ScriptModules with completely
|
|
separate code works.
|
|
"""
|
|
from package_a.test_module import ModWithTensor, SimpleTest
|
|
|
|
scripted_mod_0 = torch.jit.script(SimpleTest())
|
|
scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as e:
|
|
e.save_pickle("res", "mod1.pkl", scripted_mod_0)
|
|
e.save_pickle("res", "mod2.pkl", scripted_mod_1)
|
|
|
|
buffer.seek(0)
|
|
importer = PackageImporter(buffer)
|
|
loaded_mod_0 = importer.load_pickle("res", "mod1.pkl")
|
|
loaded_mod_1 = importer.load_pickle("res", "mod2.pkl")
|
|
input = torch.rand(1, 2, 3)
|
|
self.assertEqual(loaded_mod_0(input), scripted_mod_0(input))
|
|
self.assertEqual(loaded_mod_1(input), scripted_mod_1(input))
|
|
|
|
def test_save_repeat_scriptmodules(self):
|
|
"""
|
|
Test to verify saving multiple different modules and
|
|
repeats of same scriptmodule in package works. Also tests that
|
|
PyTorchStreamReader isn't having code hidden from
|
|
PyTorchStreamWriter writing ScriptModule code files multiple times.
|
|
"""
|
|
from package_a.test_module import (
|
|
ModWithSubmodAndTensor,
|
|
ModWithTensor,
|
|
SimpleTest,
|
|
)
|
|
|
|
scripted_mod_0 = torch.jit.script(SimpleTest())
|
|
scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
|
|
scripted_mod_2 = torch.jit.script(
|
|
ModWithSubmodAndTensor(
|
|
torch.rand(1, 2, 3), ModWithTensor(torch.rand(1, 2, 3))
|
|
)
|
|
)
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as e:
|
|
e.save_pickle("res", "mod0.pkl", scripted_mod_0)
|
|
e.save_pickle("res", "mod1.pkl", scripted_mod_1)
|
|
e.save_pickle("res", "mod2.pkl", scripted_mod_0)
|
|
e.save_pickle("res", "mod3.pkl", scripted_mod_1)
|
|
e.save_pickle("res", "mod4.pkl", scripted_mod_2)
|
|
|
|
buffer.seek(0)
|
|
importer = PackageImporter(buffer)
|
|
loaded_mod_0 = importer.load_pickle("res", "mod0.pkl")
|
|
loaded_mod_1 = importer.load_pickle("res", "mod3.pkl")
|
|
loaded_mod_2 = importer.load_pickle("res", "mod4.pkl")
|
|
input = torch.rand(1, 2, 3)
|
|
self.assertEqual(loaded_mod_0(input), scripted_mod_0(input))
|
|
self.assertEqual(loaded_mod_1(input), scripted_mod_1(input))
|
|
self.assertEqual(loaded_mod_2(input), scripted_mod_2(input))
|
|
|
|
def test_scriptmodules_repeat_save(self):
|
|
"""
|
|
Test to verify saving and loading same ScriptModule object works
|
|
across multiple packages.
|
|
"""
|
|
from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor
|
|
|
|
scripted_mod_0 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
|
|
scripted_mod_1 = torch.jit.script(
|
|
ModWithSubmodAndTensor(
|
|
torch.rand(1, 2, 3), ModWithTensor(torch.rand(1, 2, 3))
|
|
)
|
|
)
|
|
|
|
buffer_0 = BytesIO()
|
|
with PackageExporter(buffer_0) as e:
|
|
e.save_pickle("res", "mod1.pkl", scripted_mod_0)
|
|
|
|
buffer_0.seek(0)
|
|
importer_0 = PackageImporter(buffer_0)
|
|
loaded_module_0 = importer_0.load_pickle("res", "mod1.pkl")
|
|
|
|
buffer_1 = BytesIO()
|
|
with PackageExporter(buffer_1) as e:
|
|
e.save_pickle("res", "mod1.pkl", scripted_mod_1)
|
|
e.save_pickle("res", "mod2.pkl", loaded_module_0)
|
|
|
|
buffer_1.seek(0)
|
|
importer_1 = PackageImporter(buffer_1)
|
|
loaded_module_1 = importer_1.load_pickle("res", "mod1.pkl")
|
|
reloaded_module_0 = importer_1.load_pickle("res", "mod2.pkl")
|
|
|
|
input = torch.rand(1, 2, 3)
|
|
self.assertEqual(loaded_module_0(input), scripted_mod_0(input))
|
|
self.assertEqual(loaded_module_0(input), reloaded_module_0(input))
|
|
self.assertEqual(loaded_module_1(input), scripted_mod_1(input))
|
|
|
|
@skipIfNoTorchVision
|
|
def test_save_scriptmodule_only_necessary_code(self):
|
|
"""
|
|
Test to verify when saving multiple packages with same CU
|
|
that packages don't include unnecessary torchscript code files.
|
|
The TorchVision code should only be saved in the package that
|
|
relies on it.
|
|
"""
|
|
from package_a.test_module import ModWithTensor
|
|
|
|
class ModWithTorchVision(torch.nn.Module):
|
|
def __init__(self, name: str):
|
|
super().__init__()
|
|
self.tvmod = resnet18()
|
|
|
|
def forward(self, input):
|
|
return input * 4
|
|
|
|
scripted_mod_0 = torch.jit.script(ModWithTorchVision("foo"))
|
|
scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
|
|
|
|
buffer_0 = BytesIO()
|
|
with PackageExporter(buffer_0) as e:
|
|
e.save_pickle("res", "mod1.pkl", scripted_mod_0)
|
|
|
|
buffer_0.seek(0)
|
|
importer_0 = importer = PackageImporter(buffer_0)
|
|
|
|
buffer_1 = BytesIO()
|
|
with PackageExporter(buffer_1) as e:
|
|
e.save_pickle("res", "mod1.pkl", scripted_mod_1)
|
|
|
|
buffer_1.seek(0)
|
|
importer_1 = PackageImporter(buffer_1)
|
|
|
|
self.assertTrue("torchvision" in str(importer_0.file_structure()))
|
|
self.assertFalse("torchvision" in str(importer_1.file_structure()))
|
|
|
|
def test_save_scriptmodules_in_container(self):
|
|
"""
|
|
Test saving of ScriptModules inside of container. Checks that relations
|
|
between shared modules are upheld.
|
|
"""
|
|
from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor
|
|
|
|
scripted_mod_a = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
|
|
scripted_mod_b = torch.jit.script(
|
|
ModWithSubmodAndTensor(torch.rand(1, 2, 3), scripted_mod_a)
|
|
)
|
|
script_mods_list = [scripted_mod_a, scripted_mod_b]
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as e:
|
|
e.save_pickle("res", "list.pkl", script_mods_list)
|
|
|
|
buffer.seek(0)
|
|
importer = PackageImporter(buffer)
|
|
loaded_mod_list = importer.load_pickle("res", "list.pkl")
|
|
input = torch.rand(1, 2, 3)
|
|
self.assertEqual(loaded_mod_list[0](input), scripted_mod_a(input))
|
|
self.assertEqual(loaded_mod_list[1](input), scripted_mod_b(input))
|
|
|
|
def test_save_eager_mods_sharing_scriptmodule(self):
|
|
"""
|
|
Test saving of single ScriptModule shared by multiple
|
|
eager modules (ScriptModule should be saved just once
|
|
even though is contained in multiple pickles).
|
|
"""
|
|
from package_a.test_module import ModWithSubmod, SimpleTest
|
|
|
|
scripted_mod = torch.jit.script(SimpleTest())
|
|
|
|
mod1 = ModWithSubmod(scripted_mod)
|
|
mod2 = ModWithSubmod(scripted_mod)
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as e:
|
|
e.intern("**")
|
|
e.save_pickle("res", "mod1.pkl", mod1)
|
|
e.save_pickle("res", "mod2.pkl", mod2)
|
|
|
|
buffer.seek(0)
|
|
importer = PackageImporter(buffer)
|
|
file_structure = importer.file_structure()
|
|
self.assertTrue(file_structure.has_file(".data/ts_code/0"))
|
|
self.assertFalse(file_structure.has_file(".data/ts_code/1"))
|
|
|
|
def test_load_shared_scriptmodules(self):
|
|
"""
|
|
Test loading of single ScriptModule shared by multiple eager
|
|
modules in single pickle (ScriptModule objects should be the same).
|
|
"""
|
|
from package_a.test_module import (
|
|
ModWithMultipleSubmods,
|
|
ModWithSubmod,
|
|
SimpleTest,
|
|
)
|
|
|
|
scripted_mod = torch.jit.script(SimpleTest())
|
|
|
|
mod1 = ModWithSubmod(scripted_mod)
|
|
mod2 = ModWithSubmod(scripted_mod)
|
|
|
|
mod_parent = ModWithMultipleSubmods(mod1, mod2)
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as e:
|
|
e.intern("**")
|
|
e.save_pickle("res", "mod.pkl", mod_parent)
|
|
|
|
buffer.seek(0)
|
|
importer = PackageImporter(buffer)
|
|
|
|
loaded_mod = importer.load_pickle("res", "mod.pkl")
|
|
self.assertTrue(
|
|
id(loaded_mod.mod1.script_mod) == id(loaded_mod.mod2.script_mod)
|
|
)
|
|
|
|
def test_save_shared_tensors(self):
|
|
"""
|
|
Test tensors shared across eager and ScriptModules are serialized once.
|
|
"""
|
|
from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor
|
|
|
|
shared_tensor = torch.rand(2, 3, 4)
|
|
scripted_mod = torch.jit.script(ModWithTensor(shared_tensor))
|
|
|
|
mod1 = ModWithSubmodAndTensor(shared_tensor, scripted_mod)
|
|
mod2 = ModWithSubmodAndTensor(shared_tensor, scripted_mod)
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as e:
|
|
e.intern("**")
|
|
e.save_pickle("res", "tensor", shared_tensor)
|
|
e.save_pickle("res", "mod1.pkl", mod1)
|
|
e.save_pickle("res", "mod2.pkl", mod2)
|
|
|
|
buffer.seek(0)
|
|
importer = PackageImporter(buffer)
|
|
loaded_mod_1 = importer.load_pickle("res", "mod1.pkl")
|
|
|
|
# assert that there is only one storage stored in package
|
|
file_structure = importer.file_structure(include=".data/*.storage")
|
|
self.assertTrue(len(file_structure.children[".data"].children) == 1)
|
|
|
|
input = torch.rand(2, 3, 4)
|
|
self.assertEqual(loaded_mod_1(input), mod1(input))
|
|
|
|
def test_load_shared_tensors(self):
|
|
"""
|
|
Test tensors shared across eager and ScriptModules on load
|
|
are the same.
|
|
"""
|
|
from package_a.test_module import (
|
|
ModWithTensor,
|
|
ModWithTwoSubmodsAndTensor,
|
|
)
|
|
|
|
shared_tensor = torch.ones(3, 3)
|
|
|
|
scripted_mod_0 = torch.jit.script(ModWithTensor(shared_tensor))
|
|
scripted_mod_1 = torch.jit.script(ModWithTensor(shared_tensor))
|
|
|
|
mod1 = ModWithTwoSubmodsAndTensor(shared_tensor, scripted_mod_0, scripted_mod_1)
|
|
|
|
self.assertEqual(
|
|
shared_tensor.storage()._cdata,
|
|
scripted_mod_0.tensor.storage()._cdata,
|
|
)
|
|
self.assertEqual(
|
|
shared_tensor.storage()._cdata,
|
|
scripted_mod_1.tensor.storage()._cdata,
|
|
)
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as e:
|
|
e.intern("**")
|
|
e.save_pickle("res", "mod1.pkl", mod1)
|
|
|
|
buffer.seek(0)
|
|
importer = PackageImporter(buffer)
|
|
loaded_mod_1 = importer.load_pickle("res", "mod1.pkl")
|
|
|
|
self.assertEqual(
|
|
loaded_mod_1.tensor.storage()._cdata,
|
|
loaded_mod_1.sub_mod_0.tensor.storage()._cdata,
|
|
)
|
|
self.assertEqual(
|
|
loaded_mod_1.tensor.storage()._cdata,
|
|
loaded_mod_1.sub_mod_1.tensor.storage()._cdata,
|
|
)
|
|
|
|
loaded_mod_1.tensor.add_(torch.ones(3, 3))
|
|
|
|
self.assertTrue(
|
|
torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_0.tensor)
|
|
)
|
|
self.assertTrue(
|
|
torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_1.tensor)
|
|
)
|
|
|
|
def test_load_shared_tensors_repackaged(self):
|
|
"""
|
|
Test tensors shared across eager and ScriptModules on load
|
|
are the same across multiple package saves and loads. This is
|
|
an important test because not all of the tensor information is restored
|
|
in python between packages. The python identity is not maintained, but
|
|
the backing cpp TensorImpl is. We load/save storages based off of this
|
|
cpp TensorImpl and not the python identity.
|
|
"""
|
|
from package_a.test_module import (
|
|
ModWithTensor,
|
|
ModWithTwoSubmodsAndTensor,
|
|
)
|
|
|
|
shared_tensor = torch.ones(3, 3)
|
|
|
|
scripted_mod_0 = torch.jit.script(ModWithTensor(shared_tensor))
|
|
scripted_mod_1 = torch.jit.script(ModWithTensor(shared_tensor))
|
|
|
|
mod1 = ModWithTwoSubmodsAndTensor(shared_tensor, scripted_mod_0, scripted_mod_1)
|
|
|
|
buffer_0 = BytesIO()
|
|
with PackageExporter(buffer_0) as e:
|
|
e.intern("**")
|
|
e.save_pickle("res", "mod1.pkl", mod1)
|
|
|
|
buffer_0.seek(0)
|
|
importer_0 = PackageImporter(buffer_0)
|
|
loaded_mod_0 = importer_0.load_pickle("res", "mod1.pkl")
|
|
|
|
buffer_1 = BytesIO()
|
|
with PackageExporter(buffer_1, importer=importer_0) as e:
|
|
e.intern("**")
|
|
e.save_pickle("res", "mod1.pkl", loaded_mod_0)
|
|
|
|
buffer_1.seek(0)
|
|
importer = PackageImporter(buffer_1)
|
|
loaded_mod_1 = importer.load_pickle("res", "mod1.pkl")
|
|
|
|
self.assertEqual(
|
|
loaded_mod_1.tensor.storage()._cdata,
|
|
loaded_mod_1.sub_mod_0.tensor.storage()._cdata,
|
|
)
|
|
self.assertEqual(
|
|
loaded_mod_1.tensor.storage()._cdata,
|
|
loaded_mod_1.sub_mod_1.tensor.storage()._cdata,
|
|
)
|
|
|
|
loaded_mod_1.tensor.add_(
|
|
torch.ones(3, 3)
|
|
) # all tensors should reflect this change
|
|
|
|
self.assertTrue(
|
|
torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_0.tensor)
|
|
)
|
|
self.assertTrue(
|
|
torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_1.tensor)
|
|
)
|
|
|
|
def test_saving_and_scripting_packaged_mod(self):
|
|
"""
|
|
Test scripting a module loaded from a package
|
|
and saving it in a new package as a script object.
|
|
"""
|
|
from package_a.test_module import SimpleTest
|
|
|
|
orig_mod = SimpleTest()
|
|
|
|
buffer_0 = BytesIO()
|
|
with PackageExporter(buffer_0) as e:
|
|
e.intern("**")
|
|
e.save_pickle("model", "model.pkl", orig_mod)
|
|
|
|
buffer_0.seek(0)
|
|
importer_0 = PackageImporter(buffer_0)
|
|
loaded_mod = importer_0.load_pickle("model", "model.pkl")
|
|
|
|
input = torch.rand(2, 3)
|
|
self.assertEqual(loaded_mod(input), orig_mod(input))
|
|
|
|
scripted_mod = torch.jit.script(loaded_mod)
|
|
|
|
buffer_1 = BytesIO()
|
|
with PackageExporter(buffer_1, importer=importer_0) as e:
|
|
e.intern("**")
|
|
e.save_pickle("res", "scripted_mod.pkl", scripted_mod)
|
|
|
|
buffer_1.seek(0)
|
|
importer_1 = PackageImporter(buffer_1)
|
|
loaded_mod_scripted = importer_1.load_pickle("res", "scripted_mod.pkl")
|
|
|
|
self.assertEqual(loaded_mod_scripted(input), orig_mod(input))
|
|
|
|
def test_mixing_packaged_and_inline_modules(self):
|
|
"""
|
|
Test saving inline and imported modules in same package with
|
|
independent code.
|
|
"""
|
|
|
|
class InlineMod(torch.nn.Module):
|
|
def __init__(self, name: str):
|
|
super().__init__()
|
|
self.name = name
|
|
self.tensor = torch.rand(1, 2, 3)
|
|
|
|
def forward(self, input: str):
|
|
input = input + "_modInline:" + self.name
|
|
return input, (self.tensor * 4)
|
|
|
|
inline_mod = InlineMod("inline")
|
|
scripted_inline = torch.jit.script(inline_mod)
|
|
|
|
from package_a.test_module import SimpleTest
|
|
|
|
imported_mod = SimpleTest()
|
|
scripted_imported = torch.jit.script(imported_mod)
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as e:
|
|
e.save_pickle("model", "inline.pkl", scripted_inline)
|
|
e.save_pickle("model", "imported.pkl", scripted_imported)
|
|
|
|
buffer.seek(0)
|
|
importer = PackageImporter(buffer)
|
|
loaded_inline = importer.load_pickle("model", "inline.pkl")
|
|
loaded_imported = importer.load_pickle("model", "imported.pkl")
|
|
|
|
input = torch.rand(2, 3)
|
|
self.assertEqual(loaded_imported(input), imported_mod(input))
|
|
self.assertEqual(loaded_inline("input"), inline_mod("input"))
|
|
|
|
@skipIfNoTorchVision
|
|
def test_mixing_packaged_and_inline_modules_shared_code(self):
|
|
"""
|
|
Test saving inline and imported modules in same package that
|
|
share code.
|
|
"""
|
|
|
|
class TorchVisionTestInline(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.tvmod = resnet18()
|
|
|
|
def forward(self, x):
|
|
x = a_non_torch_leaf(x, x)
|
|
return torch.relu(x + 3.0)
|
|
|
|
def a_non_torch_leaf(a, b):
|
|
return a + b
|
|
|
|
inline_mod = TorchVisionTestInline()
|
|
scripted_inline = torch.jit.script(inline_mod)
|
|
|
|
from package_c.test_module import TorchVisionTest
|
|
|
|
imported_mod = TorchVisionTest()
|
|
scripted_imported = torch.jit.script(imported_mod)
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as e:
|
|
e.save_pickle("model", "inline.pkl", scripted_inline)
|
|
e.save_pickle("model", "imported.pkl", scripted_imported)
|
|
|
|
buffer.seek(0)
|
|
importer = PackageImporter(buffer)
|
|
loaded_inline = importer.load_pickle("model", "inline.pkl")
|
|
loaded_imported = importer.load_pickle("model", "imported.pkl")
|
|
|
|
input = torch.rand(2, 3)
|
|
self.assertEqual(loaded_imported(input), imported_mod(input))
|
|
self.assertEqual(loaded_inline(input), inline_mod(input))
|
|
|
|
def test_tensor_sharing_pickle(self):
|
|
"""Test that saving a ScriptModule and a separately saving a tensor
|
|
object causes no issues.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.foo = torch.ones(2, 3)
|
|
|
|
def forward(self):
|
|
return self.foo
|
|
|
|
scripted_m = torch.jit.script(M())
|
|
original_tensor = torch.ones(0)
|
|
|
|
f = BytesIO()
|
|
with torch.package.PackageExporter(f) as exporter:
|
|
exporter.save_pickle("model", "model.pkl", scripted_m)
|
|
exporter.save_pickle("model", "input.pkl", original_tensor)
|
|
|
|
f.seek(0)
|
|
# Should be able to load correctly
|
|
importer = PackageImporter(f)
|
|
loaded_m = importer.load_pickle("model", "model.pkl")
|
|
loaded_tensor = importer.load_pickle("model", "input.pkl")
|
|
|
|
self.assertEqual(scripted_m.foo, loaded_m.foo)
|
|
self.assertEqual(original_tensor, loaded_tensor)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|