[export] Beef up guard_added logs (#149465)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149465
Approved by: https://github.com/pianpwk
This commit is contained in:
angelayi 2025-03-20 23:02:03 +00:00 committed by PyTorch MergeBot
parent 1d3c50fcc5
commit bf34e228c5
4 changed files with 35 additions and 29 deletions

View File

@ -265,7 +265,6 @@ class TestDraftExport(TestCase):
self.assertEqual(
report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR
)
self.assertTrue(len(report.expressions_created) >= 4)
for _ep in [ep, ep.run_decompositions()]:
# check data-dependent asserts
assert_scalar_nodes = [

View File

@ -1294,7 +1294,7 @@ def dtrace_structured(
*,
payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
suppress_context: bool = False,
expect_trace_id: bool = True, # Whether or not we expect to have a current trace id
expect_trace_id: bool = False, # Whether or not we expect to have a current trace id
record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging
):
"""

View File

@ -291,6 +291,22 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
self.prev_get_dtrace = False
def emit(self, record: Any) -> None:
def _log_expression_created(
emit_func: Callable[[Any], None], sym_node_id: int
) -> None:
# Log all the relevant expression_created logs
if sym_node_id is None:
return
if res := self.expression_created_logs.get(sym_node_id, None):
# Don't log the expression if we have already
# printed it beforehand
if not res.visited:
res.visited = True
for arg in res.argument_ids:
_log_expression_created(emit_func, arg)
emit_func(res.record)
metadata = record.metadata
for key in self.specific_log_keys:
if key in metadata:
@ -306,27 +322,23 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
metadata[key].get("argument_ids", []),
record,
)
continue
return
elif key == "propagate_real_tensors_provenance":
_log_expression_created(
super().emit, metadata[key].get("expr_node_id")
)
def _log_expression_created(
emit_func: Callable[[Any], None], sym_node_id: int
) -> None:
# Log all the relevant expression_created logs
if sym_node_id is None:
return
if res := self.expression_created_logs.get(
sym_node_id, None
):
# Don't log the expression if we have already
# printed it beforehand
if not res.visited:
res.visited = True
for arg in res.argument_ids:
_log_expression_created(emit_func, arg)
emit_func(res.record)
elif key == "guard_added":
if len(metadata[key]["symbol_to_sources"]) == 0:
# We only want to include guards added that are relevant to
# the symbolic shapes corresponding to the inputs which were
# specified in the dynamic_shapes arg. These have a source.
return
elif metadata[key]["prefix"] == "runtime_assert":
# This should've been captured by a
# propagate_real_tensors log
return
_log_expression_created(
super().emit, metadata[key].get("expr_node_id")
@ -409,12 +421,6 @@ def draft_export(
continue
failure_type = FailureType.CONSTRAINT_VIOLATION_ERROR
if len(log_contents["symbol_to_sources"]) == 0:
# We only want to include guards added that are relevant to
# the symbolic shapes corresponding to the inputs which were
# specified in the dynamic_shapes arg. These have a source.
continue
log_contents["new_dynamic_shapes"] = new_shapes
elif log_name == "missing_fake_kernel":
failure_type = FailureType.MISSING_FAKE_KERNEL

View File

@ -6593,9 +6593,10 @@ class ShapeEnv:
"guard_added",
metadata_fn=lambda: {
"expr": str(g),
"stack": structured.from_traceback(
CapturedTraceback.extract(skip=1).summary()
),
"prefix": prefix,
"expr_node_id": self._expr_sym_node_id,
"user_stack": structured.get_user_stack(3),
"stack": structured.get_framework_stack(3),
"symbol_to_sources": {
str(v): k
for k, v in self.source_to_var.items()