mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112326 Approved by: https://github.com/jansel ghstack dependencies: #112325
516 lines
18 KiB
C++
516 lines
18 KiB
C++
#include <torch/csrc/dynamo/python_compiled_autograd.h>
|
|
|
|
#include <torch/csrc/autograd/engine.h>
|
|
#include <torch/csrc/autograd/functions/accumulate_grad.h>
|
|
#include <torch/csrc/dynamo/compiled_autograd.h>
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <torch/csrc/utils/pythoncapi_compat.h>
|
|
#include <iostream>
|
|
#include <vector>
|
|
|
|
/*
|
|
[Note: Compiled Autograd]
|
|
|
|
Compiled autograd replaces the standard autograd engine by converting
|
|
the autograd graph to an FX graph that can be torch.compiled. It caches
|
|
this conversion using a shadow graph. We compare the new graph to the
|
|
shadow graph by walking the two graphs simultaneously and computing a
|
|
CacheKey for each original node to find the next edge in the shadow graph.
|
|
Two different graphs might have a shared common prefix in the shadow
|
|
graph, but then diverge at the first difference. Tensors, SavedVariables,
|
|
and SymInt found stored on the nodes in the autograd graph are lifted to
|
|
become inputs to the graph. All other properties (ints, floats, types,
|
|
etc.) are specialized using the CacheKey and will result in landing on
|
|
a different cache node in the shadow graph if some property differs.
|
|
|
|
To interact with the (hundreds) of different autograd::Node types,
|
|
we use a visitor pattern that walks each Node structure recursively.
|
|
|
|
- The first pass, compiled_args/collect, extracts all the inputs to the
|
|
graph and builds a CacheKey for us to specialize on. On a cache hit,
|
|
we stop here and this is the only pass.
|
|
|
|
- On a cache miss, a second pass kicks in to extract the FX graph using
|
|
apply_with_saved, which uses another visitor pattern. The before()
|
|
visitor swaps out all the Tensors, SavedVariables, and SymInt for
|
|
fake/symbolic versions to allow tracing. We then run the standard apply()
|
|
method, and after() restores things to how we found them.
|
|
|
|
When we see tensor hooks, we record them directly in the output graph
|
|
without tracing into them. We do this to avoid executing unsafe code
|
|
at trace time.
|
|
|
|
Notes:
|
|
- We require hooks to not change shapes of tensors.
|
|
- We require non-hook autograd nodes to be tracable.
|
|
*/
|
|
|
|
namespace torch::dynamo::autograd {
|
|
using c10::SymInt;
|
|
|
|
static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) {
|
|
PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(inputs.size()));
|
|
for (const auto i : c10::irange(inputs.size())) {
|
|
PyTuple_SET_ITEM(pyinput, i, PyLong_FromSsize_t(inputs[i]));
|
|
}
|
|
return pyinput;
|
|
}
|
|
|
|
static PyObject* convert_hook_list(std::vector<c10::SafePyObject>& inputs) {
|
|
// inplace, consumes the input hooks
|
|
PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(inputs.size()));
|
|
for (const auto i : c10::irange(inputs.size())) {
|
|
PyTuple_SET_ITEM(pyinput, i, inputs[i].release());
|
|
}
|
|
return pyinput;
|
|
}
|
|
|
|
static PyObject* check(PyObject* pyresult) {
|
|
if (C10_UNLIKELY(pyresult == nullptr)) {
|
|
// see https://github.com/pytorch/pytorch/pull/34845
|
|
python_error err;
|
|
err.persist();
|
|
throw err;
|
|
}
|
|
return pyresult;
|
|
}
|
|
|
|
static void check(bool result) {
|
|
if (C10_UNLIKELY(!result))
|
|
check(nullptr);
|
|
}
|
|
|
|
struct CacheNode {
|
|
// A node in the shadow graph, we follow next edges until we reach the end of
|
|
// the graph
|
|
static CacheNode* root() {
|
|
static CacheNode _root;
|
|
return &_root;
|
|
}
|
|
|
|
CacheNode* lookup(const CacheKey& key) {
|
|
auto it = next.find(key);
|
|
if (it == next.end()) {
|
|
// caller's key is in temporary memory, must copy it
|
|
CacheKeyBuffer buffer(key.key, key.key_size);
|
|
CacheKey key_with_storage(key.node_type, buffer.get(), key.key_size);
|
|
it = next.emplace(key_with_storage, std::make_unique<CacheNode>()).first;
|
|
key_storage.emplace_back(std::move(buffer));
|
|
}
|
|
return it->second.get();
|
|
}
|
|
|
|
void clear() {
|
|
next.clear();
|
|
key_storage.clear();
|
|
expected_sizes.clear();
|
|
compiled_fn = nullptr;
|
|
}
|
|
|
|
bool is_empty() const {
|
|
return next.empty() && !compiled_fn;
|
|
}
|
|
|
|
CacheNode() : compiled_fn(nullptr) {}
|
|
~CacheNode() {
|
|
if (!Py_IsInitialized()) {
|
|
compiled_fn.release(); // leak on shutdown
|
|
}
|
|
}
|
|
CacheNode(CacheNode&&) = delete;
|
|
CacheNode(const CacheNode&) = delete;
|
|
CacheNode& operator=(const CacheNode&) = delete;
|
|
CacheNode& operator=(CacheNode&&) = delete;
|
|
|
|
bool check_dynamic_sizes(AutogradCompilerCall& call) {
|
|
/*
|
|
We start off by assuming everything is static, then we mark things
|
|
as dynamic when we see them change. This function:
|
|
1) Checks for a cache hit
|
|
2) Updates expected_sizes to track what is dynamic
|
|
3) Populates call.dyn_size_inputs by filtering call.all_size_inputs
|
|
*/
|
|
bool cache_hit = compiled_fn.get() != nullptr;
|
|
auto len = call.all_size_inputs.size();
|
|
const SizeInput* data = call.all_size_inputs.data();
|
|
if (expected_sizes.empty()) {
|
|
expected_sizes.reserve(len);
|
|
for (const auto i : c10::irange(len)) {
|
|
expected_sizes.emplace_back(data[i]);
|
|
}
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(expected_sizes.size() == call.all_size_inputs.size());
|
|
for (const auto i : c10::irange(len)) {
|
|
auto& expected = expected_sizes[i];
|
|
if (expected.dyn_type == SizeInput::DYNAMIC ||
|
|
expected.value != data[i].value) {
|
|
cache_hit = cache_hit && expected.dyn_type == SizeInput::DYNAMIC;
|
|
if (expected.value != data[i].value) {
|
|
expected = SizeInput(SizeInput::DYNAMIC, data[i].value);
|
|
}
|
|
if (call.dyn_size_inputs.empty()) {
|
|
call.dyn_size_inputs.reserve(len);
|
|
}
|
|
call.dyn_size_inputs.emplace_back(data[i].value);
|
|
}
|
|
}
|
|
|
|
if (!cache_hit) {
|
|
// we missed cache because static size inputs didn't match; force
|
|
// recompilation with the varying size input as dynamic
|
|
compiled_fn = nullptr;
|
|
}
|
|
return cache_hit;
|
|
}
|
|
|
|
PyObject* wrap_dynamic_inputs() {
|
|
size_t dynamic_count = 0;
|
|
size_t idx = 0;
|
|
for (const auto& i : expected_sizes) {
|
|
if (i.dyn_type == SizeInput::DYNAMIC) {
|
|
++dynamic_count;
|
|
}
|
|
}
|
|
PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(dynamic_count));
|
|
for (const auto& i : expected_sizes) {
|
|
if (i.dyn_type == SizeInput::DYNAMIC) {
|
|
PyTuple_SET_ITEM(pyinput, idx++, PyLong_FromSsize_t(i.value));
|
|
}
|
|
}
|
|
TORCH_INTERNAL_ASSERT(idx == dynamic_count);
|
|
return pyinput;
|
|
}
|
|
|
|
std::vector<c10::optional<SymInt>> unwrap_dynamic_inputs(PyObject* pyresult) {
|
|
TORCH_INTERNAL_ASSERT(PyList_CheckExact(pyresult));
|
|
size_t idx = 0;
|
|
size_t result_len = PyList_GET_SIZE(pyresult);
|
|
std::vector<c10::optional<SymInt>> result;
|
|
result.reserve(expected_sizes.size());
|
|
for (const auto& i : expected_sizes) {
|
|
if (i.dyn_type == SizeInput::DYNAMIC) {
|
|
TORCH_INTERNAL_ASSERT(idx < result_len);
|
|
result.emplace_back(
|
|
py::cast<c10::SymInt>(PyList_GET_ITEM(pyresult, idx++)));
|
|
} else {
|
|
result.emplace_back();
|
|
}
|
|
}
|
|
TORCH_INTERNAL_ASSERT(
|
|
idx == result_len && result.size() == expected_sizes.size());
|
|
return result;
|
|
}
|
|
|
|
std::unordered_map<CacheKey, std::unique_ptr<CacheNode>> next;
|
|
std::vector<CacheKeyBuffer> key_storage;
|
|
std::vector<SizeInput> expected_sizes;
|
|
|
|
THPObjectPtr compiled_fn;
|
|
};
|
|
|
|
struct InputBuffers : public std::unordered_map<Node*, InputBuffer> {
|
|
InputBuffer& lookup(Node* function) {
|
|
auto it = find(function);
|
|
if (it == end()) {
|
|
it = emplace(function, InputBuffer(function->num_inputs())).first;
|
|
}
|
|
return it->second;
|
|
}
|
|
};
|
|
|
|
static PyObject* the_autograd_compiler = nullptr;
|
|
static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args);
|
|
|
|
static PyObject* clear_cache(PyObject* dummy, PyObject* args) {
|
|
HANDLE_TH_ERRORS;
|
|
CacheNode::root()->clear();
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS;
|
|
}
|
|
|
|
static PyObject* is_cache_empty(PyObject* dummy, PyObject* args) {
|
|
HANDLE_TH_ERRORS;
|
|
if (CacheNode::root()->is_empty()) {
|
|
Py_RETURN_TRUE;
|
|
}
|
|
Py_RETURN_FALSE;
|
|
END_HANDLE_TH_ERRORS;
|
|
}
|
|
|
|
static PyMethodDef _methods[] = {
|
|
{"set_autograd_compiler", set_autograd_compiler, METH_VARARGS, nullptr},
|
|
{"clear_cache", clear_cache, METH_NOARGS, nullptr},
|
|
{"is_cache_empty", is_cache_empty, METH_NOARGS, nullptr},
|
|
{nullptr, nullptr, 0, nullptr}};
|
|
|
|
static struct PyModuleDef _module = {
|
|
PyModuleDef_HEAD_INIT,
|
|
"torch._C._dynamo.autograd_compiler",
|
|
"Hooks for compiling autograd",
|
|
-1,
|
|
_methods};
|
|
|
|
static TraceState call_begin_capture(
|
|
PyObject* self,
|
|
CacheNode& cache,
|
|
AutogradCompilerCall& compiler_call,
|
|
size_t num_outputs) {
|
|
static PyObject* method_name = PyUnicode_InternFromString("begin_capture");
|
|
THPObjectPtr pyinput(THPVariable_WrapList(compiler_call.tensor_args.inputs));
|
|
THPObjectPtr pysizeinput(cache.wrap_dynamic_inputs());
|
|
THPObjectPtr pyresult(check(PyObject_CallMethodObjArgs(
|
|
self, method_name, pyinput.get(), pysizeinput.get(), nullptr)));
|
|
|
|
PyObject *fake_inputs{nullptr}, *fake_sizes{nullptr};
|
|
check(PyArg_ParseTuple(pyresult.get(), "OO", &fake_inputs, &fake_sizes));
|
|
|
|
variable_list proxy_inputs = THPVariable_UnpackList(fake_inputs);
|
|
TORCH_INTERNAL_ASSERT(
|
|
proxy_inputs.size() == compiler_call.tensor_args.inputs.size());
|
|
for (const auto i : c10::irange(proxy_inputs.size())) {
|
|
TensorArg& arg =
|
|
compiler_call.tensor_args.lookup(compiler_call.tensor_args.inputs[i]);
|
|
arg.proxy_tensor = proxy_inputs[i];
|
|
}
|
|
|
|
return TraceState(cache.unwrap_dynamic_inputs(fake_sizes), num_outputs);
|
|
}
|
|
|
|
static PyObject* call_end_capture(PyObject* self, const variable_list& inputs) {
|
|
static PyObject* method_name = PyUnicode_InternFromString("end_capture");
|
|
THPObjectPtr pyinput(THPVariable_WrapList(inputs));
|
|
return check(PyObject_CallMethodOneArg(self, method_name, pyinput.get()));
|
|
}
|
|
|
|
struct ClosingTHPObjectPtr : public THPObjectPtr {
|
|
ClosingTHPObjectPtr(PyObject* o) : THPObjectPtr(o) {}
|
|
~ClosingTHPObjectPtr() {
|
|
if (PyErr_Occurred()) {
|
|
// do nothing, do not attempt to close
|
|
return;
|
|
}
|
|
static PyObject* method_name = PyUnicode_InternFromString("close");
|
|
if (PyObject_CallMethodNoArgs(get(), method_name) == nullptr) {
|
|
PyErr_WriteUnraisable(get());
|
|
PyErr_Clear();
|
|
}
|
|
}
|
|
};
|
|
|
|
variable_list compiled_autograd(
|
|
const std::shared_ptr<Node>& graph_root,
|
|
GraphTask& graph_task,
|
|
bool accumulate_grad,
|
|
const edge_list& output_edges) {
|
|
TORCH_CHECK(
|
|
output_edges.empty() || !accumulate_grad,
|
|
"specifying inputs= with .backward() not yet implemented for compiled autograd")
|
|
TORCH_CHECK(
|
|
c10::impl::TorchDispatchModeTLS::stack_len() == 0,
|
|
"TorchDispatchMode not yet implemented for compiled autograd")
|
|
static std::mutex lock;
|
|
std::lock_guard<std::mutex> lock_guard(lock);
|
|
pybind11::gil_scoped_acquire gil;
|
|
at::ThreadLocalStateGuard tls_guard(graph_task.thread_locals_);
|
|
std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
|
|
std::vector<std::shared_ptr<Node>> worklist{graph_root};
|
|
AutogradCompilerCall compiler_call;
|
|
|
|
for (const auto i : c10::irange(output_edges.size())) {
|
|
compiler_call.node_calls.lookup(output_edges[i].function)
|
|
.mark_output(output_edges[i].input_nr, i);
|
|
}
|
|
const bool check_exec_info = !graph_task.exec_info_.empty();
|
|
CacheNode* cache = CacheNode::root();
|
|
std::vector<NodeCall*> calls;
|
|
calls.reserve(
|
|
check_exec_info ? graph_task.exec_info_.size() : dependencies.size() + 1);
|
|
|
|
while (!worklist.empty()) {
|
|
std::shared_ptr<Node> fn = std::move(worklist.back());
|
|
worklist.pop_back();
|
|
NodeCall& call = compiler_call.node_calls.lookup(fn);
|
|
calls.emplace_back(&call);
|
|
|
|
{ // update cache and gather args into `compiler_call`
|
|
CompiledNodeArgs node_args(compiler_call, call);
|
|
node_args.collect(call);
|
|
if (node_args.cond(call.needed)) {
|
|
fn->compiled_args(node_args);
|
|
node_args.collect(call.node->next_edges());
|
|
}
|
|
cache = cache->lookup(node_args.key());
|
|
}
|
|
|
|
for (const auto& edge : fn->next_edges()) {
|
|
if (!edge.is_valid()) {
|
|
continue;
|
|
}
|
|
if (check_exec_info) {
|
|
auto it = graph_task.exec_info_.find(edge.function.get());
|
|
if (it == graph_task.exec_info_.end() || !it->second.should_execute()) {
|
|
continue;
|
|
}
|
|
if (!it->second.needed_) {
|
|
compiler_call.node_calls.lookup(edge.function).needed = false;
|
|
}
|
|
}
|
|
auto it = dependencies.find(edge.function.get());
|
|
TORCH_INTERNAL_ASSERT(it != dependencies.end());
|
|
if (--it->second == 0) {
|
|
dependencies.erase(it);
|
|
worklist.emplace_back(edge.function);
|
|
}
|
|
}
|
|
}
|
|
|
|
// TODO(jansel): some dynamic sizes seem to be ints not symints
|
|
if (!cache->check_dynamic_sizes(compiler_call)) {
|
|
// cache miss, need to capture FX graph
|
|
ClosingTHPObjectPtr py_compiler(
|
|
check(PyObject_CallNoArgs((the_autograd_compiler))));
|
|
TraceState state = call_begin_capture(
|
|
py_compiler, *cache, compiler_call, output_edges.size());
|
|
InputBuffers input_buffers;
|
|
|
|
for (NodeCall* call_ptr : calls) {
|
|
NodeCall& call = *call_ptr;
|
|
// TODO(jansel): consider adding some of this stuff:
|
|
// guard(local_graph_task); NodeGuard ndguard(task.fn_); const auto
|
|
// opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
|
|
// c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
|
|
// CheckpointValidGuard cpvguard(graph_task);
|
|
// at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION);
|
|
// if (C10_UNLIKELY(step_callbacks.has_value())) { ... }
|
|
|
|
variable_list inputs = input_buffers.lookup(call.node.get()).buffer;
|
|
|
|
if (!call.tensor_pre_hooks.empty()) {
|
|
THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
|
|
for (const auto& hook : call.tensor_pre_hooks) {
|
|
pyinputs = check(PyObject_CallMethod(
|
|
py_compiler,
|
|
"tensor_pre_hook",
|
|
"Oii",
|
|
pyinputs.get(),
|
|
hook.first,
|
|
hook.second));
|
|
}
|
|
inputs = THPVariable_UnpackList(pyinputs);
|
|
}
|
|
for (const auto& graph_output : call.graph_output) {
|
|
int input_nr = graph_output.first;
|
|
int output_index = graph_output.second;
|
|
TORCH_INTERNAL_ASSERT(
|
|
output_index < static_cast<int>(state.outputs.size()));
|
|
TORCH_INTERNAL_ASSERT(!state.outputs[output_index].defined());
|
|
state.outputs[output_index] = inputs[input_nr];
|
|
}
|
|
if (!call.needed) {
|
|
continue;
|
|
}
|
|
if (!call.pre_hooks.empty()) {
|
|
THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
|
|
for (const auto hook : call.pre_hooks) {
|
|
pyinputs = check(PyObject_CallMethod(
|
|
py_compiler.get(), "pre_hook", "Oi", pyinputs.get(), hook));
|
|
}
|
|
inputs = THPVariable_UnpackList(pyinputs);
|
|
}
|
|
|
|
SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call);
|
|
variable_list outputs = call.node->apply_with_saved(inputs, saved);
|
|
|
|
saved.debug_asserts();
|
|
saved.before(call.node->next_edges());
|
|
validate_outputs(
|
|
call.node->next_edges(), outputs, [&](const std::string& msg) {
|
|
std::ostringstream ss;
|
|
ss << "[Compiled Autograd Tracing: " << call.node->name() << "] "
|
|
<< msg;
|
|
return ss.str();
|
|
});
|
|
saved.after(call.node->next_edges());
|
|
saved.debug_asserts();
|
|
|
|
if (!call.post_hooks.empty()) {
|
|
THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
|
|
THPObjectPtr pyoutputs(THPVariable_WrapList(outputs));
|
|
for (const auto hook : call.post_hooks) {
|
|
pyoutputs = check(PyObject_CallMethod(
|
|
py_compiler.get(),
|
|
"post_hook",
|
|
"OOi",
|
|
pyoutputs.get(),
|
|
pyinputs.get(),
|
|
hook));
|
|
}
|
|
outputs = THPVariable_UnpackList(pyoutputs);
|
|
}
|
|
for (const auto i : c10::irange(outputs.size())) {
|
|
auto& output = outputs[i];
|
|
const auto& next = call.node->next_edge(i);
|
|
if (next.is_valid() && output.defined()) {
|
|
input_buffers.lookup(next.function.get())
|
|
.add(
|
|
next.input_nr, std::move(output), c10::nullopt, c10::nullopt);
|
|
}
|
|
}
|
|
}
|
|
|
|
cache->compiled_fn = check(call_end_capture(py_compiler, state.outputs));
|
|
state.debug_asserts();
|
|
} // End cache miss region
|
|
|
|
// TODO(jansel): we should release all the variables and then use a
|
|
// boxed calling convention so activation memory can be freed
|
|
// TODO(jansel): clear grads we will overwrite below
|
|
if (!graph_task.keep_graph_) {
|
|
for (auto& call : calls) {
|
|
call->node->release_variables();
|
|
}
|
|
}
|
|
|
|
THPObjectPtr inputs(THPVariable_WrapList(compiler_call.tensor_args.inputs));
|
|
THPObjectPtr sizes(wrap_int_list(compiler_call.dyn_size_inputs));
|
|
THPObjectPtr hooks(convert_hook_list(compiler_call.hooks));
|
|
THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs(
|
|
cache->compiled_fn.get(), inputs.get(), sizes.get(), hooks.get(), NULL)));
|
|
variable_list outputs = THPVariable_UnpackList(pyresult);
|
|
TORCH_INTERNAL_ASSERT(outputs.size() == output_edges.size());
|
|
return outputs;
|
|
}
|
|
|
|
static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args) {
|
|
HANDLE_TH_ERRORS;
|
|
PyObject* obj;
|
|
if (!PyArg_ParseTuple(args, "O", &obj)) {
|
|
return nullptr;
|
|
}
|
|
|
|
PyObject* prior = the_autograd_compiler;
|
|
if (obj == Py_None) { // disable
|
|
the_autograd_compiler = nullptr; // decref not needed due to `prior`
|
|
Engine::set_compiled_autograd(nullptr);
|
|
} else { // enable
|
|
Py_INCREF(obj);
|
|
the_autograd_compiler = obj;
|
|
Engine::set_compiled_autograd(&compiled_autograd);
|
|
}
|
|
|
|
if (prior == nullptr) {
|
|
Py_RETURN_NONE;
|
|
} else {
|
|
return prior;
|
|
}
|
|
END_HANDLE_TH_ERRORS;
|
|
}
|
|
|
|
PyObject* torch_c_dynamo_compiled_autograd_init() {
|
|
return PyModule_Create(&_module);
|
|
}
|
|
|
|
} // namespace torch::dynamo::autograd
|