pytorch/test/mobile/test_lite_script_module.py
Zhengxu Chen 75ad6fea66 [jit][edge] Pass through dynamic type for DictType. (#74025)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74025

When users are trying to inspect IValues out of the Lite Interpreter, dynamic types are still attached, therefore torch::jit::toPyObject will fail on these dynamic types while converting dictionary keys.
We should just let dynamic types pass through under this corner case since they won't be used by anything later.
ghstack-source-id: 151051826

Test Plan:
buck test //caffe2/test:mobile -- -r 'test_bundled_input_with_dynamic_type'

without patch:
```
BUILD SUCCEEDED
Tpx test run coordinator for Facebook. See https://fburl.com/tpx for details.
Running with tpx session id: c6693277-2dad-4882-97c7-f69c58f67259
Trace available for this run at /tmp/tpx-20220310-000040.948069-c6693277-2dad-4882-97c7-f69c58f67259/trace.log
RemoteExecution session id: reSessionID-c6693277-2dad-4882-97c7-f69c58f67259-tpx
Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/6473924544183693
    ✓ ListingSuccess: caffe2/test:mobile : 40 tests discovered (2.122)
    ✗ Fail: caffe2/test:mobile - test_bundled_input_with_dynamic_type (mobile.test_lite_script_module.TestLiteScriptQuantizedModule) (3.059)
Test output:
> RuntimeError: Cannot create dict for key type 'Dynamic<8>', only int, float, complex, Tensor, device and string keys are supported
  File "/usr/local/fbcode/platform009/lib/python3.8/unittest/case.py", line 60, in testPartExecutor
    yield
  File "/usr/local/fbcode/platform009/lib/python3.8/unittest/case.py", line 676, in run
    self._callTestMethod(testMethod)
  File "/usr/local/fbcode/platform009/lib/python3.8/unittest/case.py", line 633, in _callTestMethod
    method()
  File "/data/users/zhxchen17/fbsource/fbcode/buck-out/dbg/gen/caffe2/test/mobile#binary,link-tree/mobile/test_lite_script_module.py", line 558, in test_bundled_input_with_dynamic_type
    i = mobile_module.run_method("get_all_bundled_inputs")
  File "/data/users/zhxchen17/fbsource/fbcode/buck-out/dbg/gen/caffe2/test/mobile#binary,link-tree/torch/jit/mobile/__init__.py", line 69, in run_method
    return self._c.run_method(method_name, input)
stdout:

stderr:

Summary
  Fail: 1
    ✗ caffe2/test:mobile - test_bundled_input_with_dynamic_type (mobile.test_lite_script_module.TestLiteScriptQuantizedModule)
  ListingSuccess: 1
If you need help understanding your runs, please follow the wiki: https://fburl.com/posting_in_tpx_users
Finished test run: https://www.internalfb.com/intern/testinfra/testrun/6473924544183693
```

Reviewed By: cccclai

Differential Revision: D34780805

fbshipit-source-id: 88b139c5e91becc031e4b06e055a78a52a429c09
(cherry picked from commit 41abbacf3025cf8fc82516a3e1cefe8b4081a4b6)
2022-03-11 02:18:04 +00:00

571 lines
21 KiB
Python

# Owner(s): ["oncall: mobile"]
import torch
import torch.utils.bundled_inputs
import io
from typing import Dict, List
import inspect
from torch.testing import FileCheck
from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.common_quantization import (
AnnotatedSingleLayerLinearModel,
TwoLayerLinearModel,
AnnotatedNestedModel
)
from torch.testing._internal.common_quantization import QuantizationLiteTestCase
class TestLiteScriptModule(TestCase):
def getScriptExportImportCopy(self, m, save_mobile_debug_info=True, also_test_file=False):
m_scripted = torch.jit.script(m)
if not also_test_file:
buffer = io.BytesIO(m_scripted._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=save_mobile_debug_info))
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
return mobile_module
with TemporaryFileName() as fname:
m_scripted._save_for_lite_interpreter(fname, _save_mobile_debug_info=save_mobile_debug_info)
mobile_module = _load_for_lite_interpreter(fname)
return mobile_module
def test_load_mobile_module(self):
class MyTestModule(torch.nn.Module):
def __init__(self):
super(MyTestModule, self).__init__()
def forward(self, x):
return x + 10
input = torch.tensor([1])
script_module = torch.jit.script(MyTestModule())
script_module_result = script_module(input)
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
mobile_module_result = mobile_module(input)
torch.testing.assert_close(script_module_result, mobile_module_result)
mobile_module_forward_result = mobile_module.forward(input)
torch.testing.assert_close(script_module_result, mobile_module_forward_result)
mobile_module_run_method_result = mobile_module.run_method("forward", input)
torch.testing.assert_close(script_module_result, mobile_module_run_method_result)
def test_save_mobile_module_with_debug_info_with_trace(self):
class A(torch.nn.Module):
def __init__(self):
super(A, self).__init__()
def forward(self, x, y):
return x * y
class B(torch.nn.Module):
def __init__(self):
super(B, self).__init__()
self.A0 = A()
self.A1 = A()
def forward(self, x, y, z):
return self.A0(x, y) + self.A1(y, z)
for export_method in ['trace', 'script']:
x = torch.rand((2, 3))
y = torch.rand((2, 3))
z = torch.rand((2, 3))
if export_method == 'trace':
trace_module = torch.jit.trace(B(), [x, y, z])
else:
trace_module = torch.jit.script(B())
exported_module = trace_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True)
buffer = io.BytesIO(exported_module)
buffer.seek(0)
assert(b"callstack_debug_map.pkl" in exported_module)
mobile_module = _load_for_lite_interpreter(buffer)
with self.assertRaisesRegex(RuntimeError, r"Module hierarchy:top\(B\)::<unknown>.A0\(A\)::forward.aten::mul"):
x = torch.rand((2, 3))
y = torch.rand((8, 10))
z = torch.rand((8, 10))
mobile_module(x, y, z)
with self.assertRaisesRegex(RuntimeError, r"Module hierarchy:top\(B\)::<unknown>.A1\(A\)::forward.aten::mul"):
x = torch.rand((2, 3))
y = torch.rand((2, 3))
z = torch.rand((8, 10))
mobile_module(x, y, z)
def test_load_mobile_module_with_debug_info(self):
class MyTestModule(torch.nn.Module):
def __init__(self):
super(MyTestModule, self).__init__()
def forward(self, x):
return x + 5
input = torch.tensor([3])
script_module = torch.jit.script(MyTestModule())
script_module_result = script_module(input)
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True))
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
mobile_module_result = mobile_module(input)
torch.testing.assert_close(script_module_result, mobile_module_result)
mobile_module_forward_result = mobile_module.forward(input)
torch.testing.assert_close(script_module_result, mobile_module_forward_result)
mobile_module_run_method_result = mobile_module.run_method("forward", input)
torch.testing.assert_close(script_module_result, mobile_module_run_method_result)
def test_find_and_run_method(self):
class MyTestModule(torch.nn.Module):
def forward(self, arg):
return arg
input = (torch.tensor([1]), )
script_module = torch.jit.script(MyTestModule())
script_module_result = script_module(*input)
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
has_bundled_inputs = mobile_module.find_method("get_all_bundled_inputs")
self.assertFalse(has_bundled_inputs)
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
script_module, [input], [])
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
has_bundled_inputs = mobile_module.find_method("get_all_bundled_inputs")
self.assertTrue(has_bundled_inputs)
bundled_inputs = mobile_module.run_method("get_all_bundled_inputs")
mobile_module_result = mobile_module.forward(*bundled_inputs[0])
torch.testing.assert_close(script_module_result, mobile_module_result)
def test_method_calls_with_optional_arg(self):
class A(torch.nn.Module):
def __init__(self):
super(A, self).__init__()
# opt arg in script-to-script invocation
def forward(self, x, two: int = 2):
return x + two
class B(torch.nn.Module):
def __init__(self):
super(B, self).__init__()
self.A0 = A()
# opt arg in Python-to-script invocation
def forward(self, x, one: int = 1):
return self.A0(x) + one
script_module = torch.jit.script(B())
buffer = io.BytesIO(
script_module._save_to_buffer_for_lite_interpreter()
)
mobile_module = _load_for_lite_interpreter(buffer)
input = torch.tensor([5])
script_module_forward_result = script_module.forward(input)
mobile_module_forward_result = mobile_module.forward(input)
torch.testing.assert_close(
script_module_forward_result,
mobile_module_forward_result
)
# change ref only
script_module_forward_result = script_module.forward(input, 2)
self.assertFalse(
(script_module_forward_result == mobile_module_forward_result)
.all()
.item()
)
# now both match again
mobile_module_forward_result = mobile_module.forward(input, 2)
torch.testing.assert_close(
script_module_forward_result,
mobile_module_forward_result
)
def test_unsupported_classtype(self):
class Foo():
def __init__(self):
return
def func(self, x: int, y: int):
return x + y
class MyTestModule(torch.nn.Module):
def forward(self, arg):
f = Foo()
return f.func(1, 2)
script_module = torch.jit.script(MyTestModule())
with self.assertRaisesRegex(RuntimeError,
r"Workaround: instead of using arbitrary class type \(class Foo\(\)\), "
r"define a pytorch class \(class Foo\(torch\.nn\.Module\)\)\. "
r"The problematic type is: "):
script_module._save_to_buffer_for_lite_interpreter()
def test_unsupported_return_list_with_module_class(self):
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
class MyTestModuleForListWithModuleClass(torch.nn.Module):
def __init__(self):
super(MyTestModuleForListWithModuleClass, self).__init__()
self.foo = Foo()
def forward(self):
my_list: List[Foo] = [self.foo]
return my_list
script_module = torch.jit.script(MyTestModuleForListWithModuleClass())
with self.assertRaisesRegex(RuntimeError,
r"^Returining a list or dictionary with pytorch class type "
r"is not supported in mobile module "
r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. "
r"Workaround\: instead of using pytorch class as their element type\, "
r"use a combination of list\, dictionary\, and single types\.$"):
script_module._save_to_buffer_for_lite_interpreter()
def test_unsupported_return_dict_with_module_class(self):
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
class MyTestModuleForDictWithModuleClass(torch.nn.Module):
def __init__(self):
super(MyTestModuleForDictWithModuleClass, self).__init__()
self.foo = Foo()
def forward(self):
my_dict: Dict[int, Foo] = {1: self.foo}
return my_dict
script_module = torch.jit.script(MyTestModuleForDictWithModuleClass())
with self.assertRaisesRegex(RuntimeError,
r"^Returining a list or dictionary with pytorch class type "
r"is not supported in mobile module "
r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. "
r"Workaround\: instead of using pytorch class as their element type\, "
r"use a combination of list\, dictionary\, and single types\.$"):
script_module._save_to_buffer_for_lite_interpreter()
def test_module_export_operator_list(self):
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.weight = torch.ones((20, 1, 5, 5))
self.bias = torch.ones(20)
def forward(self, input):
x1 = torch.zeros(2, 2)
x2 = torch.empty_like(torch.empty(2, 2))
x3 = torch._convolution(
input,
self.weight,
self.bias,
[1, 1],
[0, 0],
[1, 1],
False,
[0, 0],
1,
False,
False,
True,
True,
)
return (x1, x2, x3)
m = torch.jit.script(Foo())
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
expected_ops = {
"aten::_convolution",
"aten::empty.memory_format",
"aten::empty_like",
"aten::zeros",
}
actual_ops = _export_operator_list(mobile_module)
self.assertEqual(actual_ops, expected_ops)
def test_source_range_simple(self):
class FooTest(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, w):
return torch.mm(x, w.t())
ft = FooTest()
loaded = self.getScriptExportImportCopy(ft)
_, lineno = inspect.getsourcelines(FooTest)
with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)):
loaded(torch.rand(3, 4), torch.rand(30, 40))
def test_source_range_raise_exception(self):
class FooTest2(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self):
raise RuntimeError('foo')
_, lineno = inspect.getsourcelines(FooTest2)
# In C++ code, the type of exception thrown is torch::jit::JITException
# which does not extend c10::Error, and hence it isn't possible to add
# additional context to the exception message and preserve the correct
# C++ stack trace for symbolication. i.e. it isn't possible to add
# the debug handle string to show where in the Python code the exception
# occured w/o first changing
# torch::jit::JITException to extend c10::Error.
with self.assertRaisesRegex(torch.jit.Error, 'foo'):
ft = FooTest2()
loaded = self.getScriptExportImportCopy(ft)
loaded()
def test_source_range_function_call(self):
class FooTest3(torch.jit.ScriptModule):
@torch.jit.script_method
def add_method(self, x, w):
return x + w
@torch.jit.script_method
def forward(self, x, y, w):
x = x * y
x = x + 2
return self.add_method(x, w)
ft = FooTest3()
loaded = self.getScriptExportImportCopy(ft)
_, lineno = inspect.getsourcelines(FooTest3)
try:
loaded(torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40))
except RuntimeError as e:
error_message = f"{e}"
self.assertTrue('test_lite_script_module.py\", line {}'.format(lineno + 3) in error_message)
self.assertTrue('test_lite_script_module.py\", line {}'.format(lineno + 9) in error_message)
self.assertTrue('top(FooTest3)' in error_message)
def test_source_range_no_debug_info(self):
class FooTest4(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, w):
return torch.mm(x, w.t())
ft = FooTest4()
loaded = self.getScriptExportImportCopy(ft, save_mobile_debug_info=False)
try:
loaded(torch.rand(3, 4), torch.rand(30, 40))
except RuntimeError as e:
error_message = f"{e}"
self.assertTrue("test_lite_script_module.py" not in error_message)
def test_source_range_raise_exc(self):
class FooTest5(torch.jit.ScriptModule):
def __init__(self, val: int):
super(FooTest5, self).__init__()
self.val = val
@torch.jit.script_method
def add_method(self, val: int, x, w):
if (val == self.val):
raise RuntimeError('self.val and val are same')
return x + w
@torch.jit.script_method
def forward(self, val: int, x, y, w):
x = x * y
x = x + 2
return self.add_method(val, x, w)
ft = FooTest5(42)
loaded = self.getScriptExportImportCopy(ft)
_, lineno = inspect.getsourcelines(FooTest5)
try:
loaded(42, torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40))
except torch.jit.Error as e:
error_message = f"{e}"
# In C++ code, the type of exception thrown is torch::jit::JITException
# which does not extend c10::Error, and hence it isn't possible to add
# additional context to the exception message and preserve the correct
# C++ stack trace for symbolication. i.e. it isn't possible to add
# the debug handle string to show where in the Python code the exception
# occured w/o first changing
# torch::jit::JITException to extend c10::Error.
self.assertTrue('self.val and val are same' in error_message)
def test_stacktrace_interface_call(self):
@torch.jit.interface
class Forward(torch.nn.Module):
def forward(self, x) -> torch.Tensor:
pass
def forwardError(self, x) -> torch.Tensor:
pass
class B(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
def forwardError(self, x):
return self.call() + x
def call(self):
return torch.ones(-1)
class A(torch.nn.Module):
b : Forward
def __init__(self):
super().__init__()
self.b = B()
def forward(self):
self.b.forward(torch.ones(1))
self.b.forwardError(torch.ones(1))
a = torch.jit.script(A())
torch._C._enable_mobile_interface_call_export()
buffer = io.BytesIO(a._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True))
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
try:
mobile_module()
self.assertTrue(False)
except RuntimeError as exp:
FileCheck().check("Trying to create tensor with negative dimension") \
.check("Traceback of TorchScript") \
.check("self.b.forwardError").check_next("~~~~~~~~~~~~~~~~~~~ <--- HERE") \
.check("return self.call").check_next("~~~~~~~~~ <--- HERE") \
.check("return torch.ones").check_next("~~~~~~~~~~ <--- HERE").run(str(exp))
class TestLiteScriptQuantizedModule(QuantizationLiteTestCase):
def test_single_layer(self):
input = torch.rand(2, 5, dtype=torch.float)
quantized_model = self._create_quantized_model(model_class=AnnotatedSingleLayerLinearModel, qengine="qnnpack")
self._compare_script_and_mobile(model=quantized_model, input=input)
def test_two_layer(self):
input = torch.rand(2, 5, dtype=torch.float)
quantized_model = self._create_quantized_model(model_class=TwoLayerLinearModel)
self._compare_script_and_mobile(model=quantized_model, input=input)
def test_annotated_nested(self):
input = torch.rand(2, 5, dtype=torch.float)
quantized_model = self._create_quantized_model(model_class=AnnotatedNestedModel, qengine="qnnpack")
self._compare_script_and_mobile(model=quantized_model, input=input)
def test_quantization_example(self):
# From the example in Static Quantization section of https://pytorch.org/docs/stable/quantization.html
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.quant = torch.ao.quantization.QuantStub()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.relu = torch.nn.ReLU()
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv(x)
x = self.relu(x)
x = self.dequant(x)
return x
model_fp32 = M()
model_fp32.eval()
model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)
input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
input = torch.randn(4, 1, 4, 4)
self._compare_script_and_mobile(model=model_int8, input=input)
def test_bundled_input_with_dynamic_type(self):
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(
self,
x: Dict[int, torch.Tensor],
y: Dict[int, torch.Tensor],
z: Dict[int, torch.Tensor],
):
return x
model = Model()
script_module = torch.jit.script(model)
sample_input = {
script_module.forward: [
(
{0: torch.ones(1)},
{1: torch.ones(1)},
{2: torch.ones(1)},
)
]
}
bundled_model = torch.utils.bundled_inputs.bundle_inputs(
script_module, sample_input
)
buf = bundled_model._save_to_buffer_for_lite_interpreter()
mobile_module = _load_for_lite_interpreter(io.BytesIO(buf))
i = mobile_module.run_method("get_all_bundled_inputs")
self.assertEqual(
i[0],
(
{0: torch.ones(1)},
{1: torch.ones(1)},
{2: torch.ones(1)},
),
)
if __name__ == '__main__':
run_tests()