mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Preserve stack trace for backward nodes over AOTAutograd (#83558)
For the following program.
```
def my_relu(a):
return a.relu()
def func(a, b):
a = torch.nn.Linear(10, 10)(a)
d = torch.square(b)
d = my_relu(d)
loss = d.sum()
return loss
with torchdynamo.optimize("aot_nop"):
x = torch.rand(10, 10, requires_grad=True)
y = torch.rand(10, 10, requires_grad=True)
out = func(x, y)
```
It would generate the following fx graph with stack_trace populated in both forward and backward nodes.
```
def forward(self, primals, tangents):
primals_1, primals_2, primals_3, primals_4, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
t_default = torch.ops.aten.t.default(primals_3); primals_3 = None
addmm_default = torch.ops.aten.addmm.default(primals_4, primals_1, t_default); primals_4 = primals_1 = t_default = None
pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(primals_2, 2)
relu_default = torch.ops.aten.relu.default(pow_tensor_scalar); pow_tensor_scalar = None
detach_default = torch.ops.aten.detach.default(relu_default)
sum_default = torch.ops.aten.sum.default(relu_default); relu_default = None
is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
expand_default = torch.ops.aten.expand.default(tangents_1, [10, 10]); tangents_1 = None
detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None
threshold_backward_default = torch.ops.aten.threshold_backward.default(expand_default, detach_default_1, 0); expand_default = detach_default_1 = None
pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(primals_2, 1.0); primals_2 = None
mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None
mul_tensor = torch.ops.aten.mul.Tensor(threshold_backward_default, mul_scalar); threshold_backward_default = mul_scalar = None
return pytree.tree_unflatten([sum_default, None, mul_tensor, None, None], self._out_spec)
====== joint graph =======
primals_1 None
primals_2 None
primals_3 None
primals_4 None
tangents_1 None
t_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 12, in func
def func(a, b):
File "/fsx/users/bahuang/repos/pytorch_fsx/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
addmm_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 12, in func
def func(a, b):
File "/fsx/users/bahuang/repos/pytorch_fsx/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
pow_tensor_scalar File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
d = torch.square(b)
relu_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
d = my_relu(d)
File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
return a.relu()
detach_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
d = my_relu(d)
File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
return a.relu()
sum_default
is_same_size_default
expand_default
detach_default_1 File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
d = my_relu(d)
File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
return a.relu()
threshold_backward_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
d = my_relu(d)
File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
return a.relu()
pow_tensor_scalar_1 File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
d = torch.square(b)
mul_scalar File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
d = torch.square(b)
mul_tensor File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
d = torch.square(b)
output None
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83558
Approved by: https://github.com/albanD
This commit is contained in:
parent
e2e71c1f4c
commit
a7baad04f6
|
|
@ -1,3 +1,4 @@
|
||||||
|
import collections
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
|
|
@ -59,6 +60,70 @@ def preserve_rng_state():
|
||||||
torch.cuda.set_rng_state(cuda_rng_state)
|
torch.cuda.set_rng_state(cuda_rng_state)
|
||||||
|
|
||||||
|
|
||||||
|
# Set up hooks so that during backward the fx's stack_trace is properly set
|
||||||
|
callback_set = False
|
||||||
|
|
||||||
|
|
||||||
|
def setup_stacktrace_preservation_hooks(roots: List):
|
||||||
|
def iter_graph(roots):
|
||||||
|
if not roots:
|
||||||
|
return
|
||||||
|
seen = set()
|
||||||
|
q = collections.deque()
|
||||||
|
for node in roots:
|
||||||
|
if node is not None:
|
||||||
|
seen.add(node)
|
||||||
|
q.append(node)
|
||||||
|
|
||||||
|
while q:
|
||||||
|
node = q.popleft()
|
||||||
|
for fn, _idx in node.next_functions:
|
||||||
|
if fn in seen or fn is None:
|
||||||
|
continue
|
||||||
|
seen.add(fn)
|
||||||
|
q.append(fn)
|
||||||
|
|
||||||
|
yield node
|
||||||
|
|
||||||
|
def get_callback(saved_stack_):
|
||||||
|
def callback():
|
||||||
|
global callback_set
|
||||||
|
fx_traceback.set_stack_trace(saved_stack_)
|
||||||
|
callback_set = False
|
||||||
|
|
||||||
|
return callback
|
||||||
|
|
||||||
|
def get_prehook(stack_):
|
||||||
|
def prehook(grad_output):
|
||||||
|
global callback_set
|
||||||
|
|
||||||
|
if not callback_set:
|
||||||
|
torch.autograd.variable.Variable._execution_engine.queue_callback(
|
||||||
|
get_callback(fx_traceback.format_stack())
|
||||||
|
)
|
||||||
|
callback_set = True
|
||||||
|
|
||||||
|
fx_traceback.set_stack_trace(stack_)
|
||||||
|
|
||||||
|
return prehook
|
||||||
|
|
||||||
|
def get_posthook(special_stack_):
|
||||||
|
def posthook(grad_input, grad_output):
|
||||||
|
fx_traceback.set_stack_trace(special_stack_)
|
||||||
|
|
||||||
|
return posthook
|
||||||
|
|
||||||
|
for node in iter_graph(roots):
|
||||||
|
forward_node_stack = node.metadata.get("traceback_", [])
|
||||||
|
node.register_prehook(get_prehook(forward_node_stack))
|
||||||
|
|
||||||
|
special_stack = forward_node_stack.copy()
|
||||||
|
special_stack.append(
|
||||||
|
"Gradient addition node due to mulitple use of tensor around:"
|
||||||
|
)
|
||||||
|
node.register_hook(get_posthook(special_stack))
|
||||||
|
|
||||||
|
|
||||||
def create_joint_forward_backward(fn):
|
def create_joint_forward_backward(fn):
|
||||||
def joint_forward_backward(
|
def joint_forward_backward(
|
||||||
primals: List[Any], tangents: List[Any]
|
primals: List[Any], tangents: List[Any]
|
||||||
|
|
@ -82,15 +147,19 @@ def create_joint_forward_backward(fn):
|
||||||
if isinstance(out, Tensor) and out.requires_grad:
|
if isinstance(out, Tensor) and out.requires_grad:
|
||||||
needed_outs.append(out)
|
needed_outs.append(out)
|
||||||
needed_tangents.append(tangent)
|
needed_tangents.append(tangent)
|
||||||
|
|
||||||
|
setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs])
|
||||||
|
|
||||||
backward_out = []
|
backward_out = []
|
||||||
# Call the backwards pass
|
# Call the backwards pass
|
||||||
if grad_primals:
|
if grad_primals:
|
||||||
backward_out = torch.autograd.grad(
|
with fx_traceback.override_stack_trace():
|
||||||
needed_outs,
|
backward_out = torch.autograd.grad(
|
||||||
grad_primals,
|
needed_outs,
|
||||||
grad_outputs=needed_tangents,
|
grad_primals,
|
||||||
allow_unused=True,
|
grad_outputs=needed_tangents,
|
||||||
)
|
allow_unused=True,
|
||||||
|
)
|
||||||
backward_out_iter = iter(backward_out)
|
backward_out_iter = iter(backward_out)
|
||||||
return outs, [
|
return outs, [
|
||||||
next(backward_out_iter) if i else None for i in inputs_needs_grads
|
next(backward_out_iter) if i else None for i in inputs_needs_grads
|
||||||
|
|
@ -735,7 +804,9 @@ def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module:
|
||||||
mod, pytree.tree_unflatten(args[:params_len], params_spec)
|
mod, pytree.tree_unflatten(args[:params_len], params_spec)
|
||||||
):
|
):
|
||||||
if isinstance(mod, torch.fx.GraphModule):
|
if isinstance(mod, torch.fx.GraphModule):
|
||||||
with fx_traceback.override_stack_trace():
|
with fx_traceback.override_stack_trace(), torch.autograd.detect_anomaly(
|
||||||
|
check_nan=False
|
||||||
|
):
|
||||||
out = Interpreter(mod).run(*args[params_len:], **kwargs)
|
out = Interpreter(mod).run(*args[params_len:], **kwargs)
|
||||||
else:
|
else:
|
||||||
out = mod(*args[params_len:], **kwargs)
|
out = mod(*args[params_len:], **kwargs)
|
||||||
|
|
|
||||||
|
|
@ -598,14 +598,16 @@ class TestAOTModuleSimplified(AOTTestCase):
|
||||||
assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad)
|
assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad)
|
||||||
|
|
||||||
def test_aot_module_simplified_preserves_stack_trace(self):
|
def test_aot_module_simplified_preserves_stack_trace(self):
|
||||||
|
|
||||||
class MockModule(torch.nn.Module):
|
class MockModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear = torch.nn.Linear(20, 30)
|
self.linear = torch.nn.Linear(20, 30)
|
||||||
|
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return (self.linear(x) + y, )
|
z = self.linear(x)
|
||||||
|
z = z + y
|
||||||
|
z = z.relu()
|
||||||
|
return (z, )
|
||||||
|
|
||||||
tracer = torch.fx.Tracer()
|
tracer = torch.fx.Tracer()
|
||||||
tracer.record_stack_traces = True
|
tracer.record_stack_traces = True
|
||||||
|
|
@ -626,7 +628,7 @@ class TestAOTModuleSimplified(AOTTestCase):
|
||||||
assert 'test_pythonkey.py' in node.stack_trace
|
assert 'test_pythonkey.py' in node.stack_trace
|
||||||
return gm.forward # return a python callable
|
return gm.forward # return a python callable
|
||||||
|
|
||||||
aot_mod = aot_module_simplified(mod, fw_compiler=assert_compiler, bw_compiler=nop)
|
aot_mod = aot_module_simplified(mod, fw_compiler=assert_compiler, bw_compiler=assert_compiler)
|
||||||
|
|
||||||
x = torch.randn(128, 20, requires_grad=True)
|
x = torch.randn(128, 20, requires_grad=True)
|
||||||
y = torch.randn(128, 30, requires_grad=True)
|
y = torch.randn(128, 30, requires_grad=True)
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ namespace autograd {
|
||||||
|
|
||||||
void PyAnomalyMetadata::store_stack() {
|
void PyAnomalyMetadata::store_stack() {
|
||||||
pybind11::gil_scoped_acquire gil;
|
pybind11::gil_scoped_acquire gil;
|
||||||
THPObjectPtr mod(PyImport_ImportModule("traceback"));
|
THPObjectPtr mod(PyImport_ImportModule("torch.fx.traceback"));
|
||||||
if (!mod) {
|
if (!mod) {
|
||||||
throw python_error();
|
throw python_error();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from contextlib import contextmanager
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from ._compatibility import compatibility
|
from ._compatibility import compatibility
|
||||||
|
|
||||||
__all__ = ['override_stack_trace', 'append_stack_trace', 'format_stack', 'is_stack_trace_overridden']
|
__all__ = ['override_stack_trace', 'set_stack_trace', 'append_stack_trace', 'format_stack', 'is_stack_trace_overridden']
|
||||||
|
|
||||||
|
|
||||||
current_stack: List[str] = []
|
current_stack: List[str] = []
|
||||||
|
|
@ -23,6 +23,13 @@ def override_stack_trace():
|
||||||
is_overridden = saved_is_overridden
|
is_overridden = saved_is_overridden
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=False)
|
||||||
|
def set_stack_trace(stack : List[str]):
|
||||||
|
global current_stack
|
||||||
|
|
||||||
|
if is_overridden and stack:
|
||||||
|
current_stack = stack
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def append_stack_trace(stack : Optional[str]):
|
def append_stack_trace(stack : Optional[str]):
|
||||||
|
|
@ -44,7 +51,7 @@ def append_stack_trace(stack : Optional[str]):
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def format_stack() -> List[str]:
|
def format_stack() -> List[str]:
|
||||||
if is_overridden:
|
if is_overridden:
|
||||||
return current_stack
|
return current_stack.copy()
|
||||||
else:
|
else:
|
||||||
# fallback to traceback.format_stack()
|
# fallback to traceback.format_stack()
|
||||||
return traceback.format_stack()
|
return traceback.format_stack()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user