[dynamo] Remove L scoping for recompilation messages (#148917)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148917
Approved by: https://github.com/williamwen42
This commit is contained in:
Animesh Jain 2025-03-10 20:55:41 -07:00 committed by PyTorch MergeBot
parent 992838e702
commit f1787ee0f7
5 changed files with 48 additions and 39 deletions

View File

@ -318,7 +318,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
compare_equal_outs_and_grads(self, F(), fxy, (x, y))
compare_equal_outs_and_grads(self, F(), fxy, (x, z))
self.assertIn(
"""tensor 'L['y']' requires_grad mismatch. expected requires_grad=1""",
"""tensor 'y' requires_grad mismatch. expected requires_grad=1""",
failure_reason,
)
@ -436,7 +436,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
fxx(x3, x3)
fxx(x4, y4)
self.assertEqual(cc.frame_count, 2)
self.assertIn("""L['x'] is L['y']""", failure_reason)
self.assertIn("""x is y""", failure_reason)
@patch("torch._functorch.config.debug_assert", True)
def test_arg_dupe_via_dynamo_recompiles_many_args_param_non_tensor_arg(self):
@ -470,7 +470,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
f(a2, b2, 2, 2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
"""a is b""",
failure_reason,
)
@ -487,7 +487,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
f(c3, c3, 3, 3)
f(c4, d4, 3, 3)
self.assertEqual(cc.frame_count, 2)
self.assertIn("""L['a'] is L['b']""", failure_reason)
self.assertIn("""a is b""", failure_reason)
@patch("torch._functorch.config.debug_assert", True)
def test_arg_dupe_via_dynamo_recompiles_many_with_global(self):
@ -524,7 +524,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
f(a2, b2, 2, 2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
"""a is b""",
failure_reason,
)
@ -560,7 +560,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
f([3, 2, 1], [4, 5, 6], a2, b2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
"""a is b""",
failure_reason,
)
@ -610,7 +610,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
f(a2, b2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
"""a is b""",
failure_reason,
)
@ -627,7 +627,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
f(c3, c3)
f(c4, d4)
self.assertEqual(cc.frame_count, 2)
self.assertIn("""L['a'] is L['b']""", failure_reason)
self.assertIn("""a is b""", failure_reason)
@patch("torch._functorch.config.debug_assert", True)
def test_arg_dupe_via_dynamo_recompiles_many_args(self):
@ -659,7 +659,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
f(a2, b2, b2, b2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
"""a is b""",
failure_reason,
)
@ -676,7 +676,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
f(a3, b3, c3, c3)
f(a4, b4, c4, d4)
self.assertEqual(cc.frame_count, 2)
self.assertIn("""L['c'] is L['d']""", failure_reason)
self.assertIn("""c is d""", failure_reason)
def test_alias_inputs(self):
def fn():
@ -1523,7 +1523,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
)
self.assertExpectedInline(
guard_failure,
"""0/0: check_overlapping(overlapping=[L['args'][1], L['args'][2]], non_overlapping=[L['args'][0]])""",
"""0/0: check_overlapping(overlapping=[args[1], args[2]], non_overlapping=[args[0]])""",
)
def test_different_inputs_overlapping_set_with_mutation(self):
@ -1546,7 +1546,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
)
self.assertExpectedInline(
guard_failure,
"""0/0: check_overlapping(overlapping=[L['a'], L['b']], non_overlapping=[L['c'], L['d']])""",
"""0/0: check_overlapping(overlapping=[a, b], non_overlapping=[c, d])""",
)
def _test_no_storage_overlap_guards(self, f, argsfn):

View File

@ -897,7 +897,7 @@ class MiscTests(torch._inductor.test_case.TestCase):
opt_fn(torch.randn([3, 4]))
opt_fn(torch.randn([4, 3]))
self.assertIn(
"""tensor 'L['a']' size mismatch at index 0. expected 3, actual 4""",
"""tensor 'a' size mismatch at index 0. expected 3, actual 4""",
guard_failure.reason,
)
@ -6657,11 +6657,11 @@ utils_device.CURRENT_DEVICE == None""".split(
first_guard_failure = guard_failure[0].partition("\n")[0]
if torch._dynamo.config.assume_static_by_default:
self.assertIn(
"""tensor 'L['x']' size mismatch at index 0. expected 2, actual 5""",
"""tensor 'x' size mismatch at index 0. expected 2, actual 5""",
first_guard_failure,
)
else:
self.assertIn("""L['x'].size()[0] < 3""", first_guard_failure)
self.assertIn("""x.size()[0] < 3""", first_guard_failure)
def test_guard_failure_fn2(self):
def fn(x, y):
@ -6690,7 +6690,7 @@ utils_device.CURRENT_DEVICE == None""".split(
if torch._dynamo.config.assume_static_by_default:
self.assertIn(
"""tensor 'L['x']' size mismatch at index 0. expected 2, actual 3""",
"""tensor 'x' size mismatch at index 0. expected 2, actual 3""",
guard_failure[0],
)
else:
@ -6725,7 +6725,7 @@ utils_device.CURRENT_DEVICE == None""".split(
# guard is expected for both static and dynamic shapes
self.assertTrue(guard_failure is not None)
self.assertIn(
"""len(L['x']) == 10""",
"""len(x) == 10""",
guard_failure[0],
)
@ -6782,7 +6782,7 @@ utils_device.CURRENT_DEVICE == None""".split(
opt_out = opt_fn(args2)
self.assertEqual(out, opt_out)
self.assertTrue(guard_failure is not None)
self.assertIn("""tensor 'L['x']' size mismatch at index 0""", guard_failure[0])
self.assertIn("""tensor 'x' size mismatch at index 0""", guard_failure[0])
def test_restore_graphstate(self):
# This function does some guard accumulation,
@ -6851,9 +6851,7 @@ utils_device.CURRENT_DEVICE == None""".split(
x = torch.randn(3)
self.assertEqual(fn(x), opt_fn(x))
self.assertTrue(guard_failure is not None)
self.assertIn(
"""tensor 'L['rank']' size mismatch at index 0""", guard_failure[0]
)
self.assertIn("""tensor 'rank' size mismatch at index 0""", guard_failure[0])
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
def test_symint_as_device_kwarg_non_strict_export(self):

View File

@ -161,28 +161,26 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
cache_fail_test(
a,
a[0:2, :, :],
"tensor 'L['a']' size mismatch at index 0. expected 3, actual 2",
"tensor 'a' size mismatch at index 0. expected 3, actual 2",
)
cache_fail_test(
a,
a.clone().as_strided((3, 4, 5), stride=(1, 3, 12)),
"tensor 'L['a']' stride mismatch at index 0. expected 20, actual 1",
"tensor 'a' stride mismatch at index 0. expected 20, actual 1",
)
cache_fail_test(
a, a[0, :, :], "tensor 'L['a']' rank mismatch. expected 3, actual 2"
)
cache_fail_test(a, a.to("meta"), "tensor 'L['a']' dispatch key set mismatch.")
cache_fail_test(a, a[0, :, :], "tensor 'a' rank mismatch. expected 3, actual 2")
cache_fail_test(a, a.to("meta"), "tensor 'a' dispatch key set mismatch.")
cache_fail_test(
a,
a.to(torch.float16),
"tensor 'L['a']' dtype mismatch. expected Float, actual Half",
"tensor 'a' dtype mismatch. expected Float, actual Half",
)
a_grad = a.clone()
a_grad.requires_grad = True
cache_fail_test(
a,
a_grad,
"tensor 'L['a']' requires_grad mismatch. expected requires_grad=0",
"tensor 'a' requires_grad mismatch. expected requires_grad=0",
)
def test_mismatched_type(self):
@ -201,7 +199,7 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
opt_func(a, 1)
self.assert_single_log_contains(
logs,
"expected type of 'L['b']' to be a tensor type, ' but found <class 'int'>",
"expected type of 'b' to be a tensor type, ' but found <class 'int'>",
)
@torch._dynamo.config.patch(recompile_limit=1, fail_on_recompile_limit_hit=True)
@ -237,10 +235,10 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
failure_str = "\n".join(failure_reasons)
for line in """\
tensor 'L['x']' size mismatch at index 0. expected 11, actual 12
tensor 'L['x']' size mismatch at index 0. expected 10, actual 12
tensor 'L['x']' size mismatch at index 0. expected 9, actual 12
tensor 'L['x']' size mismatch at index 0. expected 8, actual 12""".split(
tensor 'x' size mismatch at index 0. expected 11, actual 12
tensor 'x' size mismatch at index 0. expected 10, actual 12
tensor 'x' size mismatch at index 0. expected 9, actual 12
tensor 'x' size mismatch at index 0. expected 8, actual 12""".split(
"\n"
):
self.assertIn(
@ -278,7 +276,7 @@ tensor 'L['x']' size mismatch at index 0. expected 8, actual 12""".split(
opt_f([7, 8])
for line in """\
len(L['x']) == 3""".split(
len(x) == 3""".split(
"\n"
):
self.assertIn(line, filter_reasons())
@ -287,8 +285,8 @@ len(L['x']) == 3""".split(
opt_f([9])
for line in """\
len(L['x']) == 2
len(L['x']) == 3""".split(
len(x) == 2
len(x) == 3""".split(
"\n"
):
self.assertIn(line, filter_reasons())

View File

@ -4429,7 +4429,7 @@ class CommonTemplate:
gemm_opt(x1, y1)
self.assertTrue(failed_guard is not None)
self.assertTrue(
"tensor 'L['x']' Tensor device index mismatch. Expected device index to be"
"tensor 'x' Tensor device index mismatch. Expected device index to be"
in failed_guard.reason
)

View File

@ -2846,6 +2846,19 @@ def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope):
return [f"Duplicate tensors found: {reason}"]
def strip_local_scope(s: str) -> str:
"""
Replace occurrences of L[...] with just the inner content.
Handles both single and double quotes.
This is to generate user friendly recompilation messages.
"""
import re
pattern = r"L\[\s*['\"](.*?)['\"]\s*\]"
return re.sub(pattern, r"\1", s)
def get_guard_fail_reason_helper(
guard_manager: GuardFn,
f_locals: dict[str, object],
@ -2910,7 +2923,7 @@ def get_guard_fail_reason_helper(
break
reason_str = f"{compile_id}: " + "; ".join(reasons)
return reason_str
return strip_local_scope(reason_str)
def get_guard_fail_reason(