mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Delete ifdyn and ifunspec combinators (#103596)
Replaced with expect tests for ease of updating. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/103596 Approved by: https://github.com/voznesenskym
This commit is contained in:
parent
e82616d900
commit
ddf4cd69ec
|
|
@ -39,7 +39,7 @@ from torch._dynamo.testing import (
|
|||
unsupported,
|
||||
)
|
||||
|
||||
from torch._dynamo.utils import CompileProfiler, ifdyn, ifdynstaticdefault, ifunspec
|
||||
from torch._dynamo.utils import CompileProfiler, ifdynstaticdefault
|
||||
from torch.ao.quantization import MinMaxObserver
|
||||
from torch.ao.quantization.fake_quantize import FakeQuantize
|
||||
from torch.ao.quantization.qconfig import QConfig
|
||||
|
|
@ -351,8 +351,10 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertTrue(same(ref, res))
|
||||
self.assertEqual(counts.frame_count, 1)
|
||||
|
||||
expected_op_count = ifdynstaticdefault(1, 11)
|
||||
self.assertEqual(counts.op_count, expected_op_count)
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(counts.op_count, """1""")
|
||||
else:
|
||||
self.assertExpectedInline(counts.op_count, """11""")
|
||||
|
||||
@torch._dynamo.config.patch(dynamic_shapes=True)
|
||||
def test_user_defined_binop(self):
|
||||
|
|
@ -377,8 +379,10 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
self.assertTrue(same(ref, res))
|
||||
self.assertEqual(counts.frame_count, 1)
|
||||
expected_op_count = ifdynstaticdefault(1, 4)
|
||||
self.assertEqual(counts.op_count, expected_op_count)
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(counts.op_count, """1""")
|
||||
else:
|
||||
self.assertExpectedInline(counts.op_count, """4""")
|
||||
|
||||
def test_compare_shapes_eq(self):
|
||||
def compare_shapes(a, b, to_list):
|
||||
|
|
@ -941,8 +945,12 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
# output anything and none of the traced operations have side
|
||||
# effects. Probably need better heuristic for bailing on
|
||||
# dynamo if there are no outputs
|
||||
self.assertEqual(cnts.frame_count, ifunspec(1, 0))
|
||||
self.assertEqual(cnts.op_count, ifunspec(2, 0))
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(cnts.frame_count, """0""")
|
||||
self.assertExpectedInline(cnts.op_count, """0""")
|
||||
else:
|
||||
self.assertExpectedInline(cnts.frame_count, """1""")
|
||||
self.assertExpectedInline(cnts.op_count, """2""")
|
||||
|
||||
def test_list_slice_mul(self):
|
||||
def fn(count):
|
||||
|
|
@ -953,8 +961,12 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
self.assertEqual(opt_fn(2), [2, 3] * 4)
|
||||
self.assertEqual(cnts.frame_count, ifunspec(1, 0))
|
||||
self.assertEqual(cnts.op_count, ifunspec(2, 0))
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(cnts.frame_count, """0""")
|
||||
self.assertExpectedInline(cnts.op_count, """0""")
|
||||
else:
|
||||
self.assertExpectedInline(cnts.frame_count, """1""")
|
||||
self.assertExpectedInline(cnts.op_count, """2""")
|
||||
|
||||
def test_tuple_mul(self):
|
||||
def fn(count):
|
||||
|
|
@ -964,8 +976,12 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
self.assertEqual(opt_fn(2), (2, 3) * 4)
|
||||
self.assertEqual(cnts.frame_count, ifunspec(1, 0))
|
||||
self.assertEqual(cnts.op_count, ifunspec(ifdynstaticdefault(2, 2), 0))
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(cnts.frame_count, """0""")
|
||||
self.assertExpectedInline(cnts.op_count, """0""")
|
||||
else:
|
||||
self.assertExpectedInline(cnts.frame_count, """1""")
|
||||
self.assertExpectedInline(cnts.op_count, """2""")
|
||||
|
||||
def test_tuple_mul_with_shape(self):
|
||||
def fn(a):
|
||||
|
|
@ -2377,14 +2393,20 @@ def fn():
|
|||
opt_m(data, correct_ref_id)
|
||||
# Extra op is the recorded equality test (although once
|
||||
# the trace is flattened this is dead!)
|
||||
self.assertEqual(cnts.op_count, ifunspec(3, 2))
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(cnts.op_count, """2""")
|
||||
else:
|
||||
self.assertExpectedInline(cnts.op_count, """3""")
|
||||
|
||||
torch._dynamo.reset()
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
incorrect_ref_id = id(m) + 1
|
||||
opt_m = torch._dynamo.optimize(cnts, nopython=True)(m)
|
||||
opt_m(data, incorrect_ref_id)
|
||||
self.assertEqual(cnts.op_count, ifunspec(2, 1))
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(cnts.op_count, """1""")
|
||||
else:
|
||||
self.assertExpectedInline(cnts.op_count, """2""")
|
||||
|
||||
def test_inline_func_jump_on_tensor_condition(self):
|
||||
def f1(input):
|
||||
|
|
@ -4441,7 +4463,10 @@ def fn():
|
|||
ref = fn(x, y)
|
||||
res = opt_fn(x, y)
|
||||
self.assertTrue(same(ref, res))
|
||||
self.assertEqual(cnt.frame_count, ifunspec(ifdyn(1, 5), 5))
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(cnt.frame_count, """5""")
|
||||
else:
|
||||
self.assertExpectedInline(cnt.frame_count, """1""")
|
||||
|
||||
# specifically test for tensor.attribute -> torch.something()
|
||||
def test_real_imag_tensor_attribute(self):
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ import torch.library
|
|||
from torch import nn
|
||||
from torch._dynamo.debug_utils import same_two_models
|
||||
from torch._dynamo.testing import rand_strided, requires_static_shapes, same
|
||||
from torch._dynamo.utils import ifdyn, ifdynstaticdefault, ifunspec
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
|
|
@ -877,8 +876,12 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
# repeat_interleave is a dynamic shape operator we do not execute/
|
||||
# In the future, we could reduce the frame_count down to 1
|
||||
# by guarding on the exact values of `Tensor repeats` arg
|
||||
self.assertEqual(cnt.frame_count, 4)
|
||||
self.assertEqual(cnt.op_count, ifdyn(16, 10))
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(cnt.frame_count, """4""")
|
||||
self.assertExpectedInline(cnt.op_count, """10""")
|
||||
else:
|
||||
self.assertExpectedInline(cnt.frame_count, """4""")
|
||||
self.assertExpectedInline(cnt.op_count, """16""")
|
||||
|
||||
def test_boxes_len(self):
|
||||
def fn(boxes):
|
||||
|
|
@ -889,8 +892,12 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
|
||||
self.assertTrue(same(opt_fn(boxes1), boxes1.tensor + 4.0))
|
||||
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(1, 6), 1))
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(cnt.frame_count, """1""")
|
||||
self.assertExpectedInline(cnt.op_count, """1""")
|
||||
else:
|
||||
self.assertExpectedInline(cnt.frame_count, """1""")
|
||||
self.assertExpectedInline(cnt.op_count, """6""")
|
||||
|
||||
def _reformer(self, nopython):
|
||||
input = torch.randn([1, 64, 256])
|
||||
|
|
@ -964,8 +971,12 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
with torch.enable_grad():
|
||||
cnt = self._reformer(nopython=False)
|
||||
# cant inline torch.autograd.Function means graph break
|
||||
self.assertEqual(cnt.frame_count, ifunspec(ifdyn(3, 1), 3))
|
||||
self.assertEqual(cnt.op_count, ifunspec(ifdyn(10, 11), 10))
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(cnt.frame_count, """3""")
|
||||
self.assertExpectedInline(cnt.op_count, """10""")
|
||||
else:
|
||||
self.assertExpectedInline(cnt.frame_count, """3""")
|
||||
self.assertExpectedInline(cnt.op_count, """10""")
|
||||
|
||||
def test_longformer_chunk(self):
|
||||
input1 = torch.randn([1, 4096, 1])
|
||||
|
|
@ -980,10 +991,12 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertTrue(same(opt_fn(input1), correct1))
|
||||
self.assertTrue(same(opt_fn(input2), correct2))
|
||||
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
self.assertEqual(
|
||||
cnt.op_count, ifunspec(35, ifdyn(ifdynstaticdefault(4, 20), 4))
|
||||
)
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(cnt.frame_count, """2""")
|
||||
self.assertExpectedInline(cnt.op_count, """4""")
|
||||
else:
|
||||
self.assertExpectedInline(cnt.frame_count, """2""")
|
||||
self.assertExpectedInline(cnt.op_count, """35""")
|
||||
|
||||
def test_hf_t5_forward(self):
|
||||
input = torch.randn([1, 2048, 512])
|
||||
|
|
@ -993,8 +1006,12 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
opt_model = torch._dynamo.optimize_assert(cnt)(model)
|
||||
self.assertTrue(same(opt_model(input), correct))
|
||||
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(11, 12), 11))
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(cnt.frame_count, """1""")
|
||||
self.assertExpectedInline(cnt.op_count, """11""")
|
||||
else:
|
||||
self.assertExpectedInline(cnt.frame_count, """1""")
|
||||
self.assertExpectedInline(cnt.op_count, """12""")
|
||||
|
||||
def test_module_in_skipfiles(self):
|
||||
model = nn.Linear(10, 10)
|
||||
|
|
@ -1072,7 +1089,10 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
for _ in range(10):
|
||||
self.assertTrue(same(opt_model(a, b, c, d), correct))
|
||||
|
||||
# self.assertEqual(cnt.frame_count, ifdyn(3, 2))
|
||||
# if torch._dynamo.config.assume_static_by_default:
|
||||
# self.assertExpectedInline(cnt.frame_count, """2""")
|
||||
# else:
|
||||
# self.assertExpectedInline(cnt.frame_count, """3""")
|
||||
# TODO(jansel): figure out why op count depends on imports
|
||||
self.assertIn(cnt.op_count, (36, 35, 34, 29, 28, 27))
|
||||
|
||||
|
|
@ -1091,9 +1111,9 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertTrue(same(opt_model(a, b, c, d), correct))
|
||||
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertEqual(cnt.frame_count, ifdyn(2, 4))
|
||||
self.assertExpectedInline(cnt.frame_count, """2""")
|
||||
else:
|
||||
self.assertEqual(cnt.frame_count, ifdyn(3, 6))
|
||||
self.assertExpectedInline(cnt.frame_count, """3""")
|
||||
|
||||
def test_hf_model_output(self):
|
||||
ex = ModelOutput(a=torch.randn(10), b=torch.randn(10), c=torch.randn(10))
|
||||
|
|
@ -1271,9 +1291,13 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
|
||||
self.assertEqual(opt_fn(cfg), 64)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
# With unspec int, maximum computation is preserved
|
||||
self.assertEqual(cnt.op_count, ifunspec(4, 3))
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(cnt.frame_count, """1""")
|
||||
self.assertExpectedInline(cnt.op_count, """3""")
|
||||
else:
|
||||
self.assertExpectedInline(cnt.frame_count, """1""")
|
||||
self.assertExpectedInline(cnt.op_count, """4""")
|
||||
|
||||
def test_reformer_sorting(self):
|
||||
x = torch.zeros([1, 12, 4096], dtype=torch.int64)
|
||||
|
|
@ -1283,8 +1307,12 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
|
||||
self.assertTrue(same(opt_fn(x), correct))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(14, 27), 14))
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(cnt.frame_count, """1""")
|
||||
self.assertExpectedInline(cnt.op_count, """14""")
|
||||
else:
|
||||
self.assertExpectedInline(cnt.frame_count, """1""")
|
||||
self.assertExpectedInline(cnt.op_count, """27""")
|
||||
|
||||
def test_recursive_map(self):
|
||||
# https://github.com/pytorch/torchdynamo/issues/132
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import torch._dynamo.test_case
|
|||
import torch._dynamo.testing
|
||||
from torch._dynamo import config
|
||||
from torch._dynamo.testing import unsupported
|
||||
from torch._dynamo.utils import ifdyn, ifdynstaticdefault, ifunspec
|
||||
from torch._dynamo.utils import ifdynstaticdefault
|
||||
|
||||
globalmod = torch.nn.ReLU()
|
||||
|
||||
|
|
@ -328,7 +328,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
|
|||
# means we fail to unroll the loop.
|
||||
# TODO: Consider forcing specialization when we iterate over
|
||||
# the loop
|
||||
self._common(fn, 2, ifunspec(1, 4))
|
||||
self._common(fn, 2, ifdynstaticdefault(4, 1))
|
||||
|
||||
def test_restore_range_iter(self):
|
||||
def fn(a, b):
|
||||
|
|
@ -598,7 +598,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
|
|||
b = b + x * i
|
||||
return b
|
||||
|
||||
self._common(fn, 1, ifdyn(ifdynstaticdefault(2, 7), 2))
|
||||
self._common(fn, 1, ifdynstaticdefault(2, 7))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -35,7 +35,9 @@ cache_size_limit = 64
|
|||
|
||||
# whether or not to specialize on int inputs. This only has an effect with
|
||||
# dynamic_shapes; when dynamic_shapes is False, we ALWAYS specialize on int
|
||||
# inputs
|
||||
# inputs. Note that assume_static_by_default will also cause ints to get
|
||||
# specialized, so this is mostly useful for export, where we want inputs
|
||||
# to be dynamic, but accesses to ints should NOT get promoted into inputs.
|
||||
specialize_int = False
|
||||
|
||||
# Assume these functions return constants
|
||||
|
|
|
|||
|
|
@ -1398,13 +1398,6 @@ def fqn(obj: Any):
|
|||
return f"{obj.__module__}.{obj.__qualname__}"
|
||||
|
||||
|
||||
def ifdyn(count1, count2):
|
||||
if torch._dynamo.config.dynamic_shapes:
|
||||
return count1
|
||||
else:
|
||||
return count2
|
||||
|
||||
|
||||
def ifdynstaticdefault(count1, count2):
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
return count1
|
||||
|
|
@ -1412,13 +1405,6 @@ def ifdynstaticdefault(count1, count2):
|
|||
return count2
|
||||
|
||||
|
||||
def ifunspec(count1, count2):
|
||||
if torch._dynamo.config.dynamic_shapes and not torch._dynamo.config.specialize_int:
|
||||
return count1
|
||||
else:
|
||||
return count2
|
||||
|
||||
|
||||
def import_submodule(mod: types.ModuleType):
|
||||
"""
|
||||
Ensure all the files in a given submodule are imported
|
||||
|
|
|
|||
|
|
@ -2159,6 +2159,9 @@ class TestCase(expecttest.TestCase):
|
|||
def enforceNonDefaultStream(self):
|
||||
return CudaNonDefaultStream()
|
||||
|
||||
def assertExpectedInline(self, actual, expect, skip=0):
|
||||
return super().assertExpectedInline(actual if isinstance(actual, str) else str(actual), expect, skip + 1)
|
||||
|
||||
def assertLogs(self, logger=None, level=None):
|
||||
if logger is None:
|
||||
logger = logging.getLogger("torch")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user