From bf34e228c52d3075f23a211623a569ea9be5e5ee Mon Sep 17 00:00:00 2001 From: angelayi Date: Thu, 20 Mar 2025 23:02:03 +0000 Subject: [PATCH] [export] Beef up guard_added logs (#149465) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149465 Approved by: https://github.com/pianpwk --- test/export/test_draft_export.py | 1 - torch/_logging/_internal.py | 2 +- torch/export/_draft_export.py | 54 +++++++++++++----------- torch/fx/experimental/symbolic_shapes.py | 7 +-- 4 files changed, 35 insertions(+), 29 deletions(-) diff --git a/test/export/test_draft_export.py b/test/export/test_draft_export.py index 81a9354ea1f..6fda3fcdb0a 100644 --- a/test/export/test_draft_export.py +++ b/test/export/test_draft_export.py @@ -265,7 +265,6 @@ class TestDraftExport(TestCase): self.assertEqual( report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR ) - self.assertTrue(len(report.expressions_created) >= 4) for _ep in [ep, ep.run_decompositions()]: # check data-dependent asserts assert_scalar_nodes = [ diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index d07c8277c7a..0889b1ee23e 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -1294,7 +1294,7 @@ def dtrace_structured( *, payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None, suppress_context: bool = False, - expect_trace_id: bool = True, # Whether or not we expect to have a current trace id + expect_trace_id: bool = False, # Whether or not we expect to have a current trace id record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging ): """ diff --git a/torch/export/_draft_export.py b/torch/export/_draft_export.py index d7b036a0a95..604f865a2b0 100644 --- a/torch/export/_draft_export.py +++ b/torch/export/_draft_export.py @@ -291,6 +291,22 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler): self.prev_get_dtrace = False def emit(self, record: Any) -> None: + def _log_expression_created( + emit_func: Callable[[Any], None], sym_node_id: int + ) -> None: + # Log all the relevant expression_created logs + if sym_node_id is None: + return + if res := self.expression_created_logs.get(sym_node_id, None): + # Don't log the expression if we have already + # printed it beforehand + if not res.visited: + res.visited = True + for arg in res.argument_ids: + _log_expression_created(emit_func, arg) + + emit_func(res.record) + metadata = record.metadata for key in self.specific_log_keys: if key in metadata: @@ -306,27 +322,23 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler): metadata[key].get("argument_ids", []), record, ) - continue + return elif key == "propagate_real_tensors_provenance": + _log_expression_created( + super().emit, metadata[key].get("expr_node_id") + ) - def _log_expression_created( - emit_func: Callable[[Any], None], sym_node_id: int - ) -> None: - # Log all the relevant expression_created logs - if sym_node_id is None: - return - if res := self.expression_created_logs.get( - sym_node_id, None - ): - # Don't log the expression if we have already - # printed it beforehand - if not res.visited: - res.visited = True - for arg in res.argument_ids: - _log_expression_created(emit_func, arg) - - emit_func(res.record) + elif key == "guard_added": + if len(metadata[key]["symbol_to_sources"]) == 0: + # We only want to include guards added that are relevant to + # the symbolic shapes corresponding to the inputs which were + # specified in the dynamic_shapes arg. These have a source. + return + elif metadata[key]["prefix"] == "runtime_assert": + # This should've been captured by a + # propagate_real_tensors log + return _log_expression_created( super().emit, metadata[key].get("expr_node_id") @@ -409,12 +421,6 @@ def draft_export( continue failure_type = FailureType.CONSTRAINT_VIOLATION_ERROR - if len(log_contents["symbol_to_sources"]) == 0: - # We only want to include guards added that are relevant to - # the symbolic shapes corresponding to the inputs which were - # specified in the dynamic_shapes arg. These have a source. - continue - log_contents["new_dynamic_shapes"] = new_shapes elif log_name == "missing_fake_kernel": failure_type = FailureType.MISSING_FAKE_KERNEL diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 951dd084f35..ca84f6f44a9 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -6593,9 +6593,10 @@ class ShapeEnv: "guard_added", metadata_fn=lambda: { "expr": str(g), - "stack": structured.from_traceback( - CapturedTraceback.extract(skip=1).summary() - ), + "prefix": prefix, + "expr_node_id": self._expr_sym_node_id, + "user_stack": structured.get_user_stack(3), + "stack": structured.get_framework_stack(3), "symbol_to_sources": { str(v): k for k, v in self.source_to_var.items()