mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Remove L scoping for recompilation messages (#148917)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148917 Approved by: https://github.com/williamwen42
This commit is contained in:
parent
992838e702
commit
f1787ee0f7
|
|
@ -318,7 +318,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
|||
compare_equal_outs_and_grads(self, F(), fxy, (x, y))
|
||||
compare_equal_outs_and_grads(self, F(), fxy, (x, z))
|
||||
self.assertIn(
|
||||
"""tensor 'L['y']' requires_grad mismatch. expected requires_grad=1""",
|
||||
"""tensor 'y' requires_grad mismatch. expected requires_grad=1""",
|
||||
failure_reason,
|
||||
)
|
||||
|
||||
|
|
@ -436,7 +436,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
|||
fxx(x3, x3)
|
||||
fxx(x4, y4)
|
||||
self.assertEqual(cc.frame_count, 2)
|
||||
self.assertIn("""L['x'] is L['y']""", failure_reason)
|
||||
self.assertIn("""x is y""", failure_reason)
|
||||
|
||||
@patch("torch._functorch.config.debug_assert", True)
|
||||
def test_arg_dupe_via_dynamo_recompiles_many_args_param_non_tensor_arg(self):
|
||||
|
|
@ -470,7 +470,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
|||
f(a2, b2, 2, 2)
|
||||
self.assertEqual(cc.frame_count, 2)
|
||||
self.assertIn(
|
||||
"""L['a'] is L['b']""",
|
||||
"""a is b""",
|
||||
failure_reason,
|
||||
)
|
||||
|
||||
|
|
@ -487,7 +487,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
|||
f(c3, c3, 3, 3)
|
||||
f(c4, d4, 3, 3)
|
||||
self.assertEqual(cc.frame_count, 2)
|
||||
self.assertIn("""L['a'] is L['b']""", failure_reason)
|
||||
self.assertIn("""a is b""", failure_reason)
|
||||
|
||||
@patch("torch._functorch.config.debug_assert", True)
|
||||
def test_arg_dupe_via_dynamo_recompiles_many_with_global(self):
|
||||
|
|
@ -524,7 +524,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
|||
f(a2, b2, 2, 2)
|
||||
self.assertEqual(cc.frame_count, 2)
|
||||
self.assertIn(
|
||||
"""L['a'] is L['b']""",
|
||||
"""a is b""",
|
||||
failure_reason,
|
||||
)
|
||||
|
||||
|
|
@ -560,7 +560,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
|||
f([3, 2, 1], [4, 5, 6], a2, b2)
|
||||
self.assertEqual(cc.frame_count, 2)
|
||||
self.assertIn(
|
||||
"""L['a'] is L['b']""",
|
||||
"""a is b""",
|
||||
failure_reason,
|
||||
)
|
||||
|
||||
|
|
@ -610,7 +610,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
|||
f(a2, b2)
|
||||
self.assertEqual(cc.frame_count, 2)
|
||||
self.assertIn(
|
||||
"""L['a'] is L['b']""",
|
||||
"""a is b""",
|
||||
failure_reason,
|
||||
)
|
||||
|
||||
|
|
@ -627,7 +627,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
|||
f(c3, c3)
|
||||
f(c4, d4)
|
||||
self.assertEqual(cc.frame_count, 2)
|
||||
self.assertIn("""L['a'] is L['b']""", failure_reason)
|
||||
self.assertIn("""a is b""", failure_reason)
|
||||
|
||||
@patch("torch._functorch.config.debug_assert", True)
|
||||
def test_arg_dupe_via_dynamo_recompiles_many_args(self):
|
||||
|
|
@ -659,7 +659,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
|||
f(a2, b2, b2, b2)
|
||||
self.assertEqual(cc.frame_count, 2)
|
||||
self.assertIn(
|
||||
"""L['a'] is L['b']""",
|
||||
"""a is b""",
|
||||
failure_reason,
|
||||
)
|
||||
|
||||
|
|
@ -676,7 +676,7 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
|||
f(a3, b3, c3, c3)
|
||||
f(a4, b4, c4, d4)
|
||||
self.assertEqual(cc.frame_count, 2)
|
||||
self.assertIn("""L['c'] is L['d']""", failure_reason)
|
||||
self.assertIn("""c is d""", failure_reason)
|
||||
|
||||
def test_alias_inputs(self):
|
||||
def fn():
|
||||
|
|
@ -1523,7 +1523,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
|||
)
|
||||
self.assertExpectedInline(
|
||||
guard_failure,
|
||||
"""0/0: check_overlapping(overlapping=[L['args'][1], L['args'][2]], non_overlapping=[L['args'][0]])""",
|
||||
"""0/0: check_overlapping(overlapping=[args[1], args[2]], non_overlapping=[args[0]])""",
|
||||
)
|
||||
|
||||
def test_different_inputs_overlapping_set_with_mutation(self):
|
||||
|
|
@ -1546,7 +1546,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
|||
)
|
||||
self.assertExpectedInline(
|
||||
guard_failure,
|
||||
"""0/0: check_overlapping(overlapping=[L['a'], L['b']], non_overlapping=[L['c'], L['d']])""",
|
||||
"""0/0: check_overlapping(overlapping=[a, b], non_overlapping=[c, d])""",
|
||||
)
|
||||
|
||||
def _test_no_storage_overlap_guards(self, f, argsfn):
|
||||
|
|
|
|||
|
|
@ -897,7 +897,7 @@ class MiscTests(torch._inductor.test_case.TestCase):
|
|||
opt_fn(torch.randn([3, 4]))
|
||||
opt_fn(torch.randn([4, 3]))
|
||||
self.assertIn(
|
||||
"""tensor 'L['a']' size mismatch at index 0. expected 3, actual 4""",
|
||||
"""tensor 'a' size mismatch at index 0. expected 3, actual 4""",
|
||||
guard_failure.reason,
|
||||
)
|
||||
|
||||
|
|
@ -6657,11 +6657,11 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
first_guard_failure = guard_failure[0].partition("\n")[0]
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertIn(
|
||||
"""tensor 'L['x']' size mismatch at index 0. expected 2, actual 5""",
|
||||
"""tensor 'x' size mismatch at index 0. expected 2, actual 5""",
|
||||
first_guard_failure,
|
||||
)
|
||||
else:
|
||||
self.assertIn("""L['x'].size()[0] < 3""", first_guard_failure)
|
||||
self.assertIn("""x.size()[0] < 3""", first_guard_failure)
|
||||
|
||||
def test_guard_failure_fn2(self):
|
||||
def fn(x, y):
|
||||
|
|
@ -6690,7 +6690,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertIn(
|
||||
"""tensor 'L['x']' size mismatch at index 0. expected 2, actual 3""",
|
||||
"""tensor 'x' size mismatch at index 0. expected 2, actual 3""",
|
||||
guard_failure[0],
|
||||
)
|
||||
else:
|
||||
|
|
@ -6725,7 +6725,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
# guard is expected for both static and dynamic shapes
|
||||
self.assertTrue(guard_failure is not None)
|
||||
self.assertIn(
|
||||
"""len(L['x']) == 10""",
|
||||
"""len(x) == 10""",
|
||||
guard_failure[0],
|
||||
)
|
||||
|
||||
|
|
@ -6782,7 +6782,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
opt_out = opt_fn(args2)
|
||||
self.assertEqual(out, opt_out)
|
||||
self.assertTrue(guard_failure is not None)
|
||||
self.assertIn("""tensor 'L['x']' size mismatch at index 0""", guard_failure[0])
|
||||
self.assertIn("""tensor 'x' size mismatch at index 0""", guard_failure[0])
|
||||
|
||||
def test_restore_graphstate(self):
|
||||
# This function does some guard accumulation,
|
||||
|
|
@ -6851,9 +6851,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
x = torch.randn(3)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
self.assertTrue(guard_failure is not None)
|
||||
self.assertIn(
|
||||
"""tensor 'L['rank']' size mismatch at index 0""", guard_failure[0]
|
||||
)
|
||||
self.assertIn("""tensor 'rank' size mismatch at index 0""", guard_failure[0])
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
||||
def test_symint_as_device_kwarg_non_strict_export(self):
|
||||
|
|
|
|||
|
|
@ -161,28 +161,26 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
|
|||
cache_fail_test(
|
||||
a,
|
||||
a[0:2, :, :],
|
||||
"tensor 'L['a']' size mismatch at index 0. expected 3, actual 2",
|
||||
"tensor 'a' size mismatch at index 0. expected 3, actual 2",
|
||||
)
|
||||
cache_fail_test(
|
||||
a,
|
||||
a.clone().as_strided((3, 4, 5), stride=(1, 3, 12)),
|
||||
"tensor 'L['a']' stride mismatch at index 0. expected 20, actual 1",
|
||||
"tensor 'a' stride mismatch at index 0. expected 20, actual 1",
|
||||
)
|
||||
cache_fail_test(
|
||||
a, a[0, :, :], "tensor 'L['a']' rank mismatch. expected 3, actual 2"
|
||||
)
|
||||
cache_fail_test(a, a.to("meta"), "tensor 'L['a']' dispatch key set mismatch.")
|
||||
cache_fail_test(a, a[0, :, :], "tensor 'a' rank mismatch. expected 3, actual 2")
|
||||
cache_fail_test(a, a.to("meta"), "tensor 'a' dispatch key set mismatch.")
|
||||
cache_fail_test(
|
||||
a,
|
||||
a.to(torch.float16),
|
||||
"tensor 'L['a']' dtype mismatch. expected Float, actual Half",
|
||||
"tensor 'a' dtype mismatch. expected Float, actual Half",
|
||||
)
|
||||
a_grad = a.clone()
|
||||
a_grad.requires_grad = True
|
||||
cache_fail_test(
|
||||
a,
|
||||
a_grad,
|
||||
"tensor 'L['a']' requires_grad mismatch. expected requires_grad=0",
|
||||
"tensor 'a' requires_grad mismatch. expected requires_grad=0",
|
||||
)
|
||||
|
||||
def test_mismatched_type(self):
|
||||
|
|
@ -201,7 +199,7 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
|
|||
opt_func(a, 1)
|
||||
self.assert_single_log_contains(
|
||||
logs,
|
||||
"expected type of 'L['b']' to be a tensor type, ' but found <class 'int'>",
|
||||
"expected type of 'b' to be a tensor type, ' but found <class 'int'>",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(recompile_limit=1, fail_on_recompile_limit_hit=True)
|
||||
|
|
@ -237,10 +235,10 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
failure_str = "\n".join(failure_reasons)
|
||||
for line in """\
|
||||
tensor 'L['x']' size mismatch at index 0. expected 11, actual 12
|
||||
tensor 'L['x']' size mismatch at index 0. expected 10, actual 12
|
||||
tensor 'L['x']' size mismatch at index 0. expected 9, actual 12
|
||||
tensor 'L['x']' size mismatch at index 0. expected 8, actual 12""".split(
|
||||
tensor 'x' size mismatch at index 0. expected 11, actual 12
|
||||
tensor 'x' size mismatch at index 0. expected 10, actual 12
|
||||
tensor 'x' size mismatch at index 0. expected 9, actual 12
|
||||
tensor 'x' size mismatch at index 0. expected 8, actual 12""".split(
|
||||
"\n"
|
||||
):
|
||||
self.assertIn(
|
||||
|
|
@ -278,7 +276,7 @@ tensor 'L['x']' size mismatch at index 0. expected 8, actual 12""".split(
|
|||
opt_f([7, 8])
|
||||
|
||||
for line in """\
|
||||
len(L['x']) == 3""".split(
|
||||
len(x) == 3""".split(
|
||||
"\n"
|
||||
):
|
||||
self.assertIn(line, filter_reasons())
|
||||
|
|
@ -287,8 +285,8 @@ len(L['x']) == 3""".split(
|
|||
opt_f([9])
|
||||
|
||||
for line in """\
|
||||
len(L['x']) == 2
|
||||
len(L['x']) == 3""".split(
|
||||
len(x) == 2
|
||||
len(x) == 3""".split(
|
||||
"\n"
|
||||
):
|
||||
self.assertIn(line, filter_reasons())
|
||||
|
|
|
|||
|
|
@ -4429,7 +4429,7 @@ class CommonTemplate:
|
|||
gemm_opt(x1, y1)
|
||||
self.assertTrue(failed_guard is not None)
|
||||
self.assertTrue(
|
||||
"tensor 'L['x']' Tensor device index mismatch. Expected device index to be"
|
||||
"tensor 'x' Tensor device index mismatch. Expected device index to be"
|
||||
in failed_guard.reason
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -2846,6 +2846,19 @@ def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope):
|
|||
return [f"Duplicate tensors found: {reason}"]
|
||||
|
||||
|
||||
def strip_local_scope(s: str) -> str:
|
||||
"""
|
||||
Replace occurrences of L[...] with just the inner content.
|
||||
Handles both single and double quotes.
|
||||
|
||||
This is to generate user friendly recompilation messages.
|
||||
"""
|
||||
import re
|
||||
|
||||
pattern = r"L\[\s*['\"](.*?)['\"]\s*\]"
|
||||
return re.sub(pattern, r"\1", s)
|
||||
|
||||
|
||||
def get_guard_fail_reason_helper(
|
||||
guard_manager: GuardFn,
|
||||
f_locals: dict[str, object],
|
||||
|
|
@ -2910,7 +2923,7 @@ def get_guard_fail_reason_helper(
|
|||
break
|
||||
|
||||
reason_str = f"{compile_id}: " + "; ".join(reasons)
|
||||
return reason_str
|
||||
return strip_local_scope(reason_str)
|
||||
|
||||
|
||||
def get_guard_fail_reason(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user