mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This is useful for splitting grad to run in two parts while preserving intermediates:
<details>
<summary>
Click to see code
</summary>
```python
import collections
import weakref
from torch.autograd.graph import GradientEdge
def _get_grad_fn_or_grad_acc(t):
if t.requires_grad and t.grad_fn is None:
return t.view_as(t).grad_fn.next_functions[0][0]
else:
return t.grad_fn
def reverse_closure(roots, target_nodes):
# Recurse until we reach a target node
closure = set()
actual_target_nodes = set()
q: Deque = collections.deque()
for node in roots:
if node is not None and node not in closure:
closure.add(node)
q.append(node)
while q:
node = q.popleft()
reverse_edges = node.metadata.get("reverse_edges", [])
for holder_ref, idx in reverse_edges:
ref = holder_ref()
if ref is not None:
raise RuntimeError("Reverse graph is no longer alive")
fn = ref.node
if fn in closure or fn is None:
continue
if fn in target_nodes:
actual_target_nodes.add(fn)
continue
closure.add(fn)
q.append(fn)
return closure, actual_target_nodes
# Enable weak pointer
class Holder():
def __init__(self, node):
self.node = node
# TODO: use weak references to avoid reference cycle
def construct_reverse_graph(roots):
q: Deque = collections.deque()
root_seen = set()
reverse_graph_refs = []
for node in roots:
if node is not None and node not in root_seen:
q.append(node)
root_seen.add(node)
while q:
node = q.popleft()
for fn, idx in node.next_functions:
if fn is not None:
# Don't necessarily need to store on the graph
reverse_edges = fn.metadata.get("reverse_edges", [])
if len(reverse_edges) == 0:
q.append(fn)
holder = Holder(node)
holder_ref = weakref.ref(holder)
reverse_graph_refs.append(holder)
reverse_edges.append((holder_ref, idx))
fn.metadata["reverse_edges"] = reverse_edges
return reverse_graph_refs
def get_param_groups(inputs, params):
inputs_closure, _ = reverse_closure(inputs, set())
param_groups = dict() # keyed on intermediates
for i, param in enumerate(params):
closure, intersected = reverse_closure([param], inputs_closure)
param_group = {
"params": set([param]),
"intermediates": set(intersected),
}
for input_node in intersected:
existing = param_groups.get(input_node, None)
if existing is not None:
existing["params"] = existing["params"].union(param_group["params"])
existing["intermediates"] = existing["intermediates"].union(param_group["intermediates"])
param_group = existing
else:
param_groups[input_node] = param_group
# Sanity check: union of all param_groups params should be equal to all params
union_params = set()
seen_ids = set()
unique_param_groups = []
for param_group in param_groups.values():
if id(param_group) not in seen_ids:
seen_ids.add(id(param_group))
unique_param_groups.append(param_group)
union_params = union_params.union(param_group["params"])
assert union_params == set(params)
return unique_param_groups
def compute_grads_only_inputs2(roots, inps, weights):
root_grad_fns = list(map(_get_grad_fn_or_grad_acc, roots))
inp_grad_fns = list(map(_get_grad_fn_or_grad_acc, inps))
weight_grad_fns = list(map(_get_grad_fn_or_grad_acc, weights))
reverse_graph_refs = construct_reverse_graph(root_grad_fns)
param_groups = get_param_groups(inp_grad_fns, weight_grad_fns)
del reverse_graph_refs
for param_group in param_groups:
for i, intermediate in enumerate(param_group["intermediates"]):
def get_hook(param_group, i):
def hook(grad_inputs):
if param_group.get("grads", None) is None:
param_group["grads"] = [None] * len(param_group["intermediates"])
param_group["grads"][i] = grad_inputs
return hook
# These are always "split" nodes that we need to recompute, so
# save their inputs.
intermediate.register_prehook(get_hook(param_group, i))
dinputs = torch.autograd.grad((out,), inputs=tuple(inps), grad_outputs=(torch.ones_like(out),), retain_graph=True)
return dinputs, param_groups
def compute_grads_only_weights2(user_weights, param_groups):
all_dweights = dict()
for param_group in param_groups:
# TODO: Handle case where intermediate can have multiple outputs
intermediate_edges = tuple(GradientEdge(i, 0) for i in param_group["intermediates"])
weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
assert all(len(g) == 1 for g in param_group["grads"])
# [NEW!] Able to pass a GradientEdge to autograd.grad as output
# We do not need to retain_graph because... guarantee no overlap?
print("trying to execute: ", intermediate_edges, weights_edges)
dweights = torch.autograd.grad(intermediate_edges, weights_edges, grad_outputs=sum(param_group["grads"], tuple()))
for w, dw in zip(param_group["params"], dweights):
all_dweights[w] = dw
# return grads in the original order weights were provided in
out = []
for w in user_weights:
grad_acc = _get_grad_fn_or_grad_acc(w)
out.append(all_dweights[grad_acc])
return tuple(out)
```
</details>
```python
import torch.nn as nn
# Setup
mod1 = nn.Linear(10, 10)
mod2 = nn.Linear(10, 10)
a = torch.rand(10, requires_grad=True)
weights = tuple(mod1.parameters()) + tuple(mod2.parameters())
inps = (a,)
out = mod2(mod1(a))
class LoggingTensorMode(torch.utils._python_dispatch.TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
rs = func(*args, **kwargs)
print(f"{func.__module__}.{func.__name__}")
return rs
print(" -- SPLIT -- ")
# Compute gradients in two parts
with LoggingTensorMode():
print("PART 1")
dinputs, state = compute_grads_only_inputs2((out,), inps, weights)
print("PART 2")
dweights = compute_grads_only_weights2(weights, state)
out = mod2(mod1(a))
print(" -- REF -- ")
# Compare with reference
with LoggingTensorMode():
ref_all_gradients = torch.autograd.grad(out, inputs=tuple(inps) + weights, grad_outputs=(torch.ones_like(out),))
for actual, ref in zip(dinputs + dweights, ref_all_gradients):
print(torch.allclose(actual, ref))
```
<img width="598" alt="image" src="https://github.com/pytorch/pytorch/assets/13428986/3681b8a7-3ab4-4d1d-a836-abef6913e671">
```
PART 1
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.ones_like.default
V0603 10:17:21.590878 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1ee160> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591204 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591578 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x100d7ae50> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
V0603 10:17:21.591747 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a60> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591834 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591922 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a90> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
PART 2
trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1e4bb0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a21b130>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b7c0>, output_nr=0))
V0603 10:17:21.592223 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.592421 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a1cad60> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1ee0d0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a1e41c0>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b670>, output_nr=0))
V0603 10:17:21.593481 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.593750 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a21b2b0> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127766
Approved by: https://github.com/albanD
524 lines
18 KiB
C++
524 lines
18 KiB
C++
#include <torch/csrc/autograd/python_engine.h>
|
|
|
|
#include <ATen/LegacyBatchedTensorImpl.h>
|
|
#include <ATen/LegacyVmapMode.h>
|
|
#include <c10/util/irange.h>
|
|
#include <pybind11/pybind11.h>
|
|
#include <torch/csrc/DynamicTypes.h>
|
|
#include <torch/csrc/THP.h>
|
|
#include <torch/csrc/autograd/edge.h>
|
|
#include <torch/csrc/autograd/engine.h>
|
|
#include <torch/csrc/autograd/function.h>
|
|
#include <torch/csrc/autograd/functions/basic_ops.h>
|
|
#include <torch/csrc/autograd/python_anomaly_mode.h>
|
|
#include <torch/csrc/autograd/python_cpp_function.h>
|
|
#include <torch/csrc/autograd/python_function.h>
|
|
#include <torch/csrc/autograd/python_saved_variable_hooks.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
#include <torch/csrc/utils/pycfunction_helpers.h>
|
|
|
|
#ifndef _WIN32
|
|
#include <pthread.h>
|
|
#endif
|
|
|
|
#include <memory> // for unique_ptr
|
|
#include <utility>
|
|
|
|
using namespace torch::autograd;
|
|
|
|
struct THPEngine {
|
|
PyObject_HEAD
|
|
};
|
|
|
|
static bool _reinitialize_engine = false;
|
|
|
|
namespace torch {
|
|
namespace autograd {
|
|
namespace python {
|
|
|
|
PythonEngine::PythonEngine() = default;
|
|
|
|
Engine& PythonEngine::get_python_engine() {
|
|
static PythonEngine engine;
|
|
// This is "probably" thread-safe because the flag is set in a fork handler
|
|
// before any threads are created, and this function is only called with the
|
|
// GIL held. However, using fork + threads is playing with fire so this is
|
|
// more of a "best effort" thing. For example, if the fork occurs while the
|
|
// backwards threads hold a lock, we'll probably deadlock in the engine
|
|
// destructor.
|
|
if (_reinitialize_engine) {
|
|
engine.release_workers();
|
|
engine.~PythonEngine();
|
|
new (&engine) torch::autograd::python::PythonEngine();
|
|
_reinitialize_engine = false;
|
|
}
|
|
return engine;
|
|
}
|
|
|
|
PythonEngine::~PythonEngine() {
|
|
Engine::stop();
|
|
}
|
|
|
|
#if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION >= 9
|
|
#define IS_PYTHON_3_9_PLUS
|
|
#endif
|
|
|
|
void PythonEngine::thread_init(
|
|
int device,
|
|
const std::shared_ptr<ReadyQueue>& ready_queue,
|
|
bool should_increment) {
|
|
// Increment thread usage count before acquiring the GIL
|
|
if (should_increment) {
|
|
increment_non_reentrant_thread_count();
|
|
}
|
|
// Create a PyThreadState, but release the GIL. This lets
|
|
// pybind11::gil_scoped_acquire calls inside thread_main acquire the GIL
|
|
// without having to create a new PyThreadState each time.
|
|
#if defined(IS_PYTHON_3_9_PLUS)
|
|
auto gil = std::make_unique<pybind11::gil_scoped_acquire>();
|
|
#else
|
|
pybind11::gil_scoped_acquire gil;
|
|
#endif
|
|
pybind11::gil_scoped_release no_gil;
|
|
Engine::thread_init(device, ready_queue, false);
|
|
|
|
if (should_increment) {
|
|
// Decrement the count during shutdown if we incremented earlier.
|
|
decrement_non_reentrant_thread_count();
|
|
}
|
|
|
|
#if defined(IS_PYTHON_3_9_PLUS)
|
|
// Do not call PyEval_RestoreThread, PyThreadState_[Clear|DeleteCurrent] if
|
|
// runtime is finalizing
|
|
if (!Py_IsInitialized()) {
|
|
no_gil.disarm();
|
|
// TODO: call disarm once PyThreadState_Clear can safely be called from
|
|
// finalize NOTE: deploy.cpp calls `PyInterpreterState_Delete` to destruct
|
|
// PyThreadState, so avoid use-after-free here.
|
|
auto ptr = gil.release();
|
|
operator delete(ptr);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
void PythonEngine::thread_on_exception(
|
|
std::shared_ptr<GraphTask> graph_task,
|
|
const std::shared_ptr<Node>& fn,
|
|
std::exception& e) {
|
|
// See Note [ Persisting PyErr state across autograd engine threads ]
|
|
auto python_err = dynamic_cast<python_error*>(&e);
|
|
if (python_err) {
|
|
python_err->persist();
|
|
}
|
|
Engine::thread_on_exception(std::move(graph_task), fn, e);
|
|
}
|
|
|
|
std::unique_ptr<AnomalyMetadata> PythonEngine::make_anomaly_metadata() {
|
|
return std::make_unique<PyAnomalyMetadata>();
|
|
}
|
|
|
|
std::unique_ptr<SavedVariableHooks> PythonEngine::
|
|
get_default_saved_variable_hooks() {
|
|
return PyDefaultSavedVariableHooks::get_hooks();
|
|
}
|
|
|
|
variable_list PythonEngine::execute(
|
|
const edge_list& roots,
|
|
const variable_list& inputs,
|
|
bool keep_graph,
|
|
bool create_graph,
|
|
bool accumulate_grad,
|
|
const edge_list& outputs) {
|
|
TORCH_CHECK(
|
|
!PyGILState_Check(),
|
|
"The autograd engine was called while holding the GIL. If you are using the C++ "
|
|
"API, the autograd engine is an expensive operation that does not require the "
|
|
"GIL to be held so you should release it with 'pybind11::gil_scoped_release no_gil;'"
|
|
". If you are not using the C++ API, please report a bug to the pytorch team.")
|
|
try {
|
|
return Engine::execute(
|
|
roots, inputs, keep_graph, create_graph, accumulate_grad, outputs);
|
|
} catch (python_error& e) {
|
|
e.restore();
|
|
throw;
|
|
}
|
|
}
|
|
|
|
c10::intrusive_ptr<at::ivalue::Future> PythonEngine::execute_with_graph_task(
|
|
const std::shared_ptr<GraphTask>& graph_task,
|
|
std::shared_ptr<Node> graph_root,
|
|
InputBuffer&& input_buffer) {
|
|
try {
|
|
return Engine::execute_with_graph_task(
|
|
graph_task, std::move(graph_root), std::move(input_buffer));
|
|
} catch (python_error& e) {
|
|
pybind11::gil_scoped_acquire gil;
|
|
if (!PyErr_Occurred()) {
|
|
// Set the error indicator only if it is not set already.
|
|
e.restore();
|
|
}
|
|
throw;
|
|
}
|
|
}
|
|
} // namespace python
|
|
} // namespace autograd
|
|
} // namespace torch
|
|
|
|
PyObject* THPEngineClass = nullptr;
|
|
|
|
inline static Edge parseGradientEdge(PyObject* obj, int64_t index) {
|
|
PyObject* grad_fn = PyTuple_GetItem(obj, 0);
|
|
auto output_nr = THPUtils_unpackLong(PyTuple_GetItem(obj, 1));
|
|
std::shared_ptr<torch::autograd::Node> grad_fn_sp;
|
|
if (THPFunction_Check(grad_fn)) {
|
|
grad_fn_sp = ((THPFunction*)grad_fn)->cdata.lock();
|
|
} else if (THPCppFunction_Check(grad_fn)) {
|
|
grad_fn_sp = ((THPCppFunction*)grad_fn)->cdata;
|
|
} else {
|
|
TORCH_CHECK(
|
|
false,
|
|
"GradientEdge's first object must be an autograd.graph.Node "
|
|
"but got ",
|
|
THPUtils_typename(grad_fn));
|
|
}
|
|
return Edge(grad_fn_sp, output_nr);
|
|
}
|
|
|
|
// Implementation of torch._C._EngineBase.run_backward
|
|
PyObject* THPEngine_run_backward(
|
|
PyObject* self,
|
|
PyObject* args,
|
|
PyObject* kwargs) {
|
|
HANDLE_TH_ERRORS
|
|
PyObject* tensors = nullptr;
|
|
PyObject* grad_tensors = nullptr;
|
|
unsigned char keep_graph = 0;
|
|
unsigned char create_graph = 0;
|
|
PyObject* inputs = nullptr;
|
|
unsigned char allow_unreachable = 0;
|
|
unsigned char accumulate_grad =
|
|
0; // Indicate whether to accumulate grad into leaf Tensors or capture
|
|
constexpr const char* accepted_kwargs[] = {// NOLINT
|
|
"tensors",
|
|
"grad_tensors",
|
|
"keep_graph",
|
|
"create_graph",
|
|
"inputs",
|
|
"allow_unreachable",
|
|
"accumulate_grad",
|
|
nullptr};
|
|
if (!PyArg_ParseTupleAndKeywords(
|
|
args,
|
|
kwargs,
|
|
"OObb|Obb",
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,-warnings-as-errors)
|
|
const_cast<char**>(accepted_kwargs),
|
|
&tensors,
|
|
&grad_tensors,
|
|
&keep_graph,
|
|
&create_graph,
|
|
&inputs,
|
|
&allow_unreachable,
|
|
&accumulate_grad))
|
|
return nullptr;
|
|
TORCH_CHECK(
|
|
PyTuple_Check(tensors),
|
|
"tensors argument is expected to "
|
|
"be a tuple, but got ",
|
|
THPUtils_typename(tensors));
|
|
TORCH_CHECK(
|
|
PyTuple_Check(grad_tensors),
|
|
"grad_tensors argument is "
|
|
"expected to be a tuple, but got ",
|
|
THPUtils_typename(grad_tensors));
|
|
|
|
Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors);
|
|
Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_tensors);
|
|
TORCH_CHECK(
|
|
num_tensors == num_gradients,
|
|
"got ",
|
|
num_tensors,
|
|
" tensors and ",
|
|
num_gradients,
|
|
" gradients");
|
|
|
|
// The user either called autograd.backward(...) or autograd.grad(...) to get
|
|
// here
|
|
bool backward_api_called = accumulate_grad;
|
|
TORCH_CHECK(
|
|
!backward_api_called || at::impl::VmapMode::current_vmap_level() == 0,
|
|
"backward() called inside torch.vmap. This is not supported, "
|
|
"please call backward() outside torch.vmap or instead use "
|
|
"torch.autograd.grad inside torch.vmap");
|
|
|
|
edge_list roots;
|
|
roots.reserve(num_tensors);
|
|
variable_list grads;
|
|
grads.reserve(num_tensors);
|
|
for (const auto i : c10::irange(num_tensors)) {
|
|
PyObject* _tensor = PyTuple_GET_ITEM(tensors, i);
|
|
Edge gradient_edge; // Temporary variable to hold the gradient edge
|
|
c10::optional<at::Tensor> mb_output;
|
|
if (THPVariable_Check(_tensor)) {
|
|
mb_output = THPVariable_Unpack(_tensor);
|
|
TORCH_CHECK(
|
|
!isBatchedTensor(mb_output.value()),
|
|
"torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
|
|
"torch.vmap. We do not support the case where any outputs are ",
|
|
"vmapped tensors (output ",
|
|
i,
|
|
" is being vmapped over). Please "
|
|
"call autograd.grad() outside torch.vmap or file a bug report "
|
|
"with your use case.");
|
|
gradient_edge = torch::autograd::impl::gradient_edge(mb_output.value());
|
|
} else if (PyObject_IsInstance(_tensor, THPGradientEdgeClass)) {
|
|
gradient_edge = parseGradientEdge(_tensor, i);
|
|
} else {
|
|
TORCH_CHECK(
|
|
false,
|
|
"element ",
|
|
i,
|
|
" of tensors tuple is neither a Tensor nor a GradientEdge");
|
|
}
|
|
TORCH_CHECK(
|
|
gradient_edge.function,
|
|
"element ",
|
|
i,
|
|
" of tensors does not require grad and does not have a grad_fn");
|
|
roots.push_back(std::move(gradient_edge));
|
|
|
|
PyObject* grad = PyTuple_GET_ITEM(grad_tensors, i);
|
|
if (THPVariable_Check(grad)) {
|
|
const Variable& grad_var = THPVariable_Unpack(grad);
|
|
if (grad_var.has_names()) {
|
|
TORCH_WARN(
|
|
"Autograd was passed a named grad tensor with dims ",
|
|
grad_var.names(),
|
|
". Autograd does not yet support named tensor semantics, so all names ",
|
|
"will be ignored. In practice all computed gradients will still be correct "
|
|
"according to regular tensor semantics.");
|
|
}
|
|
grads.push_back(grad_var);
|
|
} else {
|
|
TORCH_CHECK(
|
|
grad == Py_None,
|
|
"element ",
|
|
i,
|
|
" of gradients tuple is not a Tensor or None");
|
|
TORCH_CHECK(
|
|
mb_output.has_value(),
|
|
"element ",
|
|
i,
|
|
" of gradients tuple is None, but the corresponding output is a GradientEdge."
|
|
"This is not supported.");
|
|
TORCH_CHECK(
|
|
!mb_output.value().requires_grad(),
|
|
"element ",
|
|
i,
|
|
" of gradients tuple is None, but the corresponding Tensor requires grad");
|
|
}
|
|
}
|
|
|
|
std::vector<Edge> output_edges;
|
|
if (inputs != nullptr) {
|
|
TORCH_CHECK(
|
|
PyTuple_CheckExact(inputs), "inputs to run_backward must be a tuple");
|
|
int num_inputs = PyTuple_GET_SIZE(inputs);
|
|
output_edges.reserve(num_inputs);
|
|
for (const auto i : c10::irange(num_inputs)) {
|
|
PyObject* input = PyTuple_GET_ITEM(inputs, i);
|
|
if (THPVariable_Check(input)) {
|
|
const auto& tensor = THPVariable_Unpack(input);
|
|
TORCH_CHECK(
|
|
!isBatchedTensor(tensor),
|
|
"torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
|
|
"torch.vmap. We do not support the case where any inputs are ",
|
|
"vmapped tensors (input ",
|
|
i,
|
|
" is being vmapped over). Please "
|
|
"call autograd.grad() outside torch.vmap or file a bug report "
|
|
"with your use case.")
|
|
const auto output_nr = tensor.output_nr();
|
|
auto grad_fn = tensor.grad_fn();
|
|
if (!grad_fn) {
|
|
grad_fn = torch::autograd::impl::try_get_grad_accumulator(tensor);
|
|
}
|
|
if (accumulate_grad) {
|
|
tensor.retain_grad();
|
|
}
|
|
TORCH_CHECK(
|
|
tensor.requires_grad(),
|
|
"One of the differentiated Tensors does not require grad");
|
|
if (!grad_fn) {
|
|
// NOTE [ Autograd Unreachable Input ]
|
|
// Since input has no grad_accumulator, its guaranteed to be
|
|
// unreachable. We initialize an edge pointing to a non-nullptr Node
|
|
// so nodes in the graph (e.g., mul when an operand is scalar) that
|
|
// have edges pointing to nullptr don't get erroneously assigned
|
|
// `needed = True` in exec_info.
|
|
output_edges.emplace_back(std::make_shared<Identity>(), 0);
|
|
} else {
|
|
output_edges.emplace_back(grad_fn, output_nr);
|
|
}
|
|
} else if (PyObject_IsInstance(input, THPGradientEdgeClass)) {
|
|
output_edges.emplace_back(parseGradientEdge(input, i));
|
|
} else {
|
|
TORCH_CHECK(
|
|
false,
|
|
"all inputs have to be Tensors or GradientEdges, but got ",
|
|
THPUtils_typename(input));
|
|
}
|
|
}
|
|
}
|
|
|
|
variable_list outputs;
|
|
{
|
|
pybind11::gil_scoped_release no_gil;
|
|
auto& engine = python::PythonEngine::get_python_engine();
|
|
outputs = engine.execute(
|
|
roots, grads, keep_graph, create_graph, accumulate_grad, output_edges);
|
|
}
|
|
|
|
if (!backward_api_called && inputs != nullptr) {
|
|
int num_inputs = PyTuple_GET_SIZE(inputs);
|
|
THPObjectPtr py_outputs{PyTuple_New(num_inputs)};
|
|
if (!py_outputs)
|
|
return nullptr;
|
|
for (const auto i : c10::irange(num_inputs)) {
|
|
TORCH_CHECK(
|
|
allow_unreachable || outputs[i].defined(),
|
|
"One of the "
|
|
"differentiated Tensors appears to not have been used "
|
|
"in the graph. Set allow_unused=True if this is the "
|
|
"desired behavior.");
|
|
PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i]));
|
|
}
|
|
return py_outputs.release();
|
|
} else {
|
|
Py_RETURN_NONE;
|
|
}
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject* THPEngine_queue_callback(PyObject* self, PyObject* _callback) {
|
|
HANDLE_TH_ERRORS
|
|
auto& engine = python::PythonEngine::get_python_engine();
|
|
std::shared_ptr<PyObject> callback(_callback, [](PyObject* obj) {
|
|
pybind11::gil_scoped_acquire gil;
|
|
Py_DECREF(obj);
|
|
});
|
|
Py_INCREF(_callback);
|
|
engine.queue_callback([callback]() {
|
|
pybind11::gil_scoped_acquire gil;
|
|
THPObjectPtr result{PyObject_CallFunctionObjArgs(callback.get(), nullptr)};
|
|
if (!result) {
|
|
// Note [ Persisting PyErr state across autograd engine threads ]
|
|
//
|
|
// Since the autograd engine is multi-threaded, and Python error state is
|
|
// local to each thread, it must preserve the python error from the worker
|
|
// thread and rethrow it as-is in the calling thread. This is done via
|
|
// persisting the error in the two places that can encounter Python
|
|
// errors: (1) evaluate function and (2) queued callbacks.
|
|
//
|
|
// TODO: the engine is not actually responsible for persisting the error
|
|
// in the custom autograd Function case today! See the note above
|
|
// `raise_python_error()` function in python_function.cpp and
|
|
// python_hooks.cpp for more details. Persisting an extra time in the
|
|
// engine is fine because doing so is a no-op when the python_error has
|
|
// already been persisted.
|
|
python_error err;
|
|
err.persist();
|
|
throw std::move(err);
|
|
}
|
|
});
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject* THPEngine_is_checkpoint_valid(PyObject* self, PyObject* noargs) {
|
|
HANDLE_TH_ERRORS
|
|
auto& engine = python::PythonEngine::get_python_engine();
|
|
if (engine.is_checkpoint_valid()) {
|
|
Py_RETURN_TRUE;
|
|
} else {
|
|
Py_RETURN_FALSE;
|
|
}
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject* THPEngine_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
|
|
return type->tp_alloc(type, 0);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
|
|
static struct PyMethodDef THPEngine_methods[] = {
|
|
{(char*)"run_backward",
|
|
castPyCFunctionWithKeywords(THPEngine_run_backward),
|
|
METH_VARARGS | METH_KEYWORDS,
|
|
nullptr},
|
|
{(char*)"queue_callback", THPEngine_queue_callback, METH_O, nullptr},
|
|
{(char*)"is_checkpoint_valid",
|
|
THPEngine_is_checkpoint_valid,
|
|
METH_NOARGS,
|
|
nullptr},
|
|
{nullptr}};
|
|
|
|
PyTypeObject THPEngineType = {
|
|
PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._EngineBase", /* tp_name */
|
|
sizeof(THPEngine), /* tp_basicsize */
|
|
0, /* tp_itemsize */
|
|
nullptr, /* tp_dealloc */
|
|
0, /* tp_vectorcall_offset */
|
|
nullptr, /* tp_getattr */
|
|
nullptr, /* tp_setattr */
|
|
nullptr, /* tp_reserved */
|
|
nullptr, /* tp_repr */
|
|
nullptr, /* tp_as_number */
|
|
nullptr, /* tp_as_sequence */
|
|
nullptr, /* tp_as_mapping */
|
|
nullptr, /* tp_hash */
|
|
nullptr, /* tp_call */
|
|
nullptr, /* tp_str */
|
|
nullptr, /* tp_getattro */
|
|
nullptr, /* tp_setattro */
|
|
nullptr, /* tp_as_buffer */
|
|
// NOLINTNEXTLINE(misc-redundant-expression)
|
|
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
|
|
nullptr, /* tp_doc */
|
|
nullptr, /* tp_traverse */
|
|
nullptr, /* tp_clear */
|
|
nullptr, /* tp_richcompare */
|
|
0, /* tp_weaklistoffset */
|
|
nullptr, /* tp_iter */
|
|
nullptr, /* tp_iternext */
|
|
THPEngine_methods, /* tp_methods */
|
|
nullptr, /* tp_members */
|
|
nullptr, /* tp_getset */
|
|
nullptr, /* tp_base */
|
|
nullptr, /* tp_dict */
|
|
nullptr, /* tp_descr_get */
|
|
nullptr, /* tp_descr_set */
|
|
0, /* tp_dictoffset */
|
|
nullptr, /* tp_init */
|
|
nullptr, /* tp_alloc */
|
|
THPEngine_new /* tp_new */
|
|
};
|
|
|
|
static void child_atfork() {
|
|
_reinitialize_engine = true;
|
|
}
|
|
|
|
bool THPEngine_initModule(PyObject* module) {
|
|
#ifndef _WIN32
|
|
if (pthread_atfork(nullptr, nullptr, child_atfork) != 0) {
|
|
throw std::runtime_error("unable to set pthread_atfork handler");
|
|
}
|
|
#endif
|
|
if (PyType_Ready(&THPEngineType) < 0)
|
|
return false;
|
|
Py_INCREF(&THPEngineType);
|
|
PyModule_AddObject(module, "_ImperativeEngine", (PyObject*)&THPEngineType);
|
|
set_default_engine_stub(python::PythonEngine::get_python_engine);
|
|
return true;
|
|
}
|