mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Previously, our AST was a DAG, where shared Nodes indicated a computation should be reused. This commit rewrites the IR into a new functional representation which represents sharing explicitly using variable bindings. We offer a few justifications for this new style: 1. The new representation is not all that different from the old one; it is about as easy to construct, and the lack of an explicit graph doesn't negatively impact our ability to interpret the graph, since we've chosen, as a matter of design, to NOT have the IR participate in the actual execution of a graph. 2. The new let-binding representation has an implicit ordering, which we can use to conveniently keep track of the original order the trace showed up as. This automatically gives us a topsort, and gives us an easier to read textual representation of our IR: %14 = Embedding %11, %0, -1, None, 2, False, False %15 = Dropout %14, 0.2, True, False %16 = Index %12, 0 %17 = Index %12, 1 %18 = Index %13, 0 %19 = Index %13, 1 %20 = Index %15, 0 %21 = Linear %20, %1, %3 %22 = Linear %16, %2, %4 3. It moves us closer to a Futhark style language (http://futhark-lang.org/publications/pldi17.pdf). Major aspects of the diff - Node is replaced with Expr and Arg, a pair of mutually recursive structures which represent our new language. In BNF, the language looks like this: a ::= c | %i e ::= %i, ... = e | PyOp e, ... | Ret %i, ... Technically, Ret is not actually a return (no control flow is involved), it just tuples up a series of tensors (identified by variables). One important invariant is that locals are always tensors; they are never constants (this is asymmetric with Args.) - Arguments support Python constants. This is an important piece because many operators take extra Python literals like integers and tuples in order to specify extra parameters about how an operator operates. Adding this was essential to getting word_language_model to work. - As both Expr and Arg have multiple variants, there is new infrastructure for doing case on the variants using ExprVisitor and ArgVisitor. The strategy here is adapted from WebAssembly's visitors, although we have generalized to permit arbitrary argument forwarding, which is necessary to support tail-recursive visitor calls. TCO is important because our interpreter may recurse arbitrarily deep into a stack of nested lets. If users wish, they can also manually case on the type tag. - Tracing is now turned on and off using _tracer_enter/_tracer_exit in torch._C. _tracer_enter accepts a list of variables which are to be treated as arguments; _tracer_exit accepts the list of traced variables which should be returned when you reexecute the trace, and returns the trace expression which can be reexecuted. GlobalTracingState is a global variable which tracks whether or not we are tracing or not. - You use run_forward to execute a trace on some set of parameters. - When under tracing, variables keep track, via trace_local, what the name of their variables in the IR are. Here is a simple runner which leaks memory but can be used to JIT models: import torch.autograd.function as F import torch._C def jit(model): import types real_forward = model.forward def forward(self, *args): def flatten(x): return tuple(F._iter_variables(x)) if not hasattr(self, "saved_trace"): torch._C._tracer_enter(tuple(self.parameters()) + flatten(args)) out = real_forward(*args) self.saved_trace = torch._C._tracer_exit(flatten(out)) self.saved_outs = out return out else: flat_out = Variable._execution_engine.run_forward(self.saved_trace, tuple(self.parameters()) + flatten(args)) return F._unflatten(flat_out, self.saved_outs) Major problems: - Sanity checking is spotty at best, especially when users pass in variables. - The interpreter leaks tensor memory from the store. When we add back def-use we should be able to deallocate tensors as soon as we know they are no longer necessary. - The interpreter needs to reach feature parity with the old execution engine. From there, we need to see if backwards can be subsumed as well. - I still have no confidence in having memory managed everything correctly. This requires a close look. - Rather than return an *open* expression as a trace, we should return a *lambda* instead, which knows about how many formal parameters it requires. - The IR is not introspectable from Python at the moment, but this is simply a matter of implementing all the binding code. - The tracer is NOT reentrant (you can't trace while you're inside a trace.) Furthermore, no sanity checking is done if you try to incorrectly reuse things from one trace in another. Signed-off-by: Edward Z. Yang <ezyang@fb.com>
83 lines
2.8 KiB
C++
83 lines
2.8 KiB
C++
#pragma once
|
|
|
|
#include <Python.h>
|
|
#include <vector>
|
|
#include <utility>
|
|
#include <memory>
|
|
|
|
#include "torch/csrc/autograd/function.h"
|
|
#include "torch/csrc/autograd/variable.h"
|
|
#include "torch/csrc/utils/object_ptr.h"
|
|
|
|
// (class, gpu id, sizes)
|
|
using output_info_type = std::tuple<PyObject *, int, std::vector<int64_t>>;
|
|
|
|
namespace torch { namespace autograd {
|
|
|
|
// A Function which is implemented by a Python object (i.e., a THPFunction).
|
|
// Calls to 'apply' are forwarded to the Python method implementation.
|
|
struct PyFunction : public Function {
|
|
PyFunction(PyObject* obj) : obj(obj) {}
|
|
|
|
virtual variable_list apply(const variable_list& inputs) override;
|
|
variable_list legacy_apply(const variable_list& inputs);
|
|
|
|
virtual void releaseVariables() override;
|
|
virtual std::string name() override;
|
|
|
|
// THPFunction this Function is wrapping.
|
|
PyObject* obj;
|
|
};
|
|
|
|
using environment = std::unordered_map<int, std::shared_ptr<Variable>>;
|
|
using output_list = std::vector<Local>;
|
|
variable_list interpret(std::shared_ptr<Expr>, environment);
|
|
|
|
}} // namespace torch::autograd
|
|
|
|
struct THPFunction {
|
|
PyObject_HEAD
|
|
|
|
PyObject *needs_input_grad;
|
|
|
|
// Python tuple of tensors whose variables we should save. Set
|
|
// by Python with 'save_for_backward'. If NULL, no tensors were
|
|
// saved.
|
|
PyObject *to_save;
|
|
// Python pairs of distinct tensors which share storage. Set by
|
|
// Python with 'mark_shared_storage'. If NULL, no tensors share
|
|
// storage.
|
|
PyObject *shared_pairs;
|
|
// Python tuple of tensors which are not differentiable. Set by
|
|
// Python with 'mark_non_differentiable'. If NULL, no tensors were
|
|
// non-differentiable.
|
|
PyObject *non_differentiable;
|
|
// Python tuple of tensors which had inplace updates in the forward()
|
|
// pass. Set by Python with 'mark_dirty'. If NULL, no tensors were
|
|
// modified inplace.
|
|
PyObject *dirty_tensors;
|
|
|
|
std::vector<output_info_type> *output_info;
|
|
std::vector<torch::autograd::SavedVariable> *saved_variables;
|
|
// For each input, true if the input is a THPVariable
|
|
std::vector<bool> *is_variable_input;
|
|
char has_freed_buffers;
|
|
|
|
// The C++ wrapper for this Python function.
|
|
// See a comment in THPFunction_asFunction for details about this field.
|
|
torch::autograd::PyFunction cdata;
|
|
};
|
|
|
|
bool THPFunction_initModule(PyObject *module);
|
|
extern PyTypeObject THPFunctionType;
|
|
extern PyObject *THPFunctionClass;
|
|
extern PyObject *THPStochasticFunctionClass;
|
|
extern PyObject *THPBatchNormBackwardBackwardFunction; // Temporarily here until we move it to C++
|
|
|
|
// XXX: this function requires the GIL (it can have side effects).
|
|
std::shared_ptr<torch::autograd::PyFunction> THPFunction_asFunction(THPFunction* self);
|
|
|
|
inline bool THPFunction_Check(PyObject* obj) {
|
|
return PyObject_IsInstance(obj, (PyObject*)&THPFunctionType);
|
|
}
|