From a7baad04f6f29a97743e98d25c369b21aed18faf Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Thu, 18 Aug 2022 17:54:52 +0000 Subject: [PATCH] Preserve stack trace for backward nodes over AOTAutograd (#83558) For the following program. ``` def my_relu(a): return a.relu() def func(a, b): a = torch.nn.Linear(10, 10)(a) d = torch.square(b) d = my_relu(d) loss = d.sum() return loss with torchdynamo.optimize("aot_nop"): x = torch.rand(10, 10, requires_grad=True) y = torch.rand(10, 10, requires_grad=True) out = func(x, y) ``` It would generate the following fx graph with stack_trace populated in both forward and backward nodes. ``` def forward(self, primals, tangents): primals_1, primals_2, primals_3, primals_4, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) t_default = torch.ops.aten.t.default(primals_3); primals_3 = None addmm_default = torch.ops.aten.addmm.default(primals_4, primals_1, t_default); primals_4 = primals_1 = t_default = None pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(primals_2, 2) relu_default = torch.ops.aten.relu.default(pow_tensor_scalar); pow_tensor_scalar = None detach_default = torch.ops.aten.detach.default(relu_default) sum_default = torch.ops.aten.sum.default(relu_default); relu_default = None is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1) expand_default = torch.ops.aten.expand.default(tangents_1, [10, 10]); tangents_1 = None detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None threshold_backward_default = torch.ops.aten.threshold_backward.default(expand_default, detach_default_1, 0); expand_default = detach_default_1 = None pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(primals_2, 1.0); primals_2 = None mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None mul_tensor = torch.ops.aten.mul.Tensor(threshold_backward_default, mul_scalar); threshold_backward_default = mul_scalar = None return pytree.tree_unflatten([sum_default, None, mul_tensor, None, None], self._out_spec) ====== joint graph ======= primals_1 None primals_2 None primals_3 None primals_4 None tangents_1 None t_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 12, in func def func(a, b): File "/fsx/users/bahuang/repos/pytorch_fsx/torch/nn/modules/linear.py", line 114, in forward return F.linear(input, self.weight, self.bias) addmm_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 12, in func def func(a, b): File "/fsx/users/bahuang/repos/pytorch_fsx/torch/nn/modules/linear.py", line 114, in forward return F.linear(input, self.weight, self.bias) pow_tensor_scalar File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func d = torch.square(b) relu_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func d = my_relu(d) File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu return a.relu() detach_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func d = my_relu(d) File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu return a.relu() sum_default is_same_size_default expand_default detach_default_1 File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func d = my_relu(d) File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu return a.relu() threshold_backward_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func d = my_relu(d) File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu return a.relu() pow_tensor_scalar_1 File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func d = torch.square(b) mul_scalar File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func d = torch.square(b) mul_tensor File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func d = torch.square(b) output None ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/83558 Approved by: https://github.com/albanD --- functorch/functorch/_src/aot_autograd.py | 85 +++++++++++++++++++-- functorch/test/test_pythonkey.py | 8 +- torch/csrc/autograd/python_anomaly_mode.cpp | 2 +- torch/fx/traceback.py | 11 ++- 4 files changed, 93 insertions(+), 13 deletions(-) diff --git a/functorch/functorch/_src/aot_autograd.py b/functorch/functorch/_src/aot_autograd.py index fffc986de3d..743ac584843 100644 --- a/functorch/functorch/_src/aot_autograd.py +++ b/functorch/functorch/_src/aot_autograd.py @@ -1,3 +1,4 @@ +import collections import dataclasses import warnings from contextlib import contextmanager, nullcontext @@ -59,6 +60,70 @@ def preserve_rng_state(): torch.cuda.set_rng_state(cuda_rng_state) +# Set up hooks so that during backward the fx's stack_trace is properly set +callback_set = False + + +def setup_stacktrace_preservation_hooks(roots: List): + def iter_graph(roots): + if not roots: + return + seen = set() + q = collections.deque() + for node in roots: + if node is not None: + seen.add(node) + q.append(node) + + while q: + node = q.popleft() + for fn, _idx in node.next_functions: + if fn in seen or fn is None: + continue + seen.add(fn) + q.append(fn) + + yield node + + def get_callback(saved_stack_): + def callback(): + global callback_set + fx_traceback.set_stack_trace(saved_stack_) + callback_set = False + + return callback + + def get_prehook(stack_): + def prehook(grad_output): + global callback_set + + if not callback_set: + torch.autograd.variable.Variable._execution_engine.queue_callback( + get_callback(fx_traceback.format_stack()) + ) + callback_set = True + + fx_traceback.set_stack_trace(stack_) + + return prehook + + def get_posthook(special_stack_): + def posthook(grad_input, grad_output): + fx_traceback.set_stack_trace(special_stack_) + + return posthook + + for node in iter_graph(roots): + forward_node_stack = node.metadata.get("traceback_", []) + node.register_prehook(get_prehook(forward_node_stack)) + + special_stack = forward_node_stack.copy() + special_stack.append( + "Gradient addition node due to mulitple use of tensor around:" + ) + node.register_hook(get_posthook(special_stack)) + + def create_joint_forward_backward(fn): def joint_forward_backward( primals: List[Any], tangents: List[Any] @@ -82,15 +147,19 @@ def create_joint_forward_backward(fn): if isinstance(out, Tensor) and out.requires_grad: needed_outs.append(out) needed_tangents.append(tangent) + + setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs]) + backward_out = [] # Call the backwards pass if grad_primals: - backward_out = torch.autograd.grad( - needed_outs, - grad_primals, - grad_outputs=needed_tangents, - allow_unused=True, - ) + with fx_traceback.override_stack_trace(): + backward_out = torch.autograd.grad( + needed_outs, + grad_primals, + grad_outputs=needed_tangents, + allow_unused=True, + ) backward_out_iter = iter(backward_out) return outs, [ next(backward_out_iter) if i else None for i in inputs_needs_grads @@ -735,7 +804,9 @@ def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module: mod, pytree.tree_unflatten(args[:params_len], params_spec) ): if isinstance(mod, torch.fx.GraphModule): - with fx_traceback.override_stack_trace(): + with fx_traceback.override_stack_trace(), torch.autograd.detect_anomaly( + check_nan=False + ): out = Interpreter(mod).run(*args[params_len:], **kwargs) else: out = mod(*args[params_len:], **kwargs) diff --git a/functorch/test/test_pythonkey.py b/functorch/test/test_pythonkey.py index 1b5c933e835..5deeac1eb27 100644 --- a/functorch/test/test_pythonkey.py +++ b/functorch/test/test_pythonkey.py @@ -598,14 +598,16 @@ class TestAOTModuleSimplified(AOTTestCase): assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad) def test_aot_module_simplified_preserves_stack_trace(self): - class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(20, 30) def forward(self, x, y): - return (self.linear(x) + y, ) + z = self.linear(x) + z = z + y + z = z.relu() + return (z, ) tracer = torch.fx.Tracer() tracer.record_stack_traces = True @@ -626,7 +628,7 @@ class TestAOTModuleSimplified(AOTTestCase): assert 'test_pythonkey.py' in node.stack_trace return gm.forward # return a python callable - aot_mod = aot_module_simplified(mod, fw_compiler=assert_compiler, bw_compiler=nop) + aot_mod = aot_module_simplified(mod, fw_compiler=assert_compiler, bw_compiler=assert_compiler) x = torch.randn(128, 20, requires_grad=True) y = torch.randn(128, 30, requires_grad=True) diff --git a/torch/csrc/autograd/python_anomaly_mode.cpp b/torch/csrc/autograd/python_anomaly_mode.cpp index ec5dfe1b099..3c91316c06f 100644 --- a/torch/csrc/autograd/python_anomaly_mode.cpp +++ b/torch/csrc/autograd/python_anomaly_mode.cpp @@ -16,7 +16,7 @@ namespace autograd { void PyAnomalyMetadata::store_stack() { pybind11::gil_scoped_acquire gil; - THPObjectPtr mod(PyImport_ImportModule("traceback")); + THPObjectPtr mod(PyImport_ImportModule("torch.fx.traceback")); if (!mod) { throw python_error(); } diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index aac94a85e17..a07b36b997b 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from typing import Optional, List from ._compatibility import compatibility -__all__ = ['override_stack_trace', 'append_stack_trace', 'format_stack', 'is_stack_trace_overridden'] +__all__ = ['override_stack_trace', 'set_stack_trace', 'append_stack_trace', 'format_stack', 'is_stack_trace_overridden'] current_stack: List[str] = [] @@ -23,6 +23,13 @@ def override_stack_trace(): is_overridden = saved_is_overridden +@compatibility(is_backward_compatible=False) +def set_stack_trace(stack : List[str]): + global current_stack + + if is_overridden and stack: + current_stack = stack + @compatibility(is_backward_compatible=False) @contextmanager def append_stack_trace(stack : Optional[str]): @@ -44,7 +51,7 @@ def append_stack_trace(stack : Optional[str]): @compatibility(is_backward_compatible=False) def format_stack() -> List[str]: if is_overridden: - return current_stack + return current_stack.copy() else: # fallback to traceback.format_stack() return traceback.format_stack()