mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied. - #94587 - #94588 - #94592 Also, methods with only a `super()` call are removed: ```diff class MyModule(nn.Module): - def __init__(self): - super().__init__() - def forward(self, ...): ... ``` Some cases that change the semantics should be kept unchanged. E.g.:f152a79be9/caffe2/python/net_printer.py (L184-L190)f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)Pull Request resolved: https://github.com/pytorch/pytorch/pull/94592 Approved by: https://github.com/ezyang, https://github.com/seemethere
178 lines
6.1 KiB
Python
178 lines
6.1 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import os
|
|
import sys
|
|
|
|
from itertools import product
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.testing import FileCheck
|
|
import unittest
|
|
|
|
try:
|
|
import torchvision
|
|
HAS_TORCHVISION = True
|
|
except ImportError:
|
|
HAS_TORCHVISION = False
|
|
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
if __name__ == '__main__':
|
|
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead.")
|
|
|
|
activations = [
|
|
F.celu,
|
|
F.elu,
|
|
F.hardsigmoid,
|
|
F.hardswish,
|
|
F.hardtanh,
|
|
F.leaky_relu,
|
|
F.relu,
|
|
F.relu6,
|
|
F.rrelu,
|
|
F.selu,
|
|
F.silu,
|
|
]
|
|
|
|
class TestFunctionalToInplaceActivation(JitTestCase):
|
|
def test_check_no_type_promotion(self):
|
|
dtypes = [
|
|
torch.bool,
|
|
torch.int8,
|
|
torch.int16,
|
|
torch.int32,
|
|
torch.int64,
|
|
torch.float32,
|
|
torch.float64,
|
|
]
|
|
# restore_mutation.h contains a mapping from activation operators
|
|
# to whether they allow type conversion. Use this checking to
|
|
# guard the mapping, and if any later change breaks the assumption
|
|
# we need to update the mapping correspondingly.
|
|
for activation, dtype in product(activations, dtypes):
|
|
inp = torch.normal(0, 5, size=(4, 4)).to(dtype)
|
|
try:
|
|
out = activation(inp)
|
|
self.assertEqual(dtype, out.dtype)
|
|
except RuntimeError:
|
|
# Skip the not implemented error
|
|
pass
|
|
|
|
def test_functional_to_inplace_activation(self):
|
|
for activation in activations:
|
|
def test_basic(x):
|
|
y = x + 1
|
|
z = activation(y)
|
|
return z
|
|
|
|
fn = torch.jit.script(test_basic)
|
|
self.run_pass("inline", fn.graph)
|
|
self.run_pass("constant_propagation", fn.graph)
|
|
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
|
|
self.run_pass('functional_to_inplace_activation', fn.graph)
|
|
FileCheck().check_not(f"aten::{activation.__name__}(").run(fn.graph)
|
|
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
|
|
inp = torch.rand([2, 2])
|
|
self.assertEqual(fn(inp), test_basic(inp))
|
|
|
|
def test_no_functional_to_inplace(self):
|
|
# inplace conversion should not happen because sigmoid may
|
|
# perform type conversion
|
|
def test1():
|
|
y = torch.ones([2, 2])
|
|
z = torch.sigmoid(y)
|
|
return z
|
|
|
|
fn = torch.jit.script(test1)
|
|
self.run_pass('functional_to_inplace_activation', fn.graph)
|
|
FileCheck().check_not("aten::sigmoid_").run(fn.graph)
|
|
|
|
# inplace conversion should not happen because y is alias
|
|
# the input x
|
|
def test2(x):
|
|
y = x[0]
|
|
z = torch.relu(y)
|
|
return z
|
|
|
|
fn = torch.jit.script(test2)
|
|
self.run_pass('functional_to_inplace_activation', fn.graph)
|
|
FileCheck().check_not("aten::relu_").run(fn.graph)
|
|
|
|
# inplace conversion should not happen because self.x is
|
|
# at the global scope
|
|
class Test3(nn.Module):
|
|
def __init__(self, x):
|
|
super().__init__()
|
|
self.x = x
|
|
|
|
def forward(self):
|
|
y = torch.relu(self.x)
|
|
return y
|
|
|
|
fn = torch.jit.script(Test3(torch.rand([2, 2])).eval())
|
|
self.run_pass('functional_to_inplace_activation', fn.graph)
|
|
FileCheck().check_not("aten::relu_").run(fn.graph)
|
|
|
|
@skipIfNoTorchVision
|
|
def test_resnet18_correctness(self):
|
|
model = torchvision.models.resnet18()
|
|
frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
|
|
N, C, H, W, = 10, 3, 224, 224
|
|
inp = torch.randn(N, C, H, W)
|
|
self.run_pass('functional_to_inplace_activation', frozen_model.graph)
|
|
self.assertEqual(model(inp), frozen_model(inp))
|
|
|
|
|
|
class TestInplaceToFunctionalActivation(JitTestCase):
|
|
def test_inplace_to_functional_activation(self):
|
|
for activation in activations:
|
|
def test_basic(x):
|
|
y = x + 1
|
|
activation(y, inplace=True)
|
|
return y
|
|
|
|
fn = torch.jit.script(test_basic)
|
|
self.run_pass("inline", fn.graph)
|
|
self.run_pass("constant_propagation", fn.graph)
|
|
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
|
|
self.run_pass('inplace_to_functional_activation', fn.graph)
|
|
FileCheck().check_not(f"aten::{activation.__name__}_").run(fn.graph)
|
|
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
|
|
|
|
for activation in [
|
|
torch.relu_,
|
|
torch.sigmoid_,
|
|
torch.tanh_,
|
|
]:
|
|
def test_basic(x):
|
|
y = x + 1
|
|
activation(y)
|
|
return y
|
|
|
|
fn = torch.jit.script(test_basic)
|
|
self.run_pass("inline", fn.graph)
|
|
self.run_pass("constant_propagation", fn.graph)
|
|
FileCheck().check(f"aten::{activation.__name__}").run(fn.graph)
|
|
self.run_pass('inplace_to_functional_activation', fn.graph)
|
|
FileCheck().check_not(f"aten::{activation.__name__}").run(fn.graph)
|
|
FileCheck().check(f"aten::{activation.__name__[:-1]}(").run(fn.graph)
|
|
|
|
inp = torch.rand([2, 2])
|
|
self.assertEqual(fn(inp), test_basic(inp))
|
|
|
|
@skipIfNoTorchVision
|
|
def test_resnet18_correctness(self):
|
|
model = torchvision.models.resnet18()
|
|
frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
|
|
N, C, H, W, = 10, 3, 224, 224
|
|
inp = torch.randn(N, C, H, W)
|
|
self.run_pass('inplace_to_functional_activation', frozen_model.graph)
|
|
self.assertEqual(model(inp), frozen_model(inp))
|