mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
For the following program.
```
def my_relu(a):
return a.relu()
def func(a, b):
a = torch.nn.Linear(10, 10)(a)
d = torch.square(b)
d = my_relu(d)
loss = d.sum()
return loss
with torchdynamo.optimize("aot_nop"):
x = torch.rand(10, 10, requires_grad=True)
y = torch.rand(10, 10, requires_grad=True)
out = func(x, y)
```
It would generate the following fx graph with stack_trace populated in both forward and backward nodes.
```
def forward(self, primals, tangents):
primals_1, primals_2, primals_3, primals_4, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
t_default = torch.ops.aten.t.default(primals_3); primals_3 = None
addmm_default = torch.ops.aten.addmm.default(primals_4, primals_1, t_default); primals_4 = primals_1 = t_default = None
pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(primals_2, 2)
relu_default = torch.ops.aten.relu.default(pow_tensor_scalar); pow_tensor_scalar = None
detach_default = torch.ops.aten.detach.default(relu_default)
sum_default = torch.ops.aten.sum.default(relu_default); relu_default = None
is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
expand_default = torch.ops.aten.expand.default(tangents_1, [10, 10]); tangents_1 = None
detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None
threshold_backward_default = torch.ops.aten.threshold_backward.default(expand_default, detach_default_1, 0); expand_default = detach_default_1 = None
pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(primals_2, 1.0); primals_2 = None
mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None
mul_tensor = torch.ops.aten.mul.Tensor(threshold_backward_default, mul_scalar); threshold_backward_default = mul_scalar = None
return pytree.tree_unflatten([sum_default, None, mul_tensor, None, None], self._out_spec)
====== joint graph =======
primals_1 None
primals_2 None
primals_3 None
primals_4 None
tangents_1 None
t_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 12, in func
def func(a, b):
File "/fsx/users/bahuang/repos/pytorch_fsx/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
addmm_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 12, in func
def func(a, b):
File "/fsx/users/bahuang/repos/pytorch_fsx/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
pow_tensor_scalar File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
d = torch.square(b)
relu_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
d = my_relu(d)
File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
return a.relu()
detach_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
d = my_relu(d)
File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
return a.relu()
sum_default
is_same_size_default
expand_default
detach_default_1 File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
d = my_relu(d)
File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
return a.relu()
threshold_backward_default File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
d = my_relu(d)
File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
return a.relu()
pow_tensor_scalar_1 File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
d = torch.square(b)
mul_scalar File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
d = torch.square(b)
mul_tensor File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
d = torch.square(b)
output None
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83558
Approved by: https://github.com/albanD
63 lines
1.5 KiB
Python
63 lines
1.5 KiB
Python
import traceback
|
|
from contextlib import contextmanager
|
|
from typing import Optional, List
|
|
from ._compatibility import compatibility
|
|
|
|
__all__ = ['override_stack_trace', 'set_stack_trace', 'append_stack_trace', 'format_stack', 'is_stack_trace_overridden']
|
|
|
|
|
|
current_stack: List[str] = []
|
|
is_overridden = False
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
@contextmanager
|
|
def override_stack_trace():
|
|
global is_overridden
|
|
|
|
saved_is_overridden = is_overridden
|
|
try:
|
|
is_overridden = True
|
|
yield
|
|
finally:
|
|
is_overridden = saved_is_overridden
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def set_stack_trace(stack : List[str]):
|
|
global current_stack
|
|
|
|
if is_overridden and stack:
|
|
current_stack = stack
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
@contextmanager
|
|
def append_stack_trace(stack : Optional[str]):
|
|
"""
|
|
The content of stack here is an entire stacktraces as a string
|
|
"""
|
|
global current_stack
|
|
|
|
if is_overridden and stack:
|
|
try:
|
|
current_stack.append(stack)
|
|
yield
|
|
finally:
|
|
current_stack.pop()
|
|
else:
|
|
yield
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def format_stack() -> List[str]:
|
|
if is_overridden:
|
|
return current_stack.copy()
|
|
else:
|
|
# fallback to traceback.format_stack()
|
|
return traceback.format_stack()
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def is_stack_trace_overridden() -> bool:
|
|
return is_overridden
|