mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Preserve stack trace for backward nodes over AOTAutograd (#83558)
For the following program.
```
def my_relu(a):
return a.relu()
def func(a, b):
a = torch.nn.Linear(10, 10)(a)
d = torch.square(b)
d = my_relu(d)
loss = d.sum()
return loss
with torchdynamo.optimize("aot_nop"):
x = torch.rand(10, 10, requires_grad=True)
y = torch.rand(10, 10, requires_grad=True)
out = func(x, y)
```
It would generate the following fx graph with stack_trace populated in both forward and backward nodes.
```
def forward(self, primals, tangents):
primals_1, primals_2, primals_3, primals_4, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
t_default = torch.ops.aten.t.default(primals_3); primals_3 = None
addmm_default = torch.ops.aten.addmm.default(primals_4, primals_1, t_default); primals_4 = primals_1 = t_default = None
pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(primals_2, 2)
relu_default = torch.ops.aten.relu.default(pow_tensor_scalar); pow_tensor_scalar = None
detach_default = torch.ops.aten.detach.default(relu_default)
sum_default = torch.ops.aten.sum.default(relu_default); relu_default = None
is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
expand_default = torch.ops.aten.expand.default(tangents_1, [10, 10]); tangents_1 = None
detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None
threshold_backward_default = torch.ops.aten.threshold_backward.default(expand_default, detach_default_1, 0); expand_default = detach_default_1 = None
pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(primals_2, 1.0); primals_2 = None
mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None
mul_tensor = torch.ops.aten.mul.Tensor(threshold_backward_default, mul_scalar); threshold_backward_default = mul_scalar = None
return pytree.tree_unflatten([sum_default, None, mul_tensor, None, None], self._out_spec)
====== joint graph =======
primals_1 None
primals_2 None
primals_3 None
primals_4 None
tangents_1 None
t_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 12, in func
def func(a, b):
File "/fsx/users/bahuang/repos/pytorch_fsx/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
addmm_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 12, in func
def func(a, b):
File "/fsx/users/bahuang/repos/pytorch_fsx/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
pow_tensor_scalar File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
d = torch.square(b)
relu_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
d = my_relu(d)
File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
return a.relu()
detach_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
d = my_relu(d)
File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
return a.relu()
sum_default
is_same_size_default
expand_default
detach_default_1 File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
d = my_relu(d)
File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
return a.relu()
threshold_backward_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
d = my_relu(d)
File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
return a.relu()
pow_tensor_scalar_1 File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
d = torch.square(b)
mul_scalar File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
d = torch.square(b)
mul_tensor File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
d = torch.square(b)
output None
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83558
Approved by: https://github.com/albanD
This commit is contained in:
parent
e2e71c1f4c
commit
a7baad04f6
|
|
@ -1,3 +1,4 @@
|
|||
import collections
|
||||
import dataclasses
|
||||
import warnings
|
||||
from contextlib import contextmanager, nullcontext
|
||||
|
|
@ -59,6 +60,70 @@ def preserve_rng_state():
|
|||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
|
||||
|
||||
# Set up hooks so that during backward the fx's stack_trace is properly set
|
||||
callback_set = False
|
||||
|
||||
|
||||
def setup_stacktrace_preservation_hooks(roots: List):
|
||||
def iter_graph(roots):
|
||||
if not roots:
|
||||
return
|
||||
seen = set()
|
||||
q = collections.deque()
|
||||
for node in roots:
|
||||
if node is not None:
|
||||
seen.add(node)
|
||||
q.append(node)
|
||||
|
||||
while q:
|
||||
node = q.popleft()
|
||||
for fn, _idx in node.next_functions:
|
||||
if fn in seen or fn is None:
|
||||
continue
|
||||
seen.add(fn)
|
||||
q.append(fn)
|
||||
|
||||
yield node
|
||||
|
||||
def get_callback(saved_stack_):
|
||||
def callback():
|
||||
global callback_set
|
||||
fx_traceback.set_stack_trace(saved_stack_)
|
||||
callback_set = False
|
||||
|
||||
return callback
|
||||
|
||||
def get_prehook(stack_):
|
||||
def prehook(grad_output):
|
||||
global callback_set
|
||||
|
||||
if not callback_set:
|
||||
torch.autograd.variable.Variable._execution_engine.queue_callback(
|
||||
get_callback(fx_traceback.format_stack())
|
||||
)
|
||||
callback_set = True
|
||||
|
||||
fx_traceback.set_stack_trace(stack_)
|
||||
|
||||
return prehook
|
||||
|
||||
def get_posthook(special_stack_):
|
||||
def posthook(grad_input, grad_output):
|
||||
fx_traceback.set_stack_trace(special_stack_)
|
||||
|
||||
return posthook
|
||||
|
||||
for node in iter_graph(roots):
|
||||
forward_node_stack = node.metadata.get("traceback_", [])
|
||||
node.register_prehook(get_prehook(forward_node_stack))
|
||||
|
||||
special_stack = forward_node_stack.copy()
|
||||
special_stack.append(
|
||||
"Gradient addition node due to mulitple use of tensor around:"
|
||||
)
|
||||
node.register_hook(get_posthook(special_stack))
|
||||
|
||||
|
||||
def create_joint_forward_backward(fn):
|
||||
def joint_forward_backward(
|
||||
primals: List[Any], tangents: List[Any]
|
||||
|
|
@ -82,15 +147,19 @@ def create_joint_forward_backward(fn):
|
|||
if isinstance(out, Tensor) and out.requires_grad:
|
||||
needed_outs.append(out)
|
||||
needed_tangents.append(tangent)
|
||||
|
||||
setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs])
|
||||
|
||||
backward_out = []
|
||||
# Call the backwards pass
|
||||
if grad_primals:
|
||||
backward_out = torch.autograd.grad(
|
||||
needed_outs,
|
||||
grad_primals,
|
||||
grad_outputs=needed_tangents,
|
||||
allow_unused=True,
|
||||
)
|
||||
with fx_traceback.override_stack_trace():
|
||||
backward_out = torch.autograd.grad(
|
||||
needed_outs,
|
||||
grad_primals,
|
||||
grad_outputs=needed_tangents,
|
||||
allow_unused=True,
|
||||
)
|
||||
backward_out_iter = iter(backward_out)
|
||||
return outs, [
|
||||
next(backward_out_iter) if i else None for i in inputs_needs_grads
|
||||
|
|
@ -735,7 +804,9 @@ def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module:
|
|||
mod, pytree.tree_unflatten(args[:params_len], params_spec)
|
||||
):
|
||||
if isinstance(mod, torch.fx.GraphModule):
|
||||
with fx_traceback.override_stack_trace():
|
||||
with fx_traceback.override_stack_trace(), torch.autograd.detect_anomaly(
|
||||
check_nan=False
|
||||
):
|
||||
out = Interpreter(mod).run(*args[params_len:], **kwargs)
|
||||
else:
|
||||
out = mod(*args[params_len:], **kwargs)
|
||||
|
|
|
|||
|
|
@ -598,14 +598,16 @@ class TestAOTModuleSimplified(AOTTestCase):
|
|||
assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad)
|
||||
|
||||
def test_aot_module_simplified_preserves_stack_trace(self):
|
||||
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(20, 30)
|
||||
|
||||
def forward(self, x, y):
|
||||
return (self.linear(x) + y, )
|
||||
z = self.linear(x)
|
||||
z = z + y
|
||||
z = z.relu()
|
||||
return (z, )
|
||||
|
||||
tracer = torch.fx.Tracer()
|
||||
tracer.record_stack_traces = True
|
||||
|
|
@ -626,7 +628,7 @@ class TestAOTModuleSimplified(AOTTestCase):
|
|||
assert 'test_pythonkey.py' in node.stack_trace
|
||||
return gm.forward # return a python callable
|
||||
|
||||
aot_mod = aot_module_simplified(mod, fw_compiler=assert_compiler, bw_compiler=nop)
|
||||
aot_mod = aot_module_simplified(mod, fw_compiler=assert_compiler, bw_compiler=assert_compiler)
|
||||
|
||||
x = torch.randn(128, 20, requires_grad=True)
|
||||
y = torch.randn(128, 30, requires_grad=True)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ namespace autograd {
|
|||
|
||||
void PyAnomalyMetadata::store_stack() {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
THPObjectPtr mod(PyImport_ImportModule("traceback"));
|
||||
THPObjectPtr mod(PyImport_ImportModule("torch.fx.traceback"));
|
||||
if (!mod) {
|
||||
throw python_error();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from contextlib import contextmanager
|
|||
from typing import Optional, List
|
||||
from ._compatibility import compatibility
|
||||
|
||||
__all__ = ['override_stack_trace', 'append_stack_trace', 'format_stack', 'is_stack_trace_overridden']
|
||||
__all__ = ['override_stack_trace', 'set_stack_trace', 'append_stack_trace', 'format_stack', 'is_stack_trace_overridden']
|
||||
|
||||
|
||||
current_stack: List[str] = []
|
||||
|
|
@ -23,6 +23,13 @@ def override_stack_trace():
|
|||
is_overridden = saved_is_overridden
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def set_stack_trace(stack : List[str]):
|
||||
global current_stack
|
||||
|
||||
if is_overridden and stack:
|
||||
current_stack = stack
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@contextmanager
|
||||
def append_stack_trace(stack : Optional[str]):
|
||||
|
|
@ -44,7 +51,7 @@ def append_stack_trace(stack : Optional[str]):
|
|||
@compatibility(is_backward_compatible=False)
|
||||
def format_stack() -> List[str]:
|
||||
if is_overridden:
|
||||
return current_stack
|
||||
return current_stack.copy()
|
||||
else:
|
||||
# fallback to traceback.format_stack()
|
||||
return traceback.format_stack()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user