Delete ifdyn and ifunspec combinators (#103596)

Replaced with expect tests for ease of updating.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103596
Approved by: https://github.com/voznesenskym
This commit is contained in:
Edward Z. Yang 2023-06-14 11:57:11 -07:00 committed by PyTorch MergeBot
parent e82616d900
commit ddf4cd69ec
6 changed files with 96 additions and 52 deletions

View File

@ -39,7 +39,7 @@ from torch._dynamo.testing import (
unsupported,
)
from torch._dynamo.utils import CompileProfiler, ifdyn, ifdynstaticdefault, ifunspec
from torch._dynamo.utils import CompileProfiler, ifdynstaticdefault
from torch.ao.quantization import MinMaxObserver
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.qconfig import QConfig
@ -351,8 +351,10 @@ class MiscTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(ref, res))
self.assertEqual(counts.frame_count, 1)
expected_op_count = ifdynstaticdefault(1, 11)
self.assertEqual(counts.op_count, expected_op_count)
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(counts.op_count, """1""")
else:
self.assertExpectedInline(counts.op_count, """11""")
@torch._dynamo.config.patch(dynamic_shapes=True)
def test_user_defined_binop(self):
@ -377,8 +379,10 @@ class MiscTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(ref, res))
self.assertEqual(counts.frame_count, 1)
expected_op_count = ifdynstaticdefault(1, 4)
self.assertEqual(counts.op_count, expected_op_count)
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(counts.op_count, """1""")
else:
self.assertExpectedInline(counts.op_count, """4""")
def test_compare_shapes_eq(self):
def compare_shapes(a, b, to_list):
@ -941,8 +945,12 @@ class MiscTests(torch._dynamo.test_case.TestCase):
# output anything and none of the traced operations have side
# effects. Probably need better heuristic for bailing on
# dynamo if there are no outputs
self.assertEqual(cnts.frame_count, ifunspec(1, 0))
self.assertEqual(cnts.op_count, ifunspec(2, 0))
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(cnts.frame_count, """0""")
self.assertExpectedInline(cnts.op_count, """0""")
else:
self.assertExpectedInline(cnts.frame_count, """1""")
self.assertExpectedInline(cnts.op_count, """2""")
def test_list_slice_mul(self):
def fn(count):
@ -953,8 +961,12 @@ class MiscTests(torch._dynamo.test_case.TestCase):
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
self.assertEqual(opt_fn(2), [2, 3] * 4)
self.assertEqual(cnts.frame_count, ifunspec(1, 0))
self.assertEqual(cnts.op_count, ifunspec(2, 0))
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(cnts.frame_count, """0""")
self.assertExpectedInline(cnts.op_count, """0""")
else:
self.assertExpectedInline(cnts.frame_count, """1""")
self.assertExpectedInline(cnts.op_count, """2""")
def test_tuple_mul(self):
def fn(count):
@ -964,8 +976,12 @@ class MiscTests(torch._dynamo.test_case.TestCase):
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
self.assertEqual(opt_fn(2), (2, 3) * 4)
self.assertEqual(cnts.frame_count, ifunspec(1, 0))
self.assertEqual(cnts.op_count, ifunspec(ifdynstaticdefault(2, 2), 0))
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(cnts.frame_count, """0""")
self.assertExpectedInline(cnts.op_count, """0""")
else:
self.assertExpectedInline(cnts.frame_count, """1""")
self.assertExpectedInline(cnts.op_count, """2""")
def test_tuple_mul_with_shape(self):
def fn(a):
@ -2377,14 +2393,20 @@ def fn():
opt_m(data, correct_ref_id)
# Extra op is the recorded equality test (although once
# the trace is flattened this is dead!)
self.assertEqual(cnts.op_count, ifunspec(3, 2))
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(cnts.op_count, """2""")
else:
self.assertExpectedInline(cnts.op_count, """3""")
torch._dynamo.reset()
cnts = torch._dynamo.testing.CompileCounter()
incorrect_ref_id = id(m) + 1
opt_m = torch._dynamo.optimize(cnts, nopython=True)(m)
opt_m(data, incorrect_ref_id)
self.assertEqual(cnts.op_count, ifunspec(2, 1))
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(cnts.op_count, """1""")
else:
self.assertExpectedInline(cnts.op_count, """2""")
def test_inline_func_jump_on_tensor_condition(self):
def f1(input):
@ -4441,7 +4463,10 @@ def fn():
ref = fn(x, y)
res = opt_fn(x, y)
self.assertTrue(same(ref, res))
self.assertEqual(cnt.frame_count, ifunspec(ifdyn(1, 5), 5))
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(cnt.frame_count, """5""")
else:
self.assertExpectedInline(cnt.frame_count, """1""")
# specifically test for tensor.attribute -> torch.something()
def test_real_imag_tensor_attribute(self):

View File

@ -31,7 +31,6 @@ import torch.library
from torch import nn
from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.testing import rand_strided, requires_static_shapes, same
from torch._dynamo.utils import ifdyn, ifdynstaticdefault, ifunspec
from torch.nn import functional as F
@ -877,8 +876,12 @@ class ReproTests(torch._dynamo.test_case.TestCase):
# repeat_interleave is a dynamic shape operator we do not execute/
# In the future, we could reduce the frame_count down to 1
# by guarding on the exact values of `Tensor repeats` arg
self.assertEqual(cnt.frame_count, 4)
self.assertEqual(cnt.op_count, ifdyn(16, 10))
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(cnt.frame_count, """4""")
self.assertExpectedInline(cnt.op_count, """10""")
else:
self.assertExpectedInline(cnt.frame_count, """4""")
self.assertExpectedInline(cnt.op_count, """16""")
def test_boxes_len(self):
def fn(boxes):
@ -889,8 +892,12 @@ class ReproTests(torch._dynamo.test_case.TestCase):
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
self.assertTrue(same(opt_fn(boxes1), boxes1.tensor + 4.0))
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(1, 6), 1))
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(cnt.frame_count, """1""")
self.assertExpectedInline(cnt.op_count, """1""")
else:
self.assertExpectedInline(cnt.frame_count, """1""")
self.assertExpectedInline(cnt.op_count, """6""")
def _reformer(self, nopython):
input = torch.randn([1, 64, 256])
@ -964,8 +971,12 @@ class ReproTests(torch._dynamo.test_case.TestCase):
with torch.enable_grad():
cnt = self._reformer(nopython=False)
# cant inline torch.autograd.Function means graph break
self.assertEqual(cnt.frame_count, ifunspec(ifdyn(3, 1), 3))
self.assertEqual(cnt.op_count, ifunspec(ifdyn(10, 11), 10))
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(cnt.frame_count, """3""")
self.assertExpectedInline(cnt.op_count, """10""")
else:
self.assertExpectedInline(cnt.frame_count, """3""")
self.assertExpectedInline(cnt.op_count, """10""")
def test_longformer_chunk(self):
input1 = torch.randn([1, 4096, 1])
@ -980,10 +991,12 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(opt_fn(input1), correct1))
self.assertTrue(same(opt_fn(input2), correct2))
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(
cnt.op_count, ifunspec(35, ifdyn(ifdynstaticdefault(4, 20), 4))
)
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(cnt.frame_count, """2""")
self.assertExpectedInline(cnt.op_count, """4""")
else:
self.assertExpectedInline(cnt.frame_count, """2""")
self.assertExpectedInline(cnt.op_count, """35""")
def test_hf_t5_forward(self):
input = torch.randn([1, 2048, 512])
@ -993,8 +1006,12 @@ class ReproTests(torch._dynamo.test_case.TestCase):
opt_model = torch._dynamo.optimize_assert(cnt)(model)
self.assertTrue(same(opt_model(input), correct))
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(11, 12), 11))
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(cnt.frame_count, """1""")
self.assertExpectedInline(cnt.op_count, """11""")
else:
self.assertExpectedInline(cnt.frame_count, """1""")
self.assertExpectedInline(cnt.op_count, """12""")
def test_module_in_skipfiles(self):
model = nn.Linear(10, 10)
@ -1072,7 +1089,10 @@ class ReproTests(torch._dynamo.test_case.TestCase):
for _ in range(10):
self.assertTrue(same(opt_model(a, b, c, d), correct))
# self.assertEqual(cnt.frame_count, ifdyn(3, 2))
# if torch._dynamo.config.assume_static_by_default:
# self.assertExpectedInline(cnt.frame_count, """2""")
# else:
# self.assertExpectedInline(cnt.frame_count, """3""")
# TODO(jansel): figure out why op count depends on imports
self.assertIn(cnt.op_count, (36, 35, 34, 29, 28, 27))
@ -1091,9 +1111,9 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(opt_model(a, b, c, d), correct))
if torch._dynamo.config.assume_static_by_default:
self.assertEqual(cnt.frame_count, ifdyn(2, 4))
self.assertExpectedInline(cnt.frame_count, """2""")
else:
self.assertEqual(cnt.frame_count, ifdyn(3, 6))
self.assertExpectedInline(cnt.frame_count, """3""")
def test_hf_model_output(self):
ex = ModelOutput(a=torch.randn(10), b=torch.randn(10), c=torch.randn(10))
@ -1271,9 +1291,13 @@ class ReproTests(torch._dynamo.test_case.TestCase):
cnt = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
self.assertEqual(opt_fn(cfg), 64)
self.assertEqual(cnt.frame_count, 1)
# With unspec int, maximum computation is preserved
self.assertEqual(cnt.op_count, ifunspec(4, 3))
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(cnt.frame_count, """1""")
self.assertExpectedInline(cnt.op_count, """3""")
else:
self.assertExpectedInline(cnt.frame_count, """1""")
self.assertExpectedInline(cnt.op_count, """4""")
def test_reformer_sorting(self):
x = torch.zeros([1, 12, 4096], dtype=torch.int64)
@ -1283,8 +1307,12 @@ class ReproTests(torch._dynamo.test_case.TestCase):
cnt = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
self.assertTrue(same(opt_fn(x), correct))
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(14, 27), 14))
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(cnt.frame_count, """1""")
self.assertExpectedInline(cnt.op_count, """14""")
else:
self.assertExpectedInline(cnt.frame_count, """1""")
self.assertExpectedInline(cnt.op_count, """27""")
def test_recursive_map(self):
# https://github.com/pytorch/torchdynamo/issues/132

View File

@ -8,7 +8,7 @@ import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo import config
from torch._dynamo.testing import unsupported
from torch._dynamo.utils import ifdyn, ifdynstaticdefault, ifunspec
from torch._dynamo.utils import ifdynstaticdefault
globalmod = torch.nn.ReLU()
@ -328,7 +328,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
# means we fail to unroll the loop.
# TODO: Consider forcing specialization when we iterate over
# the loop
self._common(fn, 2, ifunspec(1, 4))
self._common(fn, 2, ifdynstaticdefault(4, 1))
def test_restore_range_iter(self):
def fn(a, b):
@ -598,7 +598,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
b = b + x * i
return b
self._common(fn, 1, ifdyn(ifdynstaticdefault(2, 7), 2))
self._common(fn, 1, ifdynstaticdefault(2, 7))
if __name__ == "__main__":

View File

@ -35,7 +35,9 @@ cache_size_limit = 64
# whether or not to specialize on int inputs. This only has an effect with
# dynamic_shapes; when dynamic_shapes is False, we ALWAYS specialize on int
# inputs
# inputs. Note that assume_static_by_default will also cause ints to get
# specialized, so this is mostly useful for export, where we want inputs
# to be dynamic, but accesses to ints should NOT get promoted into inputs.
specialize_int = False
# Assume these functions return constants

View File

@ -1398,13 +1398,6 @@ def fqn(obj: Any):
return f"{obj.__module__}.{obj.__qualname__}"
def ifdyn(count1, count2):
if torch._dynamo.config.dynamic_shapes:
return count1
else:
return count2
def ifdynstaticdefault(count1, count2):
if torch._dynamo.config.assume_static_by_default:
return count1
@ -1412,13 +1405,6 @@ def ifdynstaticdefault(count1, count2):
return count2
def ifunspec(count1, count2):
if torch._dynamo.config.dynamic_shapes and not torch._dynamo.config.specialize_int:
return count1
else:
return count2
def import_submodule(mod: types.ModuleType):
"""
Ensure all the files in a given submodule are imported

View File

@ -2159,6 +2159,9 @@ class TestCase(expecttest.TestCase):
def enforceNonDefaultStream(self):
return CudaNonDefaultStream()
def assertExpectedInline(self, actual, expect, skip=0):
return super().assertExpectedInline(actual if isinstance(actual, str) else str(actual), expect, skip + 1)
def assertLogs(self, logger=None, level=None):
if logger is None:
logger = logging.getLogger("torch")