mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] Beef up guard_added logs (#149465)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149465 Approved by: https://github.com/pianpwk
This commit is contained in:
parent
1d3c50fcc5
commit
bf34e228c5
|
|
@ -265,7 +265,6 @@ class TestDraftExport(TestCase):
|
|||
self.assertEqual(
|
||||
report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR
|
||||
)
|
||||
self.assertTrue(len(report.expressions_created) >= 4)
|
||||
for _ep in [ep, ep.run_decompositions()]:
|
||||
# check data-dependent asserts
|
||||
assert_scalar_nodes = [
|
||||
|
|
|
|||
|
|
@ -1294,7 +1294,7 @@ def dtrace_structured(
|
|||
*,
|
||||
payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
|
||||
suppress_context: bool = False,
|
||||
expect_trace_id: bool = True, # Whether or not we expect to have a current trace id
|
||||
expect_trace_id: bool = False, # Whether or not we expect to have a current trace id
|
||||
record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging
|
||||
):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -291,6 +291,22 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
|
|||
self.prev_get_dtrace = False
|
||||
|
||||
def emit(self, record: Any) -> None:
|
||||
def _log_expression_created(
|
||||
emit_func: Callable[[Any], None], sym_node_id: int
|
||||
) -> None:
|
||||
# Log all the relevant expression_created logs
|
||||
if sym_node_id is None:
|
||||
return
|
||||
if res := self.expression_created_logs.get(sym_node_id, None):
|
||||
# Don't log the expression if we have already
|
||||
# printed it beforehand
|
||||
if not res.visited:
|
||||
res.visited = True
|
||||
for arg in res.argument_ids:
|
||||
_log_expression_created(emit_func, arg)
|
||||
|
||||
emit_func(res.record)
|
||||
|
||||
metadata = record.metadata
|
||||
for key in self.specific_log_keys:
|
||||
if key in metadata:
|
||||
|
|
@ -306,27 +322,23 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
|
|||
metadata[key].get("argument_ids", []),
|
||||
record,
|
||||
)
|
||||
continue
|
||||
return
|
||||
|
||||
elif key == "propagate_real_tensors_provenance":
|
||||
_log_expression_created(
|
||||
super().emit, metadata[key].get("expr_node_id")
|
||||
)
|
||||
|
||||
def _log_expression_created(
|
||||
emit_func: Callable[[Any], None], sym_node_id: int
|
||||
) -> None:
|
||||
# Log all the relevant expression_created logs
|
||||
if sym_node_id is None:
|
||||
return
|
||||
if res := self.expression_created_logs.get(
|
||||
sym_node_id, None
|
||||
):
|
||||
# Don't log the expression if we have already
|
||||
# printed it beforehand
|
||||
if not res.visited:
|
||||
res.visited = True
|
||||
for arg in res.argument_ids:
|
||||
_log_expression_created(emit_func, arg)
|
||||
|
||||
emit_func(res.record)
|
||||
elif key == "guard_added":
|
||||
if len(metadata[key]["symbol_to_sources"]) == 0:
|
||||
# We only want to include guards added that are relevant to
|
||||
# the symbolic shapes corresponding to the inputs which were
|
||||
# specified in the dynamic_shapes arg. These have a source.
|
||||
return
|
||||
elif metadata[key]["prefix"] == "runtime_assert":
|
||||
# This should've been captured by a
|
||||
# propagate_real_tensors log
|
||||
return
|
||||
|
||||
_log_expression_created(
|
||||
super().emit, metadata[key].get("expr_node_id")
|
||||
|
|
@ -409,12 +421,6 @@ def draft_export(
|
|||
continue
|
||||
|
||||
failure_type = FailureType.CONSTRAINT_VIOLATION_ERROR
|
||||
if len(log_contents["symbol_to_sources"]) == 0:
|
||||
# We only want to include guards added that are relevant to
|
||||
# the symbolic shapes corresponding to the inputs which were
|
||||
# specified in the dynamic_shapes arg. These have a source.
|
||||
continue
|
||||
|
||||
log_contents["new_dynamic_shapes"] = new_shapes
|
||||
elif log_name == "missing_fake_kernel":
|
||||
failure_type = FailureType.MISSING_FAKE_KERNEL
|
||||
|
|
|
|||
|
|
@ -6593,9 +6593,10 @@ class ShapeEnv:
|
|||
"guard_added",
|
||||
metadata_fn=lambda: {
|
||||
"expr": str(g),
|
||||
"stack": structured.from_traceback(
|
||||
CapturedTraceback.extract(skip=1).summary()
|
||||
),
|
||||
"prefix": prefix,
|
||||
"expr_node_id": self._expr_sym_node_id,
|
||||
"user_stack": structured.get_user_stack(3),
|
||||
"stack": structured.get_framework_stack(3),
|
||||
"symbol_to_sources": {
|
||||
str(v): k
|
||||
for k, v in self.source_to_var.items()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user