pytorch/test/jit/test_convert_activation.py
Xuehai Pan 046e88a291 [BE] [3/3] Rewrite super() calls in test (#94592)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94592
Approved by: https://github.com/ezyang, https://github.com/seemethere
2023-02-12 22:20:53 +00:00

178 lines
6.1 KiB
Python

# Owner(s): ["oncall: jit"]
import os
import sys
from itertools import product
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.testing import FileCheck
import unittest
try:
import torchvision
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
activations = [
F.celu,
F.elu,
F.hardsigmoid,
F.hardswish,
F.hardtanh,
F.leaky_relu,
F.relu,
F.relu6,
F.rrelu,
F.selu,
F.silu,
]
class TestFunctionalToInplaceActivation(JitTestCase):
def test_check_no_type_promotion(self):
dtypes = [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float32,
torch.float64,
]
# restore_mutation.h contains a mapping from activation operators
# to whether they allow type conversion. Use this checking to
# guard the mapping, and if any later change breaks the assumption
# we need to update the mapping correspondingly.
for activation, dtype in product(activations, dtypes):
inp = torch.normal(0, 5, size=(4, 4)).to(dtype)
try:
out = activation(inp)
self.assertEqual(dtype, out.dtype)
except RuntimeError:
# Skip the not implemented error
pass
def test_functional_to_inplace_activation(self):
for activation in activations:
def test_basic(x):
y = x + 1
z = activation(y)
return z
fn = torch.jit.script(test_basic)
self.run_pass("inline", fn.graph)
self.run_pass("constant_propagation", fn.graph)
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
self.run_pass('functional_to_inplace_activation', fn.graph)
FileCheck().check_not(f"aten::{activation.__name__}(").run(fn.graph)
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
inp = torch.rand([2, 2])
self.assertEqual(fn(inp), test_basic(inp))
def test_no_functional_to_inplace(self):
# inplace conversion should not happen because sigmoid may
# perform type conversion
def test1():
y = torch.ones([2, 2])
z = torch.sigmoid(y)
return z
fn = torch.jit.script(test1)
self.run_pass('functional_to_inplace_activation', fn.graph)
FileCheck().check_not("aten::sigmoid_").run(fn.graph)
# inplace conversion should not happen because y is alias
# the input x
def test2(x):
y = x[0]
z = torch.relu(y)
return z
fn = torch.jit.script(test2)
self.run_pass('functional_to_inplace_activation', fn.graph)
FileCheck().check_not("aten::relu_").run(fn.graph)
# inplace conversion should not happen because self.x is
# at the global scope
class Test3(nn.Module):
def __init__(self, x):
super().__init__()
self.x = x
def forward(self):
y = torch.relu(self.x)
return y
fn = torch.jit.script(Test3(torch.rand([2, 2])).eval())
self.run_pass('functional_to_inplace_activation', fn.graph)
FileCheck().check_not("aten::relu_").run(fn.graph)
@skipIfNoTorchVision
def test_resnet18_correctness(self):
model = torchvision.models.resnet18()
frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
N, C, H, W, = 10, 3, 224, 224
inp = torch.randn(N, C, H, W)
self.run_pass('functional_to_inplace_activation', frozen_model.graph)
self.assertEqual(model(inp), frozen_model(inp))
class TestInplaceToFunctionalActivation(JitTestCase):
def test_inplace_to_functional_activation(self):
for activation in activations:
def test_basic(x):
y = x + 1
activation(y, inplace=True)
return y
fn = torch.jit.script(test_basic)
self.run_pass("inline", fn.graph)
self.run_pass("constant_propagation", fn.graph)
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
self.run_pass('inplace_to_functional_activation', fn.graph)
FileCheck().check_not(f"aten::{activation.__name__}_").run(fn.graph)
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
for activation in [
torch.relu_,
torch.sigmoid_,
torch.tanh_,
]:
def test_basic(x):
y = x + 1
activation(y)
return y
fn = torch.jit.script(test_basic)
self.run_pass("inline", fn.graph)
self.run_pass("constant_propagation", fn.graph)
FileCheck().check(f"aten::{activation.__name__}").run(fn.graph)
self.run_pass('inplace_to_functional_activation', fn.graph)
FileCheck().check_not(f"aten::{activation.__name__}").run(fn.graph)
FileCheck().check(f"aten::{activation.__name__[:-1]}(").run(fn.graph)
inp = torch.rand([2, 2])
self.assertEqual(fn(inp), test_basic(inp))
@skipIfNoTorchVision
def test_resnet18_correctness(self):
model = torchvision.models.resnet18()
frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
N, C, H, W, = 10, 3, 224, 224
inp = torch.randn(N, C, H, W)
self.run_pass('inplace_to_functional_activation', frozen_model.graph)
self.assertEqual(model(inp), frozen_model(inp))