pytorch/test/dynamo/test_recompiles.py
Oguz Ulgen dc55704b48 Rename cache limit to recompile limit in configs (#143709)
This PR renames every cache_limit to recompile_limit via sed.

Old config options are maintained via Config(alias='xyz')

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143709
Approved by: https://github.com/jansel
2024-12-22 10:03:57 +00:00

482 lines
17 KiB
Python

# Owner(s): ["module: dynamo"]
from unittest.mock import patch
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
class RecompileTests(torch._dynamo.test_case.TestCase):
def test_automatic_dynamic_reduce_recompiles(self):
# Test the counterfactual, lots of recompiles without this config
def foo(x, y):
return x * y
def run_foo_6_times_and_count_recompiles(dynamic=None):
cnt = torch._dynamo.testing.CompileCounter()
x = torch.randn([2])
y = torch.randn([2])
opt = torch.compile(foo, backend=cnt, dynamic=dynamic)
opt(x, y)
x = torch.randn([3])
y = torch.randn([3])
opt(x, y)
x = torch.randn([4])
y = torch.randn([4])
opt(x, y)
opt(x, y)
x = torch.randn([5])
y = torch.randn([5])
opt(x, y)
opt(x, y)
x = torch.randn([6])
y = torch.randn([6])
opt(x, y)
return cnt
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_without_automatic():
return run_foo_6_times_and_count_recompiles()
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_with_automatic():
return run_foo_6_times_and_count_recompiles()
without = run_without_automatic()
self.assertEqual(without.frame_count, 5)
self.assertEqual(without.op_count, 5)
torch._dynamo.reset()
without = run_foo_6_times_and_count_recompiles(dynamic=False)
self.assertEqual(without.frame_count, 5)
self.assertEqual(without.op_count, 5)
torch._dynamo.reset()
with_automatic = run_with_automatic()
self.assertEqual(with_automatic.frame_count, 2)
self.assertEqual(with_automatic.op_count, 2)
torch._dynamo.reset()
with_automatic = run_foo_6_times_and_count_recompiles(dynamic=None)
self.assertEqual(with_automatic.frame_count, 2)
self.assertEqual(with_automatic.op_count, 2)
torch._dynamo.reset()
with_dynamic = run_foo_6_times_and_count_recompiles(dynamic=True)
self.assertEqual(with_dynamic.frame_count, 1)
self.assertEqual(with_dynamic.op_count, 1)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def test_recompiles_true_false_flop(self):
# Test the counterfactual, lots of recompiles without this config
def foo(x, y):
if x:
return y * 2
else:
return y * y
def run_foo_6_times_and_count_recompiles():
cnt = torch._dynamo.testing.CompileCounter()
opt = torch.compile(foo, backend=cnt, fullgraph=True)
x = True
y = torch.randn([2])
opt(x, y)
x = False
y = torch.randn([2])
opt(x, y)
x = True
y = torch.randn([3])
opt(x, y)
x = True
y = torch.randn([4])
opt(x, y)
x = True
y = torch.randn([5])
opt(x, y)
return cnt
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_without_automatic():
return run_foo_6_times_and_count_recompiles()
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_with_automatic():
return run_foo_6_times_and_count_recompiles()
without = run_without_automatic()
self.assertEqual(without.frame_count, 5)
self.assertEqual(without.op_count, 5)
torch._dynamo.reset()
with_automatic = run_with_automatic()
self.assertEqual(with_automatic.frame_count, 3)
self.assertEqual(with_automatic.op_count, 3)
def test_automatic_dynamic_tensor_scalar_change(self):
# Test the counterfactual, lots of recompiles without this config
def foo(x, y):
return x * y
def run_foo_6_times_and_count_recompiles_swap_types():
cnt = torch._dynamo.testing.CompileCounter()
x = torch.randn([2])
y = torch.randn([2])
opt = torch.compile(foo, backend=cnt)
opt(x, y)
x = torch.randn([3])
y = 3
opt(x, y)
x = torch.randn([4])
y = torch.randn([4])
opt(x, y)
opt(x, y)
x = torch.randn([5])
y = 4
opt(x, y)
opt(x, y)
x = torch.randn([6])
y = torch.randn([6])
opt(x, y)
return cnt
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_without_automatic():
return run_foo_6_times_and_count_recompiles_swap_types()
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_with_automatic():
return run_foo_6_times_and_count_recompiles_swap_types()
without = run_without_automatic()
self.assertEqual(without.frame_count, 5)
self.assertEqual(without.op_count, 5)
torch._dynamo.reset()
with_automatic = run_with_automatic()
self.assertEqual(with_automatic.frame_count, 3)
self.assertEqual(with_automatic.op_count, 3)
def test_aliasing_guard_failures(self):
def foo(a, b, c):
a.add_(b)
return c + 1
cnt = torch._dynamo.testing.CompileCounter()
compiled_foo = torch.compile(foo, backend=cnt, fullgraph=True)
x = torch.randn([3])
y = torch.randn([3])
z = torch.randn([3])
cmp_result = compiled_foo(
x.detach().clone(), y.detach().clone(), z.detach().clone()
)
eager_result = foo(x.detach().clone(), y.detach().clone(), z.detach().clone())
self.assertEqual(cmp_result, eager_result)
self.assertEqual(cnt.frame_count, 1)
cmp_result = compiled_foo(
z.detach().clone(), y.detach().clone(), x.detach().clone()
)
eager_result = foo(z.detach().clone(), y.detach().clone(), x.detach().clone())
self.assertEqual(cmp_result, eager_result)
# No recompile, alias preserved
self.assertEqual(cnt.frame_count, 1)
x_clone = x.detach().clone()
cmp_result = compiled_foo(x_clone, y.detach().clone(), x_clone)
x_clone = x.detach().clone()
eager_result = compiled_foo(x_clone, y.detach().clone(), x_clone)
self.assertEqual(cmp_result, eager_result)
# Recompile, alias changed
self.assertEqual(cnt.frame_count, 2)
def test_aliasing_guard_failures_with_globals(self):
g1 = torch.randn([3])
g2 = torch.randn([3])
def foo(a):
a.add_(g1)
return g2 + 1
cnt = torch._dynamo.testing.CompileCounter()
compiled_foo = torch.compile(foo, backend=cnt, fullgraph=True)
z = torch.randn([3])
cmp_result = compiled_foo(z.detach().clone())
eager_result = foo(z.detach().clone())
self.assertEqual(cmp_result, eager_result)
self.assertEqual(cnt.frame_count, 1)
g1 = g1.detach().clone()
cmp_result = compiled_foo(g1)
g1 = g1.detach().clone()
eager_result = compiled_foo(g1)
self.assertEqual(cmp_result, eager_result)
# Recompile, alias changed
self.assertEqual(cnt.frame_count, 2)
def test_dynamic_shape_parameter_recompile(self):
# Test the matrix multiplication with Parameters.
# Without the config assume_parameters_shapes_static_by_default,
# the torch.nn.Parameter shapes are assumed to be static which leads to recompilation
w = torch.nn.Parameter(torch.randn(3, 2))
def foo(x):
return x @ w
def run_foo_6_times_and_count_recompiles():
cnt = torch._dynamo.testing.CompileCounter()
opt = torch.compile(foo, backend=cnt, fullgraph=True)
x = torch.nn.Parameter(torch.randn(1, 3))
opt(x)
x = torch.nn.Parameter(torch.randn(10, 3))
opt(x)
x = torch.nn.Parameter(torch.randn(11, 3))
opt(x)
x = torch.nn.Parameter(torch.randn(15, 3))
opt(x)
x = torch.nn.Parameter(torch.randn(15, 3))
opt(x)
return cnt
@patch.object(torch._dynamo.config, "force_parameter_static_shapes", True)
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_static_comp_default_param():
return run_foo_6_times_and_count_recompiles()
@patch.object(torch._dynamo.config, "force_parameter_static_shapes", True)
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_dynamic_comp_default_param():
return run_foo_6_times_and_count_recompiles()
@patch.object(torch._dynamo.config, "force_parameter_static_shapes", False)
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_static_comp_dynamic_param():
return run_foo_6_times_and_count_recompiles()
@patch.object(torch._dynamo.config, "force_parameter_static_shapes", False)
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_dynamic_comp_dynamic_param():
return run_foo_6_times_and_count_recompiles()
torch._dynamo.reset()
static_comp_default_param = run_static_comp_default_param()
self.assertEqual(static_comp_default_param.frame_count, 4)
self.assertEqual(static_comp_default_param.op_count, 4)
torch._dynamo.reset()
dynamic_comp_default_param = run_dynamic_comp_default_param()
self.assertEqual(dynamic_comp_default_param.frame_count, 4)
self.assertEqual(dynamic_comp_default_param.op_count, 4)
torch._dynamo.reset()
static_comp_dynamic_param = run_static_comp_dynamic_param()
self.assertEqual(static_comp_dynamic_param.frame_count, 4)
self.assertEqual(static_comp_dynamic_param.op_count, 4)
torch._dynamo.reset()
dynamic_comp_dynamic_param = run_dynamic_comp_dynamic_param()
self.assertEqual(dynamic_comp_dynamic_param.frame_count, 2)
self.assertEqual(dynamic_comp_dynamic_param.op_count, 2)
def test_simple_module_recompile(self):
class SimpleDropout(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.dropout = torch.nn.Dropout(0.5)
self.linear = torch.nn.Linear(10, 1)
def forward(self, x):
return self.dropout(self.linear(x))
model = SimpleDropout()
x = torch.randn(10)
counter = torch._dynamo.testing.CompileCounter()
model = torch.compile(model, backend=counter, fullgraph=True)
for _ in range(20):
model.eval()
model(x)
model.train()
model(x)
self.assertEqual(counter.frame_count, 2)
@patch.object(torch._dynamo.config, "recompile_limit", 2)
def test_no_recursive_compile_after_cache_limit_hit(self):
def f(x, n):
x = x + n
return g(x, n)
def g(x, n):
x = x + n
return h(x, n)
def h(x, n):
return x + n
counter = torch._dynamo.testing.CompileCounter()
opt_f = torch.compile(f, backend=counter, dynamic=False)
for i in range(10):
opt_f(torch.ones(3), i)
self.assertEqual(counter.frame_count, 2)
def test_automatic_dynamic_on_closed_ints(self):
def f(x):
def g(y):
return y + x
return g
counter = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=counter)
def h(x, g):
return g(x)
for i in range(10):
h(torch.randn(5), f(i))
self.assertEqual(counter.frame_count, 2)
@patch.object(torch._dynamo.config, "recompile_limit", 2)
def test_run_mode_after_cache_limit_hit(self):
def f(x, n):
x = x + n
if torch._dynamo.is_compiling():
x = x + 1
return g(x, n)
def g(x, n):
x = x + n
if torch._dynamo.is_compiling():
x = x + 2
return x
counter = torch._dynamo.testing.CompileCounter()
opt_f = torch.compile(f, backend=counter, dynamic=False)
# compiles
self.assertEqual(opt_f(torch.ones(3), 0), torch.ones(3) + 3)
self.assertEqual(opt_f(torch.ones(3), 1), torch.ones(3) + 5)
# cache limit hit
self.assertEqual(opt_f(torch.ones(3), 2), torch.ones(3) + 4)
self.assertEqual(opt_f(torch.ones(3), 3), torch.ones(3) + 6)
# run mode
self.assertEqual(opt_f(torch.ones(3), 0), torch.ones(3) + 3)
self.assertEqual(opt_f(torch.ones(3), 1), torch.ones(3) + 5)
self.assertEqual(counter.frame_count, 2)
@torch._dynamo.config.patch(automatic_dynamic_shapes_mark_as="unbacked")
def test_automatic_dynamic_shapes_mark_as_unbacked(self):
counter = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=counter)
def f(x):
return x * x
f(torch.randn(3))
f(torch.randn(2))
f(torch.randn(1))
f(torch.randn(0))
self.assertEqual(counter.frame_count, 2) # not three or four!
@torch._dynamo.config.patch(automatic_dynamic_shapes_mark_as="oblivious")
def test_automatic_dynamic_shapes_mark_as_oblivious(self):
counter = torch._dynamo.testing.CompileCounter()
def f(x):
if x.size(0) < 10:
return x * 1
else:
return x + 10
opt_f = torch.compile(backend=counter, fullgraph=True)(f)
for i in [3, 2, 1, 0]:
self.assertEqual(f(torch.zeros(i)), opt_f(torch.zeros(i)))
self.assertEqual(counter.frame_count, 2) # not three or four!
@torch._dynamo.config.patch(automatic_dynamic_shapes_mark_as="oblivious")
def test_automatic_dynamic_shapes_mark_as_oblivious_fail_counterfactual(self):
counter = torch._dynamo.testing.CompileCounter()
def f(x):
if x.size(0) < 2:
return x * 1
else:
return x + 10
opt_f = torch.compile(backend=counter, fullgraph=True)(f)
opt_f(torch.randn(1))
with self.assertRaises(torch._dynamo.exc.UserError):
opt_f(torch.randn(0))
def test_ambient_autocast_recompile(self):
weights = torch.randn(10, 10)
counter = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
@torch.compile(backend=counter, fullgraph=True)
def fn(x):
return torch.mm(x, weights)
x = torch.randn(1, 10)
self.assertEqual(fn(x).dtype, torch.float32)
with torch.autocast("cpu", torch.float16):
self.assertEqual(fn(x).dtype, torch.float16)
with torch.autocast("cpu", torch.bfloat16):
self.assertEqual(fn(x).dtype, torch.bfloat16)
# should recompile each time
self.assertEqual(counter.frame_count, 3)
def test_autocast_constant_fold(self):
# test that constant-folded autocast functions
# work properly - it should work if the global autocast
# state is guarded.
weights = torch.randn(10, 10)
counter = torch._dynamo.testing.CompileCounterWithBackend("eager")
def fn(x):
if torch.get_autocast_dtype("cpu") == torch.float16:
x = x + 1
else:
x = x - 1
return torch.mm(x, weights)
opt_fn = torch.compile(fn, backend=counter, fullgraph=True)
x = torch.randn(1, 10)
with torch.autocast("cpu", torch.float16):
self.assertEqual(fn(x), opt_fn(x))
with torch.autocast("cpu", torch.bfloat16):
self.assertEqual(fn(x), opt_fn(x))
self.assertEqual(counter.frame_count, 2)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()