mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied. - #94587 - #94588 - #94592 Also, methods with only a `super()` call are removed: ```diff class MyModule(nn.Module): - def __init__(self): - super().__init__() - def forward(self, ...): ... ``` Some cases that change the semantics should be kept unchanged. E.g.:f152a79be9/caffe2/python/net_printer.py (L184-L190)f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)Pull Request resolved: https://github.com/pytorch/pytorch/pull/94592 Approved by: https://github.com/ezyang, https://github.com/seemethere
92 lines
3.2 KiB
Python
92 lines
3.2 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import os
|
|
import sys
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
from torch.jit.frontend import _IS_ASTUNPARSE_INSTALLED
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead.")
|
|
|
|
class TestIgnoreContextManager(JitTestCase):
|
|
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
|
|
def test_with_ignore_context_manager_with_inp_out(self):
|
|
class A(torch.nn.Module):
|
|
def forward(self):
|
|
a: int = 4
|
|
b: int = 5
|
|
c: int = 0
|
|
d: int = 6
|
|
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int", c="out:int", d="out:int"):
|
|
l = [2 for i in range(a) if i > 2]
|
|
c = l[0] + a + b
|
|
d = 9
|
|
return c + d
|
|
model = A()
|
|
s = torch.jit.script(model)
|
|
self.assertEqual(s(), model())
|
|
self.assertEqual(s(), 20)
|
|
|
|
class B(torch.nn.Module):
|
|
def forward(self):
|
|
a: int = 4
|
|
b: int = 5
|
|
c: int = 0
|
|
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int", c="out:int"):
|
|
l = [2 for i in range(a) if i > 2]
|
|
c = l[0] + a + b
|
|
return c
|
|
model = B()
|
|
s = torch.jit.script(model)
|
|
self.assertEqual(s(), 11)
|
|
self.assertEqual(s(), model())
|
|
|
|
class C(torch.nn.Module):
|
|
def forward(self):
|
|
a: int = 4
|
|
b: int = 5
|
|
with torch.jit._IgnoreContextManager(a="inp:int", b="out:int"):
|
|
l = [2 for i in range(a) if i > 2]
|
|
b = l[0] + a
|
|
return b
|
|
model = C()
|
|
s = torch.jit.script(model)
|
|
self.assertEqual(s(), 6)
|
|
self.assertEqual(s(), model())
|
|
|
|
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
|
|
def test_with_ignore_context_manager_with_just_inp(self):
|
|
class A(torch.nn.Module):
|
|
def forward(self):
|
|
a: int = 4
|
|
b: int = 5
|
|
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int"):
|
|
l = [2 + b for i in range(a) if i > 2]
|
|
return a
|
|
model = A()
|
|
s = torch.jit.script(model)
|
|
self.assertEqual(s(), 4)
|
|
self.assertEqual(s(), model())
|
|
|
|
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
|
|
def test_with_ignore_context_manager_with_just_out(self):
|
|
class A(torch.nn.Module):
|
|
def forward(self):
|
|
with torch.jit._IgnoreContextManager(c="out:List[int]"):
|
|
c = [2 for i in range(7) if i > 2]
|
|
c[0] = 3
|
|
return c[0] + c[1]
|
|
model = A()
|
|
s = torch.jit.script(model)
|
|
self.assertEqual(s(), 5)
|
|
self.assertEqual(s(), model())
|