Preserve stack trace for backward nodes over AOTAutograd (#83558)

For the following program.
```
def my_relu(a):
    return a.relu()

def func(a, b):
    a = torch.nn.Linear(10, 10)(a)
    d = torch.square(b)
    d = my_relu(d)
    loss = d.sum()

    return loss

with torchdynamo.optimize("aot_nop"):
    x = torch.rand(10, 10, requires_grad=True)
    y = torch.rand(10, 10, requires_grad=True)
    out = func(x, y)
```

It would generate the following fx graph with stack_trace populated in both forward and backward nodes.
```
def forward(self, primals, tangents):
    primals_1, primals_2, primals_3, primals_4, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    t_default = torch.ops.aten.t.default(primals_3);  primals_3 = None
    addmm_default = torch.ops.aten.addmm.default(primals_4, primals_1, t_default);  primals_4 = primals_1 = t_default = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(primals_2, 2)
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar);  pow_tensor_scalar = None
    detach_default = torch.ops.aten.detach.default(relu_default)
    sum_default = torch.ops.aten.sum.default(relu_default);  relu_default = None
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    expand_default = torch.ops.aten.expand.default(tangents_1, [10, 10]);  tangents_1 = None
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(expand_default, detach_default_1, 0);  expand_default = detach_default_1 = None
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(primals_2, 1.0);  primals_2 = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor = torch.ops.aten.mul.Tensor(threshold_backward_default, mul_scalar);  threshold_backward_default = mul_scalar = None
    return pytree.tree_unflatten([sum_default, None, mul_tensor, None, None], self._out_spec)

====== joint graph =======
primals_1 None
primals_2 None
primals_3 None
primals_4 None
tangents_1 None
t_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 12, in func
    def func(a, b):
  File "/fsx/users/bahuang/repos/pytorch_fsx/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)

addmm_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 12, in func
    def func(a, b):
  File "/fsx/users/bahuang/repos/pytorch_fsx/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)

pow_tensor_scalar   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
    d = torch.square(b)

relu_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
    d = my_relu(d)
  File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
    return a.relu()

detach_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
    d = my_relu(d)
  File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
    return a.relu()

sum_default
is_same_size_default
expand_default
detach_default_1   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
    d = my_relu(d)
  File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
    return a.relu()

threshold_backward_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
    d = my_relu(d)
  File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
    return a.relu()

pow_tensor_scalar_1   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
    d = torch.square(b)

mul_scalar   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
    d = torch.square(b)

mul_tensor   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
    d = torch.square(b)

output None
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83558
Approved by: https://github.com/albanD
This commit is contained in:
Sherlock Huang 2022-08-18 17:54:52 +00:00 committed by PyTorch MergeBot
parent e2e71c1f4c
commit a7baad04f6
4 changed files with 93 additions and 13 deletions

View File

@ -1,3 +1,4 @@
import collections
import dataclasses import dataclasses
import warnings import warnings
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
@ -59,6 +60,70 @@ def preserve_rng_state():
torch.cuda.set_rng_state(cuda_rng_state) torch.cuda.set_rng_state(cuda_rng_state)
# Set up hooks so that during backward the fx's stack_trace is properly set
callback_set = False
def setup_stacktrace_preservation_hooks(roots: List):
def iter_graph(roots):
if not roots:
return
seen = set()
q = collections.deque()
for node in roots:
if node is not None:
seen.add(node)
q.append(node)
while q:
node = q.popleft()
for fn, _idx in node.next_functions:
if fn in seen or fn is None:
continue
seen.add(fn)
q.append(fn)
yield node
def get_callback(saved_stack_):
def callback():
global callback_set
fx_traceback.set_stack_trace(saved_stack_)
callback_set = False
return callback
def get_prehook(stack_):
def prehook(grad_output):
global callback_set
if not callback_set:
torch.autograd.variable.Variable._execution_engine.queue_callback(
get_callback(fx_traceback.format_stack())
)
callback_set = True
fx_traceback.set_stack_trace(stack_)
return prehook
def get_posthook(special_stack_):
def posthook(grad_input, grad_output):
fx_traceback.set_stack_trace(special_stack_)
return posthook
for node in iter_graph(roots):
forward_node_stack = node.metadata.get("traceback_", [])
node.register_prehook(get_prehook(forward_node_stack))
special_stack = forward_node_stack.copy()
special_stack.append(
"Gradient addition node due to mulitple use of tensor around:"
)
node.register_hook(get_posthook(special_stack))
def create_joint_forward_backward(fn): def create_joint_forward_backward(fn):
def joint_forward_backward( def joint_forward_backward(
primals: List[Any], tangents: List[Any] primals: List[Any], tangents: List[Any]
@ -82,15 +147,19 @@ def create_joint_forward_backward(fn):
if isinstance(out, Tensor) and out.requires_grad: if isinstance(out, Tensor) and out.requires_grad:
needed_outs.append(out) needed_outs.append(out)
needed_tangents.append(tangent) needed_tangents.append(tangent)
setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs])
backward_out = [] backward_out = []
# Call the backwards pass # Call the backwards pass
if grad_primals: if grad_primals:
backward_out = torch.autograd.grad( with fx_traceback.override_stack_trace():
needed_outs, backward_out = torch.autograd.grad(
grad_primals, needed_outs,
grad_outputs=needed_tangents, grad_primals,
allow_unused=True, grad_outputs=needed_tangents,
) allow_unused=True,
)
backward_out_iter = iter(backward_out) backward_out_iter = iter(backward_out)
return outs, [ return outs, [
next(backward_out_iter) if i else None for i in inputs_needs_grads next(backward_out_iter) if i else None for i in inputs_needs_grads
@ -735,7 +804,9 @@ def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module:
mod, pytree.tree_unflatten(args[:params_len], params_spec) mod, pytree.tree_unflatten(args[:params_len], params_spec)
): ):
if isinstance(mod, torch.fx.GraphModule): if isinstance(mod, torch.fx.GraphModule):
with fx_traceback.override_stack_trace(): with fx_traceback.override_stack_trace(), torch.autograd.detect_anomaly(
check_nan=False
):
out = Interpreter(mod).run(*args[params_len:], **kwargs) out = Interpreter(mod).run(*args[params_len:], **kwargs)
else: else:
out = mod(*args[params_len:], **kwargs) out = mod(*args[params_len:], **kwargs)

View File

@ -598,14 +598,16 @@ class TestAOTModuleSimplified(AOTTestCase):
assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad) assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad)
def test_aot_module_simplified_preserves_stack_trace(self): def test_aot_module_simplified_preserves_stack_trace(self):
class MockModule(torch.nn.Module): class MockModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.linear = torch.nn.Linear(20, 30) self.linear = torch.nn.Linear(20, 30)
def forward(self, x, y): def forward(self, x, y):
return (self.linear(x) + y, ) z = self.linear(x)
z = z + y
z = z.relu()
return (z, )
tracer = torch.fx.Tracer() tracer = torch.fx.Tracer()
tracer.record_stack_traces = True tracer.record_stack_traces = True
@ -626,7 +628,7 @@ class TestAOTModuleSimplified(AOTTestCase):
assert 'test_pythonkey.py' in node.stack_trace assert 'test_pythonkey.py' in node.stack_trace
return gm.forward # return a python callable return gm.forward # return a python callable
aot_mod = aot_module_simplified(mod, fw_compiler=assert_compiler, bw_compiler=nop) aot_mod = aot_module_simplified(mod, fw_compiler=assert_compiler, bw_compiler=assert_compiler)
x = torch.randn(128, 20, requires_grad=True) x = torch.randn(128, 20, requires_grad=True)
y = torch.randn(128, 30, requires_grad=True) y = torch.randn(128, 30, requires_grad=True)

View File

@ -16,7 +16,7 @@ namespace autograd {
void PyAnomalyMetadata::store_stack() { void PyAnomalyMetadata::store_stack() {
pybind11::gil_scoped_acquire gil; pybind11::gil_scoped_acquire gil;
THPObjectPtr mod(PyImport_ImportModule("traceback")); THPObjectPtr mod(PyImport_ImportModule("torch.fx.traceback"));
if (!mod) { if (!mod) {
throw python_error(); throw python_error();
} }

View File

@ -3,7 +3,7 @@ from contextlib import contextmanager
from typing import Optional, List from typing import Optional, List
from ._compatibility import compatibility from ._compatibility import compatibility
__all__ = ['override_stack_trace', 'append_stack_trace', 'format_stack', 'is_stack_trace_overridden'] __all__ = ['override_stack_trace', 'set_stack_trace', 'append_stack_trace', 'format_stack', 'is_stack_trace_overridden']
current_stack: List[str] = [] current_stack: List[str] = []
@ -23,6 +23,13 @@ def override_stack_trace():
is_overridden = saved_is_overridden is_overridden = saved_is_overridden
@compatibility(is_backward_compatible=False)
def set_stack_trace(stack : List[str]):
global current_stack
if is_overridden and stack:
current_stack = stack
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
@contextmanager @contextmanager
def append_stack_trace(stack : Optional[str]): def append_stack_trace(stack : Optional[str]):
@ -44,7 +51,7 @@ def append_stack_trace(stack : Optional[str]):
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
def format_stack() -> List[str]: def format_stack() -> List[str]:
if is_overridden: if is_overridden:
return current_stack return current_stack.copy()
else: else:
# fallback to traceback.format_stack() # fallback to traceback.format_stack()
return traceback.format_stack() return traceback.format_stack()