pytorch/test/jit/test_recursive_script.py
Michael Suo 58005382c8 fix @property (#28395)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28395

Currently property methods are broken in TorchScript because we
basically treat it as an attribute in the existing path: we'll evaluate
the method once and store that as the value forever.

Since lack of property support is easily worked around (just make it
a method), I've opted to just explicitly error to avoid confusion. If
people want it, they can file an issue and we can look at their use
case.

This also helps us nicely clean up some parts of the ScriptModule conversion
path.

Test Plan: Imported from OSS

Reviewed By: shannonzhu

Differential Revision: D18054946

Pulled By: suo

fbshipit-source-id: 7e927836ae687cd2f13a94b9f0af399437fae422
2019-11-06 23:51:07 -08:00

580 lines
16 KiB
Python

import unittest
import os
import sys
from typing import List, Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from jit_utils import JitTestCase, _tmp_donotuse_dont_inline_everything
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestRecursiveScript(JitTestCase):
def test_inferred_nonetype(self):
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.x = None
def forward(self):
assert self.x is None
m = torch.jit.script(M())
self.checkModule(M(), ())
def test_script_function_attribute(self):
@torch.jit.script
def fn1(x):
return x + x
@torch.jit.script
def fn2(x):
return x - x
class M(torch.nn.Module):
def __init__(self, fn):
super(M, self).__init__()
self.fn = fn
def forward(self, x):
return self.fn(x)
fn1_mod = M(fn1)
fn2_mod = M(fn2)
self.checkModule(fn1_mod, (torch.randn(2, 2),))
self.checkModule(fn2_mod, (torch.randn(2, 2),))
def test_python_function_attribute(self):
class M(torch.nn.Module):
def __init__(self, fn):
super(M, self).__init__()
self.fn = fn
def forward(self, x):
return self.fn(x)
mod = M(F.sigmoid)
self.checkModule(mod, (torch.randn(2, 2),))
def test_failed_function_compilation(self):
def fn(x):
return i_dont_exist
class M(torch.nn.Module):
def __init__(self, fn):
super(M, self).__init__()
self.fn = fn
def forward(self, x):
return self.fn(x)
m = M(fn)
with self.assertRaisesRegex(RuntimeError, "failed to compile"):
torch.jit.script(m)
def test_init_error(self):
class M(nn.Module):
def __init__(self):
self.x = 2
def forward(self):
pass
with self.assertRaisesRegex(RuntimeError, "has not been initialized"):
torch.jit.script(M())
def test_script_after_eval(self):
class M(nn.Module):
def forward(self):
if self.training:
return 2
else:
return 0
m = M()
sm1 = torch.jit.script(m)
m.eval()
sm2 = torch.jit.script(m)
# m is in eval mode, training should be False
self.assertFalse(m.training)
# sm1 was created while m had training = True
self.assertTrue(sm1.training)
self.assertEqual(sm1.training, sm1._c._get_attribute('training'))
self.assertEqual(sm1(), 2)
# sm2 was created after m was eval'ed
self.assertFalse(sm2.training)
self.assertEqual(sm2.training, sm2._c._get_attribute('training'))
self.assertEqual(sm2(), 0)
def test_module_name(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.x = 2
def forward(self, t):
return t + self.x
m = torch.jit.script(MyModule())
FileCheck().check("ClassType<MyModule>").run(m.graph)
def test_repeated_error_stack(self):
def d(x):
return "a" - 2
def c(x):
return d(x)
def b(x):
return c(x)
def a(x):
return b(x)
try:
torch.jit.script(a)
except Exception as e:
FileCheck().check_count("is being compiled", 2).run(str(e))
try:
torch.jit.script(a)
except Exception as e:
# Make sure that no entries are left over from the previous failure
FileCheck().check_count("is being compiled", 2).run(str(e))
@unittest.skipIf(True, "Class annotations are a thing in > 3.5, need to fix for < 3.7")
def test_constants_with_final(self):
class M(torch.nn.Module):
# TODO: Use this (see below)
# x : torch.jit.Final[int]
def __init__(self):
super(M, self).__init__()
self.x = 2
def forward(self, t):
return t + self.x
# TODO: Fix this test so that we can actually define the class like
# class M(torch.nn.Module):
# x : torch.jit.Final[int]
M.__annotations__ = {'x': torch.jit.Final[int]}
m = M()
self.checkModule(M(), (torch.randn(2, 2),))
def test_ignore_class(self):
@torch.jit.ignore
class MyScriptClass(object):
def unscriptable(self):
return "a" + 200
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
def forward(self, x):
return MyScriptClass()
with self.assertRaisesRegex(torch.jit.frontend.FrontendError, "Cannot instantiate class"):
t = torch.jit.script(TestModule())
def test_method_call(self):
class M(nn.Module):
def test(self, x):
return x
def forward(self, z):
y = self.test(z)
return z + 20 + y
self.checkModule(M(), (torch.randn(2, 2),))
def test_module_repr(self):
class Submodule(nn.Module):
def forward(self, x):
return x
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.conv = nn.Conv2d(10, 10, 3)
self.lin = nn.Linear(10, 10)
self.sub = Submodule()
def forward(self, x):
return self.lin(x) + self.sub(x) + self.conv(x)
m = torch.jit.script(MyModule())
with self.capture_stdout() as out:
print(m)
f = FileCheck()
f.check('MyModule')
f.check('Conv2d')
f.check('Linear')
f.check('Submodule')
f.run(out[0])
self.assertEqual(m.original_name, 'MyModule')
def test_class_compile(self):
def other_fn(a, b):
# type: (int, Tensor) -> Tensor
return a * b
class B(object):
def __init__(self, x):
self.x = 2
def helper(self, a):
return self.x + a + other_fn(self.x, a)
class N(torch.nn.Module):
def __init__(self):
super(N, self).__init__()
def forward(self, x):
b = B(x)
return b.helper(x)
self.checkModule(N(), (torch.randn(2, 2),))
def test_error_stack(self):
def d(x):
# type: (int) -> int
return x + 10
def c(x):
return d("hello") + d(x)
def b(x):
return c(x)
def a(x):
return b(x)
try:
scripted = torch.jit.script(a)
except RuntimeError as e:
checker = FileCheck()
checker.check("Expected a value of type 'int'")
checker.check("def c(x)")
checker.check("def b(x)")
checker.check("def a(x)")
checker.run(str(e))
def test_error_stack_module(self):
def d(x):
# type: (int) -> int
return x + 10
def c(x):
return d("hello") + d(x)
def b(x):
return c(x)
class Submodule(torch.nn.Module):
def __init__(self):
super(Submodule, self).__init__()
def forward(self, x):
return b(x)
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.submodule = Submodule()
def some_method(self, y):
return y + self.submodule(y)
def forward(self, x):
return self.some_method(x)
try:
scripted = torch.jit.script(M())
except RuntimeError as e:
checker = FileCheck()
checker.check("Expected a value of type 'int'")
checker.check("'c' is being compiled since it was called from 'b'")
checker.check("'b' is being compiled since it was called from")
checker.run(str(e))
@_tmp_donotuse_dont_inline_everything
def test_script_basic(self):
def a_python_fn(a, b, c):
return a + b + c
@torch.jit.script
def a_script_fn(d, e, f):
return a_python_fn(d, e, f)
graph = str(a_script_fn.graph)
FileCheck().check("prim::CallFunction").run(graph)
FileCheck().check_not("^a_python_fn").run(graph)
t = torch.ones(2, 2)
self.assertEqual(a_script_fn(t, t, t), t + t + t)
def test_error_stack_class(self):
class X(object):
def bad_fn(self):
import pdb # noqa
def fn(x):
return X(10)
try:
torch.jit.script(fn)
except Exception as e:
checker = FileCheck()
checker.check("import statements")
checker.check("is being compiled since it was called from")
checker.run(str(e))
def test_module_basic(self):
class Other(torch.nn.Module):
__constants__ = ['x']
def __init__(self, x):
super(Other, self).__init__()
self.x = x
self.param = torch.nn.Parameter(torch.ones(2, 2))
def some_unscriptable_method(self):
a = 2
a = [2]
return a
def forward(self, t):
return t + self.x + self.param
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.other = Other(200)
def forward(self, t):
return self.other(t) * 2
self.checkModule(M(), (torch.ones(2, 2),))
def test_module_function_export(self):
class Other(torch.nn.Module):
__constants__ = ['x']
def __init__(self, x):
super(Other, self).__init__()
self.x = x
self.param = torch.nn.Parameter(torch.ones(2, 2))
@torch.jit.export
def some_entry_point(self, y):
return y + 20
def forward(self, t):
return t + self.x + self.param
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.other = Other(200)
def forward(self, t):
return self.other(t) * 2
self.checkModule(M(), (torch.ones(2, 2),))
def test_iterable_modules(self):
class Inner(torch.nn.Module):
def forward(self, x):
return x + 10
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.sequential = nn.Sequential(
Inner(),
Inner(),
nn.Sequential(Inner(), Inner())
)
self.module_list = nn.ModuleList([Inner(), Inner()])
def forward(self, x):
for mod in self.module_list:
x += mod(x)
x += self.sequential(x)
return x
self.checkModule(M(), (torch.randn(5, 5),))
def test_attributes(self):
@torch.jit.script
class Inner(object):
def __init__(self):
self.b = "a string"
@torch.jit.script
class Foo(object):
def __init__(self):
self.a = 4
self.inner = Inner()
@torch.jit.script
class SFoo(object):
def __init__(self):
self.a = 4
self.inner = Inner()
def __setstate__(self, obj):
# type: (Tuple[int, Inner]) -> None
a, inner = obj
self.a = a
self.inner = inner
def __getstate__(self):
return (self.a, self.inner)
untyped_values = (
('my_dict', {"I": "am", "a test": "test"}),
('my_float', 2.3),
('my_int', 99),
('my_bool', False),
('my_tuple', (1, 2, 3, 4)),
('my_list', [(1, 2), (3, 4)]),
# ('my_tensor', torch.randn(2, 2)),
('my_int_list', [1, 2, 3, 4]),
# ('my_tensor_list', [torch.ones(2, 2) + i for i in range(4)]),
('my_bool_list', [True, True, False, True]),
('my_float_list', [1., 2., 3., 4.]),
('my_str_list', ['hello', 'bye']),
)
typed_values = (
('my_empty_list', []),
('my_empty_dict', {}),
('my_none', None),
('my_object', Foo()),
('my_object2', SFoo()),
)
class M(torch.nn.Module):
# TODO: re-enable this once this test is in a Python 3-only syntax
# file
# my_empty_list : List[int]
# my_empty_dict : Dict[str, int]
# my_none : Optional[int]
def __init__(self):
super(M, self).__init__()
def forward(self, x):
return (
self.my_dict,
self.my_float,
self.my_int,
self.my_bool,
# self.my_tensor,
self.my_int_list,
# self.my_tensor_list,
self.my_bool_list,
self.my_float_list,
self.my_str_list,
self.my_empty_list,
self.my_empty_dict,
self.my_none,
self.my_object.a,
self.my_object.inner.b,
self.my_object.a,
self.my_object2.inner.b,
)
# TODO: as a followup, fix this test
# We can't define class attributes like we should be doing:
# class M(torch.nn.Module):
# my_empty_list : List[int]
# my_empty_dict : Dict[str, int]
# my_none : Optional[int]
# my_out_of_line_attribute: List[int] = [1, 2, 3]
# since there's no string frontend for Python classes (so the `define`)
# trick doesn't work.
M.__annotations__ = {
'my_empty_list': List[int],
'my_empty_dict': Dict[str, int],
'my_none': Optional[int],
'my_object': Foo,
'my_object2': SFoo,
}
m = M()
for name, value in untyped_values + typed_values:
setattr(m, name, value)
self.checkModule(m, (torch.randn(5, 5),))
def test_function_attribute_in_submodule(self):
class N(nn.Module):
def __init__(self, norm):
super(N, self).__init__()
self.activation = torch.nn.functional.relu
self.norm = norm
def forward(self, src):
output = src
output = self.norm(output)
return output
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
encoder_norm = nn.ReLU()
self.encoder = N(encoder_norm)
def forward(self, x):
return self.encoder(x)
m = M()
self.checkModule(m, (torch.randn(5, 5), ))
def test_property(self):
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.x = 0
@property
def x_and_1(self):
return self.x + 1
def forward(self, new_x):
# type: (int) -> int
self.x = new_x
return self.x_and_1
with self.assertRaisesRegex(RuntimeError, "property"):
torch.jit.script(M())