pytorch/test/jit/test_ignore_context_manager.py

106 lines
3.3 KiB
Python

# Owner(s): ["oncall: jit"]
import os
import sys
import unittest
import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.jit.frontend import _IS_ASTUNPARSE_INSTALLED
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestIgnoreContextManager(JitTestCase):
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
def test_with_ignore_context_manager_with_inp_out(self):
class A(torch.nn.Module):
def forward(self):
a: int = 4
b: int = 5
c: int = 0
d: int = 6
with torch.jit._IgnoreContextManager(
a="inp:int", b="inp:int", c="out:int", d="out:int"
):
l = [2 for i in range(a) if i > 2]
c = l[0] + a + b
d = 9
return c + d
model = A()
s = torch.jit.script(model)
self.assertEqual(s(), model())
self.assertEqual(s(), 20)
class B(torch.nn.Module):
def forward(self):
a: int = 4
b: int = 5
c: int = 0
with torch.jit._IgnoreContextManager(
a="inp:int", b="inp:int", c="out:int"
):
l = [2 for i in range(a) if i > 2]
c = l[0] + a + b
return c
model = B()
s = torch.jit.script(model)
self.assertEqual(s(), 11)
self.assertEqual(s(), model())
class C(torch.nn.Module):
def forward(self):
a: int = 4
b: int = 5
with torch.jit._IgnoreContextManager(a="inp:int", b="out:int"):
l = [2 for i in range(a) if i > 2]
b = l[0] + a
return b
model = C()
s = torch.jit.script(model)
self.assertEqual(s(), 6)
self.assertEqual(s(), model())
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
def test_with_ignore_context_manager_with_just_inp(self):
class A(torch.nn.Module):
def forward(self):
a: int = 4
b: int = 5
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int"):
l = [2 + b for i in range(a) if i > 2] # noqa: F841
return a
model = A()
s = torch.jit.script(model)
self.assertEqual(s(), 4)
self.assertEqual(s(), model())
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
def test_with_ignore_context_manager_with_just_out(self):
class A(torch.nn.Module):
def forward(self):
with torch.jit._IgnoreContextManager(c="out:List[int]"):
c = [2 for i in range(7) if i > 2]
c[0] = 3
return c[0] + c[1]
model = A()
s = torch.jit.script(model)
self.assertEqual(s(), 5)
self.assertEqual(s(), model())