mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[JIT] Factor out peephole to own test file (#50220)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50220 Test Plan: Imported from OSS Reviewed By: tugsbayasgalan Differential Revision: D25856263 Pulled By: eellison fbshipit-source-id: f3d918d860e64e788e0bb9b9cb85125660f834c6
This commit is contained in:
parent
6971149326
commit
a69f008cb7
187
test/jit/test_peephole.py
Normal file
187
test/jit/test_peephole.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
import torch
|
||||
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA, _inline_everything
|
||||
|
||||
from torch.testing import FileCheck
|
||||
|
||||
import unittest
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
class TestPeephole(JitTestCase):
|
||||
def test_peephole_with_writes(self):
|
||||
def test_write(x):
|
||||
s = 0
|
||||
s += x
|
||||
s += x
|
||||
return s
|
||||
|
||||
self.checkScript(test_write, (torch.ones(4, 4),))
|
||||
|
||||
def test_peephole_with_non_output_writes(self):
|
||||
@torch.jit.ignore
|
||||
def nomnom(x):
|
||||
pass
|
||||
|
||||
def test_write(x):
|
||||
t = torch.ones_like(x)
|
||||
z = x.clone()
|
||||
y = z + 0
|
||||
z.add_(t)
|
||||
# this makes sure z isn't blasted out of existence
|
||||
# because it isn't returned or used in a side-effectful
|
||||
# way
|
||||
nomnom(z)
|
||||
return y + y
|
||||
|
||||
a = torch.ones(4, 4)
|
||||
j = self.checkScript(test_write, (a,))
|
||||
|
||||
def test_peephole_no_output_aliasing(self):
|
||||
def test_peephole(x):
|
||||
y = x + 0
|
||||
return x, y
|
||||
|
||||
a = torch.ones(4, 4)
|
||||
j = self.checkScript(test_peephole, (a,))
|
||||
r1, r2 = j(a)
|
||||
self.assertNotEqual(r1.data_ptr(), r2.data_ptr())
|
||||
|
||||
def test_peephole(self):
|
||||
a = torch.tensor([0.4])
|
||||
b = torch.tensor([0.7])
|
||||
c = torch.tensor([0], dtype=torch.int32)
|
||||
|
||||
def f(x, y):
|
||||
return x.type_as(y)
|
||||
|
||||
tf = torch.jit.trace(f, (a, b))
|
||||
FileCheck().check("type_as").run(str(tf.graph))
|
||||
self.run_pass('peephole', tf.graph)
|
||||
FileCheck().check_not("type_as").run(str(tf.graph))
|
||||
tf2 = torch.jit.trace(f, (a, c))
|
||||
s = str(tf2.graph)
|
||||
self.run_pass('peephole', tf2.graph)
|
||||
self.assertEqual(s, str(s))
|
||||
|
||||
def test_peephole_dynamic(self):
|
||||
def f(x, y):
|
||||
return x.type_as(y)
|
||||
|
||||
fn = torch.jit.script(f)
|
||||
s = str(fn.graph)
|
||||
torch._C._jit_pass_peephole(fn.graph)
|
||||
self.assertEqual(s, str(fn.graph))
|
||||
|
||||
def test_peephole_list_ops(self):
|
||||
@torch.jit.script
|
||||
def foo(x, y, z):
|
||||
return len([x, y, z])
|
||||
|
||||
self.run_pass('peephole', foo.graph)
|
||||
FileCheck().check("value=3").check_next("return").run(foo.graph)
|
||||
|
||||
@torch.jit.script
|
||||
def foo(x, y, z):
|
||||
li = [x, y, z]
|
||||
for i in range(len(x)):
|
||||
li.append(x)
|
||||
return len([x, y, z])
|
||||
|
||||
self.run_pass('peephole', foo.graph)
|
||||
FileCheck().check_not("aten::len").run(foo.graph)
|
||||
|
||||
@torch.jit.script
|
||||
def foo(x, y, z):
|
||||
li = [x, y, z]
|
||||
return li[1], li[-2]
|
||||
|
||||
FileCheck().check("aten::__getitem__").run(foo.graph)
|
||||
self.run_pass('peephole', foo.graph)
|
||||
FileCheck().check_not("aten::__getitem__").run(foo.graph)
|
||||
|
||||
@torch.jit.script
|
||||
def foo(x, y, z):
|
||||
li = [x, y, z]
|
||||
return li[-7]
|
||||
|
||||
self.run_pass('peephole', foo.graph)
|
||||
FileCheck().check("aten::__getitem__").run(foo.graph)
|
||||
|
||||
@torch.jit.script
|
||||
def foo(x, y, z):
|
||||
li = [x, y, z]
|
||||
for i in range(len(x)):
|
||||
li.append(x)
|
||||
return li[-2]
|
||||
|
||||
self.run_pass('peephole', foo.graph)
|
||||
FileCheck().check("aten::__getitem__").run(foo.graph)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
|
||||
def test_peephole_cuda(self):
|
||||
a = torch.tensor([0.4], device='cpu')
|
||||
b = torch.tensor([0.7], device='cuda')
|
||||
c = torch.tensor([0.7], device='cuda')
|
||||
|
||||
def f(x, y):
|
||||
return x.type_as(y)
|
||||
|
||||
trace = torch.jit.trace(f, (a, c))
|
||||
s = str(trace.graph)
|
||||
self.run_pass('peephole', trace.graph)
|
||||
self.assertEqual(s, str(trace.graph))
|
||||
trace = torch.jit.trace(f, (b, c))
|
||||
self.run_pass('peephole', trace.graph)
|
||||
self.run_pass('dce', trace.graph)
|
||||
FileCheck().check_not("type_as").run(str(trace.graph))
|
||||
|
||||
@_inline_everything
|
||||
def test_peephole_type_refinements(self):
|
||||
def refine(x):
|
||||
# type: (Optional[Tensor]) -> Tensor
|
||||
return x if x is not None else torch.tensor(3)
|
||||
|
||||
@torch.jit.script
|
||||
def test():
|
||||
return refine(torch.tensor(4))
|
||||
|
||||
FileCheck().check("prim::unchecked_cast").run(test.graph)
|
||||
self.run_pass('peephole', test.graph)
|
||||
FileCheck().check_not("prim::unchecked_cast").run(test.graph)
|
||||
|
||||
# refinement not optimzied out
|
||||
def is_int_tensor(x):
|
||||
scalar = x.item()
|
||||
if isinstance(scalar, int):
|
||||
return scalar + 3
|
||||
else:
|
||||
return 8
|
||||
|
||||
self.checkScript(is_int_tensor, (torch.tensor(2),))
|
||||
self.checkScript(is_int_tensor, (torch.tensor(2.5),))
|
||||
graph = torch.jit.script(is_int_tensor).graph
|
||||
self.run_pass('peephole', graph)
|
||||
FileCheck().check("prim::unchecked_cast").run(graph)
|
||||
|
||||
def test_short_circuit_optimization(self):
|
||||
@torch.jit.script
|
||||
def const_expressions(x):
|
||||
# type: (int) -> Tuple[bool, bool]
|
||||
return x == 1 and False, x == 1 or True
|
||||
self.run_pass('constant_propagation', const_expressions.graph)
|
||||
FileCheck().check_not("prim::If").check_not("aten::eq").run(const_expressions.graph)
|
||||
self.assertEqual(const_expressions(1), (False, True))
|
||||
|
||||
@torch.jit.script
|
||||
def redundant_expressions(x):
|
||||
# type: (int) -> Tuple[bool, bool]
|
||||
return x == 1 and True, x == 1 or False
|
||||
|
||||
self.run_pass('peephole', redundant_expressions.graph)
|
||||
self.assertEqual(redundant_expressions(1), (True, True))
|
||||
self.assertEqual(redundant_expressions(0), (False, False))
|
||||
# and True / or False are removed from graph
|
||||
FileCheck().check("aten::eq").check_not("prim::If").run(redundant_expressions.graph)
|
||||
177
test/test_jit.py
177
test/test_jit.py
|
|
@ -20,6 +20,7 @@ from jit.test_class_type import TestClassType # noqa: F401
|
|||
from jit.test_builtins import TestBuiltins, TestTensorBuiltins # noqa: F401
|
||||
from jit.test_unsupported_ops import TestUnsupportedOps # noqa: F401
|
||||
from jit.test_freezing import TestFreezing, TestFrozenOptimizations # noqa: F401
|
||||
from jit.test_peephole import TestPeephole # noqa: F401
|
||||
from jit.test_save_load import TestSaveLoad # noqa: F401
|
||||
from jit.test_module_containers import TestModuleContainers # noqa: F401
|
||||
from jit.test_python_ir import TestPythonIr # noqa: F401
|
||||
|
|
@ -481,134 +482,6 @@ class TestJit(JitTestCase):
|
|||
self.assertTrue(m2.b0.is_shared())
|
||||
self.assertEqual(m2.b0.storage().data_ptr(), m2.p0.storage().data_ptr())
|
||||
|
||||
def test_peephole_with_writes(self):
|
||||
def test_write(x):
|
||||
s = 0
|
||||
s += x
|
||||
s += x
|
||||
return s
|
||||
|
||||
self.checkScript(test_write, (torch.ones(4, 4),))
|
||||
|
||||
def test_peephole_with_non_output_writes(self):
|
||||
|
||||
@torch.jit.ignore
|
||||
def nomnom(x):
|
||||
pass
|
||||
|
||||
def test_write(x):
|
||||
t = torch.ones_like(x)
|
||||
z = x.clone()
|
||||
y = z + 0
|
||||
z.add_(t)
|
||||
# this makes sure z isn't blasted out of existence
|
||||
# because it isn't returned or used in a side-effectful
|
||||
# way
|
||||
nomnom(z)
|
||||
return y + y
|
||||
|
||||
a = torch.ones(4, 4)
|
||||
j = self.checkScript(test_write, (a,))
|
||||
|
||||
def test_peephole_no_output_aliasing(self):
|
||||
def test_peephole(x):
|
||||
y = x + 0
|
||||
return x, y
|
||||
|
||||
a = torch.ones(4, 4)
|
||||
j = self.checkScript(test_peephole, (a,))
|
||||
r1, r2 = j(a)
|
||||
self.assertNotEqual(r1.data_ptr(), r2.data_ptr())
|
||||
|
||||
def test_peephole(self):
|
||||
a = torch.tensor([0.4])
|
||||
b = torch.tensor([0.7])
|
||||
c = torch.tensor([0], dtype=torch.int32)
|
||||
|
||||
def f(x, y):
|
||||
return x.type_as(y)
|
||||
|
||||
tf = torch.jit.trace(f, (a, b))
|
||||
FileCheck().check("type_as").run(str(tf.graph))
|
||||
self.run_pass('peephole', tf.graph)
|
||||
FileCheck().check_not("type_as").run(str(tf.graph))
|
||||
tf2 = torch.jit.trace(f, (a, c))
|
||||
s = str(tf2.graph)
|
||||
self.run_pass('peephole', tf2.graph)
|
||||
self.assertEqual(s, str(s))
|
||||
|
||||
def test_peephole_dynamic(self):
|
||||
def f(x, y):
|
||||
return x.type_as(y)
|
||||
|
||||
fn = torch.jit.script(f)
|
||||
s = str(fn.graph)
|
||||
torch._C._jit_pass_peephole(fn.graph)
|
||||
self.assertEqual(s, str(fn.graph))
|
||||
|
||||
def test_peephole_list_ops(self):
|
||||
@torch.jit.script
|
||||
def foo(x, y, z):
|
||||
return len([x, y, z])
|
||||
|
||||
self.run_pass('peephole', foo.graph)
|
||||
FileCheck().check("value=3").check_next("return").run(foo.graph)
|
||||
|
||||
@torch.jit.script
|
||||
def foo(x, y, z):
|
||||
li = [x, y, z]
|
||||
for i in range(len(x)):
|
||||
li.append(x)
|
||||
return len([x, y, z])
|
||||
|
||||
self.run_pass('peephole', foo.graph)
|
||||
FileCheck().check_not("aten::len").run(foo.graph)
|
||||
|
||||
@torch.jit.script
|
||||
def foo(x, y, z):
|
||||
li = [x, y, z]
|
||||
return li[1], li[-2]
|
||||
|
||||
FileCheck().check("aten::__getitem__").run(foo.graph)
|
||||
self.run_pass('peephole', foo.graph)
|
||||
FileCheck().check_not("aten::__getitem__").run(foo.graph)
|
||||
|
||||
@torch.jit.script
|
||||
def foo(x, y, z):
|
||||
li = [x, y, z]
|
||||
return li[-7]
|
||||
|
||||
self.run_pass('peephole', foo.graph)
|
||||
FileCheck().check("aten::__getitem__").run(foo.graph)
|
||||
|
||||
@torch.jit.script
|
||||
def foo(x, y, z):
|
||||
li = [x, y, z]
|
||||
for i in range(len(x)):
|
||||
li.append(x)
|
||||
return li[-2]
|
||||
|
||||
self.run_pass('peephole', foo.graph)
|
||||
FileCheck().check("aten::__getitem__").run(foo.graph)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
|
||||
def test_peephole_cuda(self):
|
||||
a = torch.tensor([0.4], device='cpu')
|
||||
b = torch.tensor([0.7], device='cuda')
|
||||
c = torch.tensor([0.7], device='cuda')
|
||||
|
||||
def f(x, y):
|
||||
return x.type_as(y)
|
||||
|
||||
trace = torch.jit.trace(f, (a, c))
|
||||
s = str(trace.graph)
|
||||
self.run_pass('peephole', trace.graph)
|
||||
self.assertEqual(s, str(trace.graph))
|
||||
trace = torch.jit.trace(f, (b, c))
|
||||
self.run_pass('peephole', trace.graph)
|
||||
self.run_pass('dce', trace.graph)
|
||||
FileCheck().check_not("type_as").run(str(trace.graph))
|
||||
|
||||
def test_add_relu_fusion(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, relu_op):
|
||||
|
|
@ -1789,54 +1662,6 @@ graph(%Ra, %Rb):
|
|||
|
||||
FileCheck().check_not("prim::If").run(fn.graph)
|
||||
|
||||
def test_short_circuit_optimization(self):
|
||||
@torch.jit.script
|
||||
def const_expressions(x):
|
||||
# type: (int) -> Tuple[bool, bool]
|
||||
return x == 1 and False, x == 1 or True
|
||||
self.run_pass('constant_propagation', const_expressions.graph)
|
||||
FileCheck().check_not("prim::If").check_not("aten::eq").run(const_expressions.graph)
|
||||
self.assertEqual(const_expressions(1), (False, True))
|
||||
|
||||
@torch.jit.script
|
||||
def redundant_expressions(x):
|
||||
# type: (int) -> Tuple[bool, bool]
|
||||
return x == 1 and True, x == 1 or False
|
||||
|
||||
self.run_pass('peephole', redundant_expressions.graph)
|
||||
self.assertEqual(redundant_expressions(1), (True, True))
|
||||
self.assertEqual(redundant_expressions(0), (False, False))
|
||||
# and True / or False are removed from graph
|
||||
FileCheck().check("aten::eq").check_not("prim::If").run(redundant_expressions.graph)
|
||||
|
||||
@_inline_everything
|
||||
def test_peephole_type_refinements(self):
|
||||
def refine(x):
|
||||
# type: (Optional[Tensor]) -> Tensor
|
||||
return x if x is not None else torch.tensor(3)
|
||||
|
||||
@torch.jit.script
|
||||
def test():
|
||||
return refine(torch.tensor(4))
|
||||
|
||||
FileCheck().check("prim::unchecked_cast").run(test.graph)
|
||||
self.run_pass('peephole', test.graph)
|
||||
FileCheck().check_not("prim::unchecked_cast").run(test.graph)
|
||||
|
||||
# refinement not optimzied out
|
||||
def is_int_tensor(x):
|
||||
scalar = x.item()
|
||||
if isinstance(scalar, int):
|
||||
return scalar + 3
|
||||
else:
|
||||
return 8
|
||||
|
||||
self.checkScript(is_int_tensor, (torch.tensor(2),))
|
||||
self.checkScript(is_int_tensor, (torch.tensor(2.5),))
|
||||
graph = torch.jit.script(is_int_tensor).graph
|
||||
self.run_pass('peephole', graph)
|
||||
FileCheck().check("prim::unchecked_cast").run(graph)
|
||||
|
||||
def test_unchecked_cast(self):
|
||||
def test(cond):
|
||||
# type: (bool)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user