mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Fixes https://github.com/pytorch/pytorch/issues/43405. This pull request adds a feature of printing all tracebacks if a `detect_anomaly` mode detects `nan` in nested backward operations. The way I did it is by assigning a node as a parent to all nodes it produces during its backward calculation. Then if one of the children produces `nan`, it will print the traceback from the parent and grand parents (if any). The parent is assigned in `parent_node_` member in `Node` class which is accessible in C++ by function `node->parent()` and in Python by `node.parent_function`. A node has a parent iff: 1. it is created from a backward operation, and 2. created when anomaly mode and grad mode are both enabled. An example of this feature: import torch def example(): x = torch.tensor(1.0, requires_grad=True) y = torch.tensor(1e-8, requires_grad=True) # small to induce nan in n-th backward a = x * y b = x * y z1 = a / b # can produce nan in n-th backward as long as https://github.com/pytorch/pytorch/issues/43414 is unsolved z = z1 * z1 gy , = torch.autograd.grad( z , (y,), create_graph=True) gy2, = torch.autograd.grad(gy , (y,), create_graph=True) gy3, = torch.autograd.grad(gy2, (y,), create_graph=True) gy4, = torch.autograd.grad(gy3, (y,), create_graph=True) return gy4 with torch.autograd.detect_anomaly(): gy4 = example() with output: example.py:16: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging. with torch.autograd.detect_anomaly(): /home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py:190: UserWarning: Error detected in DivBackward0. Traceback of forward call that caused the error: File "example.py", line 17, in <module> gy4 = example() File "example.py", line 12, in example gy3, = torch.autograd.grad(gy2, (y,), create_graph=True) File "/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 190, in grad return Variable._execution_engine.run_backward( (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:61.) return Variable._execution_engine.run_backward( /home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py:190: UserWarning: Traceback of forward call that induces the previous calculation: File "example.py", line 17, in <module> gy4 = example() File "example.py", line 11, in example gy2, = torch.autograd.grad(gy , (y,), create_graph=True) File "/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 190, in grad return Variable._execution_engine.run_backward( (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:65.) return Variable._execution_engine.run_backward( /home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py:190: UserWarning: Traceback of forward call that induces the previous calculation: File "example.py", line 17, in <module> gy4 = example() File "example.py", line 8, in example z1 = a / b # can produce nan in n-th backward as long as https://github.com/pytorch/pytorch/issues/43414 is unsolved (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:65.) return Variable._execution_engine.run_backward( Traceback (most recent call last): File "example.py", line 17, in <module> gy4 = example() File "example.py", line 13, in example gy4, = torch.autograd.grad(gy3, (y,), create_graph=True) File "/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 190, in grad return Variable._execution_engine.run_backward( RuntimeError: Function 'DivBackward0' returned nan values in its 1th output. cc & thanks to albanD Pull Request resolved: https://github.com/pytorch/pytorch/pull/43626 Reviewed By: malfet Differential Revision: D23397499 Pulled By: albanD fbshipit-source-id: aa7435ec2a7f0d23a7a02ab7db751c198faf3b7d
114 lines
3.7 KiB
C++
114 lines
3.7 KiB
C++
#include <torch/csrc/autograd/python_anomaly_mode.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <pybind11/pybind11.h>
|
|
#include <torch/csrc/Exceptions.h>
|
|
#include <torch/csrc/autograd/python_cpp_function.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <torch/csrc/utils/auto_gil.h>
|
|
#include <torch/csrc/utils/object_ptr.h>
|
|
#include <torch/csrc/utils/python_strings.h>
|
|
|
|
#include <iostream>
|
|
|
|
namespace torch { namespace autograd {
|
|
|
|
void PyAnomalyMetadata::store_stack() {
|
|
pybind11::gil_scoped_acquire gil;
|
|
THPObjectPtr mod(PyImport_ImportModule("traceback"));
|
|
if (!mod) {
|
|
throw python_error();
|
|
}
|
|
|
|
THPObjectPtr list(PyObject_CallMethod(mod.get(), "format_stack", ""));
|
|
if (!list) {
|
|
throw python_error();
|
|
}
|
|
|
|
if (PyDict_SetItemString(dict(), ANOMALY_TRACE_KEY, list.get())) {
|
|
throw python_error();
|
|
}
|
|
}
|
|
|
|
void PyAnomalyMetadata::print_stack(const std::string& current_node_name) {
|
|
pybind11::gil_scoped_acquire gil;
|
|
if (!PyDict_Check(dict())) {
|
|
throw std::runtime_error("Anomaly metadata is not a python dictionary.");
|
|
}
|
|
PyObject* trace_stack = PyDict_GetItemString(dict(), ANOMALY_TRACE_KEY);
|
|
_print_stack(trace_stack, current_node_name, false);
|
|
PyObject* pyparent(PyDict_GetItemString(dict(), ANOMALY_PARENT_KEY));
|
|
|
|
// if there is no "parent_" in metadata, then it means this metadata's node
|
|
// is the root and stop printing the traceback
|
|
while (pyparent) {
|
|
PyObject* parent_metadata(PyObject_GetAttrString(pyparent, "metadata"));
|
|
if (!parent_metadata) {
|
|
throw python_error();
|
|
}
|
|
PyObject* parent_name_pyobj(PyObject_CallMethod(pyparent, "name", ""));
|
|
if (!parent_name_pyobj) {
|
|
throw python_error();
|
|
}
|
|
const char* parent_name_char = PyUnicode_AsUTF8(parent_name_pyobj);
|
|
if (!parent_name_char) {
|
|
throw python_error();
|
|
}
|
|
const std::string parent_name(parent_name_char);
|
|
PyObject* parent_stack = PyDict_GetItemString(parent_metadata, ANOMALY_TRACE_KEY);
|
|
_print_stack(parent_stack, parent_name, true);
|
|
// get the parent of this node, if this node is a root, pyparent is simply null
|
|
pyparent = PyDict_GetItemString(parent_metadata, ANOMALY_PARENT_KEY);
|
|
}
|
|
}
|
|
|
|
void PyAnomalyMetadata::assign_parent(const std::shared_ptr<Node>& parent_node) {
|
|
// assign the python object of parent_node in metadata["parent_"]
|
|
// if parent_node is nullptr, then do nothing (it can mean that "parent_" key
|
|
// is not in metadata)
|
|
|
|
pybind11::gil_scoped_acquire gil;
|
|
if (!parent_node) return;
|
|
|
|
PyObject* pyobj = functionToPyObject(parent_node);
|
|
if (!pyobj) {
|
|
throw python_error();
|
|
}
|
|
if (PyDict_SetItemString(dict(), ANOMALY_PARENT_KEY, pyobj)) {
|
|
throw python_error();
|
|
}
|
|
}
|
|
|
|
void _print_stack(PyObject* stack, const std::string& current_node_name, bool is_parent) {
|
|
if (!stack) {
|
|
TORCH_WARN("Error detected in ", current_node_name, ". ",
|
|
"No forward pass information available. Enable detect anomaly "
|
|
"during forward pass for more information.");
|
|
return;
|
|
}
|
|
|
|
THPObjectPtr empty_string(PyUnicode_FromString(""));
|
|
if (!empty_string) {
|
|
throw python_error();
|
|
}
|
|
|
|
// stack is a list of Python strings ending with newlines. Use join to convert
|
|
// to a single string.
|
|
THPObjectPtr msg(PyUnicode_Join(empty_string, stack));
|
|
if (!msg) {
|
|
throw python_error();
|
|
}
|
|
|
|
if (!is_parent) {
|
|
TORCH_WARN("Error detected in ", current_node_name, ". ",
|
|
"Traceback of forward call that caused the error:\n",
|
|
THPUtils_unpackString(msg.get()));
|
|
} else {
|
|
TORCH_WARN("\n\n",
|
|
"Previous calculation was induced by ", current_node_name, ". "
|
|
"Traceback of forward call that induced the previous calculation:\n",
|
|
THPUtils_unpackString(msg.get()));
|
|
}
|
|
}
|
|
|
|
}}
|