pytorch/torch/fx/traceback.py
Sherlock Huang a7baad04f6 Preserve stack trace for backward nodes over AOTAutograd (#83558)
For the following program.
```
def my_relu(a):
    return a.relu()

def func(a, b):
    a = torch.nn.Linear(10, 10)(a)
    d = torch.square(b)
    d = my_relu(d)
    loss = d.sum()

    return loss

with torchdynamo.optimize("aot_nop"):
    x = torch.rand(10, 10, requires_grad=True)
    y = torch.rand(10, 10, requires_grad=True)
    out = func(x, y)
```

It would generate the following fx graph with stack_trace populated in both forward and backward nodes.
```
def forward(self, primals, tangents):
    primals_1, primals_2, primals_3, primals_4, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    t_default = torch.ops.aten.t.default(primals_3);  primals_3 = None
    addmm_default = torch.ops.aten.addmm.default(primals_4, primals_1, t_default);  primals_4 = primals_1 = t_default = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(primals_2, 2)
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar);  pow_tensor_scalar = None
    detach_default = torch.ops.aten.detach.default(relu_default)
    sum_default = torch.ops.aten.sum.default(relu_default);  relu_default = None
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    expand_default = torch.ops.aten.expand.default(tangents_1, [10, 10]);  tangents_1 = None
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(expand_default, detach_default_1, 0);  expand_default = detach_default_1 = None
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(primals_2, 1.0);  primals_2 = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor = torch.ops.aten.mul.Tensor(threshold_backward_default, mul_scalar);  threshold_backward_default = mul_scalar = None
    return pytree.tree_unflatten([sum_default, None, mul_tensor, None, None], self._out_spec)

====== joint graph =======
primals_1 None
primals_2 None
primals_3 None
primals_4 None
tangents_1 None
t_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 12, in func
    def func(a, b):
  File "/fsx/users/bahuang/repos/pytorch_fsx/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)

addmm_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 12, in func
    def func(a, b):
  File "/fsx/users/bahuang/repos/pytorch_fsx/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)

pow_tensor_scalar   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
    d = torch.square(b)

relu_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
    d = my_relu(d)
  File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
    return a.relu()

detach_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
    d = my_relu(d)
  File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
    return a.relu()

sum_default
is_same_size_default
expand_default
detach_default_1   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
    d = my_relu(d)
  File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
    return a.relu()

threshold_backward_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
    d = my_relu(d)
  File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
    return a.relu()

pow_tensor_scalar_1   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
    d = torch.square(b)

mul_scalar   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
    d = torch.square(b)

mul_tensor   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
    d = torch.square(b)

output None
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83558
Approved by: https://github.com/albanD
2022-08-18 22:13:04 +00:00

63 lines
1.5 KiB
Python

import traceback
from contextlib import contextmanager
from typing import Optional, List
from ._compatibility import compatibility
__all__ = ['override_stack_trace', 'set_stack_trace', 'append_stack_trace', 'format_stack', 'is_stack_trace_overridden']
current_stack: List[str] = []
is_overridden = False
@compatibility(is_backward_compatible=False)
@contextmanager
def override_stack_trace():
global is_overridden
saved_is_overridden = is_overridden
try:
is_overridden = True
yield
finally:
is_overridden = saved_is_overridden
@compatibility(is_backward_compatible=False)
def set_stack_trace(stack : List[str]):
global current_stack
if is_overridden and stack:
current_stack = stack
@compatibility(is_backward_compatible=False)
@contextmanager
def append_stack_trace(stack : Optional[str]):
"""
The content of stack here is an entire stacktraces as a string
"""
global current_stack
if is_overridden and stack:
try:
current_stack.append(stack)
yield
finally:
current_stack.pop()
else:
yield
@compatibility(is_backward_compatible=False)
def format_stack() -> List[str]:
if is_overridden:
return current_stack.copy()
else:
# fallback to traceback.format_stack()
return traceback.format_stack()
@compatibility(is_backward_compatible=False)
def is_stack_trace_overridden() -> bool:
return is_overridden