pytorch/torch/csrc/autograd/python_engine.cpp
Edward Z. Yang a797ab9343 Rewrite AST to a new, more functional representation.
Previously, our AST was a DAG, where shared Nodes indicated a computation
should be reused.  This commit rewrites the IR into a new functional
representation which represents sharing explicitly using variable
bindings.

We offer a few justifications for this new style:

1. The new representation is not all that different from the
old one; it is about as easy to construct, and the lack of an
explicit graph doesn't negatively impact our ability to interpret
the graph, since we've chosen, as a matter of design, to NOT have
the IR participate in the actual execution of a graph.

2. The new let-binding representation has an implicit ordering,
which we can use to conveniently keep track of the original order
the trace showed up as.  This automatically gives us a topsort,
and gives us an easier to read textual representation of our
IR:

  %14 = Embedding %11, %0, -1, None, 2, False, False
  %15 = Dropout %14, 0.2, True, False
  %16 = Index %12, 0
  %17 = Index %12, 1
  %18 = Index %13, 0
  %19 = Index %13, 1
  %20 = Index %15, 0
  %21 = Linear %20, %1, %3
  %22 = Linear %16, %2, %4

3. It moves us closer to a Futhark style language
(http://futhark-lang.org/publications/pldi17.pdf).

Major aspects of the diff

- Node is replaced with Expr and Arg, a pair of mutually recursive
  structures which represent our new language.  In BNF, the language
  looks like this:

    a ::= c | %i
    e ::= %i, ... = e
        | PyOp e, ...
        | Ret %i, ...

  Technically, Ret is not actually a return (no control flow is involved),
  it just tuples up a series of tensors (identified by variables).

  One important invariant is that locals are always tensors; they
  are never constants (this is asymmetric with Args.)

- Arguments support Python constants.  This is an important piece because
  many operators take extra Python literals like integers and tuples in
  order to specify extra parameters about how an operator operates.  Adding
  this was essential to getting word_language_model to work.

- As both Expr and Arg have multiple variants, there is new infrastructure
  for doing case on the variants using ExprVisitor and ArgVisitor.  The
  strategy here is adapted from WebAssembly's visitors, although we have
  generalized to permit arbitrary argument forwarding, which is necessary
  to support tail-recursive visitor calls.  TCO is important because our
  interpreter may recurse arbitrarily deep into a stack of nested lets.
  If users wish, they can also manually case on the type tag.

- Tracing is now turned on and off using _tracer_enter/_tracer_exit in
  torch._C.  _tracer_enter accepts a list of variables which are to be
  treated as arguments; _tracer_exit accepts the list of traced variables
  which should be returned when you reexecute the trace, and returns
  the trace expression which can be reexecuted.  GlobalTracingState
  is a global variable which tracks whether or not we are tracing or not.

- You use run_forward to execute a trace on some set of parameters.

- When under tracing, variables keep track, via trace_local, what the
  name of their variables in the IR are.

Here is a simple runner which leaks memory but can be used to JIT models:

  import torch.autograd.function as F
  import torch._C

  def jit(model):
      import types
      real_forward = model.forward
      def forward(self, *args):
          def flatten(x):
              return tuple(F._iter_variables(x))
          if not hasattr(self, "saved_trace"):
              torch._C._tracer_enter(tuple(self.parameters()) + flatten(args))
              out = real_forward(*args)
              self.saved_trace = torch._C._tracer_exit(flatten(out))
              self.saved_outs = out
              return out
          else:
              flat_out = Variable._execution_engine.run_forward(self.saved_trace, tuple(self.parameters()) + flatten(args))
              return F._unflatten(flat_out, self.saved_outs)

Major problems:

- Sanity checking is spotty at best, especially when users pass in variables.

- The interpreter leaks tensor memory from the store.  When we add back def-use
  we should be able to deallocate tensors as soon as we know they are no longer
  necessary.

- The interpreter needs to reach feature parity with the old execution engine.
  From there, we need to see if backwards can be subsumed as well.

- I still have no confidence in having memory managed everything correctly.
  This requires a close look.

- Rather than return an *open* expression as a trace, we should return a
  *lambda* instead, which knows about how many formal parameters it
  requires.

- The IR is not introspectable from Python at the moment, but this is simply a
  matter of implementing all the binding code.

- The tracer is NOT reentrant (you can't trace while you're inside a trace.)
  Furthermore, no sanity checking is done if you try to incorrectly reuse
  things from one trace in another.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-09-05 17:48:55 -04:00

342 lines
13 KiB
C++

#include "torch/csrc/autograd/python_engine.h"
#include "torch/csrc/autograd/engine.h"
#include "torch/csrc/autograd/python_function.h"
#include "torch/csrc/THP.h"
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/utils/auto_gil.h"
#include <unordered_set>
using namespace torch::autograd;
struct THPEngine {
PyObject_HEAD
};
struct PythonEngine : public Engine {
virtual void thread_main(std::shared_ptr<ReadyQueue> queue, int device) override {
// Create a PyThreadState, but release the GIL. This lets AutoGIL calls
// inside thread_main acquire the GIL without having to create a new
// PyThreadState each time.
AutoGIL gil;
AutoNoGIL no_gil;
Engine::thread_main(queue, device);
}
virtual void thread_on_exception(FunctionTask& task, std::exception& e) override {
auto python_err = dynamic_cast<python_error*>(&e);
if (python_err) {
python_err->persist();
}
Engine::thread_on_exception(task, e);
}
};
static PythonEngine engine;
PyObject *THPEngineClass = NULL;
struct CallbackContext {
std::string error;
THPObjectPtr outputs;
// Used to determine which callback arguments should be used to
// fill outputs.
// Function -> ([grad_nr, outputs_idx], is_leaf)
std::unordered_map<
std::shared_ptr<Function>,
std::pair<std::vector<std::pair<int, int>>, bool>> output_map;
};
void compute_partial_exec_callbacks(const function_list& roots,
const CallbackContext& ctx,
Engine::callback_map& map) {
// This callback is used to suppress the computation of a node
// if it is not necessary.
static Engine::callback_type abort_callback(
[](Function* fn, variable_list &vars) { return false; });
std::vector<Function*> queue;
std::unordered_set<Function*> seen; // for the initial DFS
std::unordered_set<Function*> needed; // functions to compute
std::unordered_map<Function*, std::vector<Function*>> rev_graph;
// Reverse the next_fn edges
queue.reserve(roots.size());
for (auto& root : roots) {
auto ptr = root.first.get();
bool unseen;
std::tie(std::ignore, unseen) = seen.insert(ptr);
if (unseen) queue.emplace_back(ptr);
}
while (!queue.empty()) {
auto fn = queue.back(); queue.pop_back();
for (auto& next_fn_pair : fn->next_functions) {
auto next_fn = next_fn_pair.first.get();
if (!next_fn) continue;
rev_graph[next_fn].push_back(fn);
if (seen.insert(next_fn).second) {
queue.push_back(next_fn);
}
}
}
auto all_functions = std::move(seen); // this is cheap and improves readability
// Find all functions we need to compute
queue.clear();
for (auto input_info: ctx.output_map) {
auto input = input_info.first.get();
auto& rev_edges = rev_graph[input];
if (rev_edges.size() == 0) throw std::runtime_error("differentiated input is unreachable");
queue.emplace_back(input);
needed.insert(input);
}
while (!queue.empty()) {
auto fn = queue.back(); queue.pop_back();
for (auto rev_next_fn : rev_graph[fn]) {
if (needed.insert(rev_next_fn).second) {
queue.push_back(rev_next_fn);
}
}
}
// Prevent expansion for functions in {all_vertices} \ {needed}
for (auto fn : all_functions) {
if (needed.count(fn) > 0) continue;
map.emplace(fn, abort_callback);
}
}
// Implementation of torch._C._EngineBase.run_backward
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS
PyObject *variables = NULL;
PyObject *grad_variables = NULL;
unsigned char keep_graph = 0;
PyObject *inputs = NULL;
unsigned char only_inputs = 0;
const char *accepted_kwargs[] = {"variables", "grad_variables",
"keep_graph", "inputs", "only_inputs", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OOb|Ob", (char**)accepted_kwargs,
&variables, &grad_variables, &keep_graph, &inputs, &only_inputs))
return NULL;
THPUtils_assert(PyTuple_Check(variables), "variables argument is expected to "
"be a tuple, but got %s", THPUtils_typename(variables));
THPUtils_assert(PyTuple_Check(grad_variables), "variables argument is "
"expected to be a tuple, but got %s", THPUtils_typename(grad_variables));
Py_ssize_t num_variables = PyTuple_GET_SIZE(variables);
Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_variables);
THPUtils_assert(num_variables == num_gradients, "got %ld variables and %ld "
"gradients", num_variables, num_gradients);
function_list roots(num_variables);
variable_list grads(num_variables);
for (int i = 0; i < num_variables; i++) {
PyObject *_variable = PyTuple_GET_ITEM(variables, i);
THPUtils_assert(THPVariable_Check(_variable), "element %d of variables "
"tuple is not a Variable", i);
auto& variable = ((THPVariable*)_variable)->cdata;
THPUtils_assert(!variable->is_volatile,
"element %d of variables tuple is volatile", i);
// If grad_fn is NULL (as is the case for a leaf node), we instead
// interpret the gradient function to be a grad accumulator,
// which will accumulate its inputs into the grad property of the
// variable. These nodes get suppressed in some situations,
// see "suppress grad accumulation" below.
auto grad_fn = variable->grad_fn ? variable->grad_fn : variable->get_grad_accumulator();
THPUtils_assert(grad_fn, "element %d of variables tuple does not require grad", i);
int output_nr = variable->grad_fn ? variable->output_nr : 0;
roots[i] = std::make_pair<>(std::move(grad_fn), output_nr);
PyObject *grad = PyTuple_GET_ITEM(grad_variables, i);
if (THPVariable_Check(grad)) {
grads[i] = ((THPVariable*)grad)->cdata;
} else {
THPUtils_assert(grad == Py_None,
"element %d of gradients tuple is not a Variable or None", i);
THPUtils_assert(!variable->requires_grad,
"element %d of gradients tuple is None, but the corresponding Variable requires grad");
}
}
Engine::callback_map callbacks;
CallbackContext ctx;
if (inputs != NULL) {
THPUtils_assert(PyTuple_Check(inputs), "inputs argument has to be a tuple");
int num_inputs = PyTuple_GET_SIZE(inputs);
ctx.outputs = PyTuple_New(num_inputs);
// First, find all relevant functions and fill ctx.output_map
for (int i = 0; i < num_inputs; ++i) {
PyObject *input = PyTuple_GET_ITEM(inputs, i);
THPUtils_assert(THPVariable_Check(input),
"all inputs have to be Variables, but got %s", THPUtils_typename(input));
THPVariable *input_var = (THPVariable*)input;
auto grad_fn = input_var->cdata->grad_fn;
int output_nr = input_var->cdata->output_nr;
bool is_leaf = !grad_fn;
if (is_leaf) {
grad_fn = input_var->cdata->grad_accumulator.lock();
}
THPUtils_assert(grad_fn, "One of the differentiated Variables appears to not have "
"been used in the graph");
auto& fn_info = ctx.output_map[grad_fn];
fn_info.first.emplace_back(output_nr, i);
fn_info.second = is_leaf;
}
// Register callbacks that will gather the outputs
for (auto& entry : ctx.output_map) {
auto& fn_info = entry.second;
callbacks.emplace(entry.first.get(), [&ctx, &fn_info](Function* _unused, variable_list& grads) {
auto& saved_outputs = fn_info.first;
bool is_leaf = fn_info.second;
AutoGIL gil;
for (auto& saved_out : saved_outputs) {
PyTuple_SET_ITEM(ctx.outputs.get(), saved_out.second,
THPVariable_Wrap(grads[saved_out.first]));
}
// Suppress grad accumulation.
// If the variable is a leaf, the next function to execute
// is a grad_accumulator. But when inputs != NULL, we should
// NOT accumulate, so terminate execution.
return !is_leaf;
});
}
// Disable execution for all unneeded functions
if (only_inputs) {
compute_partial_exec_callbacks(roots, ctx, callbacks);
}
}
try {
AutoNoGIL no_gil;
engine.execute(roots, grads, keep_graph, callbacks);
} catch (python_error &e) {
e.restore();
return nullptr;
}
if (ctx.outputs) {
return ctx.outputs.release();
} else {
Py_RETURN_NONE;
}
END_HANDLE_TH_ERRORS
}
PyObject *THPEngine_run_forward(THPEngine *self, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS;
PyObject* expr_obj;
PyObject* input_objs;
const char *accepted_kwargs[] = {"expr", "inputs", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO", (char**)accepted_kwargs,
&expr_obj, &input_objs))
return NULL;
THPUtils_assert(THPExpr_Check(expr_obj), "expr argument is expected to be an "
"Expr, but got %s", THPUtils_typename(expr_obj));
THPUtils_assert(PyTuple_Check(input_objs), "inputs argument is expected to "
"be a tuple, but got %s", THPUtils_typename(input_objs));
Py_ssize_t num_inputs = PyTuple_GET_SIZE(input_objs);
environment env;
// TODO: skeevy, requires on invariant that the tracing numbering
// has the first N parameters allocated to parameters
for (int i = 0; i < num_inputs; i++) {
PyObject* input_obj = PyTuple_GET_ITEM(input_objs, i);
THPUtils_assert(THPVariable_Check(input_obj), "element %d of inputs "
"tuple is not a Variable", i);
auto& input_var = ((THPVariable*)input_obj)->cdata;
env.insert({i, input_var});
}
variable_list results = interpret(((THPExpr*)expr_obj)->cdata, env);
int num_outputs = results.size();
PyObject *result = PyTuple_New(num_outputs);
for (int i = 0; i < num_outputs; i++) {
PyTuple_SET_ITEM(result, i, THPVariable_Wrap(results.at(i)));
}
return result;
END_HANDLE_TH_ERRORS;
}
PyObject* THPEngine_queue_callback(PyObject *self, PyObject *_callback) {
std::shared_ptr<PyObject> callback(_callback, [](PyObject *obj) { AutoGIL gil; Py_DECREF(obj); });
Py_INCREF(_callback);
engine.queue_callback([callback]() {
AutoGIL gil;
THPObjectPtr result {PyObject_CallFunctionObjArgs(callback.get(), NULL)};
if (!result) throw python_error();
});
Py_RETURN_NONE;
}
PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
{
return type->tp_alloc(type, 0);
}
static struct PyMethodDef THPEngine_methods[] = {
{(char*)"run_backward", (PyCFunction)THPEngine_run_backward, METH_VARARGS | METH_KEYWORDS, NULL},
{(char*)"run_forward", (PyCFunction)THPEngine_run_forward, METH_VARARGS | METH_KEYWORDS, NULL},
{(char*)"queue_callback", (PyCFunction)THPEngine_queue_callback, METH_O, NULL},
{NULL}
};
PyTypeObject THPEngineType = {
PyVarObject_HEAD_INIT(NULL, 0)
"torch._C._EngineBase", /* tp_name */
sizeof(THPEngine), /* tp_basicsize */
0, /* tp_itemsize */
0, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_reserved */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
NULL, /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
THPEngine_methods, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
THPEngine_new /* tp_new */
};
bool THPEngine_initModule(PyObject *module)
{
if (PyType_Ready(&THPEngineType) < 0)
return false;
Py_INCREF(&THPEngineType);
PyModule_AddObject(module, "_ImperativeEngine", (PyObject *)&THPEngineType);
return true;
}