mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54284 In order to bring mobile deployment, via lite interpreter, on feature parity with JIT, with respect model level debug information we must make model level debug information available to mobile runtime. At the moment, model level debug information is stored in SourceRange which associates node's of graph to where the come from in original python source code. This information is serialized as part of debug_pkl and deserialized when JIT loads the model and reads the model code. On lite interpreter, we do not have access to all the functionality of JIT and hence we cannot load model in the same way as JIT, by reading code, constructing module hierarchy and graph corresponding module methods etc. Instead in, lite interpreter, only bytecode corresonding to the compiled graph, Code, is saved. Thus in order to annotate OPs in the bytecode with equivalent SourceRange information we do the following: 1. During model serialization, we create a unique tag for each source range of the model. 2. Create a map of <SourceRange, tag> 3. During debug_pkl serialization we save tag along with SourceRange, on top of byte offset. 4. During bytecode generation, the methods of the top module are lowered. During this process methods are inlined. In the inlined graph, when the node of a graph is lowered to bytecode, we query node's source range and look it up against the map. 5. Resulting source range tag is serialized in module_debug_info. 6. During model deserialization, we read all the debug_pkl records in the archieve and create a map of <tag, SourceRange> 7. This map can be used to find source code information. During mobile runtime: 1. We read all the debug_pkl records and create <tag=debug_handle, SourceRange> map. 1.1 This map, MobileDebugInfo, is a member of mobile Module. 2. Interpreter catches appropriate exceptions and sets the thread local debug handle and rethrows the exception. 3. In Function's run method we catch exception and query current debug handle where the exception happened. 4. Query MobileDebugInfo with debug handle to retrieve source range and augment error with source range info. This information is still incomplete as it does not contain entire callstack. In the following diffs we will serialize InlinedCallStack directly. Note that compilation is gated by SYMBOLICATE_MOBILE_DEBUG_HANDLE macro, so that mobile builds can avoid building MobileDebugInfo, source range and source range pickler/unpickler. Later we will add path where, if building without debug support stack trace will contain only debug handles. They can be symbolicated later. Test Plan: Ported bunch of source range tests from test_jit.py. Added on more test in test_lite_interpreter.py Imported from OSS Reviewed By: raziel Differential Revision: D27174722 fbshipit-source-id: a7b7c6088ce16dec37e823c7fefa4f0b61047e12
458 lines
18 KiB
Python
458 lines
18 KiB
Python
import torch
|
|
import torch.utils.bundled_inputs
|
|
from torch.utils.mobile_optimizer import optimize_for_mobile
|
|
import io
|
|
from typing import Dict, List, NamedTuple
|
|
from collections import namedtuple
|
|
import inspect
|
|
|
|
from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
|
|
class TestLiteScriptModule(TestCase):
|
|
|
|
def getScriptExportImportCopy(self, m, save_mobile_debug_info=True, also_test_file=False):
|
|
m_scripted = torch.jit.script(m)
|
|
|
|
if not also_test_file:
|
|
buffer = io.BytesIO(m_scripted._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=save_mobile_debug_info))
|
|
buffer.seek(0)
|
|
mobile_module = _load_for_lite_interpreter(buffer)
|
|
return mobile_module
|
|
|
|
with TemporaryFileName() as fname:
|
|
m_scripted._save_for_lite_interpreter(fname, _save_mobile_debug_info=save_mobile_debug_info)
|
|
mobile_module = _load_for_lite_interpreter(fname)
|
|
return mobile_module
|
|
|
|
def test_load_mobile_module(self):
|
|
class MyTestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyTestModule, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x + 10
|
|
|
|
input = torch.tensor([1])
|
|
|
|
script_module = torch.jit.script(MyTestModule())
|
|
script_module_result = script_module(input)
|
|
|
|
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
|
|
buffer.seek(0)
|
|
mobile_module = _load_for_lite_interpreter(buffer)
|
|
|
|
mobile_module_result = mobile_module(input)
|
|
torch.testing.assert_allclose(script_module_result, mobile_module_result)
|
|
|
|
mobile_module_forward_result = mobile_module.forward(input)
|
|
torch.testing.assert_allclose(script_module_result, mobile_module_forward_result)
|
|
|
|
mobile_module_run_method_result = mobile_module.run_method("forward", input)
|
|
torch.testing.assert_allclose(script_module_result, mobile_module_run_method_result)
|
|
|
|
def test_save_mobile_module_with_debug_info_with_trace(self):
|
|
class A(torch.nn.Module):
|
|
def __init__(self):
|
|
super(A, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
class B(torch.nn.Module):
|
|
def __init__(self):
|
|
super(B, self).__init__()
|
|
self.A0 = A()
|
|
self.A1 = A()
|
|
|
|
def forward(self, x):
|
|
return self.A0(x) + self.A1(x)
|
|
|
|
input = torch.tensor([5])
|
|
trace_module = torch.jit.trace(B(), input)
|
|
exported_module = trace_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True)
|
|
|
|
assert(b"mobile_debug.pkl" in exported_module)
|
|
assert(b"module_debug_info" in exported_module)
|
|
assert(b"top(B).forward" in exported_module)
|
|
assert(b"top(B).A0(A).forward" in exported_module)
|
|
assert(b"top(B).A1(A).forward" in exported_module)
|
|
|
|
def test_save_mobile_module_with_debug_info_with_script_duplicate_class(self):
|
|
class A(torch.nn.Module):
|
|
def __init__(self):
|
|
super(A, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
class B(torch.nn.Module):
|
|
def __init__(self):
|
|
super(B, self).__init__()
|
|
self.A0 = A()
|
|
self.A1 = A()
|
|
|
|
def forward(self, x):
|
|
return self.A0(x) + self.A1(x)
|
|
|
|
input_data = torch.tensor([5])
|
|
scripted_module = torch.jit.script(B(), input_data)
|
|
exported_module = scripted_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True)
|
|
|
|
assert(b"mobile_debug.pkl" in exported_module)
|
|
assert(b"module_debug_info" in exported_module)
|
|
assert(b"top(B).forward" in exported_module)
|
|
assert(b"top(B).A0(A).forward" in exported_module)
|
|
assert(b"top(B).A1(A).forward" in exported_module)
|
|
|
|
def test_save_mobile_module_with_debug_info_with_script_nested_call(self):
|
|
class A(torch.nn.Module):
|
|
def __init__(self):
|
|
super(A, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
class B(torch.nn.Module):
|
|
def __init__(self):
|
|
super(B, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x + 2
|
|
|
|
class C(torch.nn.Module):
|
|
def __init__(self):
|
|
super(C, self).__init__()
|
|
self.A0 = A()
|
|
self.B0 = B()
|
|
|
|
def forward(self, x):
|
|
return self.A0(self.B0(x)) + 1
|
|
|
|
input = torch.tensor([5])
|
|
scripted_module = torch.jit.script(C(), input)
|
|
|
|
optimized_scripted_module = optimize_for_mobile(scripted_module)
|
|
|
|
exported_module = scripted_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True)
|
|
optimized_exported_module = optimized_scripted_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True)
|
|
assert(b"mobile_debug.pkl" in exported_module)
|
|
assert(b"module_debug_info" in exported_module)
|
|
assert(b"top(C).forward" in exported_module)
|
|
assert(b"top(C).A0(A).forward" in exported_module)
|
|
assert(b"top(C).B0(B).forward" in exported_module)
|
|
|
|
assert(b"mobile_debug.pkl" in optimized_exported_module)
|
|
assert(b"module_debug_info" in optimized_exported_module)
|
|
assert(b"top(C).forward" in optimized_exported_module)
|
|
assert(b"top(C).A0(A).forward" in optimized_exported_module)
|
|
assert(b"top(C).B0(B).forward" in optimized_exported_module)
|
|
|
|
def test_load_mobile_module_with_debug_info(self):
|
|
class MyTestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyTestModule, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x + 5
|
|
|
|
input = torch.tensor([3])
|
|
|
|
script_module = torch.jit.script(MyTestModule())
|
|
script_module_result = script_module(input)
|
|
|
|
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True))
|
|
buffer.seek(0)
|
|
mobile_module = _load_for_lite_interpreter(buffer)
|
|
|
|
mobile_module_result = mobile_module(input)
|
|
torch.testing.assert_allclose(script_module_result, mobile_module_result)
|
|
|
|
mobile_module_forward_result = mobile_module.forward(input)
|
|
torch.testing.assert_allclose(script_module_result, mobile_module_forward_result)
|
|
|
|
mobile_module_run_method_result = mobile_module.run_method("forward", input)
|
|
torch.testing.assert_allclose(script_module_result, mobile_module_run_method_result)
|
|
|
|
def test_find_and_run_method(self):
|
|
class MyTestModule(torch.nn.Module):
|
|
def forward(self, arg):
|
|
return arg
|
|
|
|
input = (torch.tensor([1]), )
|
|
|
|
script_module = torch.jit.script(MyTestModule())
|
|
script_module_result = script_module(*input)
|
|
|
|
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
|
|
buffer.seek(0)
|
|
mobile_module = _load_for_lite_interpreter(buffer)
|
|
|
|
has_bundled_inputs = mobile_module.find_method("get_all_bundled_inputs")
|
|
self.assertFalse(has_bundled_inputs)
|
|
|
|
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
|
script_module, [input], [])
|
|
|
|
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
|
|
buffer.seek(0)
|
|
mobile_module = _load_for_lite_interpreter(buffer)
|
|
|
|
has_bundled_inputs = mobile_module.find_method("get_all_bundled_inputs")
|
|
self.assertTrue(has_bundled_inputs)
|
|
|
|
bundled_inputs = mobile_module.run_method("get_all_bundled_inputs")
|
|
mobile_module_result = mobile_module.forward(*bundled_inputs[0])
|
|
torch.testing.assert_allclose(script_module_result, mobile_module_result)
|
|
|
|
def test_method_calls_with_optional_arg(self):
|
|
class A(torch.nn.Module):
|
|
def __init__(self):
|
|
super(A, self).__init__()
|
|
|
|
# opt arg in script-to-script invocation
|
|
def forward(self, x, two: int = 2):
|
|
return x + two
|
|
|
|
class B(torch.nn.Module):
|
|
def __init__(self):
|
|
super(B, self).__init__()
|
|
self.A0 = A()
|
|
|
|
# opt arg in Python-to-script invocation
|
|
def forward(self, x, one: int = 1):
|
|
return self.A0(x) + one
|
|
|
|
script_module = torch.jit.script(B())
|
|
buffer = io.BytesIO(
|
|
script_module._save_to_buffer_for_lite_interpreter()
|
|
)
|
|
mobile_module = _load_for_lite_interpreter(buffer)
|
|
|
|
input = torch.tensor([5])
|
|
script_module_forward_result = script_module.forward(input)
|
|
mobile_module_forward_result = mobile_module.forward(input)
|
|
torch.testing.assert_allclose(
|
|
script_module_forward_result,
|
|
mobile_module_forward_result
|
|
)
|
|
|
|
# change ref only
|
|
script_module_forward_result = script_module.forward(input, 2)
|
|
self.assertFalse(
|
|
(script_module_forward_result == mobile_module_forward_result)
|
|
.all()
|
|
.item()
|
|
)
|
|
|
|
# now both match again
|
|
mobile_module_forward_result = mobile_module.forward(input, 2)
|
|
torch.testing.assert_allclose(
|
|
script_module_forward_result,
|
|
mobile_module_forward_result
|
|
)
|
|
|
|
def test_unsupported_classtype(self):
|
|
class Foo():
|
|
def __init__(self):
|
|
return
|
|
|
|
def func(self, x: int, y: int):
|
|
return x + y
|
|
|
|
class MyTestModule(torch.nn.Module):
|
|
def forward(self, arg):
|
|
f = Foo()
|
|
return f.func(1, 2)
|
|
|
|
script_module = torch.jit.script(MyTestModule())
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
r"Workaround: instead of using arbitrary class type \(class Foo\(\)\), "
|
|
r"define a pytorch class \(class Foo\(torch\.nn\.Module\)\)\.$"):
|
|
script_module._save_to_buffer_for_lite_interpreter()
|
|
|
|
def test_unsupported_return_typing_namedtuple(self):
|
|
myNamedTuple = NamedTuple('myNamedTuple', [('a', torch.Tensor)])
|
|
|
|
class MyTestModule(torch.nn.Module):
|
|
def forward(self):
|
|
return myNamedTuple(torch.randn(1))
|
|
|
|
script_module = torch.jit.script(MyTestModule())
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
r"A named tuple type is not supported in mobile module. "
|
|
r"Workaround: instead of using a named tuple type\'s fields, "
|
|
r"use a dictionary type\'s key-value pair itmes or "
|
|
r"a pytorch class \(class Foo\(torch\.nn\.Module\)\)\'s attributes."):
|
|
script_module._save_to_buffer_for_lite_interpreter()
|
|
|
|
def test_unsupported_return_collections_namedtuple(self):
|
|
myNamedTuple = namedtuple('myNamedTuple', [('a')])
|
|
|
|
class MyTestModule(torch.nn.Module):
|
|
def forward(self):
|
|
return myNamedTuple(torch.randn(1))
|
|
|
|
script_module = torch.jit.script(MyTestModule())
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
r"A named tuple type is not supported in mobile module. "
|
|
r"Workaround: instead of using a named tuple type\'s fields, "
|
|
r"use a dictionary type\'s key-value pair itmes or "
|
|
r"a pytorch class \(class Foo\(torch\.nn\.Module\)\)\'s attributes."):
|
|
script_module._save_to_buffer_for_lite_interpreter()
|
|
|
|
def test_unsupported_return_list_with_module_class(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Foo, self).__init__()
|
|
|
|
class MyTestModuleForListWithModuleClass(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyTestModuleForListWithModuleClass, self).__init__()
|
|
self.foo = Foo()
|
|
|
|
def forward(self):
|
|
my_list: List[Foo] = [self.foo]
|
|
return my_list
|
|
|
|
script_module = torch.jit.script(MyTestModuleForListWithModuleClass())
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
r"^Returining a list or dictionary with pytorch class type "
|
|
r"is not supported in mobile module "
|
|
r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. "
|
|
r"Workaround\: instead of using pytorch class as their element type\, "
|
|
r"use a combination of list\, dictionary\, and single types\.$"):
|
|
script_module._save_to_buffer_for_lite_interpreter()
|
|
|
|
def test_unsupported_return_dict_with_module_class(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Foo, self).__init__()
|
|
|
|
class MyTestModuleForDictWithModuleClass(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyTestModuleForDictWithModuleClass, self).__init__()
|
|
self.foo = Foo()
|
|
|
|
def forward(self):
|
|
my_dict: Dict[int, Foo] = {1: self.foo}
|
|
return my_dict
|
|
|
|
script_module = torch.jit.script(MyTestModuleForDictWithModuleClass())
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
r"^Returining a list or dictionary with pytorch class type "
|
|
r"is not supported in mobile module "
|
|
r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. "
|
|
r"Workaround\: instead of using pytorch class as their element type\, "
|
|
r"use a combination of list\, dictionary\, and single types\.$"):
|
|
script_module._save_to_buffer_for_lite_interpreter()
|
|
|
|
def test_module_export_operator_list(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Foo, self).__init__()
|
|
self.weight = torch.ones((20, 1, 5, 5))
|
|
self.bias = torch.ones(20)
|
|
|
|
def forward(self, input):
|
|
x1 = torch.zeros(2, 2)
|
|
x2 = torch.empty_like(torch.empty(2, 2))
|
|
x3 = torch._convolution(
|
|
input,
|
|
self.weight,
|
|
self.bias,
|
|
[1, 1],
|
|
[0, 0],
|
|
[1, 1],
|
|
False,
|
|
[0, 0],
|
|
1,
|
|
False,
|
|
False,
|
|
True,
|
|
True,
|
|
)
|
|
return (x1, x2, x3)
|
|
|
|
m = torch.jit.script(Foo())
|
|
|
|
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
|
|
buffer.seek(0)
|
|
mobile_module = _load_for_lite_interpreter(buffer)
|
|
|
|
expected_ops = {
|
|
"aten::_convolution",
|
|
"aten::empty.memory_format",
|
|
"aten::empty_like",
|
|
"aten::zeros",
|
|
}
|
|
actual_ops = _export_operator_list(mobile_module)
|
|
self.assertEqual(actual_ops, expected_ops)
|
|
|
|
def test_source_range_simple(self):
|
|
|
|
class FooTest(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x, w):
|
|
return torch.mm(x, w.t())
|
|
|
|
ft = FooTest()
|
|
loaded = self.getScriptExportImportCopy(ft)
|
|
_, lineno = inspect.getsourcelines(FooTest)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)):
|
|
loaded(torch.rand(3, 4), torch.rand(30, 40))
|
|
|
|
def test_source_range_raise_exception(self):
|
|
|
|
class FooTest2(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self):
|
|
raise RuntimeError('foo')
|
|
|
|
_, lineno = inspect.getsourcelines(FooTest2)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)):
|
|
ft = FooTest2()
|
|
loaded = self.getScriptExportImportCopy(ft)
|
|
loaded()
|
|
|
|
def test_source_range_function_call(self):
|
|
class FooTest3(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def add_method(self, x, w):
|
|
return x + w
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x, y, w):
|
|
x = x * y
|
|
x = x + 2
|
|
return self.add_method(x, w)
|
|
|
|
ft = FooTest3()
|
|
loaded = self.getScriptExportImportCopy(ft)
|
|
_, lineno = inspect.getsourcelines(FooTest3)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)):
|
|
loaded(torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40))
|
|
|
|
def test_source_range_no_debug_info(self):
|
|
|
|
class FooTest4(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x, w):
|
|
return torch.mm(x, w.t())
|
|
|
|
ft = FooTest4()
|
|
loaded = self.getScriptExportImportCopy(ft, save_mobile_debug_info=False)
|
|
|
|
try:
|
|
loaded(torch.rand(3, 4), torch.rand(30, 40))
|
|
except RuntimeError as e:
|
|
error_message = f"{e}"
|
|
self.assertTrue("test_lite_script_module.py" not in error_message)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|