mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Consolidate stack trace in Tracer (#156257)
Summary: - Consolidate the stack trace recording code in TracerBase and PythonKeyTracer - Change `make_fx`'s arg name to be consistent with TracerBase member name `record_stack_traces` We move the stack trace logic from `create_proxy` to `create_node` so all inherited classes of TracerBase and re-use the same stack trace logic. Test Plan: ``` buck run caffe2/test:test_export -- -r test_stack_trace ``` Rollback Plan: Pull Request resolved: https://github.com/pytorch/pytorch/pull/156257 Approved by: https://github.com/angelayi, https://github.com/zou3519
This commit is contained in:
parent
653c52fe52
commit
204db27a0c
|
|
@ -11586,7 +11586,9 @@ graph():
|
|||
return x
|
||||
|
||||
inp = torch.randn(4, 4)
|
||||
gm = torch.fx.experimental.proxy_tensor.make_fx(Foo(), stack_trace=True)(
|
||||
gm = torch.fx.experimental.proxy_tensor.make_fx(
|
||||
Foo(), record_stack_traces=True
|
||||
)(
|
||||
inp,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import functools
|
|||
import inspect
|
||||
import logging
|
||||
import operator
|
||||
import traceback
|
||||
import typing
|
||||
import typing_extensions
|
||||
import weakref
|
||||
|
|
@ -67,7 +66,6 @@ from torch.utils._python_dispatch import (
|
|||
)
|
||||
from torch.utils._stats import count
|
||||
from torch.utils._thunk import Thunk
|
||||
from torch.utils._traceback import CapturedTraceback
|
||||
from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary, WeakTensorKeyDictionary
|
||||
|
||||
from ._backward_state import BackwardState
|
||||
|
|
@ -1017,7 +1015,6 @@ class PythonKeyTracer(Tracer):
|
|||
tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
|
||||
torch_fn_counts: dict[OpOverload, int]
|
||||
enable_thunkify: bool = False
|
||||
stack_trace: bool = False
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(autowrap_modules=()) # type: ignore[arg-type]
|
||||
|
|
@ -1100,39 +1097,6 @@ class PythonKeyTracer(Tracer):
|
|||
) -> torch.fx.Node:
|
||||
node = super().create_node(kind, target, args, kwargs, name, type_expr) # type: ignore[arg-type]
|
||||
|
||||
# stack_trace
|
||||
if (
|
||||
self.stack_trace
|
||||
and "stack_trace" not in node.meta
|
||||
and node.op not in ["placeholder", "output"]
|
||||
):
|
||||
user_frame_summary = CapturedTraceback.extract().summary()
|
||||
if user_frame_summary:
|
||||
# we retain frames from forward() calls, or ops
|
||||
# located in torch/__init__.py (e.g. sym_int, sym_constrain_range, vmap)
|
||||
stack_trace = [
|
||||
frame
|
||||
for frame in user_frame_summary
|
||||
if (
|
||||
frame.name == "forward"
|
||||
or frame.filename.endswith("torch/__init__.py")
|
||||
)
|
||||
]
|
||||
# filter out forward() frames from fx/_symbolic_trace.py, export/_trace.py
|
||||
# this is hardcoded, but leads to a much cleaner stack trace
|
||||
stack_trace = [
|
||||
frame
|
||||
for frame in stack_trace
|
||||
if not frame.filename.endswith(
|
||||
("fx/_symbolic_trace.py", "export/_trace.py")
|
||||
)
|
||||
]
|
||||
if (
|
||||
stack_trace
|
||||
): # empty list for strict mode, dynamo should handle stack_trace
|
||||
stack_trace = traceback.StackSummary.from_list(stack_trace)
|
||||
node.meta["stack_trace"] = "".join(stack_trace.format()).strip()
|
||||
|
||||
if kind == "get_attr":
|
||||
assert isinstance(target, str)
|
||||
attr = getattr(self.root, target)
|
||||
|
|
@ -1698,7 +1662,8 @@ class _ModuleStackTracer(PythonKeyTracer):
|
|||
|
||||
def __init__(self, scope_root: GraphModule) -> None:
|
||||
super().__init__()
|
||||
self.stack_trace = True
|
||||
self.record_stack_traces = True
|
||||
self._record_forward_stack_traces_only = True
|
||||
self.scope_root = scope_root
|
||||
self.enable_attr_proxy = False
|
||||
self.submodule_paths = {}
|
||||
|
|
@ -1962,7 +1927,7 @@ class _MakefxTracer:
|
|||
record_module_stack: bool,
|
||||
_allow_fake_constant: bool,
|
||||
_error_on_data_dependent_ops: bool,
|
||||
stack_trace: bool = False,
|
||||
record_stack_traces: bool = False,
|
||||
) -> None:
|
||||
# Configurations that are used to initialize the context managers and their states.
|
||||
# Should not modify them during tracing.
|
||||
|
|
@ -1993,7 +1958,7 @@ class _MakefxTracer:
|
|||
self.torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode] = (
|
||||
nullcontext()
|
||||
)
|
||||
self.stack_trace = stack_trace
|
||||
self.record_stack_traces = record_stack_traces
|
||||
|
||||
def _checkpoint_modes(self) -> list[Any]:
|
||||
return [
|
||||
|
|
@ -2033,10 +1998,13 @@ class _MakefxTracer:
|
|||
if hasattr(f, "_orig_mod") and self.record_module_stack:
|
||||
scope_root = f._orig_mod
|
||||
# _ModuleStackTracer always try to preserve stack trace
|
||||
# in forward functions
|
||||
self.fx_tracer = _ModuleStackTracer(scope_root)
|
||||
else:
|
||||
self.fx_tracer = PythonKeyTracer()
|
||||
self.fx_tracer.stack_trace = self.stack_trace
|
||||
self.fx_tracer.record_stack_traces = self.record_stack_traces
|
||||
if self.record_stack_traces:
|
||||
self.fx_tracer._record_forward_stack_traces_only = True
|
||||
|
||||
if self.tracing_mode == "fake":
|
||||
import torch._dynamo
|
||||
|
|
@ -2288,14 +2256,14 @@ def make_fx(
|
|||
record_module_stack: bool = False,
|
||||
_allow_fake_constant: bool = False,
|
||||
_error_on_data_dependent_ops: bool = True,
|
||||
stack_trace: bool = False,
|
||||
record_stack_traces: bool = False,
|
||||
) -> Callable[..., GraphModule]:
|
||||
"""
|
||||
Given a function f, return a new function which when executed with valid
|
||||
arguments to f, returns an FX GraphModule representing the set of operations that
|
||||
were executed during the course of execution.
|
||||
|
||||
If stack_trace is True, the stack_trace will be preserved on node.meta["stack_trace"]
|
||||
If record_stack_traces is True, the stack trace will be preserved on node.meta["stack_trace"]
|
||||
"""
|
||||
|
||||
assert tracing_mode in ["real", "fake", "symbolic"]
|
||||
|
|
@ -2310,7 +2278,7 @@ def make_fx(
|
|||
record_module_stack,
|
||||
_allow_fake_constant,
|
||||
_error_on_data_dependent_ops,
|
||||
stack_trace=stack_trace or config.trace.enabled,
|
||||
record_stack_traces=record_stack_traces or config.trace.enabled,
|
||||
)
|
||||
|
||||
@functools.wraps(f)
|
||||
|
|
|
|||
|
|
@ -311,12 +311,15 @@ def uninteresting_files() -> set[str]:
|
|||
import torch._logging
|
||||
import torch._subclasses.fake_tensor
|
||||
import torch._subclasses.meta_utils
|
||||
import torch.export._trace
|
||||
|
||||
mods = [
|
||||
sys.modules[__name__],
|
||||
torch.export._trace,
|
||||
torch.fx.experimental.recording,
|
||||
torch.fx.experimental.sym_node,
|
||||
torch.fx.interpreter,
|
||||
torch.fx._symbolic_trace,
|
||||
torch,
|
||||
torch._compile,
|
||||
torch._dynamo.eval_frame,
|
||||
|
|
|
|||
|
|
@ -124,6 +124,10 @@ _COPY_META_FIELDS = [
|
|||
class TracerBase:
|
||||
graph: Graph
|
||||
record_stack_traces: bool = False
|
||||
# When record_stack_traces is True, only reocrd stack traces
|
||||
# with forward function names.
|
||||
# This helps when we want stack trace back to model code
|
||||
_record_forward_stack_traces_only: bool = False
|
||||
# Feature flag for mutable schema checking
|
||||
# Enableby default in 1.12
|
||||
check_mutable_operations: bool = False
|
||||
|
|
@ -204,6 +208,42 @@ class TracerBase:
|
|||
elif self.module_stack:
|
||||
node.meta["nn_module_stack"] = copy.copy(self.module_stack)
|
||||
|
||||
if self.record_stack_traces and not node.stack_trace:
|
||||
from torch.fx.experimental.symbolic_shapes import uninteresting_files
|
||||
|
||||
user_frame_summary = CapturedTraceback.extract().summary()
|
||||
if user_frame_summary:
|
||||
if self._record_forward_stack_traces_only:
|
||||
user_frame_summary = [
|
||||
frame
|
||||
for frame in user_frame_summary
|
||||
if (
|
||||
frame.name == "forward"
|
||||
or frame.filename.endswith("torch/__init__.py")
|
||||
)
|
||||
]
|
||||
else:
|
||||
first_forward = -1
|
||||
for i, frame in enumerate(user_frame_summary):
|
||||
if frame.name == "forward":
|
||||
user_frame_summary = user_frame_summary[i:]
|
||||
first_forward = i
|
||||
break
|
||||
|
||||
# Not having a "forward" call in the stacktrace implies the
|
||||
# stacktrace will probably be irrelevant
|
||||
if first_forward == -1:
|
||||
user_frame_summary = []
|
||||
|
||||
stack_trace = [
|
||||
frame
|
||||
for frame in user_frame_summary
|
||||
if frame.filename not in uninteresting_files()
|
||||
]
|
||||
if stack_trace:
|
||||
stack_trace = traceback.StackSummary.from_list(stack_trace)
|
||||
node.stack_trace = "".join(stack_trace.format()).strip()
|
||||
|
||||
log.debug("create_node %s", node)
|
||||
return node
|
||||
|
||||
|
|
@ -245,31 +285,6 @@ class TracerBase:
|
|||
else:
|
||||
proxy = proxy_factory_fn(node)
|
||||
|
||||
if self.record_stack_traces and not proxy.node.stack_trace:
|
||||
from torch.fx.experimental.symbolic_shapes import uninteresting_files
|
||||
|
||||
user_frame_summary = CapturedTraceback.extract().summary()
|
||||
if user_frame_summary:
|
||||
first_forward = -1
|
||||
for i, frame in enumerate(user_frame_summary):
|
||||
if frame.name == "forward":
|
||||
user_frame_summary = user_frame_summary[i:]
|
||||
first_forward = i
|
||||
break
|
||||
|
||||
# Not having a "forward" call in the stacktrace implies the
|
||||
# stacktrace will probably be irrelevant
|
||||
if first_forward == -1:
|
||||
user_frame_summary = []
|
||||
|
||||
stack_trace = [
|
||||
frame
|
||||
for frame in user_frame_summary
|
||||
if frame.filename not in uninteresting_files()
|
||||
]
|
||||
stack_trace = traceback.StackSummary.from_list(stack_trace)
|
||||
proxy.node.stack_trace = "".join(stack_trace.format()).strip()
|
||||
|
||||
return proxy
|
||||
|
||||
def _find_user_frame(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user