Refactor stack_trace preservation for node meta preservation (#90803)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90803
Approved by: https://github.com/jerryzh168, https://github.com/albanD
This commit is contained in:
Sherlock Huang 2023-01-09 18:01:36 +00:00 committed by PyTorch MergeBot
parent 1e768c63c1
commit 0f1302eeae
6 changed files with 40 additions and 48 deletions

View File

@ -178,7 +178,7 @@ class TestFunctionalization(TestCase):
from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks
import torch.fx.traceback as fx_traceback import torch.fx.traceback as fx_traceback
setup_stacktrace_preservation_hooks([loss.grad_fn]) setup_stacktrace_preservation_hooks([loss.grad_fn])
with fx_traceback.override_stack_trace(): with fx_traceback.preserve_node_meta():
loss.backward() loss.backward()
return x.grad return x.grad

View File

@ -667,7 +667,7 @@ def export(
if aten_graph: if aten_graph:
# Running graph with interpreter is needed for propagating the stack_trace # Running graph with interpreter is needed for propagating the stack_trace
def graph_with_interpreter(*args): def graph_with_interpreter(*args):
with torch.fx.traceback.override_stack_trace(): with torch.fx.traceback.preserve_node_meta():
return torch.fx.Interpreter(graph).run(*args) return torch.fx.Interpreter(graph).run(*args)
graph = make_fx( graph = make_fx(

View File

@ -790,7 +790,7 @@ def create_joint_forward_backward_functionalized(
backward_out = [] backward_out = []
# Call the backwards pass # Call the backwards pass
if grad_primals: if grad_primals:
with fx_traceback.override_stack_trace(): with fx_traceback.preserve_node_meta():
backward_out = torch.autograd.grad( backward_out = torch.autograd.grad(
needed_outs, needed_outs,
grad_primals, grad_primals,
@ -2319,7 +2319,7 @@ def aot_module_simplified(
mod, pytree.tree_unflatten(args[:params_len], params_spec) mod, pytree.tree_unflatten(args[:params_len], params_spec)
): ):
if isinstance(mod, torch.fx.GraphModule): if isinstance(mod, torch.fx.GraphModule):
with fx_traceback.override_stack_trace(), warnings.catch_warnings(): with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "Anomaly Detection has been enabled." "ignore", "Anomaly Detection has been enabled."
) )

View File

@ -153,7 +153,7 @@ class Interpreter:
@contextmanager @contextmanager
def _set_current_node(self, node): def _set_current_node(self, node):
with fx_traceback.append_stack_trace(node.stack_trace), fx_traceback.set_current_meta(node.meta): with fx_traceback.set_current_meta(node.meta):
yield yield
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
@ -477,7 +477,7 @@ class Transformer(Interpreter):
Transform ``self.module`` and return the transformed Transform ``self.module`` and return the transformed
``GraphModule``. ``GraphModule``.
""" """
with fx_traceback.override_stack_trace(): with fx_traceback.preserve_node_meta():
result = super().run(enable_io_processing=False) result = super().run(enable_io_processing=False)
if result is not None: if result is not None:
def strip_proxy(a : Union[Argument, Proxy]) -> Any: def strip_proxy(a : Union[Argument, Proxy]) -> Any:

View File

@ -76,10 +76,19 @@ class TracerBase:
proxy = proxy_factory_fn(node) proxy = proxy_factory_fn(node)
# Optionally set stack trace on the created Node for debugging purposes # Optionally set stack trace on the created Node for debugging purposes
if fx_traceback.is_stack_trace_overridden(): if fx_traceback.has_preserved_node_meta():
proxy.node.meta = fx_traceback.get_current_meta() current_meta: Dict[str, Any] = fx_traceback.get_current_meta()
stacks = fx_traceback.format_stack()
proxy.node.stack_trace = '\n'.join(reversed(stacks)) # Explicitly set the stack_trace and nn_module_stack on the node.meta
# If other meta fields are needed, they can be added here
stack_trace = current_meta.get("stack_trace")
if stack_trace:
proxy.node.stack_trace = stack_trace
nn_module_stack = current_meta.get("nn_module_stack")
if nn_module_stack:
proxy.node.meta["nn_module_stack"] = nn_module_stack
elif self.record_stack_traces: elif self.record_stack_traces:
user_frame = self._find_user_frame() user_frame = self._find_user_frame()
if user_frame: if user_frame:

View File

@ -1,66 +1,49 @@
import traceback import traceback
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, List, Any, Dict from typing import List, Any, Dict
from ._compatibility import compatibility from ._compatibility import compatibility
__all__ = ['override_stack_trace', 'set_stack_trace', 'append_stack_trace', 'format_stack', __all__ = ['preserve_node_meta', 'has_preserved_node_meta',
'is_stack_trace_overridden', 'get_current_meta', 'set_current_meta'] 'set_stack_trace', 'format_stack',
'set_current_meta', 'get_current_meta']
current_stack: List[str] = []
current_meta: Dict[str, Any] = {} current_meta: Dict[str, Any] = {}
is_overridden = False should_preserve_node_meta = False
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
@contextmanager @contextmanager
def override_stack_trace(): def preserve_node_meta():
global is_overridden global should_preserve_node_meta
saved_is_overridden = is_overridden saved_should_preserve_node_meta = should_preserve_node_meta
try: try:
is_overridden = True should_preserve_node_meta = True
yield yield
finally: finally:
is_overridden = saved_is_overridden should_preserve_node_meta = saved_should_preserve_node_meta
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
def set_stack_trace(stack : List[str]): def set_stack_trace(stack : List[str]):
global current_stack global current_meta
if is_overridden and stack: if should_preserve_node_meta and stack:
current_stack = stack current_meta["stack_trace"] = "".join(stack)
@compatibility(is_backward_compatible=False)
@contextmanager
def append_stack_trace(stack : Optional[str]):
"""
The content of stack here is an entire stacktraces as a string
"""
global current_stack
if is_overridden and stack:
try:
current_stack.append(stack)
yield
finally:
current_stack.pop()
else:
yield
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
def format_stack() -> List[str]: def format_stack() -> List[str]:
if is_overridden: if should_preserve_node_meta:
return current_stack.copy() return [current_meta.get("stack_trace", "")]
else: else:
# fallback to traceback.format_stack() # fallback to traceback.format_stack()
return traceback.format_list(traceback.extract_stack()[:-1]) return traceback.format_list(traceback.extract_stack()[:-1])
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
def is_stack_trace_overridden() -> bool: def has_preserved_node_meta() -> bool:
return is_overridden return should_preserve_node_meta
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
@ -68,13 +51,13 @@ def is_stack_trace_overridden() -> bool:
def set_current_meta(meta : Dict[str, Any]): def set_current_meta(meta : Dict[str, Any]):
global current_meta global current_meta
old_meta = current_meta if should_preserve_node_meta and meta:
if is_overridden and meta: saved_meta = current_meta
try: try:
current_meta = meta current_meta = meta
yield yield
finally: finally:
current_meta = old_meta current_meta = saved_meta
else: else:
yield yield