mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Generally wildcard imports are bad for the reasons described here: https://www.flake8rules.com/rules/F403.html This PR replaces wildcard imports with an explicit list of imported items where possible, and adds a `# noqa: F403` comment in the other cases (mostly re-exports in `__init__.py` files). This is a prerequisite for https://github.com/pytorch/pytorch/issues/55816, because currently [`tools/codegen/dest/register_dispatch_key.py` simply fails if you sort its imports](https://github.com/pytorch/pytorch/actions/runs/742505908). Pull Request resolved: https://github.com/pytorch/pytorch/pull/55838 Test Plan: CI. You can also run `flake8` locally. Reviewed By: jbschlosser Differential Revision: D27724232 Pulled By: samestep fbshipit-source-id: 269fb09cb4168f8a51fd65bfaacc6cda7fb87c34
379 lines
15 KiB
Python
379 lines
15 KiB
Python
import torch
|
|
import torch.utils.bundled_inputs
|
|
from torch.utils.mobile_optimizer import optimize_for_mobile
|
|
import io
|
|
from typing import Dict, List, NamedTuple
|
|
from collections import namedtuple
|
|
|
|
from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
|
|
class TestLiteScriptModule(TestCase):
|
|
|
|
def test_load_mobile_module(self):
|
|
class MyTestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyTestModule, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x + 10
|
|
|
|
input = torch.tensor([1])
|
|
|
|
script_module = torch.jit.script(MyTestModule())
|
|
script_module_result = script_module(input)
|
|
|
|
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
|
|
buffer.seek(0)
|
|
mobile_module = _load_for_lite_interpreter(buffer)
|
|
|
|
mobile_module_result = mobile_module(input)
|
|
torch.testing.assert_allclose(script_module_result, mobile_module_result)
|
|
|
|
mobile_module_forward_result = mobile_module.forward(input)
|
|
torch.testing.assert_allclose(script_module_result, mobile_module_forward_result)
|
|
|
|
mobile_module_run_method_result = mobile_module.run_method("forward", input)
|
|
torch.testing.assert_allclose(script_module_result, mobile_module_run_method_result)
|
|
|
|
def test_save_mobile_module_with_debug_info_with_trace(self):
|
|
class A(torch.nn.Module):
|
|
def __init__(self):
|
|
super(A, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
class B(torch.nn.Module):
|
|
def __init__(self):
|
|
super(B, self).__init__()
|
|
self.A0 = A()
|
|
self.A1 = A()
|
|
|
|
def forward(self, x):
|
|
return self.A0(x) + self.A1(x)
|
|
|
|
input = torch.tensor([5])
|
|
trace_module = torch.jit.trace(B(), input)
|
|
exported_module = trace_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True)
|
|
|
|
assert(b"mobile_debug.pkl" in exported_module)
|
|
assert(b"module_debug_info" in exported_module)
|
|
assert(b"top(B).forward" in exported_module)
|
|
assert(b"top(B).A0(A).forward" in exported_module)
|
|
assert(b"top(B).A1(A).forward" in exported_module)
|
|
|
|
def test_save_mobile_module_with_debug_info_with_script_duplicate_class(self):
|
|
class A(torch.nn.Module):
|
|
def __init__(self):
|
|
super(A, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
class B(torch.nn.Module):
|
|
def __init__(self):
|
|
super(B, self).__init__()
|
|
self.A0 = A()
|
|
self.A1 = A()
|
|
|
|
def forward(self, x):
|
|
return self.A0(x) + self.A1(x)
|
|
|
|
input_data = torch.tensor([5])
|
|
scripted_module = torch.jit.script(B(), input_data)
|
|
exported_module = scripted_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True)
|
|
|
|
assert(b"mobile_debug.pkl" in exported_module)
|
|
assert(b"module_debug_info" in exported_module)
|
|
assert(b"top(B).forward" in exported_module)
|
|
assert(b"top(B).A0(A).forward" in exported_module)
|
|
assert(b"top(B).A1(A).forward" in exported_module)
|
|
|
|
def test_save_mobile_module_with_debug_info_with_script_nested_call(self):
|
|
class A(torch.nn.Module):
|
|
def __init__(self):
|
|
super(A, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
class B(torch.nn.Module):
|
|
def __init__(self):
|
|
super(B, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x + 2
|
|
|
|
class C(torch.nn.Module):
|
|
def __init__(self):
|
|
super(C, self).__init__()
|
|
self.A0 = A()
|
|
self.B0 = B()
|
|
|
|
def forward(self, x):
|
|
return self.A0(self.B0(x)) + 1
|
|
|
|
input = torch.tensor([5])
|
|
scripted_module = torch.jit.script(C(), input)
|
|
|
|
optimized_scripted_module = optimize_for_mobile(scripted_module)
|
|
|
|
exported_module = scripted_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True)
|
|
optimized_exported_module = optimized_scripted_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True)
|
|
assert(b"mobile_debug.pkl" in exported_module)
|
|
assert(b"module_debug_info" in exported_module)
|
|
assert(b"top(C).forward" in exported_module)
|
|
assert(b"top(C).A0(A).forward" in exported_module)
|
|
assert(b"top(C).B0(B).forward" in exported_module)
|
|
|
|
assert(b"mobile_debug.pkl" in optimized_exported_module)
|
|
assert(b"module_debug_info" in optimized_exported_module)
|
|
assert(b"top(C).forward" in optimized_exported_module)
|
|
assert(b"top(C).A0(A).forward" in optimized_exported_module)
|
|
assert(b"top(C).B0(B).forward" in optimized_exported_module)
|
|
|
|
def test_load_mobile_module_with_debug_info(self):
|
|
class MyTestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyTestModule, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x + 5
|
|
|
|
input = torch.tensor([3])
|
|
|
|
script_module = torch.jit.script(MyTestModule())
|
|
script_module_result = script_module(input)
|
|
|
|
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True))
|
|
buffer.seek(0)
|
|
mobile_module = _load_for_lite_interpreter(buffer)
|
|
|
|
mobile_module_result = mobile_module(input)
|
|
torch.testing.assert_allclose(script_module_result, mobile_module_result)
|
|
|
|
mobile_module_forward_result = mobile_module.forward(input)
|
|
torch.testing.assert_allclose(script_module_result, mobile_module_forward_result)
|
|
|
|
mobile_module_run_method_result = mobile_module.run_method("forward", input)
|
|
torch.testing.assert_allclose(script_module_result, mobile_module_run_method_result)
|
|
|
|
def test_find_and_run_method(self):
|
|
class MyTestModule(torch.nn.Module):
|
|
def forward(self, arg):
|
|
return arg
|
|
|
|
input = (torch.tensor([1]), )
|
|
|
|
script_module = torch.jit.script(MyTestModule())
|
|
script_module_result = script_module(*input)
|
|
|
|
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
|
|
buffer.seek(0)
|
|
mobile_module = _load_for_lite_interpreter(buffer)
|
|
|
|
has_bundled_inputs = mobile_module.find_method("get_all_bundled_inputs")
|
|
self.assertFalse(has_bundled_inputs)
|
|
|
|
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
|
script_module, [input], [])
|
|
|
|
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
|
|
buffer.seek(0)
|
|
mobile_module = _load_for_lite_interpreter(buffer)
|
|
|
|
has_bundled_inputs = mobile_module.find_method("get_all_bundled_inputs")
|
|
self.assertTrue(has_bundled_inputs)
|
|
|
|
bundled_inputs = mobile_module.run_method("get_all_bundled_inputs")
|
|
mobile_module_result = mobile_module.forward(*bundled_inputs[0])
|
|
torch.testing.assert_allclose(script_module_result, mobile_module_result)
|
|
|
|
def test_method_calls_with_optional_arg(self):
|
|
class A(torch.nn.Module):
|
|
def __init__(self):
|
|
super(A, self).__init__()
|
|
|
|
# opt arg in script-to-script invocation
|
|
def forward(self, x, two: int = 2):
|
|
return x + two
|
|
|
|
class B(torch.nn.Module):
|
|
def __init__(self):
|
|
super(B, self).__init__()
|
|
self.A0 = A()
|
|
|
|
# opt arg in Python-to-script invocation
|
|
def forward(self, x, one: int = 1):
|
|
return self.A0(x) + one
|
|
|
|
script_module = torch.jit.script(B())
|
|
buffer = io.BytesIO(
|
|
script_module._save_to_buffer_for_lite_interpreter()
|
|
)
|
|
mobile_module = _load_for_lite_interpreter(buffer)
|
|
|
|
input = torch.tensor([5])
|
|
script_module_forward_result = script_module.forward(input)
|
|
mobile_module_forward_result = mobile_module.forward(input)
|
|
torch.testing.assert_allclose(
|
|
script_module_forward_result,
|
|
mobile_module_forward_result
|
|
)
|
|
|
|
# change ref only
|
|
script_module_forward_result = script_module.forward(input, 2)
|
|
self.assertFalse(
|
|
(script_module_forward_result == mobile_module_forward_result)
|
|
.all()
|
|
.item()
|
|
)
|
|
|
|
# now both match again
|
|
mobile_module_forward_result = mobile_module.forward(input, 2)
|
|
torch.testing.assert_allclose(
|
|
script_module_forward_result,
|
|
mobile_module_forward_result
|
|
)
|
|
|
|
def test_unsupported_classtype(self):
|
|
class Foo():
|
|
def __init__(self):
|
|
return
|
|
|
|
def func(self, x: int, y: int):
|
|
return x + y
|
|
|
|
class MyTestModule(torch.nn.Module):
|
|
def forward(self, arg):
|
|
f = Foo()
|
|
return f.func(1, 2)
|
|
|
|
script_module = torch.jit.script(MyTestModule())
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
r"Workaround: instead of using arbitrary class type \(class Foo\(\)\), "
|
|
r"define a pytorch class \(class Foo\(torch\.nn\.Module\)\)\.$"):
|
|
script_module._save_to_buffer_for_lite_interpreter()
|
|
|
|
def test_unsupported_return_typing_namedtuple(self):
|
|
myNamedTuple = NamedTuple('myNamedTuple', [('a', torch.Tensor)])
|
|
|
|
class MyTestModule(torch.nn.Module):
|
|
def forward(self):
|
|
return myNamedTuple(torch.randn(1))
|
|
|
|
script_module = torch.jit.script(MyTestModule())
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
r"A named tuple type is not supported in mobile module. "
|
|
r"Workaround: instead of using a named tuple type\'s fields, "
|
|
r"use a dictionary type\'s key-value pair itmes or "
|
|
r"a pytorch class \(class Foo\(torch\.nn\.Module\)\)\'s attributes."):
|
|
script_module._save_to_buffer_for_lite_interpreter()
|
|
|
|
def test_unsupported_return_collections_namedtuple(self):
|
|
myNamedTuple = namedtuple('myNamedTuple', [('a')])
|
|
|
|
class MyTestModule(torch.nn.Module):
|
|
def forward(self):
|
|
return myNamedTuple(torch.randn(1))
|
|
|
|
script_module = torch.jit.script(MyTestModule())
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
r"A named tuple type is not supported in mobile module. "
|
|
r"Workaround: instead of using a named tuple type\'s fields, "
|
|
r"use a dictionary type\'s key-value pair itmes or "
|
|
r"a pytorch class \(class Foo\(torch\.nn\.Module\)\)\'s attributes."):
|
|
script_module._save_to_buffer_for_lite_interpreter()
|
|
|
|
def test_unsupported_return_list_with_module_class(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Foo, self).__init__()
|
|
|
|
class MyTestModuleForListWithModuleClass(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyTestModuleForListWithModuleClass, self).__init__()
|
|
self.foo = Foo()
|
|
|
|
def forward(self):
|
|
my_list: List[Foo] = [self.foo]
|
|
return my_list
|
|
|
|
script_module = torch.jit.script(MyTestModuleForListWithModuleClass())
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
r"^Returining a list or dictionary with pytorch class type "
|
|
r"is not supported in mobile module "
|
|
r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. "
|
|
r"Workaround\: instead of using pytorch class as their element type\, "
|
|
r"use a combination of list\, dictionary\, and single types\.$"):
|
|
script_module._save_to_buffer_for_lite_interpreter()
|
|
|
|
def test_unsupported_return_dict_with_module_class(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Foo, self).__init__()
|
|
|
|
class MyTestModuleForDictWithModuleClass(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyTestModuleForDictWithModuleClass, self).__init__()
|
|
self.foo = Foo()
|
|
|
|
def forward(self):
|
|
my_dict: Dict[int, Foo] = {1: self.foo}
|
|
return my_dict
|
|
|
|
script_module = torch.jit.script(MyTestModuleForDictWithModuleClass())
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
r"^Returining a list or dictionary with pytorch class type "
|
|
r"is not supported in mobile module "
|
|
r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. "
|
|
r"Workaround\: instead of using pytorch class as their element type\, "
|
|
r"use a combination of list\, dictionary\, and single types\.$"):
|
|
script_module._save_to_buffer_for_lite_interpreter()
|
|
|
|
def test_module_export_operator_list(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Foo, self).__init__()
|
|
self.weight = torch.ones((20, 1, 5, 5))
|
|
self.bias = torch.ones(20)
|
|
|
|
def forward(self, input):
|
|
x1 = torch.zeros(2, 2)
|
|
x2 = torch.empty_like(torch.empty(2, 2))
|
|
x3 = torch._convolution(
|
|
input,
|
|
self.weight,
|
|
self.bias,
|
|
[1, 1],
|
|
[0, 0],
|
|
[1, 1],
|
|
False,
|
|
[0, 0],
|
|
1,
|
|
False,
|
|
False,
|
|
True,
|
|
True,
|
|
)
|
|
return (x1, x2, x3)
|
|
|
|
m = torch.jit.script(Foo())
|
|
|
|
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
|
|
buffer.seek(0)
|
|
mobile_module = _load_for_lite_interpreter(buffer)
|
|
|
|
expected_ops = {
|
|
"aten::_convolution",
|
|
"aten::empty.memory_format",
|
|
"aten::empty_like",
|
|
"aten::zeros",
|
|
}
|
|
actual_ops = _export_operator_list(mobile_module)
|
|
self.assertEqual(actual_ops, expected_ops)
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|