mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Although the sun is setting for torchscript, it is not [officially deprecated](https://github.com/pytorch/pytorch/issues/103841#issuecomment-1605017153) since nothing currently fully replaces it. Thus, "downstream" libraries like TorchVision, that started offering torchscript support still need to support it for BC. torchscript has forced us to use workaround after workaround since forever. Although this makes the code harder to read and maintain, we made our peace with it. However, we are currently looking into more elaborate API designs that are severely hampered by our torchscript BC guarantees. Although likely not intended as such, while looking for ways to enable our design while keeping a subset of it scriptable, we found the undocumented `__prepare_scriptable__` escape hatch:0cf918947d/torch/jit/_script.py (L977)One can define this method and if you call `torch.jit.script` on the object, the returned object of the method will be scripted rather than the original object. In TorchVision we are using exactly [this mechanism to enable BC](3966f9558b/torchvision/transforms/v2/_transform.py (L122-L136)) while allowing the object in eager mode to be a lot more flexible (`*args, **kwargs`, dynamic dispatch, ...). Unfortunately, this escape hatch is only available for `nn.Module`'s0cf918947d/torch/jit/_script.py (L1279-L1283)This was fine for the example above since we were subclassing from `nn.Module` anyway. However, we recently also hit a case [where this wasn't the case](https://github.com/pytorch/vision/pull/7747#issuecomment-1642045479). Given the frozen state on JIT, would it be possible to give us a general escape hatch so that we can move forward with the design unconstrained while still keeping BC? This PR implements just this by re-using the `__prepare_scriptable__` hook. Pull Request resolved: https://github.com/pytorch/pytorch/pull/106229 Approved by: https://github.com/lezcano, https://github.com/ezyang
776 lines
22 KiB
Python
776 lines
22 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import os
|
|
import re
|
|
import sys
|
|
import types
|
|
import typing
|
|
import typing_extensions
|
|
from typing import List, Dict, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.jit.frontend
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
from torch.testing import FileCheck
|
|
from collections import OrderedDict
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.jit_utils import JitTestCase, _tmp_donotuse_dont_inline_everything
|
|
|
|
if __name__ == '__main__':
|
|
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead.")
|
|
|
|
class TestRecursiveScript(JitTestCase):
|
|
def test_inferred_nonetype(self):
|
|
class M(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.x = None
|
|
|
|
def forward(self):
|
|
assert self.x is None
|
|
|
|
m = torch.jit.script(M())
|
|
self.checkModule(M(), ())
|
|
|
|
def test_script_function_attribute(self):
|
|
@torch.jit.script
|
|
def fn1(x):
|
|
return x + x
|
|
|
|
@torch.jit.script
|
|
def fn2(x):
|
|
return x - x
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
|
|
def forward(self, x):
|
|
return self.fn(x)
|
|
|
|
fn1_mod = M(fn1)
|
|
fn2_mod = M(fn2)
|
|
|
|
self.checkModule(fn1_mod, (torch.randn(2, 2),))
|
|
self.checkModule(fn2_mod, (torch.randn(2, 2),))
|
|
|
|
def test_python_function_attribute(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
|
|
def forward(self, x):
|
|
return self.fn(x)
|
|
|
|
mod = M(torch.sigmoid)
|
|
|
|
self.checkModule(mod, (torch.randn(2, 2),))
|
|
|
|
def test_failed_function_compilation(self):
|
|
def fn(x):
|
|
return i_dont_exist
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
|
|
def forward(self, x):
|
|
return self.fn(x)
|
|
|
|
m = M(fn)
|
|
with self.assertRaisesRegexWithHighlight(RuntimeError, "failed to compile", "i_dont_exist"):
|
|
torch.jit.script(m)
|
|
|
|
def test_init_error(self):
|
|
class M(nn.Module):
|
|
def __init__(self):
|
|
self.x = 2
|
|
|
|
def forward(self):
|
|
pass
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "has not been initialized"):
|
|
torch.jit.script(M())
|
|
|
|
def test_script_after_eval(self):
|
|
class M(nn.Module):
|
|
def forward(self):
|
|
if self.training:
|
|
return 2
|
|
else:
|
|
return 0
|
|
|
|
m = M()
|
|
sm1 = torch.jit.script(m)
|
|
m.eval()
|
|
sm2 = torch.jit.script(m)
|
|
|
|
# m is in eval mode, training should be False
|
|
self.assertFalse(m.training)
|
|
|
|
# sm1 was created while m had training = True
|
|
self.assertTrue(sm1.training)
|
|
self.assertEqual(sm1.training, sm1._c.getattr('training'))
|
|
self.assertEqual(sm1(), 2)
|
|
|
|
# sm2 was created after m was eval'ed
|
|
self.assertFalse(sm2.training)
|
|
self.assertEqual(sm2.training, sm2._c.getattr('training'))
|
|
self.assertEqual(sm2(), 0)
|
|
|
|
def test_module_name(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.x = 2
|
|
|
|
def forward(self, t):
|
|
return t + self.x
|
|
|
|
m = torch.jit.script(MyModule())
|
|
FileCheck().check("MyModule").run(m.graph)
|
|
|
|
def test_repeated_error_stack(self):
|
|
def d(x):
|
|
return "a" - 2
|
|
|
|
def c(x):
|
|
return d(x)
|
|
|
|
def b(x):
|
|
return c(x)
|
|
|
|
def a(x):
|
|
return b(x)
|
|
|
|
try:
|
|
torch.jit.script(a)
|
|
except Exception as e:
|
|
FileCheck().check_count("is being compiled", 2).run(str(e))
|
|
|
|
try:
|
|
torch.jit.script(a)
|
|
except Exception as e:
|
|
# Make sure that no entries are left over from the previous failure
|
|
FileCheck().check_count("is being compiled", 2).run(str(e))
|
|
|
|
def test_constants_with_final(self):
|
|
class M1(torch.nn.Module):
|
|
x : torch.jit.Final[int]
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.x = 2
|
|
|
|
def forward(self, t):
|
|
return t + self.x
|
|
|
|
self.checkModule(M1(), (torch.randn(2, 2),))
|
|
|
|
class M2(torch.nn.Module):
|
|
x : typing_extensions.Final[int]
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.x = 2
|
|
|
|
def forward(self, t):
|
|
return t + self.x
|
|
|
|
self.checkModule(M2(), (torch.randn(2, 2),))
|
|
|
|
class M3(torch.nn.Module):
|
|
x : typing.Final[int]
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.x = 2
|
|
|
|
def forward(self, t):
|
|
return t + self.x
|
|
|
|
self.checkModule(M3(), (torch.randn(2, 2),))
|
|
|
|
def test_ignore_class(self):
|
|
@torch.jit.ignore
|
|
class MyScriptClass:
|
|
def unscriptable(self):
|
|
return "a" + 200
|
|
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return MyScriptClass()
|
|
|
|
with self.assertRaisesRegexWithHighlight(torch.jit.frontend.FrontendError, "Cannot instantiate class", "MyScriptClass"):
|
|
t = torch.jit.script(TestModule())
|
|
|
|
def test_method_call(self):
|
|
class M(nn.Module):
|
|
def test(self, x):
|
|
return x
|
|
|
|
def forward(self, z):
|
|
y = self.test(z)
|
|
return z + 20 + y
|
|
|
|
self.checkModule(M(), (torch.randn(2, 2),))
|
|
|
|
def test_module_repr(self):
|
|
class Submodule(nn.Module):
|
|
def forward(self, x):
|
|
return x
|
|
|
|
class MyModule(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(10, 10, 3)
|
|
self.lin = nn.Linear(10, 10)
|
|
self.sub = Submodule()
|
|
|
|
def forward(self, x):
|
|
return self.lin(x) + self.sub(x) + self.conv(x)
|
|
|
|
m = torch.jit.script(MyModule())
|
|
|
|
with self.capture_stdout() as out:
|
|
print(m)
|
|
|
|
f = FileCheck()
|
|
f.check('MyModule')
|
|
f.check('Conv2d')
|
|
f.check('Linear')
|
|
f.check('Submodule')
|
|
f.run(out[0])
|
|
|
|
self.assertEqual(m.original_name, 'MyModule')
|
|
|
|
def test_dir(self):
|
|
def test_module_dir(mod):
|
|
dir_set = dir(mod)
|
|
scripted_mod = torch.jit.script(mod)
|
|
dir_scripted = set(dir(scripted_mod))
|
|
# set not currently copied over
|
|
ignore_set = ["training", "__delitem__", "__setitem__", "clear", "items",
|
|
"keys", "pop", "update", "values"]
|
|
for attr in dir_set:
|
|
if attr in ignore_set:
|
|
continue
|
|
self.assertTrue(attr in dir_scripted, attr)
|
|
|
|
class MyModule(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(10, 10, 3)
|
|
self.lin = nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
return self.lin(x) + self.conv(x)
|
|
|
|
test_module_dir(MyModule())
|
|
|
|
# test custom __dir__ for containers
|
|
conv = nn.Conv2d(10, 10, 3)
|
|
linear = nn.Linear(10, 10)
|
|
|
|
test_module_dir(nn.Sequential(conv, linear))
|
|
test_module_dir(nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)])))
|
|
|
|
def test_class_compile(self):
|
|
def other_fn(a: int, b: Tensor) -> Tensor:
|
|
return a * b
|
|
|
|
class B:
|
|
def __init__(self, x):
|
|
self.x = 2
|
|
|
|
def helper(self, a):
|
|
return self.x + a + other_fn(self.x, a)
|
|
|
|
|
|
class N(torch.nn.Module):
|
|
def forward(self, x):
|
|
b = B(x)
|
|
return b.helper(x)
|
|
|
|
self.checkModule(N(), (torch.randn(2, 2),))
|
|
|
|
def test_error_stack(self):
|
|
def d(x: int) -> int:
|
|
return x + 10
|
|
|
|
def c(x):
|
|
return d("hello") + d(x)
|
|
|
|
def b(x):
|
|
return c(x)
|
|
|
|
def a(x):
|
|
return b(x)
|
|
|
|
try:
|
|
scripted = torch.jit.script(a)
|
|
except RuntimeError as e:
|
|
checker = FileCheck()
|
|
checker.check("Expected a value of type 'int'")
|
|
checker.check("def c(x)")
|
|
checker.check("def b(x)")
|
|
checker.check("def a(x)")
|
|
checker.run(str(e))
|
|
|
|
def test_error_stack_module(self):
|
|
def d(x: int) -> int:
|
|
return x + 10
|
|
|
|
def c(x):
|
|
return d("hello") + d(x)
|
|
|
|
def b(x):
|
|
return c(x)
|
|
|
|
class Submodule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return b(x)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.submodule = Submodule()
|
|
|
|
def some_method(self, y):
|
|
return y + self.submodule(y)
|
|
|
|
def forward(self, x):
|
|
return self.some_method(x)
|
|
|
|
try:
|
|
scripted = torch.jit.script(M())
|
|
except RuntimeError as e:
|
|
checker = FileCheck()
|
|
checker.check("Expected a value of type 'int'")
|
|
checker.check("'c' is being compiled since it was called from 'b'")
|
|
checker.check("'b' is being compiled since it was called from")
|
|
checker.run(str(e))
|
|
|
|
@_tmp_donotuse_dont_inline_everything
|
|
def test_script_basic(self):
|
|
def a_python_fn(a, b, c):
|
|
return a + b + c
|
|
|
|
@torch.jit.script
|
|
def a_script_fn(d, e, f):
|
|
return a_python_fn(d, e, f)
|
|
|
|
graph = str(a_script_fn.graph)
|
|
FileCheck().check("prim::CallFunction").run(graph)
|
|
FileCheck().check_not("^a_python_fn").run(graph)
|
|
t = torch.ones(2, 2)
|
|
self.assertEqual(a_script_fn(t, t, t), t + t + t)
|
|
|
|
def test_error_stack_class(self):
|
|
class X:
|
|
def bad_fn(self):
|
|
import pdb # noqa: F401
|
|
|
|
def fn(x) -> X:
|
|
return X(10)
|
|
|
|
try:
|
|
torch.jit.script(fn)
|
|
except Exception as e:
|
|
checker = FileCheck()
|
|
checker.check("import statements")
|
|
checker.check("is being compiled since it was called from")
|
|
checker.run(str(e))
|
|
|
|
def test_error_stack_annotation(self):
|
|
class X:
|
|
def bad_fn(self):
|
|
import pdb # noqa: F401
|
|
|
|
def fn(x) -> X:
|
|
return X(10)
|
|
|
|
try:
|
|
torch.jit.script(fn)
|
|
except Exception as e:
|
|
checker = FileCheck()
|
|
checker.check("import statements")
|
|
checker.check("is being compiled since it was called from")
|
|
checker.check("-> X")
|
|
checker.run(str(e))
|
|
|
|
def test_module_basic(self):
|
|
class Other(torch.nn.Module):
|
|
__constants__ = ['x']
|
|
|
|
def __init__(self, x):
|
|
super().__init__()
|
|
self.x = x
|
|
self.param = torch.nn.Parameter(torch.ones(2, 2))
|
|
|
|
def some_unscriptable_method(self):
|
|
a = 2
|
|
a = [2]
|
|
return a
|
|
|
|
def forward(self, t):
|
|
return t + self.x + self.param
|
|
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.other = Other(200)
|
|
|
|
def forward(self, t):
|
|
return self.other(t) * 2
|
|
|
|
self.checkModule(M(), (torch.ones(2, 2),))
|
|
|
|
def test_module_function_export(self):
|
|
class Other(torch.nn.Module):
|
|
__constants__ = ['x']
|
|
|
|
def __init__(self, x):
|
|
super().__init__()
|
|
self.x = x
|
|
self.param = torch.nn.Parameter(torch.ones(2, 2))
|
|
|
|
@torch.jit.export
|
|
def some_entry_point(self, y):
|
|
return y + 20
|
|
|
|
def forward(self, t):
|
|
return t + self.x + self.param
|
|
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.other = Other(200)
|
|
|
|
def forward(self, t):
|
|
return self.other(t) * 2
|
|
|
|
self.checkModule(M(), (torch.ones(2, 2),))
|
|
|
|
def test_iterable_modules(self):
|
|
class Inner(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 10
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.sequential = nn.Sequential(
|
|
Inner(),
|
|
Inner(),
|
|
nn.Sequential(Inner(), Inner())
|
|
)
|
|
self.module_list = nn.ModuleList([Inner(), Inner()])
|
|
|
|
def forward(self, x):
|
|
for mod in self.module_list:
|
|
x += mod(x)
|
|
x += self.sequential(x)
|
|
return x
|
|
|
|
self.checkModule(M(), (torch.randn(5, 5),))
|
|
|
|
def test_prepare_scriptable_basic(self):
|
|
class SeluButReluWhenScripted(torch.nn.SELU):
|
|
def __prepare_scriptable__(self):
|
|
return nn.ReLU()
|
|
|
|
t = torch.randn(5, 5)
|
|
m = SeluButReluWhenScripted()
|
|
sm = torch.jit.script(m)
|
|
eager_out = m(t)
|
|
script_out = sm(t)
|
|
self.assertNotEqual(eager_out, script_out)
|
|
|
|
def test_prepare_scriptable_iterable_modules(self):
|
|
class SeluButReluWhenScripted(torch.nn.SELU):
|
|
def __prepare_scriptable__(self):
|
|
return nn.ReLU()
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
shared = SeluButReluWhenScripted()
|
|
self.sequential = nn.Sequential(
|
|
SeluButReluWhenScripted(),
|
|
SeluButReluWhenScripted(),
|
|
nn.Sequential(SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()),
|
|
shared,
|
|
)
|
|
self.module_list = nn.ModuleList([SeluButReluWhenScripted(),
|
|
shared,
|
|
SeluButReluWhenScripted()])
|
|
|
|
def forward(self, x):
|
|
for mod in self.module_list:
|
|
x += mod(x)
|
|
x += self.sequential(x)
|
|
return x
|
|
|
|
t = torch.randn(5, 5)
|
|
m = M()
|
|
eager_out = m(t.clone())
|
|
sm = torch.jit.script(m)
|
|
script_out = sm(t.clone())
|
|
self.assertNotEqual(eager_out, script_out)
|
|
|
|
def test_prepare_scriptable_cycle(self):
|
|
t = torch.randn(5, 5)
|
|
c = torch.nn.Module()
|
|
p = torch.nn.Module()
|
|
c.__dict__["_p"] = p
|
|
p.__dict__["_c"] = c
|
|
|
|
sm = torch.jit.script(p)
|
|
|
|
def test_prepare_scriptable_escape_hatch(self):
|
|
class NonJitableClass:
|
|
def __call__(self, int1, int2, *args):
|
|
total = int1 + int2
|
|
for arg in args:
|
|
total += arg
|
|
return total
|
|
|
|
obj = NonJitableClass()
|
|
|
|
self.assertEqual(obj(1, 2), 3)
|
|
self.assertEqual(obj(1, 2, 3, 4), 10)
|
|
with self.assertRaisesRegex(
|
|
torch.jit.frontend.NotSupportedError, expected_regex="can't take variable number of arguments"
|
|
):
|
|
torch.jit.script(obj)
|
|
|
|
def escape_hatch(int1: int, int2: int) -> int:
|
|
return int1 + int2
|
|
|
|
class NonJitableClassWithEscapeHatch(NonJitableClass):
|
|
def __prepare_scriptable__(self):
|
|
return escape_hatch
|
|
|
|
jit_obj = torch.jit.script(NonJitableClassWithEscapeHatch())
|
|
|
|
self.assertEqual(jit_obj(1, 2), 3)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, expected_regex=re.escape("expected at most 2 argument(s) but received 4 argument(s)")
|
|
):
|
|
jit_obj(1, 2, 3, 4)
|
|
|
|
def test_attributes(self):
|
|
@torch.jit.script
|
|
class Inner2:
|
|
def __init__(self):
|
|
self.b = "a string"
|
|
|
|
@torch.jit.script
|
|
class Foo:
|
|
def __init__(self):
|
|
self.a = 4
|
|
self.inner = Inner2()
|
|
|
|
@torch.jit.script
|
|
class SFoo:
|
|
def __init__(self):
|
|
self.a = 4
|
|
self.inner = Inner2()
|
|
|
|
def __setstate__(self, obj: Tuple[int, Inner2]) -> None:
|
|
a, inner = obj
|
|
self.a = a
|
|
self.inner = inner
|
|
|
|
def __getstate__(self):
|
|
return (self.a, self.inner)
|
|
|
|
|
|
untyped_values = (
|
|
('my_dict', {"I": "am", "a test": "test"}),
|
|
('my_float', 2.3),
|
|
('my_int', 99),
|
|
('my_bool', False),
|
|
('my_tuple', (1, 2, 3, 4)),
|
|
('my_list', [(1, 2), (3, 4)]),
|
|
# ('my_tensor', torch.randn(2, 2)),
|
|
('my_int_list', [1, 2, 3, 4]),
|
|
# ('my_tensor_list', [torch.ones(2, 2) + i for i in range(4)]),
|
|
('my_bool_list', [True, True, False, True]),
|
|
('my_float_list', [1., 2., 3., 4.]),
|
|
('my_str_list', ['hello', 'bye']),
|
|
)
|
|
typed_values = (
|
|
('my_empty_list', []),
|
|
('my_empty_dict', {}),
|
|
('my_none', None),
|
|
('my_object', Foo()),
|
|
('my_object2', SFoo()),
|
|
)
|
|
|
|
class M(torch.nn.Module):
|
|
# TODO: re-enable this once this test is in a Python 3-only syntax
|
|
# file
|
|
# my_empty_list : List[int]
|
|
# my_empty_dict : Dict[str, int]
|
|
# my_none : Optional[int]
|
|
|
|
def forward(self, x):
|
|
return (
|
|
self.my_dict,
|
|
self.my_float,
|
|
self.my_int,
|
|
self.my_bool,
|
|
# self.my_tensor,
|
|
self.my_int_list,
|
|
# self.my_tensor_list,
|
|
self.my_bool_list,
|
|
self.my_float_list,
|
|
self.my_str_list,
|
|
self.my_empty_list,
|
|
self.my_empty_dict,
|
|
self.my_none,
|
|
self.my_object.a,
|
|
self.my_object.inner.b,
|
|
self.my_object.a,
|
|
self.my_object2.inner.b,
|
|
)
|
|
|
|
# TODO: as a followup, fix this test
|
|
# We can't define class attributes like we should be doing:
|
|
# class M(torch.nn.Module):
|
|
# my_empty_list : List[int]
|
|
# my_empty_dict : Dict[str, int]
|
|
# my_none : Optional[int]
|
|
# my_out_of_line_attribute: List[int] = [1, 2, 3]
|
|
# since there's no string frontend for Python classes (so the `define`)
|
|
# trick doesn't work.
|
|
M.__annotations__ = {
|
|
'my_empty_list': List[int],
|
|
'my_empty_dict': Dict[str, int],
|
|
'my_none': Optional[int],
|
|
'my_object': Foo,
|
|
'my_object2': SFoo,
|
|
}
|
|
|
|
m = M()
|
|
for name, value in untyped_values + typed_values:
|
|
setattr(m, name, value)
|
|
|
|
self.checkModule(m, (torch.randn(5, 5),))
|
|
|
|
def test_function_attribute_in_submodule(self):
|
|
class N(nn.Module):
|
|
def __init__(self, norm):
|
|
super().__init__()
|
|
self.activation = torch.nn.functional.relu
|
|
self.norm = norm
|
|
|
|
def forward(self, src):
|
|
output = src
|
|
output = self.norm(output)
|
|
return output
|
|
|
|
class M(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
encoder_norm = nn.ReLU()
|
|
self.encoder = N(encoder_norm)
|
|
|
|
def forward(self, x):
|
|
return self.encoder(x)
|
|
|
|
m = M()
|
|
self.checkModule(m, (torch.randn(5, 5), ))
|
|
|
|
def test_inner_traced_module(self):
|
|
class Dummy(nn.Module):
|
|
def forward(self, x):
|
|
return x
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, dummies):
|
|
super().__init__()
|
|
self._dummies = dummies
|
|
|
|
def forward(self, x):
|
|
out = []
|
|
for dummy in self._dummies:
|
|
out.append(dummy(x))
|
|
return out
|
|
|
|
dummy = torch.jit.trace(Dummy(), torch.randn(1, 2))
|
|
dummies = nn.ModuleList([dummy])
|
|
model = Model(dummies)
|
|
self.checkModule(model, (torch.rand(5, 5), ))
|
|
|
|
def test_script_loaded_module(self):
|
|
"""
|
|
Test that we can hold a loaded ScriptModule as a submodule.
|
|
"""
|
|
class Dummy(nn.Module):
|
|
def forward(self, x):
|
|
return x
|
|
|
|
dummy = torch.jit.script(Dummy())
|
|
dummy = self.getExportImportCopy(dummy)
|
|
|
|
class ContainsLoaded(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.encoder = dummy
|
|
|
|
def forward(self, input):
|
|
return self.encoder(input)
|
|
|
|
self.checkModule(ContainsLoaded(), (torch.rand(2, 3), ))
|
|
|
|
def test_optional_module(self):
|
|
class Dummy(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.foo = nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
if self.foo is not None:
|
|
return self.foo(x)
|
|
return x
|
|
|
|
mod = Dummy()
|
|
self.checkModule(mod, (torch.rand(2, 2),))
|
|
mod.foo = None
|
|
self.checkModule(mod, (torch.rand(2, 2),))
|
|
|
|
def test_override_instance_method_ignore(self):
|
|
class M(torch.nn.Module):
|
|
@torch.jit.ignore
|
|
def i_am_ignored(self):
|
|
return "old"
|
|
|
|
m = M()
|
|
|
|
# Override the ignored method by binding a new method to this instance.
|
|
@torch.jit.ignore
|
|
def i_am_ignored(self):
|
|
return "new"
|
|
|
|
m.i_am_ignored = types.MethodType(i_am_ignored, m)
|
|
self.assertEqual(m.i_am_ignored(), "new")
|
|
|
|
# ScriptModule should correctly reflect the override.
|
|
s = torch.jit.script(m)
|
|
self.assertEqual(s.i_am_ignored(), "new")
|