mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] Beef up guard_added logs (#149465)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149465 Approved by: https://github.com/pianpwk
This commit is contained in:
parent
1d3c50fcc5
commit
bf34e228c5
|
|
@ -265,7 +265,6 @@ class TestDraftExport(TestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR
|
report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR
|
||||||
)
|
)
|
||||||
self.assertTrue(len(report.expressions_created) >= 4)
|
|
||||||
for _ep in [ep, ep.run_decompositions()]:
|
for _ep in [ep, ep.run_decompositions()]:
|
||||||
# check data-dependent asserts
|
# check data-dependent asserts
|
||||||
assert_scalar_nodes = [
|
assert_scalar_nodes = [
|
||||||
|
|
|
||||||
|
|
@ -1294,7 +1294,7 @@ def dtrace_structured(
|
||||||
*,
|
*,
|
||||||
payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
|
payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
|
||||||
suppress_context: bool = False,
|
suppress_context: bool = False,
|
||||||
expect_trace_id: bool = True, # Whether or not we expect to have a current trace id
|
expect_trace_id: bool = False, # Whether or not we expect to have a current trace id
|
||||||
record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging
|
record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -291,6 +291,22 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
|
||||||
self.prev_get_dtrace = False
|
self.prev_get_dtrace = False
|
||||||
|
|
||||||
def emit(self, record: Any) -> None:
|
def emit(self, record: Any) -> None:
|
||||||
|
def _log_expression_created(
|
||||||
|
emit_func: Callable[[Any], None], sym_node_id: int
|
||||||
|
) -> None:
|
||||||
|
# Log all the relevant expression_created logs
|
||||||
|
if sym_node_id is None:
|
||||||
|
return
|
||||||
|
if res := self.expression_created_logs.get(sym_node_id, None):
|
||||||
|
# Don't log the expression if we have already
|
||||||
|
# printed it beforehand
|
||||||
|
if not res.visited:
|
||||||
|
res.visited = True
|
||||||
|
for arg in res.argument_ids:
|
||||||
|
_log_expression_created(emit_func, arg)
|
||||||
|
|
||||||
|
emit_func(res.record)
|
||||||
|
|
||||||
metadata = record.metadata
|
metadata = record.metadata
|
||||||
for key in self.specific_log_keys:
|
for key in self.specific_log_keys:
|
||||||
if key in metadata:
|
if key in metadata:
|
||||||
|
|
@ -306,27 +322,23 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
|
||||||
metadata[key].get("argument_ids", []),
|
metadata[key].get("argument_ids", []),
|
||||||
record,
|
record,
|
||||||
)
|
)
|
||||||
continue
|
return
|
||||||
|
|
||||||
elif key == "propagate_real_tensors_provenance":
|
elif key == "propagate_real_tensors_provenance":
|
||||||
|
_log_expression_created(
|
||||||
|
super().emit, metadata[key].get("expr_node_id")
|
||||||
|
)
|
||||||
|
|
||||||
def _log_expression_created(
|
elif key == "guard_added":
|
||||||
emit_func: Callable[[Any], None], sym_node_id: int
|
if len(metadata[key]["symbol_to_sources"]) == 0:
|
||||||
) -> None:
|
# We only want to include guards added that are relevant to
|
||||||
# Log all the relevant expression_created logs
|
# the symbolic shapes corresponding to the inputs which were
|
||||||
if sym_node_id is None:
|
# specified in the dynamic_shapes arg. These have a source.
|
||||||
|
return
|
||||||
|
elif metadata[key]["prefix"] == "runtime_assert":
|
||||||
|
# This should've been captured by a
|
||||||
|
# propagate_real_tensors log
|
||||||
return
|
return
|
||||||
if res := self.expression_created_logs.get(
|
|
||||||
sym_node_id, None
|
|
||||||
):
|
|
||||||
# Don't log the expression if we have already
|
|
||||||
# printed it beforehand
|
|
||||||
if not res.visited:
|
|
||||||
res.visited = True
|
|
||||||
for arg in res.argument_ids:
|
|
||||||
_log_expression_created(emit_func, arg)
|
|
||||||
|
|
||||||
emit_func(res.record)
|
|
||||||
|
|
||||||
_log_expression_created(
|
_log_expression_created(
|
||||||
super().emit, metadata[key].get("expr_node_id")
|
super().emit, metadata[key].get("expr_node_id")
|
||||||
|
|
@ -409,12 +421,6 @@ def draft_export(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
failure_type = FailureType.CONSTRAINT_VIOLATION_ERROR
|
failure_type = FailureType.CONSTRAINT_VIOLATION_ERROR
|
||||||
if len(log_contents["symbol_to_sources"]) == 0:
|
|
||||||
# We only want to include guards added that are relevant to
|
|
||||||
# the symbolic shapes corresponding to the inputs which were
|
|
||||||
# specified in the dynamic_shapes arg. These have a source.
|
|
||||||
continue
|
|
||||||
|
|
||||||
log_contents["new_dynamic_shapes"] = new_shapes
|
log_contents["new_dynamic_shapes"] = new_shapes
|
||||||
elif log_name == "missing_fake_kernel":
|
elif log_name == "missing_fake_kernel":
|
||||||
failure_type = FailureType.MISSING_FAKE_KERNEL
|
failure_type = FailureType.MISSING_FAKE_KERNEL
|
||||||
|
|
|
||||||
|
|
@ -6593,9 +6593,10 @@ class ShapeEnv:
|
||||||
"guard_added",
|
"guard_added",
|
||||||
metadata_fn=lambda: {
|
metadata_fn=lambda: {
|
||||||
"expr": str(g),
|
"expr": str(g),
|
||||||
"stack": structured.from_traceback(
|
"prefix": prefix,
|
||||||
CapturedTraceback.extract(skip=1).summary()
|
"expr_node_id": self._expr_sym_node_id,
|
||||||
),
|
"user_stack": structured.get_user_stack(3),
|
||||||
|
"stack": structured.get_framework_stack(3),
|
||||||
"symbol_to_sources": {
|
"symbol_to_sources": {
|
||||||
str(v): k
|
str(v): k
|
||||||
for k, v in self.source_to_var.items()
|
for k, v in self.source_to_var.items()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user