mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: **Summary** This commit adds support for seralization and deserialization of `ScriptModules` that have been lowered to a specific backend. Nothing special was required to accomplish this, other than removing some code in `unpickler.cpp` that guarded against the deserialization of `Any` type objects. Now that lists and dicts are tagged with their types during serialization, this check is no longer necessary. **Test Plan** This commit adds a unit test for testing that a lowered module still produces the same results as Python and regular JIT after saving and loading. **Fixes** This pull request fixes part of https://github.com/pytorch/pytorch/issues/37841. Pull Request resolved: https://github.com/pytorch/pytorch/pull/38893 Differential Revision: D21825813 Pulled By: SplitInfinity fbshipit-source-id: 77a7b84504e0dddf14c89b3ed5dd6b438c086f66
111 lines
3.9 KiB
Python
111 lines
3.9 KiB
Python
from torch.testing._internal.jit_utils import JitTestCase
|
|
import io
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
import torch._C
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
|
|
if __name__ == '__main__':
|
|
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead.")
|
|
|
|
|
|
def to_test_backend(module, method_compile_spec):
|
|
return torch._C._jit_to_test_backend(module, {"forward": method_compile_spec})
|
|
|
|
|
|
def to_test_backend_multi(module, method_compile_spec):
|
|
return torch._C._jit_to_test_backend(module, method_compile_spec)
|
|
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
|
|
def forward(self, x, h):
|
|
return self.accum(x, h), self.sub_accum(x, h)
|
|
|
|
def accum(self, x, h):
|
|
return x + h
|
|
|
|
def sub_accum(self, x, h):
|
|
return x - h
|
|
|
|
|
|
class TestBackends(JitTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
# Create Python, JIT and backend versions of MyModule.
|
|
self.module = MyModule()
|
|
self.scripted_module = torch.jit.script(MyModule())
|
|
self.lowered_module = to_test_backend_multi(
|
|
self.scripted_module._c, {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}})
|
|
|
|
def compare_py_jit_backend(self, name, input):
|
|
"""
|
|
This is a helper function for comparing the outputs of self.module (Python), self.scripted_module (JIT)
|
|
and self.lowered_module (backend) when the method named 'name' is invoked using 'input'.
|
|
"""
|
|
# Get handles for Python, JIT and backend methods.
|
|
python_method = self.module.__getattribute__(name)
|
|
jit_method = self.scripted_module.__getattr__(name)
|
|
backend_method = self.lowered_module.__getattr__(name)
|
|
|
|
# Run methods.
|
|
python_output = python_method(input, input)
|
|
jit_output = jit_method(input, input)
|
|
backend_output = backend_method(input, input)
|
|
|
|
# The answers returned by Python, JIT and to_backend should all match.
|
|
self.assertEqual(python_output, backend_output)
|
|
self.assertEqual(jit_output, backend_output)
|
|
|
|
def test_simple(self):
|
|
"""
|
|
This is a simple test that compiles MyModule for the test backend and ensures it produces the correct
|
|
answers for each method.
|
|
"""
|
|
# Test execution with backend against Python and JIT.
|
|
input = torch.randn(5)
|
|
|
|
# Test all three module methods.
|
|
self.compare_py_jit_backend("accum", input)
|
|
self.compare_py_jit_backend("sub_accum", input)
|
|
self.compare_py_jit_backend("forward", input)
|
|
|
|
def test_save_load(self):
|
|
"""
|
|
This method tests that a lowered module till produces the same output as a Python module and ScriptModule after
|
|
saving and loading.
|
|
"""
|
|
# Save the lowered module.
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(self.lowered_module, buffer)
|
|
|
|
# Save the compile spec to compare against the version retrieved after loading.
|
|
pre_compile_spec = self.lowered_module.__getattr__("__method_compile_spec")
|
|
|
|
# Load the lowered module.
|
|
buffer.seek(0)
|
|
self.lowered_module = torch.jit.load(buffer)
|
|
|
|
# Get the compile spec after loading.
|
|
post_compile_spec = self.lowered_module.__getattr__("__method_compile_spec")
|
|
|
|
# Compile specs should match.
|
|
self.assertEqual(pre_compile_spec, post_compile_spec)
|
|
|
|
# Test execution with backend against Python and JIT.
|
|
input = torch.randn(5)
|
|
|
|
# Test all three module methods.
|
|
self.compare_py_jit_backend("accum", input)
|
|
self.compare_py_jit_backend("sub_accum", input)
|
|
self.compare_py_jit_backend("forward", input)
|