Enhancements to recompiles logs (#130043)

----

- We now record on CacheEntry what the compile id that populated it was, so now we can say why a specific frame was rejected
- Add structured log for recompiles under name artifact "recompile_reasons". As it stands, it's not terribly structured, but this was the easiest thing I could do to start
- Slightly reformat multi-reason printing; since we only report one guard failure seems better to have it as a single line

Example output:

```
V0703 10:34:13.273000 140345997743104 torch/_dynamo/guards.py:2590] [0/1] [__recompiles] Recompiling function f in /data/users/ezyang/a/pytorch/b.py:3
V0703 10:34:13.273000 140345997743104 torch/_dynamo/guards.py:2590] [0/1] [__recompiles]     triggered by the following guard failure(s):
V0703 10:34:13.273000 140345997743104 torch/_dynamo/guards.py:2590] [0/1] [__recompiles]     - 0/0: tensor 'L['x']' size mismatch at index 0. expected 4, actual 5
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130043
Approved by: https://github.com/anijain2305
This commit is contained in:
Edward Z. Yang 2024-07-06 05:51:36 -07:00 committed by PyTorch MergeBot
parent 29861779ce
commit e836ee1955
12 changed files with 50 additions and 56 deletions

View File

@ -2,6 +2,7 @@
import torch
import torch._dynamo.test_case
from torch._guards import CompileId
def target_with_varkwargs(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs):
@ -89,7 +90,7 @@ class FrameInitTests(torch._dynamo.test_case.TestCase):
if frame.f_code in code_map1:
transformed_code = code_map1[frame.f_code]
return torch._dynamo.types.GuardedCode(
transformed_code, lambda f_locals: True
transformed_code, lambda f_locals: True, CompileId(0, 0)
)
return None
@ -97,7 +98,7 @@ class FrameInitTests(torch._dynamo.test_case.TestCase):
if frame.f_code in code_map2:
transformed_code = code_map2[frame.f_code]
return torch._dynamo.types.GuardedCode(
transformed_code, lambda f_locals: True
transformed_code, lambda f_locals: True, CompileId(0, 0)
)
return None

View File

@ -2470,9 +2470,7 @@ class HigherOrderOpVmapGuardTests(LoggingTestCase):
self.assertGreater(len(records), 0)
record = self.getRecord(records, "pyfunctorch")
self.assertIn(
"""\
triggered by the following guard failure(s):
- torch._functorch.pyfunctorch.compare_functorch_state([])""",
"""torch._functorch.pyfunctorch.compare_functorch_state([])""",
munge_exc(record.getMessage()),
)
@ -2502,9 +2500,7 @@ class HigherOrderOpVmapGuardTests(LoggingTestCase):
self.assertGreater(len(records), 0)
record = self.getRecord(records, "forward_ad")
self.assertIn(
"""\
triggered by the following guard failure(s):
- torch.autograd.forward_ad._current_level == -1""",
"""torch.autograd.forward_ad._current_level == -1""",
munge_exc(record.getMessage()),
)
@ -2534,17 +2530,13 @@ class HigherOrderOpVmapGuardTests(LoggingTestCase):
if self.hasRecord(records, "pyfunctorch"):
record = self.getRecord(records, "pyfunctorch")
self.assertIn(
"""\
triggered by the following guard failure(s):
- torch._functorch.pyfunctorch.compare_functorch_state([])""",
"""torch._functorch.pyfunctorch.compare_functorch_state([])""",
munge_exc(record.getMessage()),
)
elif self.hasRecord(records, "forward_ad"):
record = self.getRecord(records, "forward_ad")
self.assertIn(
"""\
triggered by the following guard failure(s):
- torch.autograd.forward_ad._current_level == -1""",
"""torch.autograd.forward_ad._current_level == -1""",
munge_exc(record.getMessage()),
)
@ -2588,9 +2580,7 @@ class HigherOrderOpVmapGuardTests(LoggingTestCase):
self.assertGreater(len(records), 0)
record = self.getRecord(records, "pyfunctorch")
self.assertIn(
"""\
triggered by the following guard failure(s):
- torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'same')])""",
"""torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'same')])""",
record.getMessage(),
)
@ -2614,9 +2604,7 @@ class HigherOrderOpVmapGuardTests(LoggingTestCase):
self.assertGreater(len(records), 0)
record = self.getRecord(records, "pyfunctorch")
self.assertIn(
"""\
triggered by the following guard failure(s):
- torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""",
"""torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""",
record.getMessage(),
)
@ -2644,9 +2632,7 @@ class HigherOrderOpVmapGuardTests(LoggingTestCase):
self.assertGreater(len(records), 0)
record = self.getRecord(records, "pyfunctorch")
self.assertIn(
"""\
triggered by the following guard failure(s):
- torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""",
"""torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""",
munge_exc(record.getMessage()),
)
@ -2665,9 +2651,7 @@ class HigherOrderOpVmapGuardTests(LoggingTestCase):
self.assertGreater(len(records), 0)
record = self.getRecord(records, "pyfunctorch")
self.assertIn(
"""\
triggered by the following guard failure(s):
- torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'same')])""",
"""torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'same')])""",
munge_exc(record.getMessage()),
)

View File

@ -614,14 +614,11 @@ print("arf")
record_str = "\n".join(r.getMessage() for r in records)
self.assertIn(
"""\
L['zs'][0] == 3.0 # for y, z in zip(ys, zs):""",
"""L['zs'][0] == 3.0""",
record_str,
)
self.assertIn(
"""\
triggered by the following guard failure(s):\n\
- len(L['ys']) == 2 # for y, z in zip(ys, zs):""",
"len(L['ys']) == 2",
record_str,
)

View File

@ -2200,7 +2200,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
m._forward_hooks[handle.id] = new_forward_hook
self.assertEqual(compiled_func(inp), outer_func(inp))
self.assertEqual(compiled_func(inp).item(), 16)
self.assertRegex(failure_reason, r"^___check_obj_id\(L\['m'\]._forward_hooks")
self.assertRegex(failure_reason, r"___check_obj_id\(L\['m'\]._forward_hooks")
@patch.object(torch._dynamo.config, "guard_nn_modules", False)
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True)

View File

@ -152,7 +152,6 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
""", # noqa: B950
@ -176,7 +175,6 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
""", # noqa: B950
@ -207,9 +205,9 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"artifact": {"name": "recompile_reasons", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
@ -219,7 +217,6 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
""", # noqa: B950
@ -242,7 +239,6 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
""", # noqa: B950
@ -339,7 +335,6 @@ class StructuredTraceTest(TestCase):
self.buffer.getvalue(),
"""\
{"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"dynamo_guards": {}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1}
{"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0}
@ -365,7 +360,6 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_guards": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0}
""", # noqa: B950
@ -445,7 +439,6 @@ class StructuredTraceTest(TestCase):
self.buffer.getvalue(),
"""\
{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 1}
{"dynamo_start": {"stack": "STACK"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0}
@ -457,7 +450,6 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_guards": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 1, "frame_compile_id": 0, "attempt": 0}
""", # noqa: B950
@ -488,9 +480,9 @@ class StructuredTraceTest(TestCase):
{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [20, 30], "is_leaf": true, "stride": [30, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['b']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_a_": [10, 20], "l_b_": [20, 30], "matmul": [10, 30]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"artifact": {"name": "recompile_reasons", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 200}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [5, 10], "is_leaf": true, "stride": [10, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
@ -499,7 +491,6 @@ class StructuredTraceTest(TestCase):
{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [10, 15], "is_leaf": true, "stride": [15, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['b']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_a_": ["s0", "s1"], "l_b_": ["s1", "s3"], "matmul": ["s0", "s3"]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
""", # noqa: B950
@ -532,15 +523,14 @@ class StructuredTraceTest(TestCase):
{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_x_": [1], "x": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"artifact": {"name": "recompile_reasons", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_x_": [1], "x": [1]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
""", # noqa: B950
@ -572,7 +562,6 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
@ -583,7 +572,6 @@ class StructuredTraceTest(TestCase):
{"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
""", # noqa: B950

View File

@ -728,7 +728,7 @@ def _compile(
hooks.guard_fail_fn if hooks else None,
)
guarded_code = GuardedCode(out_code, check_fn.check_fn)
guarded_code = GuardedCode(out_code, check_fn.check_fn, compile_id)
if not output.is_empty_graph() and hooks.guard_export_fn is not None:
# We should not run the guard_export_fn when Dynamo does not

View File

@ -49,6 +49,8 @@ from torch._dynamo.source import (
TensorPropertySource,
)
from torch._guards import (
CompileContext,
CompileId,
DuplicateInputs,
Guard,
GuardBuilderBase,
@ -2133,6 +2135,7 @@ class CheckFunctionManager:
reasons = get_guard_fail_reason_helper(
self.guard_manager, # type: ignore[arg-type]
output_graph.local_scope,
CompileContext.current_compile_id(),
)
raise AssertionError(f"Guard check failed: {reasons}")
@ -2300,9 +2303,10 @@ class CheckFunctionManager:
add_code_part(code, gcl.guard, config.enable_cpp_guard_manager)
# OK, all done generating guards
torch._logging.trace_structured(
"dynamo_guards", payload_fn=lambda: [f() for f in structured_guard_fns]
)
if structured_guard_fns:
torch._logging.trace_structured(
"dynamo_guards", payload_fn=lambda: [f() for f in structured_guard_fns]
)
global_state = convert_frame.initial_global_state
if global_state is None:
@ -2472,6 +2476,7 @@ def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope):
def get_guard_fail_reason_helper(
guard_fn: GuardFn,
f_locals: Dict[str, object],
compile_id: CompileId,
) -> str:
"""
Return the reason why `guard_fn` failed.
@ -2536,7 +2541,7 @@ def get_guard_fail_reason_helper(
if not is_recompiles_verbose_enabled():
break
reason_str = "\n".join(reasons)
reason_str = f"{compile_id}: " + "; ".join(reasons)
return reason_str
@ -2544,8 +2549,9 @@ def get_guard_fail_reason(
guard_fn: GuardFn,
code: types.CodeType,
f_locals: Dict[str, object],
compile_id: CompileId,
) -> str:
reason_str = get_guard_fail_reason_helper(guard_fn, f_locals)
reason_str = get_guard_fail_reason_helper(guard_fn, f_locals, compile_id)
guard_failures[orig_code_map[code]].append(reason_str)
try:
@ -2572,7 +2578,10 @@ def get_and_maybe_log_recompilation_reason(
reasons = []
while cache_entry is not None:
reason = get_guard_fail_reason(
cache_entry.check_fn, cache_entry.code, frame.f_locals
cache_entry.check_fn,
cache_entry.code,
frame.f_locals,
cache_entry.compile_id,
)
if reason:
reasons.append(reason)
@ -2606,6 +2615,15 @@ def get_and_maybe_log_recompilation_reason(
if config.error_on_recompile:
raise exc.RecompileError(message)
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "recompile_reasons",
"encoding": "json",
},
payload_fn=lambda: reasons,
)
return reasons

View File

@ -29,7 +29,7 @@ from .bytecode_transformation import (
is_generator,
transform_code_object,
)
from .guards import CheckFunctionManager, GuardedCode
from .guards import CheckFunctionManager, CompileId, GuardedCode
from .utils import same
unsupported = eval_frame.unsupported
@ -163,7 +163,7 @@ def debug_insert_nops(
f_code=frame.f_code,
)
return GuardedCode(code, CheckFunctionManager(graph).check_fn)
return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))
class CompileCounter:

View File

@ -13,6 +13,7 @@ else:
DynamoFrameType: TypeAlias = types.FrameType
import torch
from torch._guards import CompileId
# This class has a `check_fn` field for the guard,
# and a `code` field for the code object.
@ -50,6 +51,7 @@ class GuardFn(Protocol):
class GuardedCode:
code: types.CodeType
check_fn: GuardFn
compile_id: CompileId
class DynamoCallbackFn(Protocol):

View File

@ -7,6 +7,7 @@
CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) {
this->check_fn = guarded_code.attr("check_fn");
this->code = guarded_code.attr("code");
this->compile_id = guarded_code.attr("compile_id");
this->backend = backend;
// TODO - clean this up when enable_cpp_guard_manager is True by default
if (py::hasattr(this->check_fn, "root")) {

View File

@ -41,6 +41,8 @@ typedef struct VISIBILITY_HIDDEN CacheEntry {
py::object check_fn;
// modified user bytecode (protected by check_fn's guards)
py::object code;
// CompileId corresponding to this compilation
py::object compile_id;
// root guard manager if exists
void* root_mgr{nullptr};
// backend used to create this cache entry

View File

@ -60,6 +60,7 @@ void initDynamoBindings(PyObject* torch) {
py::class_<CacheEntry>(m, "_CacheEntry")
.def_readonly("check_fn", &CacheEntry::check_fn)
.def_readonly("code", &CacheEntry::code)
.def_readonly("compile_id", &CacheEntry::compile_id)
.def_property_readonly("next", &CacheEntry::next);
py::class_<ExtraState>(m, "_ExtraState")