diff --git a/functorch/functorch/_src/aot_autograd.py b/functorch/functorch/_src/aot_autograd.py index fffc986de3d..743ac584843 100644 --- a/functorch/functorch/_src/aot_autograd.py +++ b/functorch/functorch/_src/aot_autograd.py @@ -1,3 +1,4 @@ +import collections import dataclasses import warnings from contextlib import contextmanager, nullcontext @@ -59,6 +60,70 @@ def preserve_rng_state(): torch.cuda.set_rng_state(cuda_rng_state) +# Set up hooks so that during backward the fx's stack_trace is properly set +callback_set = False + + +def setup_stacktrace_preservation_hooks(roots: List): + def iter_graph(roots): + if not roots: + return + seen = set() + q = collections.deque() + for node in roots: + if node is not None: + seen.add(node) + q.append(node) + + while q: + node = q.popleft() + for fn, _idx in node.next_functions: + if fn in seen or fn is None: + continue + seen.add(fn) + q.append(fn) + + yield node + + def get_callback(saved_stack_): + def callback(): + global callback_set + fx_traceback.set_stack_trace(saved_stack_) + callback_set = False + + return callback + + def get_prehook(stack_): + def prehook(grad_output): + global callback_set + + if not callback_set: + torch.autograd.variable.Variable._execution_engine.queue_callback( + get_callback(fx_traceback.format_stack()) + ) + callback_set = True + + fx_traceback.set_stack_trace(stack_) + + return prehook + + def get_posthook(special_stack_): + def posthook(grad_input, grad_output): + fx_traceback.set_stack_trace(special_stack_) + + return posthook + + for node in iter_graph(roots): + forward_node_stack = node.metadata.get("traceback_", []) + node.register_prehook(get_prehook(forward_node_stack)) + + special_stack = forward_node_stack.copy() + special_stack.append( + "Gradient addition node due to mulitple use of tensor around:" + ) + node.register_hook(get_posthook(special_stack)) + + def create_joint_forward_backward(fn): def joint_forward_backward( primals: List[Any], tangents: List[Any] @@ -82,15 +147,19 @@ def create_joint_forward_backward(fn): if isinstance(out, Tensor) and out.requires_grad: needed_outs.append(out) needed_tangents.append(tangent) + + setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs]) + backward_out = [] # Call the backwards pass if grad_primals: - backward_out = torch.autograd.grad( - needed_outs, - grad_primals, - grad_outputs=needed_tangents, - allow_unused=True, - ) + with fx_traceback.override_stack_trace(): + backward_out = torch.autograd.grad( + needed_outs, + grad_primals, + grad_outputs=needed_tangents, + allow_unused=True, + ) backward_out_iter = iter(backward_out) return outs, [ next(backward_out_iter) if i else None for i in inputs_needs_grads @@ -735,7 +804,9 @@ def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module: mod, pytree.tree_unflatten(args[:params_len], params_spec) ): if isinstance(mod, torch.fx.GraphModule): - with fx_traceback.override_stack_trace(): + with fx_traceback.override_stack_trace(), torch.autograd.detect_anomaly( + check_nan=False + ): out = Interpreter(mod).run(*args[params_len:], **kwargs) else: out = mod(*args[params_len:], **kwargs) diff --git a/functorch/test/test_pythonkey.py b/functorch/test/test_pythonkey.py index 1b5c933e835..5deeac1eb27 100644 --- a/functorch/test/test_pythonkey.py +++ b/functorch/test/test_pythonkey.py @@ -598,14 +598,16 @@ class TestAOTModuleSimplified(AOTTestCase): assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad) def test_aot_module_simplified_preserves_stack_trace(self): - class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(20, 30) def forward(self, x, y): - return (self.linear(x) + y, ) + z = self.linear(x) + z = z + y + z = z.relu() + return (z, ) tracer = torch.fx.Tracer() tracer.record_stack_traces = True @@ -626,7 +628,7 @@ class TestAOTModuleSimplified(AOTTestCase): assert 'test_pythonkey.py' in node.stack_trace return gm.forward # return a python callable - aot_mod = aot_module_simplified(mod, fw_compiler=assert_compiler, bw_compiler=nop) + aot_mod = aot_module_simplified(mod, fw_compiler=assert_compiler, bw_compiler=assert_compiler) x = torch.randn(128, 20, requires_grad=True) y = torch.randn(128, 30, requires_grad=True) diff --git a/torch/csrc/autograd/python_anomaly_mode.cpp b/torch/csrc/autograd/python_anomaly_mode.cpp index ec5dfe1b099..3c91316c06f 100644 --- a/torch/csrc/autograd/python_anomaly_mode.cpp +++ b/torch/csrc/autograd/python_anomaly_mode.cpp @@ -16,7 +16,7 @@ namespace autograd { void PyAnomalyMetadata::store_stack() { pybind11::gil_scoped_acquire gil; - THPObjectPtr mod(PyImport_ImportModule("traceback")); + THPObjectPtr mod(PyImport_ImportModule("torch.fx.traceback")); if (!mod) { throw python_error(); } diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index aac94a85e17..a07b36b997b 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from typing import Optional, List from ._compatibility import compatibility -__all__ = ['override_stack_trace', 'append_stack_trace', 'format_stack', 'is_stack_trace_overridden'] +__all__ = ['override_stack_trace', 'set_stack_trace', 'append_stack_trace', 'format_stack', 'is_stack_trace_overridden'] current_stack: List[str] = [] @@ -23,6 +23,13 @@ def override_stack_trace(): is_overridden = saved_is_overridden +@compatibility(is_backward_compatible=False) +def set_stack_trace(stack : List[str]): + global current_stack + + if is_overridden and stack: + current_stack = stack + @compatibility(is_backward_compatible=False) @contextmanager def append_stack_trace(stack : Optional[str]): @@ -44,7 +51,7 @@ def append_stack_trace(stack : Optional[str]): @compatibility(is_backward_compatible=False) def format_stack() -> List[str]: if is_overridden: - return current_stack + return current_stack.copy() else: # fallback to traceback.format_stack() return traceback.format_stack()