[export] Beef up guard_added logs (#149465)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149465
Approved by: https://github.com/pianpwk
This commit is contained in:
angelayi 2025-03-20 23:02:03 +00:00 committed by PyTorch MergeBot
parent 1d3c50fcc5
commit bf34e228c5
4 changed files with 35 additions and 29 deletions

View File

@ -265,7 +265,6 @@ class TestDraftExport(TestCase):
self.assertEqual( self.assertEqual(
report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR
) )
self.assertTrue(len(report.expressions_created) >= 4)
for _ep in [ep, ep.run_decompositions()]: for _ep in [ep, ep.run_decompositions()]:
# check data-dependent asserts # check data-dependent asserts
assert_scalar_nodes = [ assert_scalar_nodes = [

View File

@ -1294,7 +1294,7 @@ def dtrace_structured(
*, *,
payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None, payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
suppress_context: bool = False, suppress_context: bool = False,
expect_trace_id: bool = True, # Whether or not we expect to have a current trace id expect_trace_id: bool = False, # Whether or not we expect to have a current trace id
record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging
): ):
""" """

View File

@ -291,6 +291,22 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
self.prev_get_dtrace = False self.prev_get_dtrace = False
def emit(self, record: Any) -> None: def emit(self, record: Any) -> None:
def _log_expression_created(
emit_func: Callable[[Any], None], sym_node_id: int
) -> None:
# Log all the relevant expression_created logs
if sym_node_id is None:
return
if res := self.expression_created_logs.get(sym_node_id, None):
# Don't log the expression if we have already
# printed it beforehand
if not res.visited:
res.visited = True
for arg in res.argument_ids:
_log_expression_created(emit_func, arg)
emit_func(res.record)
metadata = record.metadata metadata = record.metadata
for key in self.specific_log_keys: for key in self.specific_log_keys:
if key in metadata: if key in metadata:
@ -306,27 +322,23 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
metadata[key].get("argument_ids", []), metadata[key].get("argument_ids", []),
record, record,
) )
continue return
elif key == "propagate_real_tensors_provenance": elif key == "propagate_real_tensors_provenance":
_log_expression_created(
super().emit, metadata[key].get("expr_node_id")
)
def _log_expression_created( elif key == "guard_added":
emit_func: Callable[[Any], None], sym_node_id: int if len(metadata[key]["symbol_to_sources"]) == 0:
) -> None: # We only want to include guards added that are relevant to
# Log all the relevant expression_created logs # the symbolic shapes corresponding to the inputs which were
if sym_node_id is None: # specified in the dynamic_shapes arg. These have a source.
return return
if res := self.expression_created_logs.get( elif metadata[key]["prefix"] == "runtime_assert":
sym_node_id, None # This should've been captured by a
): # propagate_real_tensors log
# Don't log the expression if we have already return
# printed it beforehand
if not res.visited:
res.visited = True
for arg in res.argument_ids:
_log_expression_created(emit_func, arg)
emit_func(res.record)
_log_expression_created( _log_expression_created(
super().emit, metadata[key].get("expr_node_id") super().emit, metadata[key].get("expr_node_id")
@ -409,12 +421,6 @@ def draft_export(
continue continue
failure_type = FailureType.CONSTRAINT_VIOLATION_ERROR failure_type = FailureType.CONSTRAINT_VIOLATION_ERROR
if len(log_contents["symbol_to_sources"]) == 0:
# We only want to include guards added that are relevant to
# the symbolic shapes corresponding to the inputs which were
# specified in the dynamic_shapes arg. These have a source.
continue
log_contents["new_dynamic_shapes"] = new_shapes log_contents["new_dynamic_shapes"] = new_shapes
elif log_name == "missing_fake_kernel": elif log_name == "missing_fake_kernel":
failure_type = FailureType.MISSING_FAKE_KERNEL failure_type = FailureType.MISSING_FAKE_KERNEL

View File

@ -6593,9 +6593,10 @@ class ShapeEnv:
"guard_added", "guard_added",
metadata_fn=lambda: { metadata_fn=lambda: {
"expr": str(g), "expr": str(g),
"stack": structured.from_traceback( "prefix": prefix,
CapturedTraceback.extract(skip=1).summary() "expr_node_id": self._expr_sym_node_id,
), "user_stack": structured.get_user_stack(3),
"stack": structured.get_framework_stack(3),
"symbol_to_sources": { "symbol_to_sources": {
str(v): k str(v): k
for k, v in self.source_to_var.items() for k, v in self.source_to_var.items()