pytorch/test/dynamo/test_config.py
Edward Z. Yang bc6ec97e02 Switch dynamic_shapes to True by default (#103597)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103597
Approved by: https://github.com/voznesenskym
2023-06-15 15:16:20 +00:00

66 lines
2.2 KiB
Python

# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.utils import disable_cache_limit
# NB: do NOT include this test class in test_dynamic_shapes.py
class ConfigTests(torch._dynamo.test_case.TestCase):
@disable_cache_limit()
def test_no_automatic_dynamic(self):
def fn(a, b):
return a - b * 10
torch._dynamo.reset()
cnt_static = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=False, assume_static_by_default=True
):
opt_fn = torch._dynamo.optimize(cnt_static)(fn)
for i in range(2, 12):
opt_fn(torch.randn(i), torch.randn(i))
self.assertEqual(cnt_static.frame_count, 10)
@disable_cache_limit()
def test_automatic_dynamic(self):
def fn(a, b):
return a - b * 10
torch._dynamo.reset()
cnt_dynamic = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=True
):
opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn)
# NB: must not do 0, 1 as they specialized
for i in range(2, 12):
opt_fn(torch.randn(i), torch.randn(i))
# two graphs now rather than 10
self.assertEqual(cnt_dynamic.frame_count, 2)
@disable_cache_limit()
def test_no_assume_static_by_default(self):
def fn(a, b):
return a - b * 10
torch._dynamo.reset()
cnt_dynamic = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=False
):
opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn)
# NB: must not do 0, 1 as they specialized
for i in range(2, 12):
opt_fn(torch.randn(i), torch.randn(i))
# one graph now, as we didn't wait for recompile
self.assertEqual(cnt_dynamic.frame_count, 1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()