Consolidate stack trace in Tracer (#156257)

Summary:
- Consolidate the stack trace recording code in TracerBase and PythonKeyTracer
- Change `make_fx`'s arg name to be consistent with TracerBase member name `record_stack_traces`

We move the stack trace logic from `create_proxy` to `create_node` so all inherited classes of TracerBase and re-use the same stack trace logic.

Test Plan:
```
buck run caffe2/test:test_export -- -r  test_stack_trace
```

Rollback Plan:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156257
Approved by: https://github.com/angelayi, https://github.com/zou3519
This commit is contained in:
Shangdi Yu 2025-06-25 23:07:07 +00:00 committed by PyTorch MergeBot
parent 653c52fe52
commit 204db27a0c
4 changed files with 57 additions and 69 deletions

View File

@ -11586,7 +11586,9 @@ graph():
return x
inp = torch.randn(4, 4)
gm = torch.fx.experimental.proxy_tensor.make_fx(Foo(), stack_trace=True)(
gm = torch.fx.experimental.proxy_tensor.make_fx(
Foo(), record_stack_traces=True
)(
inp,
)

View File

@ -11,7 +11,6 @@ import functools
import inspect
import logging
import operator
import traceback
import typing
import typing_extensions
import weakref
@ -67,7 +66,6 @@ from torch.utils._python_dispatch import (
)
from torch.utils._stats import count
from torch.utils._thunk import Thunk
from torch.utils._traceback import CapturedTraceback
from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary, WeakTensorKeyDictionary
from ._backward_state import BackwardState
@ -1017,7 +1015,6 @@ class PythonKeyTracer(Tracer):
tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
torch_fn_counts: dict[OpOverload, int]
enable_thunkify: bool = False
stack_trace: bool = False
def __init__(self) -> None:
super().__init__(autowrap_modules=()) # type: ignore[arg-type]
@ -1100,39 +1097,6 @@ class PythonKeyTracer(Tracer):
) -> torch.fx.Node:
node = super().create_node(kind, target, args, kwargs, name, type_expr) # type: ignore[arg-type]
# stack_trace
if (
self.stack_trace
and "stack_trace" not in node.meta
and node.op not in ["placeholder", "output"]
):
user_frame_summary = CapturedTraceback.extract().summary()
if user_frame_summary:
# we retain frames from forward() calls, or ops
# located in torch/__init__.py (e.g. sym_int, sym_constrain_range, vmap)
stack_trace = [
frame
for frame in user_frame_summary
if (
frame.name == "forward"
or frame.filename.endswith("torch/__init__.py")
)
]
# filter out forward() frames from fx/_symbolic_trace.py, export/_trace.py
# this is hardcoded, but leads to a much cleaner stack trace
stack_trace = [
frame
for frame in stack_trace
if not frame.filename.endswith(
("fx/_symbolic_trace.py", "export/_trace.py")
)
]
if (
stack_trace
): # empty list for strict mode, dynamo should handle stack_trace
stack_trace = traceback.StackSummary.from_list(stack_trace)
node.meta["stack_trace"] = "".join(stack_trace.format()).strip()
if kind == "get_attr":
assert isinstance(target, str)
attr = getattr(self.root, target)
@ -1698,7 +1662,8 @@ class _ModuleStackTracer(PythonKeyTracer):
def __init__(self, scope_root: GraphModule) -> None:
super().__init__()
self.stack_trace = True
self.record_stack_traces = True
self._record_forward_stack_traces_only = True
self.scope_root = scope_root
self.enable_attr_proxy = False
self.submodule_paths = {}
@ -1962,7 +1927,7 @@ class _MakefxTracer:
record_module_stack: bool,
_allow_fake_constant: bool,
_error_on_data_dependent_ops: bool,
stack_trace: bool = False,
record_stack_traces: bool = False,
) -> None:
# Configurations that are used to initialize the context managers and their states.
# Should not modify them during tracing.
@ -1993,7 +1958,7 @@ class _MakefxTracer:
self.torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode] = (
nullcontext()
)
self.stack_trace = stack_trace
self.record_stack_traces = record_stack_traces
def _checkpoint_modes(self) -> list[Any]:
return [
@ -2033,10 +1998,13 @@ class _MakefxTracer:
if hasattr(f, "_orig_mod") and self.record_module_stack:
scope_root = f._orig_mod
# _ModuleStackTracer always try to preserve stack trace
# in forward functions
self.fx_tracer = _ModuleStackTracer(scope_root)
else:
self.fx_tracer = PythonKeyTracer()
self.fx_tracer.stack_trace = self.stack_trace
self.fx_tracer.record_stack_traces = self.record_stack_traces
if self.record_stack_traces:
self.fx_tracer._record_forward_stack_traces_only = True
if self.tracing_mode == "fake":
import torch._dynamo
@ -2288,14 +2256,14 @@ def make_fx(
record_module_stack: bool = False,
_allow_fake_constant: bool = False,
_error_on_data_dependent_ops: bool = True,
stack_trace: bool = False,
record_stack_traces: bool = False,
) -> Callable[..., GraphModule]:
"""
Given a function f, return a new function which when executed with valid
arguments to f, returns an FX GraphModule representing the set of operations that
were executed during the course of execution.
If stack_trace is True, the stack_trace will be preserved on node.meta["stack_trace"]
If record_stack_traces is True, the stack trace will be preserved on node.meta["stack_trace"]
"""
assert tracing_mode in ["real", "fake", "symbolic"]
@ -2310,7 +2278,7 @@ def make_fx(
record_module_stack,
_allow_fake_constant,
_error_on_data_dependent_ops,
stack_trace=stack_trace or config.trace.enabled,
record_stack_traces=record_stack_traces or config.trace.enabled,
)
@functools.wraps(f)

View File

@ -311,12 +311,15 @@ def uninteresting_files() -> set[str]:
import torch._logging
import torch._subclasses.fake_tensor
import torch._subclasses.meta_utils
import torch.export._trace
mods = [
sys.modules[__name__],
torch.export._trace,
torch.fx.experimental.recording,
torch.fx.experimental.sym_node,
torch.fx.interpreter,
torch.fx._symbolic_trace,
torch,
torch._compile,
torch._dynamo.eval_frame,

View File

@ -124,6 +124,10 @@ _COPY_META_FIELDS = [
class TracerBase:
graph: Graph
record_stack_traces: bool = False
# When record_stack_traces is True, only reocrd stack traces
# with forward function names.
# This helps when we want stack trace back to model code
_record_forward_stack_traces_only: bool = False
# Feature flag for mutable schema checking
# Enableby default in 1.12
check_mutable_operations: bool = False
@ -204,6 +208,42 @@ class TracerBase:
elif self.module_stack:
node.meta["nn_module_stack"] = copy.copy(self.module_stack)
if self.record_stack_traces and not node.stack_trace:
from torch.fx.experimental.symbolic_shapes import uninteresting_files
user_frame_summary = CapturedTraceback.extract().summary()
if user_frame_summary:
if self._record_forward_stack_traces_only:
user_frame_summary = [
frame
for frame in user_frame_summary
if (
frame.name == "forward"
or frame.filename.endswith("torch/__init__.py")
)
]
else:
first_forward = -1
for i, frame in enumerate(user_frame_summary):
if frame.name == "forward":
user_frame_summary = user_frame_summary[i:]
first_forward = i
break
# Not having a "forward" call in the stacktrace implies the
# stacktrace will probably be irrelevant
if first_forward == -1:
user_frame_summary = []
stack_trace = [
frame
for frame in user_frame_summary
if frame.filename not in uninteresting_files()
]
if stack_trace:
stack_trace = traceback.StackSummary.from_list(stack_trace)
node.stack_trace = "".join(stack_trace.format()).strip()
log.debug("create_node %s", node)
return node
@ -245,31 +285,6 @@ class TracerBase:
else:
proxy = proxy_factory_fn(node)
if self.record_stack_traces and not proxy.node.stack_trace:
from torch.fx.experimental.symbolic_shapes import uninteresting_files
user_frame_summary = CapturedTraceback.extract().summary()
if user_frame_summary:
first_forward = -1
for i, frame in enumerate(user_frame_summary):
if frame.name == "forward":
user_frame_summary = user_frame_summary[i:]
first_forward = i
break
# Not having a "forward" call in the stacktrace implies the
# stacktrace will probably be irrelevant
if first_forward == -1:
user_frame_summary = []
stack_trace = [
frame
for frame in user_frame_summary
if frame.filename not in uninteresting_files()
]
stack_trace = traceback.StackSummary.from_list(stack_trace)
proxy.node.stack_trace = "".join(stack_trace.format()).strip()
return proxy
def _find_user_frame(self):