Miscellaneous documentation around autograd. (#1577)

* Miscellaneous documentation around autograd.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
This commit is contained in:
Edward Z. Yang 2017-05-17 19:19:24 -04:00 committed by Soumith Chintala
parent b8b7f879c2
commit 1f3ff5ced2
9 changed files with 103 additions and 6 deletions

View File

@ -30,7 +30,7 @@ class _ContextMethodMixin(object):
:func:`forward` **method, and all arguments should be inputs.**
Every tensor that's been modified in-place in a call to :func:`forward`
should be given to this function, to ensure correcness of our checks.
should be given to this function, to ensure correctness of our checks.
It doesn't matter wheter the function is called before or after
modification.
"""

View File

@ -0,0 +1,30 @@
## Autograd
Autograd is a hotspot for PyTorch performance, so most of the heavy lifting is
implemented in C++. This implies that we have to do some shuffling between
Python and C++; and in general, we want data to be in a form that is convenient
to manipulate from C++.
Our general model is that for any key data type that autograd manipulates,
there are two implementations: a C++ type and a Python object type. For
example, consider variables in autograd: we have both `Variable` in `variable.h`
(the C++ type) and `THPVariable` in `python_variable.h` (the Python type.)
(By the way, THP stands for TorcH Python, not to be confused with THPP, TorcH
C++). `Variable` contains the payload of a variable, while `THPVariable` just
contains a `shared_ptr` reference to `Variable`, as well as references to other
Python objects which the Python runtime needs to know about. A lot of
data accessor implementations in `python_variable.cpp` simply reach through
to the underlying `Variable` and return the appropriate value.
The most complicated application of this principle is Function, which also
supports users implementing custom behavior in Python. We have the following
classes:
* `Function` in `function.h`, the C++ type.
* `THPFunction` in `python_function.h`, the Python object type
* `PyFunction` in `python_function.h`, a subclass of `Function` which forwards
`apply` to a Python `THPFunction`.
Outside of `PyFunction`, the C++ objects largely avoid referencing Python
objects (there are a few exceptions, like `pyobj` in `Variable`, and
`PyFunction`, whose whole point is to let C++ call into Python).

View File

@ -52,12 +52,16 @@ struct ReadyQueue {
struct GraphTask {
std::exception_ptr exception;
// Indicates if an error occurred while executing any task. When this is
// true, it signals all threads to stop executing.
std::atomic_bool has_error;
std::atomic<uint64_t> outstanding_tasks;
bool keep_graph;
bool has_any_work;
std::mutex mutex;
// Notified when a task finishes executing. Check outstanding_tasks to see
// if all tasks are done.
std::condition_variable not_done;
const Engine::callback_map& function_callbacks;
std::unordered_map<Function*, InputBuffer> not_ready;

View File

@ -31,7 +31,7 @@ struct Engine {
using callback_map = std::unordered_map<Function*, callback_type>;
// Given a list of (Function, int) pairs computes the value of the graph
// Given a list of (Function, input number) pairs computes the value of the graph
// by following next_function references.
void execute(
const function_list& roots,

View File

@ -23,8 +23,14 @@ using function_list = std::vector<std::pair<std::shared_ptr<Function>, int>>;
// State used to create "backward" functions
struct FunctionFlags {
// Roughly speaking, is_executable corresponds to requires_grad.
// See http://pytorch.org/docs/notes/autograd.html for more details:
// both is_executable and is_volatile specify whether or not backwards
// gradient computation will be performed for a function, but they differ in
// their precedence.
bool is_executable = false;
bool is_volatile = false;
// What functions take the output of this function as input.
function_list next_functions;
};

View File

@ -26,6 +26,11 @@ inline tensor_list as_tensor_list(Args&& ... args) {
}
/**
* Wraps the tensor outputs in variables, and if necessary (i.e., none of the
* inputs are volatile), uses the function ctr and inputs to create a grad_fn
* for each of them.
*/
variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
function_constructor ctr);

View File

@ -33,6 +33,10 @@ PyObject *THPStochasticFunctionClass = NULL;
#define THPFunction_assert(condition, ...) \
if (!(condition)) { THPUtils_setError(__VA_ARGS__); throw python_error(); }
/**
* Cast an object into a tuple, if it is not a tuple already. Returns true
* if the original object was not a tuple.
*/
static bool _ensure_tuple(THPObjectPtr& obj)
{
if (PyTuple_Check(obj.get()))
@ -45,6 +49,9 @@ static bool _ensure_tuple(THPObjectPtr& obj)
return true;
}
/**
* Call into Python to allocate and zero a tensor as per info.
*/
static PyObject* _allocate_grad_output(output_info_type& info, AutoGPU& gpu_guard)
{
// TODO: no need to do this for non-differentiable outputs
@ -111,7 +118,9 @@ auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list {
});
}
// NOTE: this function is written in a way that assumes it's only called for backward
// NOTE: this function is written in a way that assumes it's only called for backward;
// it's used by engine.cpp. This is responsible for forwarding a call from
// C++'s Function::apply to a Python method "apply".
auto PyFunction::apply(const variable_list& inputs) -> variable_list {
AutoGIL gil;
AutoGPU _gpu_guard(-1);
@ -122,6 +131,7 @@ auto PyFunction::apply(const variable_list& inputs) -> variable_list {
return legacy_apply(inputs);
}
// Massage a C++ variable_list into a Python arguments tuple
auto num_inputs = inputs.size();
THPObjectPtr pyInputs = PyTuple_New(num_inputs);
if (!pyInputs) throw python_error();
@ -148,7 +158,8 @@ auto PyFunction::apply(const variable_list& inputs) -> variable_list {
auto& is_variable_input = *py_fn->is_variable_input;
int num_outputs = PyTuple_GET_SIZE(r.get());
int num_forward_inputs = is_variable_input.size();
// Returning too many results is ok, but only as long as they're all None
// Returning too many results is ok, but only as long as they're all None.
// Truncate the result tuple in that case.
if (num_outputs > num_forward_inputs) {
bool all_none = true;
for (int i = num_forward_inputs; i < num_outputs; i++) {
@ -170,6 +181,7 @@ auto PyFunction::apply(const variable_list& inputs) -> variable_list {
throw std::runtime_error(msg);
}
// Massage the Python results tuple back into a C++ variable_list
variable_list results;
results.reserve(num_outputs);
for (int i = 0; i != num_outputs; ++i) {
@ -296,6 +308,8 @@ PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
using t2var_type = std::unordered_map<PyObject *, THPVariable *>;
// Bump the counters of all recorded dirty input tensors, adding each of them
// into dirty_inputs. Also does some sanity checking.
static void _mark_dirty(THPFunction *self, t2var_type &t2var,
std::unordered_set<PyObject *> &dirty_inputs)
{
@ -353,6 +367,16 @@ static void _transplant_var(Variable& var, const std::shared_ptr<Function>& fn,
}
}
// Given a Python tuple of raw output tensors (raw_output), set each of
// the corresponding entries in a different Python tuple (outputs) with
// these tensors wrapped with variables. We save the gradient function (self)
// to the variable if the output is not volatile (is_volatile).
//
// There is a considerable amount of complexity to handle if the operation
// that produced these output tensors is inplace. A mapping of *input*
// tensors to variables (t2var) is used to test if this occurred, and
// the set of dirty tensors (dirty_inputs) is used to figure out what to
// do in this case.
static void _wrap_outputs(THPFunction *self, t2var_type &t2var,
std::unordered_set<PyObject *> &dirty_inputs, PyObject *raw_output,
PyObject *outputs, bool is_volatile)
@ -443,6 +467,7 @@ static void _wrap_outputs(THPFunction *self, t2var_type &t2var,
}
}
// Save any variables that requested by to_save
static void _save_variables(THPFunction* self, t2var_type &t2var)
{
if (!self->to_save) return;
@ -520,6 +545,7 @@ static void _join_version_counters(THPFunction *self, t2var_type &t2var)
self->shared_pairs = NULL;
}
// Mark requires_grad = 0 on non-differentiable variables (as per non_differentiable)
static void _mark_non_differentiable(THPFunction *self, t2var_type &t2var)
{
if (!self->non_differentiable) return;
@ -680,7 +706,7 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *_inputs)
ctx->needs_input_grad = input_info.needs_input_grad.release();
ctx->is_variable_input = new std::vector<bool>(std::move(input_info.is_variable_input));
// Prepend ctx to tensor_input
// Prepend ctx to tensor_input, in preparation for static method call
auto num_args = PyTuple_GET_SIZE(_inputs);
THPObjectPtr ctx_tensor_input = PyTuple_New(num_args + 1);
PyTuple_SET_ITEM(ctx_tensor_input.get(), 0, ctx_obj.release());

View File

@ -14,6 +14,8 @@ using output_info_type = std::tuple<PyObject *, int, std::vector<long>>;
namespace torch { namespace autograd {
// A Function which is implemented by a Python object (i.e., a THPFunction).
// Calls to 'apply' are forwarded to the Python method implementation.
struct PyFunction : public Function {
PyFunction(PyObject* obj) : obj(obj) {}
@ -23,6 +25,7 @@ struct PyFunction : public Function {
virtual void releaseVariables() override;
virtual std::string name() override;
// THPFunction this Function is wrapping.
PyObject* obj;
};
@ -33,17 +36,32 @@ struct THPFunction {
PyObject *needs_input_grad;
// Python tuple of tensors whose variables we should save. Set
// by Python with 'save_for_backward'. If NULL, no tensors were
// saved.
PyObject *to_save;
// Python pairs of distinct tensors which share storage. Set by
// Python with 'mark_shared_storage'. If NULL, no tensors share
// storage.
PyObject *shared_pairs;
// Python tuple of tensors which are not differentiable. Set by
// Python with 'mark_non_differentiable'. If NULL, no tensors were
// non-differentiable.
PyObject *non_differentiable;
// Python tuple of tensors which had inplace updates in the forward()
// pass. Set by Python with 'mark_dirty'. If NULL, no tensors were
// modified inplace.
PyObject *dirty_tensors;
std::vector<output_info_type> *output_info;
std::vector<torch::autograd::SavedVariable> *saved_variables;
// For each input, true if the input is a THPVariable
std::vector<bool> *is_variable_input;
char has_freed_buffers;
// See a comment in THPFucntion_asFunction for details about this field.
// The C++ wrapper for this Python function.
// See a comment in THPFunction_asFunction for details about this field.
// You can use cdata directly if you don't actually need a shared_ptr.
std::weak_ptr<torch::autograd::PyFunction> cdata_ptr;
torch::autograd::PyFunction cdata;
};

View File

@ -5,10 +5,18 @@
#include "torch/csrc/autograd/variable.h"
// Python object that backs torch.autograd.Variable
struct THPVariable {
PyObject_HEAD
// Payload
std::shared_ptr<torch::autograd::Variable> cdata;
// Tensor this wraps (corresponds to Python attr 'data').
// It assumed that a THPVariable is *uniquely* identified by the
// tensor it wraps.
// Invariant: v->data == v->cdata->data
PyObject* data;
// Hooks to be run on backwards pass (corresponds to Python attr
// '_backwards_hooks', set by 'register_hook')
PyObject* backward_hooks;
};