mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Refactor stack_trace preservation for node meta preservation (#90803)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90803 Approved by: https://github.com/jerryzh168, https://github.com/albanD
This commit is contained in:
parent
1e768c63c1
commit
0f1302eeae
|
|
@ -178,7 +178,7 @@ class TestFunctionalization(TestCase):
|
||||||
from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks
|
from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks
|
||||||
import torch.fx.traceback as fx_traceback
|
import torch.fx.traceback as fx_traceback
|
||||||
setup_stacktrace_preservation_hooks([loss.grad_fn])
|
setup_stacktrace_preservation_hooks([loss.grad_fn])
|
||||||
with fx_traceback.override_stack_trace():
|
with fx_traceback.preserve_node_meta():
|
||||||
loss.backward()
|
loss.backward()
|
||||||
return x.grad
|
return x.grad
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -667,7 +667,7 @@ def export(
|
||||||
if aten_graph:
|
if aten_graph:
|
||||||
# Running graph with interpreter is needed for propagating the stack_trace
|
# Running graph with interpreter is needed for propagating the stack_trace
|
||||||
def graph_with_interpreter(*args):
|
def graph_with_interpreter(*args):
|
||||||
with torch.fx.traceback.override_stack_trace():
|
with torch.fx.traceback.preserve_node_meta():
|
||||||
return torch.fx.Interpreter(graph).run(*args)
|
return torch.fx.Interpreter(graph).run(*args)
|
||||||
|
|
||||||
graph = make_fx(
|
graph = make_fx(
|
||||||
|
|
|
||||||
|
|
@ -790,7 +790,7 @@ def create_joint_forward_backward_functionalized(
|
||||||
backward_out = []
|
backward_out = []
|
||||||
# Call the backwards pass
|
# Call the backwards pass
|
||||||
if grad_primals:
|
if grad_primals:
|
||||||
with fx_traceback.override_stack_trace():
|
with fx_traceback.preserve_node_meta():
|
||||||
backward_out = torch.autograd.grad(
|
backward_out = torch.autograd.grad(
|
||||||
needed_outs,
|
needed_outs,
|
||||||
grad_primals,
|
grad_primals,
|
||||||
|
|
@ -2319,7 +2319,7 @@ def aot_module_simplified(
|
||||||
mod, pytree.tree_unflatten(args[:params_len], params_spec)
|
mod, pytree.tree_unflatten(args[:params_len], params_spec)
|
||||||
):
|
):
|
||||||
if isinstance(mod, torch.fx.GraphModule):
|
if isinstance(mod, torch.fx.GraphModule):
|
||||||
with fx_traceback.override_stack_trace(), warnings.catch_warnings():
|
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
"ignore", "Anomaly Detection has been enabled."
|
"ignore", "Anomaly Detection has been enabled."
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -153,7 +153,7 @@ class Interpreter:
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _set_current_node(self, node):
|
def _set_current_node(self, node):
|
||||||
with fx_traceback.append_stack_trace(node.stack_trace), fx_traceback.set_current_meta(node.meta):
|
with fx_traceback.set_current_meta(node.meta):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
|
|
@ -477,7 +477,7 @@ class Transformer(Interpreter):
|
||||||
Transform ``self.module`` and return the transformed
|
Transform ``self.module`` and return the transformed
|
||||||
``GraphModule``.
|
``GraphModule``.
|
||||||
"""
|
"""
|
||||||
with fx_traceback.override_stack_trace():
|
with fx_traceback.preserve_node_meta():
|
||||||
result = super().run(enable_io_processing=False)
|
result = super().run(enable_io_processing=False)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
def strip_proxy(a : Union[Argument, Proxy]) -> Any:
|
def strip_proxy(a : Union[Argument, Proxy]) -> Any:
|
||||||
|
|
|
||||||
|
|
@ -76,10 +76,19 @@ class TracerBase:
|
||||||
proxy = proxy_factory_fn(node)
|
proxy = proxy_factory_fn(node)
|
||||||
|
|
||||||
# Optionally set stack trace on the created Node for debugging purposes
|
# Optionally set stack trace on the created Node for debugging purposes
|
||||||
if fx_traceback.is_stack_trace_overridden():
|
if fx_traceback.has_preserved_node_meta():
|
||||||
proxy.node.meta = fx_traceback.get_current_meta()
|
current_meta: Dict[str, Any] = fx_traceback.get_current_meta()
|
||||||
stacks = fx_traceback.format_stack()
|
|
||||||
proxy.node.stack_trace = '\n'.join(reversed(stacks))
|
# Explicitly set the stack_trace and nn_module_stack on the node.meta
|
||||||
|
# If other meta fields are needed, they can be added here
|
||||||
|
stack_trace = current_meta.get("stack_trace")
|
||||||
|
if stack_trace:
|
||||||
|
proxy.node.stack_trace = stack_trace
|
||||||
|
|
||||||
|
nn_module_stack = current_meta.get("nn_module_stack")
|
||||||
|
if nn_module_stack:
|
||||||
|
proxy.node.meta["nn_module_stack"] = nn_module_stack
|
||||||
|
|
||||||
elif self.record_stack_traces:
|
elif self.record_stack_traces:
|
||||||
user_frame = self._find_user_frame()
|
user_frame = self._find_user_frame()
|
||||||
if user_frame:
|
if user_frame:
|
||||||
|
|
|
||||||
|
|
@ -1,66 +1,49 @@
|
||||||
import traceback
|
import traceback
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Optional, List, Any, Dict
|
from typing import List, Any, Dict
|
||||||
from ._compatibility import compatibility
|
from ._compatibility import compatibility
|
||||||
|
|
||||||
__all__ = ['override_stack_trace', 'set_stack_trace', 'append_stack_trace', 'format_stack',
|
__all__ = ['preserve_node_meta', 'has_preserved_node_meta',
|
||||||
'is_stack_trace_overridden', 'get_current_meta', 'set_current_meta']
|
'set_stack_trace', 'format_stack',
|
||||||
|
'set_current_meta', 'get_current_meta']
|
||||||
|
|
||||||
|
|
||||||
current_stack: List[str] = []
|
|
||||||
current_meta: Dict[str, Any] = {}
|
current_meta: Dict[str, Any] = {}
|
||||||
is_overridden = False
|
should_preserve_node_meta = False
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def override_stack_trace():
|
def preserve_node_meta():
|
||||||
global is_overridden
|
global should_preserve_node_meta
|
||||||
|
|
||||||
saved_is_overridden = is_overridden
|
saved_should_preserve_node_meta = should_preserve_node_meta
|
||||||
try:
|
try:
|
||||||
is_overridden = True
|
should_preserve_node_meta = True
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
is_overridden = saved_is_overridden
|
should_preserve_node_meta = saved_should_preserve_node_meta
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def set_stack_trace(stack : List[str]):
|
def set_stack_trace(stack : List[str]):
|
||||||
global current_stack
|
global current_meta
|
||||||
|
|
||||||
if is_overridden and stack:
|
if should_preserve_node_meta and stack:
|
||||||
current_stack = stack
|
current_meta["stack_trace"] = "".join(stack)
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
|
||||||
@contextmanager
|
|
||||||
def append_stack_trace(stack : Optional[str]):
|
|
||||||
"""
|
|
||||||
The content of stack here is an entire stacktraces as a string
|
|
||||||
"""
|
|
||||||
global current_stack
|
|
||||||
|
|
||||||
if is_overridden and stack:
|
|
||||||
try:
|
|
||||||
current_stack.append(stack)
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
current_stack.pop()
|
|
||||||
else:
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def format_stack() -> List[str]:
|
def format_stack() -> List[str]:
|
||||||
if is_overridden:
|
if should_preserve_node_meta:
|
||||||
return current_stack.copy()
|
return [current_meta.get("stack_trace", "")]
|
||||||
else:
|
else:
|
||||||
# fallback to traceback.format_stack()
|
# fallback to traceback.format_stack()
|
||||||
return traceback.format_list(traceback.extract_stack()[:-1])
|
return traceback.format_list(traceback.extract_stack()[:-1])
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def is_stack_trace_overridden() -> bool:
|
def has_preserved_node_meta() -> bool:
|
||||||
return is_overridden
|
return should_preserve_node_meta
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
|
|
@ -68,13 +51,13 @@ def is_stack_trace_overridden() -> bool:
|
||||||
def set_current_meta(meta : Dict[str, Any]):
|
def set_current_meta(meta : Dict[str, Any]):
|
||||||
global current_meta
|
global current_meta
|
||||||
|
|
||||||
old_meta = current_meta
|
if should_preserve_node_meta and meta:
|
||||||
if is_overridden and meta:
|
saved_meta = current_meta
|
||||||
try:
|
try:
|
||||||
current_meta = meta
|
current_meta = meta
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
current_meta = old_meta
|
current_meta = saved_meta
|
||||||
else:
|
else:
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user