Commit Graph

45 Commits

Author SHA1 Message Date
Adam Paszke
9662cffd26 Use std::list in JIT IR 2017-09-05 17:48:55 -04:00
Adam Paszke
3dcbba1f35 Keep Variable mapping as part of TracingState 2017-09-05 17:48:55 -04:00
Adam Paszke
6be47ec907 Minor fixes and improvements 2017-09-05 17:48:55 -04:00
Adam Paszke
ea05ac8f41 Move JIT-related files to jit dir. Remove IR interpreter 2017-09-05 17:48:55 -04:00
Zach DeVito
1325fa511c JIT IR including use-def chains and updated comments. 2017-09-05 17:48:55 -04:00
Zach DeVito
7c083b00f8 refcounting for Node/Value 2017-09-05 17:48:55 -04:00
Zach DeVito
f369f8e80d simplify IR 2017-09-05 17:48:55 -04:00
Edward Z. Yang
4979359800 Add graphs, trace them.
It is not an /expression/ we trace, but it is a /graph/: that is,
a closed expression which knows its parameters.  Knowing the list
of parameters is helpful and helps remove a hack when interpreting.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-09-05 17:48:55 -04:00
Edward Z. Yang
c1dec0663f New stratification: add Operator/Instruction
This prevents nested lets, which are not allowed in ANF.  We
basically have SSA now.

There's some niftiness with the visitor returning a lambda which
then gets fed the actual argument. I like it.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-09-05 17:48:55 -04:00
Edward Z. Yang
7bd4c5a27c Minor sanity check.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-09-05 17:48:55 -04:00
Edward Z. Yang
3055b69f63 Refactor Arg class away.
Although ANF style developments traditionally stratifies syntactic
classes into atomic (Arg) and complex (Expr) expressions, where
atomic expressions could be variables, constants or lambdas, Zach has
successfully convinced me that we should do away with the variant here and
always require arguments to be variables.  There are a few reasons for
this:

1) Tensor constants, not currently supported, could be modeled using a
"Constant" instruction, removing the need for them to be representable
directly inline.  An inline constant is marginally more convenient
for peephole optimizations, but since we have gone full ANF, we are going
to need to be able to see across def-uses in any case, and it is not
too much worse to need to handle constants this way.  By the way,
Swift Intermediate Language also made a similar choice, see
the slide on "Literal Instructions" in
http://llvm.org/devmtg/2015-10/slides/GroffLattner-SILHighLevelIR.pdf

2) Scalar constants, which are quite important for passing non-tensor
arguments to Python operators, are now stored out-of-band as NON
first-class values.  This more closely matches the ToffeeIR design,
and makes it clear what parameters are "first class" (tensors only)
and which ones are not.  However, we need to be able to unswizzle
the separate scalar/tensor lists into a unified list in the correct
format; this is what PyFunctionCConv is for.

Also, Locals got renamed into Tuple.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-09-05 17:48:55 -04:00
Edward Z. Yang
c466b2c1f6 Make an error message better
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-09-05 17:48:55 -04:00
Edward Z. Yang
c4ccae6a89 Document move semantics on PyObject with THPObjectPtr&& constructor.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-09-05 17:48:55 -04:00
Edward Z. Yang
a797ab9343 Rewrite AST to a new, more functional representation.
Previously, our AST was a DAG, where shared Nodes indicated a computation
should be reused.  This commit rewrites the IR into a new functional
representation which represents sharing explicitly using variable
bindings.

We offer a few justifications for this new style:

1. The new representation is not all that different from the
old one; it is about as easy to construct, and the lack of an
explicit graph doesn't negatively impact our ability to interpret
the graph, since we've chosen, as a matter of design, to NOT have
the IR participate in the actual execution of a graph.

2. The new let-binding representation has an implicit ordering,
which we can use to conveniently keep track of the original order
the trace showed up as.  This automatically gives us a topsort,
and gives us an easier to read textual representation of our
IR:

  %14 = Embedding %11, %0, -1, None, 2, False, False
  %15 = Dropout %14, 0.2, True, False
  %16 = Index %12, 0
  %17 = Index %12, 1
  %18 = Index %13, 0
  %19 = Index %13, 1
  %20 = Index %15, 0
  %21 = Linear %20, %1, %3
  %22 = Linear %16, %2, %4

3. It moves us closer to a Futhark style language
(http://futhark-lang.org/publications/pldi17.pdf).

Major aspects of the diff

- Node is replaced with Expr and Arg, a pair of mutually recursive
  structures which represent our new language.  In BNF, the language
  looks like this:

    a ::= c | %i
    e ::= %i, ... = e
        | PyOp e, ...
        | Ret %i, ...

  Technically, Ret is not actually a return (no control flow is involved),
  it just tuples up a series of tensors (identified by variables).

  One important invariant is that locals are always tensors; they
  are never constants (this is asymmetric with Args.)

- Arguments support Python constants.  This is an important piece because
  many operators take extra Python literals like integers and tuples in
  order to specify extra parameters about how an operator operates.  Adding
  this was essential to getting word_language_model to work.

- As both Expr and Arg have multiple variants, there is new infrastructure
  for doing case on the variants using ExprVisitor and ArgVisitor.  The
  strategy here is adapted from WebAssembly's visitors, although we have
  generalized to permit arbitrary argument forwarding, which is necessary
  to support tail-recursive visitor calls.  TCO is important because our
  interpreter may recurse arbitrarily deep into a stack of nested lets.
  If users wish, they can also manually case on the type tag.

- Tracing is now turned on and off using _tracer_enter/_tracer_exit in
  torch._C.  _tracer_enter accepts a list of variables which are to be
  treated as arguments; _tracer_exit accepts the list of traced variables
  which should be returned when you reexecute the trace, and returns
  the trace expression which can be reexecuted.  GlobalTracingState
  is a global variable which tracks whether or not we are tracing or not.

- You use run_forward to execute a trace on some set of parameters.

- When under tracing, variables keep track, via trace_local, what the
  name of their variables in the IR are.

Here is a simple runner which leaks memory but can be used to JIT models:

  import torch.autograd.function as F
  import torch._C

  def jit(model):
      import types
      real_forward = model.forward
      def forward(self, *args):
          def flatten(x):
              return tuple(F._iter_variables(x))
          if not hasattr(self, "saved_trace"):
              torch._C._tracer_enter(tuple(self.parameters()) + flatten(args))
              out = real_forward(*args)
              self.saved_trace = torch._C._tracer_exit(flatten(out))
              self.saved_outs = out
              return out
          else:
              flat_out = Variable._execution_engine.run_forward(self.saved_trace, tuple(self.parameters()) + flatten(args))
              return F._unflatten(flat_out, self.saved_outs)

Major problems:

- Sanity checking is spotty at best, especially when users pass in variables.

- The interpreter leaks tensor memory from the store.  When we add back def-use
  we should be able to deallocate tensors as soon as we know they are no longer
  necessary.

- The interpreter needs to reach feature parity with the old execution engine.
  From there, we need to see if backwards can be subsumed as well.

- I still have no confidence in having memory managed everything correctly.
  This requires a close look.

- Rather than return an *open* expression as a trace, we should return a
  *lambda* instead, which knows about how many formal parameters it
  requires.

- The IR is not introspectable from Python at the moment, but this is simply a
  matter of implementing all the binding code.

- The tracer is NOT reentrant (you can't trace while you're inside a trace.)
  Furthermore, no sanity checking is done if you try to incorrectly reuse
  things from one trace in another.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-09-05 17:48:55 -04:00
Edward Z. Yang
6f9774d7db Minor bugfix.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-09-05 17:48:55 -04:00
Edward Z. Yang
11107190ca Handle legacy correctly.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-09-05 17:48:55 -04:00
Edward Z. Yang
1e8bf12b3a Add an inefficient but working evaluator for forward traces.
Simple test:

  import torch
  from torch.autograd import Variable
  import torch._C as _C

  x = Variable(torch.Tensor([4]), requires_grad=True)
  y = Variable(torch.Tensor([7]), requires_grad=True)
  z = x * y
  z.sum().backward()

  print(x.grad)
  print(y.grad)

  x.data[0] = 2
  y.data[0] = 3

  (z,) = z._execution_engine.run_forward((x, y), (z,))
  z.sum().backward()

  print(x.grad)
  print(y.grad)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-09-05 17:48:55 -04:00
Edward Z. Yang
50b375d9bf Add input nodes to the IR representation.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-09-05 17:48:55 -04:00
Edward Z. Yang
e1b7872fc2 Make it possible to access IR from Python.
Also, add a new trace_fn field to attach forward IR to Variables.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-09-05 17:48:55 -04:00
Edward Z. Yang
b1b20e4097 Remove dead field from UnpackedInput
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-09-05 17:48:55 -04:00
Zachary DeVito
43c944acbd Remove dead THPP code that has been replaced with ATen objects. (#2235)
THPP usage is now isolated in THD.
2017-07-29 08:07:41 +05:30
Trevor Killeen
c304d04fc6 Replace thpp::Tensor with ATen Tensor in autograd csrc (#2170) 2017-07-28 10:18:37 -04:00
gchanan
925208af72 Implement BatchNorm double backwards (#2207)
* Implement BatchNorm double backwards as a python function called directly from C++.

This will be converted to C++ code once ATen is integrated with autograd.

* Some performance improvements via inplace ops and reusing calculations.
2017-07-27 06:00:31 +05:30
albanD
626840aef3 C function wrapper uniqueness (#1912)
* add SharedFunctionMaker to create Function shared in the graph

* Clean shared_ptr usage for only function that will be used in the graph

* make Function binding match Varible one

* remove unnecessary changes

* fix comments

* proper weakref implementation

* add call to clear in dealloc
2017-07-25 13:12:54 +05:30
Sam Gross
eba3dc8561 Fix gc_refs assertion failure (#1705)
* Fix gc_refs assertion failure

Ensure that each THPVariable -> THPFunction reference contributes one
ref count to the THPFunction by creating a new shared_ptr for each ref.

Because multiple shared_ptrs can again manage a single THPFunction, it's
not safe to use std::weak_ptr where it may point to a PyFunction. It's
still safe to use weak_ptr for grad_accumulator since these are never
PyFunctions.

Fixes #1626

* Remove stale comment
2017-06-02 21:08:50 -04:00
Trevor Killeen
05bc877a05 make THPPointer have explicit constructors (#1636) 2017-05-25 15:35:54 -04:00
Sam Gross
036c3f93af Check for released variables in SavedVariable::unpack() (#1648)
Fixes #1288
2017-05-25 00:35:19 -04:00
Edward Z. Yang
1f3ff5ced2 Miscellaneous documentation around autograd. (#1577)
* Miscellaneous documentation around autograd.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-05-17 19:19:24 -04:00
Adam Paszke
35cf380ed1 Improve output wrapping logic in autograd 2017-05-10 16:43:14 +02:00
Adam Paszke
e9d648c5e7 Fix memory leak introduced by 72e8190 (#1464) 2017-05-03 18:38:56 -04:00
Adam Paszke
72e8190994 Use at most one shared_ptr block at a time to manage THPFunctions (#1454)
* Fix failing ln in build_all.sh

* Use at most one shared_ptr block at a time to manage THPFunctions
2017-05-03 08:15:36 -04:00
Adam Paszke
5c7453447f Fix bugs, rename differentiate to grad, make it more flexible 2017-05-01 16:44:56 -04:00
Adam Paszke
20aa5b066f Convert some of the functions to new format
Also, fix a lot of issues that appeared after the previous commits.
2017-05-01 16:44:56 -04:00
Adam Paszke
de9998e198 Add support for the new Function format 2017-05-01 16:44:56 -04:00
Adam Paszke
702a2e3bc5 Make Variables not subclass Function anymore
Because of this Variables can no longer appear in the graph.
Every usage of a leaf Variable will leave an AccumulateGrad
function that has no outputs, but modifies var.grad as a side
effect.
2017-05-01 16:44:56 -04:00
Adam Paszke
2ca787fcf4 Refactor attribute names in autograd 2017-05-01 16:44:56 -04:00
albanD
a7ae04a657 fix precedence problem when building with debug python (#1201) 2017-04-06 10:30:16 -04:00
Sam Gross
5073132837 Implement 'pre' and 'post' hooks at the C++ autograd level 2017-03-06 12:47:53 -08:00
Sam Gross
34ce58c909 Parallelize backwards 2017-03-03 11:26:00 -08:00
Martin Raison
f17cfe4293 sparse tensor operations (#735) 2017-03-03 18:37:03 +01:00
Adam Paszke
502ebed796 Fix one more reference cycle and ensure correct flag propagation (#868) 2017-02-27 18:38:29 -05:00
Adam Paszke
2b23712dc3 Improve autograd memory usage (#859) 2017-02-26 22:37:26 -05:00
Adam Paszke
31941918cf Prevent creation of reference cycles with leaf Variables that don't require grad
Also, raise an error immediately, if a leaf that requiers_grad is
modified in-place. Some comments were updated too.
2017-02-26 20:02:42 +01:00
Sam Gross
dd844f741b Fix previous_functions when it contains Variables 2017-02-17 11:03:46 +05:30
Sam Gross
bd5303010d Refactor autograd package to separate Python dependencies. (#662)
The core autograd Variable, Function, and Engine no longer depend on the
Python API. This let's us implement functions in C++. In the future, we
can also multithread engine and release the GIL for most of the
non-Python backwards.
2017-02-13 16:00:16 -08:00