pytorch/torch/csrc/autograd/python_anomaly_mode.cpp
Sherlock Huang a7baad04f6 Preserve stack trace for backward nodes over AOTAutograd (#83558)
For the following program.
```
def my_relu(a):
    return a.relu()

def func(a, b):
    a = torch.nn.Linear(10, 10)(a)
    d = torch.square(b)
    d = my_relu(d)
    loss = d.sum()

    return loss

with torchdynamo.optimize("aot_nop"):
    x = torch.rand(10, 10, requires_grad=True)
    y = torch.rand(10, 10, requires_grad=True)
    out = func(x, y)
```

It would generate the following fx graph with stack_trace populated in both forward and backward nodes.
```
def forward(self, primals, tangents):
    primals_1, primals_2, primals_3, primals_4, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    t_default = torch.ops.aten.t.default(primals_3);  primals_3 = None
    addmm_default = torch.ops.aten.addmm.default(primals_4, primals_1, t_default);  primals_4 = primals_1 = t_default = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(primals_2, 2)
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar);  pow_tensor_scalar = None
    detach_default = torch.ops.aten.detach.default(relu_default)
    sum_default = torch.ops.aten.sum.default(relu_default);  relu_default = None
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    expand_default = torch.ops.aten.expand.default(tangents_1, [10, 10]);  tangents_1 = None
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(expand_default, detach_default_1, 0);  expand_default = detach_default_1 = None
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(primals_2, 1.0);  primals_2 = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor = torch.ops.aten.mul.Tensor(threshold_backward_default, mul_scalar);  threshold_backward_default = mul_scalar = None
    return pytree.tree_unflatten([sum_default, None, mul_tensor, None, None], self._out_spec)

====== joint graph =======
primals_1 None
primals_2 None
primals_3 None
primals_4 None
tangents_1 None
t_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 12, in func
    def func(a, b):
  File "/fsx/users/bahuang/repos/pytorch_fsx/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)

addmm_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 12, in func
    def func(a, b):
  File "/fsx/users/bahuang/repos/pytorch_fsx/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)

pow_tensor_scalar   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
    d = torch.square(b)

relu_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
    d = my_relu(d)
  File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
    return a.relu()

detach_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
    d = my_relu(d)
  File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
    return a.relu()

sum_default
is_same_size_default
expand_default
detach_default_1   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
    d = my_relu(d)
  File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
    return a.relu()

threshold_backward_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
    d = my_relu(d)
  File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
    return a.relu()

pow_tensor_scalar_1   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
    d = torch.square(b)

mul_scalar   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
    d = torch.square(b)

mul_tensor   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
    d = torch.square(b)

output None
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83558
Approved by: https://github.com/albanD
2022-08-18 22:13:04 +00:00

133 lines
3.9 KiB
C++

#include <c10/util/Exception.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/autograd/python_anomaly_mode.h>
#include <torch/csrc/autograd/python_cpp_function.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/auto_gil.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_strings.h>
#include <iostream>
namespace torch {
namespace autograd {
void PyAnomalyMetadata::store_stack() {
pybind11::gil_scoped_acquire gil;
THPObjectPtr mod(PyImport_ImportModule("torch.fx.traceback"));
if (!mod) {
throw python_error();
}
THPObjectPtr list(PyObject_CallMethod(mod.get(), "format_stack", ""));
if (!list) {
throw python_error();
}
if (PyDict_SetItemString(dict(), ANOMALY_TRACE_KEY, list.get())) {
throw python_error();
}
}
void PyAnomalyMetadata::print_stack(const std::string& current_node_name) {
pybind11::gil_scoped_acquire gil;
if (!PyDict_Check(dict())) {
throw std::runtime_error("Anomaly metadata is not a python dictionary.");
}
PyObject* trace_stack = PyDict_GetItemString(dict(), ANOMALY_TRACE_KEY);
_print_stack(trace_stack, current_node_name, false);
PyObject* pyparent(PyDict_GetItemString(dict(), ANOMALY_PARENT_KEY));
// if there is no "parent_" in metadata, then it means this metadata's node
// is the root and stop printing the traceback
while (pyparent) {
THPObjectPtr parent_metadata(PyObject_GetAttrString(pyparent, "metadata"));
if (!parent_metadata) {
throw python_error();
}
THPObjectPtr parent_name_pyobj(PyObject_CallMethod(pyparent, "name", ""));
if (!parent_name_pyobj) {
throw python_error();
}
const char* parent_name_char = PyUnicode_AsUTF8(parent_name_pyobj.get());
if (!parent_name_char) {
throw python_error();
}
const std::string parent_name(parent_name_char);
PyObject* parent_stack =
PyDict_GetItemString(parent_metadata.get(), ANOMALY_TRACE_KEY);
_print_stack(parent_stack, parent_name, true);
// get the parent of this node, if this node is a root, pyparent is simply
// null
pyparent = PyDict_GetItemString(parent_metadata.get(), ANOMALY_PARENT_KEY);
}
}
void PyAnomalyMetadata::assign_parent(
const std::shared_ptr<Node>& parent_node) {
// assign the python object of parent_node in metadata["parent_"]
// if parent_node is nullptr, then do nothing (it can mean that "parent_" key
// is not in metadata)
pybind11::gil_scoped_acquire gil;
if (!parent_node)
return;
THPObjectPtr parent_node_(functionToPyObject(parent_node));
if (!parent_node_) {
throw python_error();
}
if (PyDict_SetItemString(dict(), ANOMALY_PARENT_KEY, parent_node_.get())) {
throw python_error();
}
}
void _print_stack(
PyObject* stack,
const std::string& current_node_name,
bool is_parent) {
if (!stack) {
TORCH_WARN(
"Error detected in ",
current_node_name,
". ",
"No forward pass information available. Enable detect anomaly "
"during forward pass for more information.");
return;
}
THPObjectPtr empty_string(PyUnicode_FromString(""));
if (!empty_string) {
throw python_error();
}
// stack is a list of Python strings ending with newlines. Use join to convert
// to a single string.
THPObjectPtr msg(PyUnicode_Join(empty_string, stack));
if (!msg) {
throw python_error();
}
if (!is_parent) {
TORCH_WARN(
"Error detected in ",
current_node_name,
". ",
"Traceback of forward call that caused the error:\n",
THPUtils_unpackString(msg.get()));
} else {
TORCH_WARN(
"\n\n",
"Previous calculation was induced by ",
current_node_name,
". "
"Traceback of forward call that induced the previous calculation:\n",
THPUtils_unpackString(msg.get()));
}
}
} // namespace autograd
} // namespace torch