mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136964 Approved by: https://github.com/justinchuby, https://github.com/albanD
106 lines
3.3 KiB
Python
106 lines
3.3 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import os
|
|
import sys
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.jit.frontend import _IS_ASTUNPARSE_INSTALLED
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead."
|
|
)
|
|
|
|
|
|
class TestIgnoreContextManager(JitTestCase):
|
|
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
|
|
def test_with_ignore_context_manager_with_inp_out(self):
|
|
class A(torch.nn.Module):
|
|
def forward(self):
|
|
a: int = 4
|
|
b: int = 5
|
|
c: int = 0
|
|
d: int = 6
|
|
with torch.jit._IgnoreContextManager(
|
|
a="inp:int", b="inp:int", c="out:int", d="out:int"
|
|
):
|
|
l = [2 for i in range(a) if i > 2]
|
|
c = l[0] + a + b
|
|
d = 9
|
|
return c + d
|
|
|
|
model = A()
|
|
s = torch.jit.script(model)
|
|
self.assertEqual(s(), model())
|
|
self.assertEqual(s(), 20)
|
|
|
|
class B(torch.nn.Module):
|
|
def forward(self):
|
|
a: int = 4
|
|
b: int = 5
|
|
c: int = 0
|
|
with torch.jit._IgnoreContextManager(
|
|
a="inp:int", b="inp:int", c="out:int"
|
|
):
|
|
l = [2 for i in range(a) if i > 2]
|
|
c = l[0] + a + b
|
|
return c
|
|
|
|
model = B()
|
|
s = torch.jit.script(model)
|
|
self.assertEqual(s(), 11)
|
|
self.assertEqual(s(), model())
|
|
|
|
class C(torch.nn.Module):
|
|
def forward(self):
|
|
a: int = 4
|
|
b: int = 5
|
|
with torch.jit._IgnoreContextManager(a="inp:int", b="out:int"):
|
|
l = [2 for i in range(a) if i > 2]
|
|
b = l[0] + a
|
|
return b
|
|
|
|
model = C()
|
|
s = torch.jit.script(model)
|
|
self.assertEqual(s(), 6)
|
|
self.assertEqual(s(), model())
|
|
|
|
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
|
|
def test_with_ignore_context_manager_with_just_inp(self):
|
|
class A(torch.nn.Module):
|
|
def forward(self):
|
|
a: int = 4
|
|
b: int = 5
|
|
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int"):
|
|
l = [2 + b for i in range(a) if i > 2] # noqa: F841
|
|
return a
|
|
|
|
model = A()
|
|
s = torch.jit.script(model)
|
|
self.assertEqual(s(), 4)
|
|
self.assertEqual(s(), model())
|
|
|
|
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
|
|
def test_with_ignore_context_manager_with_just_out(self):
|
|
class A(torch.nn.Module):
|
|
def forward(self):
|
|
with torch.jit._IgnoreContextManager(c="out:List[int]"):
|
|
c = [2 for i in range(7) if i > 2]
|
|
c[0] = 3
|
|
return c[0] + c[1]
|
|
|
|
model = A()
|
|
s = torch.jit.script(model)
|
|
self.assertEqual(s(), 5)
|
|
self.assertEqual(s(), model())
|