Commit Graph

35 Commits

Author SHA1 Message Date
Adam Paszke
1f13453b4d Slightly relax the constraints on argument and return types to script functions (#9969)
Summary:
This lays out initial support for taking and returning a richer set
of types than only tensors. Floats and ints are already valid, lists are
straightforward to add, tuples need some discussion.

Based on top of #9948. Review only the last commit.

zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9969

Reviewed By: zdevito

Differential Revision: D9076973

Pulled By: apaszke

fbshipit-source-id: 5a1fe912ea6b79ab2bfd0dcce265eb05855b5ff0
2018-07-31 14:25:29 -07:00
Elias Ellison
e57cb4a1b2 Add a Constant Propagation Pass to the JIT (#8808)
Summary:
Adding a constant propagation pass to the JIT. I have added examples to the expect files.

There are a couple of special cases which have not been implemented here. IF nodes with constant conditions can be inlined with the correct block. WHILE nodes can be removed if the condition is false.  I have added a test for each case in test_jit.py file as expected failures.

To be consistent with DCE, python ops & CPP ops are treated as not having side-effects.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/8808

Reviewed By: wanchaol

Differential Revision: D8906770

Pulled By: eellison

fbshipit-source-id: 10ad796d89f80b843566c9ddad6a0abd1f3dc74c
2018-07-30 15:54:31 -07:00
Peter Goldsborough
04939a4745 Match parameter names and = default (#9737)
Summary:
More clang tidy cleanups in `torch/csrc`. This time:

1. `hicpp-use-equals-default` recommends `= default` instead of `{}` for constructors/destructors. This is better practice because it expresses the intent better (https://stackoverflow.com/questions/6502828/what-does-default-mean-after-a-class-function-declaration)
2. `readability-inconsistent-declaration-parameter-name` enforces that parameter names in the declaration match parameter names in the definition. This is just generally useful and can prevent confusion and bugs.

Also updated my script a little bit.

apaszke ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9737

Differential Revision: D9069069

Pulled By: goldsborough

fbshipit-source-id: f7b3f3a4eb4c9fadc30425a153566d3b613a41ae
2018-07-30 14:10:00 -07:00
Adam Paszke
8cb1eef7b9 Unify IR operator representation (stop using attributes in the JIT) (#9807)
Summary:
Based on top of #9763 (first 3 commits belong to that PR). The first commits from this PR are "Stop using attributes ..."

I tried to separate the changes into fairly meaningful commits. I can't split them up into smaller PRs, because everything starts working and all tests pass only after the whole sequence, but hopefully this will make reviewing somewhat easier.

Known issues/regressions/future tasks:
- `aten::lerp` and `aten::clamp` are no longer fusable
- `CreateAutodiffSubgraphs` needs a rewrite
  - It is much more strict now, and will miss a lot of opportunities, especially when viewing ops are involved. Our previous approach was "ignore the assumption on shape availability in gradient formulas to determine differentiability, and hope that shape prop will be robust enough to actually deliver them before we differentiate", which obviously doesn't scale well to more complex cases. We should either work on reducing the size dependency of grad formulas (feasible e.g. for `view`/`reshape`, unfeasible for `squeeze`/`unsqueeze`), or make `CreateAutodiffSubgraphs` integrate some kind of "I could integrate this node into an AD subgraph, but will I be able to infer the shape of its input" reasoning (kind of like a limited shape prop, that doesn't infer anything, and only tells if it *could* infer something).
  - It sometimes creates constant-only (or constants + one node) graphs, which is useless
- Broken `aten::add` in auto-batching, because it gained a non-tensor input. I changed the test for pointwise operations to use `aten::mul` instead, but I needed to disable the LSTM cell test. I'm not sure how scalar constants should be implemented in this case, because I don't fully understand our format. cc: ChunliF
- Graph import does some hacks to recover type of constants. This code should be removed once we'll gain the ability to export the IR along with value types.
- There's still a fair amount of dead code that can be removed. I didn't want to make this diff any bigger, and removing it is an easy task.
- Graph fuser could be improved to use signature matching (possibly using `OperatorSet`) instead of basing on node kinds.
- Manual constant propagation for the `ListConstruct` node in `torch/onnx/utils.py` should be replaced with a proper constant propagation pass (or we should ensure that the one we have handles at least this case before we remove this code).

zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9807

Reviewed By: ezyang

Differential Revision: D9004285

Pulled By: apaszke

fbshipit-source-id: fe88026a765f6b687354add034c86402362508b7
2018-07-26 22:11:50 -07:00
Adam Paszke
e39c8043dc Make GraphExecutors work on Stacks instead of variable_tensor_lists (#9763)
Summary:
This is blocking the IR operator unification, because I need to be able to pass scalars to backward functions.

zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9763

Reviewed By: zou3519

Differential Revision: D8978457

Pulled By: apaszke

fbshipit-source-id: 570b4c3409322459cb0f2592069730a7d586ab20
2018-07-26 12:00:27 -07:00
James Reed
0b16b03b98 Plumb type annotations through script compilation (new) (#9547)
Summary:
Supersedes https://github.com/pytorch/pytorch/pull/9405
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9547

Reviewed By: zdevito

Differential Revision: D8900327

Pulled By: jamesr66a

fbshipit-source-id: a00a94615af4fbaec98ee3ede0cb54bcfd9108dd
2018-07-25 17:10:14 -07:00
Peter Goldsborough
f62bc01dfe Remove TORCH_ASSERT (#9575)
Summary:
I got some tensor->variable conversion exceptions from `torch/csrc/autograd/variable.h`, which used the `TORCH_ASSERTM` macros instead of `AT_CHECK`, so they didn't have backtraces. This was such a substantial loss for debugability that I decided to update the whole codebase to use the backtrace-enabled ATen macros instead of `TORCH_ASSERT` and `JIT_ASSERT`, the latter having been an alias of the former.

ezyang apaszke zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9575

Differential Revision: D8924566

Pulled By: goldsborough

fbshipit-source-id: 7a4013b13eec9dbf024cef94cf49fca72f61d441
2018-07-24 18:10:06 -07:00
Adam Paszke
aa7af94656 Make JIT tracing a thread-local property (#9414)
Summary:
As in the title. Lets us simplify a lot of code.

Depends on #9363, so please review only the last commit.

zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9414

Reviewed By: zdevito

Differential Revision: D8836496

Pulled By: apaszke

fbshipit-source-id: 9b3c3d1f001a9dc522f8478abc005b6b86cfa3e3
2018-07-19 19:09:39 -07:00
Zachary DeVito
9ed2190bdb Add a tagged union type that replaces tensor in the interpreter. (#9368)
Summary:
IValue is short for interpreter value. It is used frequently so a short name is important.
This will allow us to implement more non-tensor types in an efficient way and remove
many hacks from the compiler.

This PR is limited. It only introduces IValue and changes interpreter to use it.
Follow up PRs will:
* Change the way aten_ops consume non-tensor types so that integer lists,
  are no longer represented as Tensors.
* Introduce TensorList as a fundamental type and remove all vararg handling in gen_jit_dispatch
* Change the compiler to implement math on primitive numbers rather than converting to tensors.

jamesr66a  apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9368

Reviewed By: ezyang

Differential Revision: D8817598

Pulled By: zdevito

fbshipit-source-id: 29dce80611ce5f6384234de9d12a67861d2b112f
2018-07-16 15:40:22 -07:00
Mary McBreen
483ae8cb5d Replaces const ref with && for apply (#9175)
Summary:
Addresses https://github.com/pytorch/pytorch/issues/5011
Tested with python test/test_autograd.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9175

Reviewed By: zdevito

Differential Revision: D8736377

Pulled By: marymcbreen

fbshipit-source-id: ff86f427f7b2cf0cab5912e7f32812bd0f49a712
2018-07-12 08:31:59 -07:00
Adam Paszke
b9f575fc33 Remove legacy code from the JIT (#9323)
Summary:
In particular, get rid of backward tracing and CppOp.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9323

Reviewed By: ezyang

Differential Revision: D8795935

Pulled By: apaszke

fbshipit-source-id: fb7a7eeee41902da35f2a8efd77262ca60fd6bbe
2018-07-11 10:25:38 -07:00
Zachary DeVito
f74207c99f
Allow autograd to work even when the shape of values cannot be determined (#8641)
This commit implements the solution proposed in https://github.com/pytorch/pytorch/issues/8410
to workaround the need to create zero tensors with the same shape as inputs.
It introduces the concept of a LinearBlock which marks places in the code
where we know if all the inputs to the node are zero, then the outputs
to the node are also zero. Autodiff introduces LinearBlocks around
backwards functions, which have this property. specializeUndef then
propagates Undef nodes using this information.

Notes:
* Since we do not always specialize, we have a pass LowerLinearBlocks
that replaces the block with an if statement that dynamically guards
the Undef case.
* We introduce AutogradAdd which is addition that still works when
its inputs might be undefined. In cases where we specialize this will
get removed in favor of a normal add, but there are cases where
gradient graphs do not specialize (e.g. when they are not differentiable,
but a derivative is required) so it is important for this op to be executable.
2018-06-25 18:40:04 -07:00
Richard Zou
8489c4cc6e
Better support for literals in jit script (#8687)
Addresses #8177

A design doc can be found here: [gist](https://gist.github.com/zou3519/4b7f13f03cc9f3612bd9363e6405fa0a) version or [quip](https://fb.quip.com/azL1AqUckBdo) version

General approach:
- Add NumberType, FloatType, IntType to represent Python numbers, floats and ints.
- Emit these types for python literals
- Change aten_schema such that Scalars are NumberType, int64_t and bool are IntType.
- Emit aten::type_as, prim::NumToTensor, and prim::TensorToNum nodes for tensor-number math. (see examples below)
- Erase NumberType,  prim::NumToTensor, and prim::TensorToNum for ONNX export

### Tensor/number math
```
import torch
@torch.jit.script
def fn(x):
    return x + 1
```
```
graph(%x : Dynamic) {
  %1 : int = prim::Constant[value={1}]()
  %2 : Dynamic = prim::NumToTensor(%1)
  %3 : Dynamic = aten::type_as(%2, %x)
  %4 : Dynamic = aten::add[alpha={1}](%x, %4)
  return (%5);
}
```

### Number/Number Math
```
import torch
@torch.jit.script
def fn(zero):
    c = 1 + 1
    return zero + c
```
```
graph(%zero : Dynamic) {
  %1 : int = prim::Constant[value={1}]()
  %2 : int = prim::Constant[value={1}]()
  %3 : Dynamic = prim::num_to_tensor(%1)
  %4 : Dynamic = prim::num_to_tensor(%2)
  %5 : Dynamic = aten::add[alpha={1}](%3, %4)
  %c : int = prim::TensorToNum(%6)  # this is the result of the addition
  ...
  return (%13);
}
```

List of squashed commits:

* Introduce Python Number types

Added: IntType, FloatType, NumberType with
IntType <: NumberType
FloatType <: NumberType

Changed aten_schema so arguments have corresponding types

* Emit a NumberType for python literals.

Also emit a NumberType for Scalar default values.

* Add prim::NumToTensor and prim::TensorToNum

* Add DynamicType -> NumberType implicit cast for bc

* Better ensureTensor error message

* Add ensureTensorOrNumber. Allow passing Number to some functions

Like the range() construct and slices

* Patch IntList to work.

IntList is still a DynamicType in the frontend: a tensor gets built from
a List[int].

Also, IntList[1] is a "union between int and IntList" the way it is
implemented. If the frontend sees an int being passed for an IntList[1]
arg, it converts it to a tensor as well.

* Enforce some order on schemas to avoid overload ambiguity

add(Tensor, Tensor) should appear earlier than add(Tensor, Scalar). This
matches the order in which python_arg_parser parses its arguments.

* Disable std_dim and var_dim tests.

With the new schema information, std(input, keepdim) and std(input, dim)
are ambiguous. This will need to be fixed at a later date.

* Add NumberType erasure pass.

This is used for ONNX export and to ensure that NumberType information
doesn't reach the interpreter

* Add support for mixed tensor/number math ops.

* Tests for new functionality.

Includes:
- Tensor/number math
- number/number math
- EraseNumberTypes pass test

* Patch tests

Update expect tests for:
- decompose_addmm
- loop unrolling tests

Because python numbers are now NumberType, they cannot be returned by
functions anymore. Work around this by using "torch.full", or by adding
a tensor([0]) (taken from FIXME_zerol()). Both approaches are used
because torch.full is more readable, but it is broken in some cases.

* Add erase_number_types to torch/CMakeLists.txt

* Move math back to emitSimpleExpr from emitSugaredExpr

* Remove some dead lines

* Renable some excluded script/trace tests that are fixed.

* Move some tests to expected failure

* Address some comments (more addressing to come)

* Erase relevant aten::type_as nodes in EraseNumberTypes

I also changed it so that EraseNumberTypes is only called for ONNX
export. It is no longer used to prevent
prim::NumToTensor/prim::TensorToNum from reaching shape_analysis or
interpreter.cpp.

shape_analysis infers the type of the output of these nodes to be the
same as their input.

intepreter.cpp treats both of these nodes as no-ops.

* Add reminder to fix std/var

* Call EraseNumberTypes only when exporting a script module

* Update expects after rebase
2018-06-21 15:43:38 -04:00
ngimel
8d674c0d51 add comparison operators to jit (#8058)
* add comparison operators to jit

* try to fix CI

* address review comments

* fix type of comparison ops result

* address review comments

* fix indentation

* add comments

* require type_as to have non-dynamic tensor arg

* Typo (should check if template argument of type_as, inputs()[1], is tensor)

* Use .at() instead of []

* Use .at() again
2018-06-14 09:30:25 -04:00
Adam Paszke
f45a3d5558
Add a loop unrolling pass to PyTorch JIT (#7672) 2018-06-06 09:36:12 +02:00
Adam Paszke
9232afeffa
Add code for TensorBoard visualization of JIT GraphExecutors (#8050) 2018-06-02 20:55:25 +02:00
James Reed
1f94a6eab3 [JIT] Fission and fusion passes for addmm (#7938)
* Addmm decomposition pass

* Addmm peephole pass

* Fix handling of output shape in fusion pass

* Add DCE to the peephole passes

* add comments

* maybe bugfix?

* Fix GPU tests

* fix py2/3 test issue
2018-05-30 18:06:58 -04:00
Adam Paszke
b45f2ff1ae
Remove CompiledFunction + clean up JIT tests (#7421) 2018-05-16 20:03:04 +02:00
Zachary DeVito
38bc732b2d
[jit] Change interpreter/fuser to work on Variables only (#7489)
* this removes the flag controlling whether the interpreter works on variables.
* now the interpreter _always_ works on variables
* constants in the IR are still _always_ non-variables, and an assert was added to ensure this.
* as_tensor was split into as_variable and as_tensor since it is sometimes used
  to construct constants in the IR
* I tried changing the IR to also always use variables but that change was much more
  cross cutting and fragile and I never got it working
2018-05-11 13:33:47 -07:00
Zachary DeVito
93eb50c103
Mark expand nodes as implicit/explicit in trace (#7303)
When tracing we record expand nodes. This is useful in some cases because
it makes it clear a broadcast happened. However, in future runs
the broadcast may be different or not needed. This change adds an
attribute to expand to track if it was implicitly added. This
takes the form of an unused input to expand with a default value.

The execution engine then removes implicit expands before execution.
Note that shape_analysis will re-add expands when it can prove by
shape analysis that they will exist and this is useful for the fuser,
so this change should not affect fusion passes.
2018-05-10 10:47:43 -07:00
Peter Goldsborough
7b09bc72a5
[WIP] Enable WERROR in tests (#6539)
* Enable WERROR in tests

* Also set WERROR=1 for cpp_build in CI

* Enable Werror after the compiler checks

* Remove -DWERROR because its picked up from the env var

* Had to fix some errors in aten/contrib/data

* Allow an uninitialized variable in ReduceOpsKernel.cpp

* Use CUDNN_DATA_UINT8 in cuDNN type string conversion

* Fixes and use target_compile_options

* Fix uninitialized variables in THNN

* Include Python.h earlier in tensor_types.cpp

* Use CUDNN_VERSION 7100 instead of 7000?

* More Python.h includes

* Make switch case in common_subexpression_elimination.cpp exhaustive

* Build with WERROR=0 just to see all the warnings

* Remove some Python includes

* Enable WERROR=1 again

* Bring back switch case default
2018-04-28 01:51:16 +01:00
Zachary DeVito
b2581c0289 Workaround in onnx to get transposes into init_nets (#6924)
* Workaround in onnx to get transposes into init_nets

This adds a pass to ONNX so that it can speculate Transpose
operators so that ONNX's split pass can put them into an init_net

Also fixes a potential bug in onnx peephole where an optimization
across blocks might move a Value and violate scoping.

* Perform shape propagation when embedding a program into a trace.

This ensures the trace still has type information specific to that trace, which will help onnx export succeed in more cases.
2018-04-26 11:04:17 -04:00
Zachary DeVito
f656301526 Allow traces to call @script functions (#6642)
This adds the ability to trace script functions while preserving their
control flow. When the trace encounters a script function it inlines
the graph of the function into the trace rather than tracing the
function itself.
2018-04-17 15:19:16 -04:00
Zachary DeVito
8995ddda05
[jit][script] Check that each builtin returns the right number of values. (#6492)
* Fixes to the way script handles multiple values, and other minor fixes.

This commit improves our handling of operators that return multiple values.
Builtins are now checked so that they return the right number of values,
and support for TupleValue is extended to all things that can return
multiple values.

This resolves issues where the compiler accepted things like:

  a, b = c + c

This would cause the interpreter to crash. Now each operator knows
how many results it will produce and can check it against the number
of requested inputs.

Notes:
* Allow True/False literals in constant expressions
* make handling of keyword constants more consistent to support True/False
* make parsing constants match the way we construct constants from python
* improve the error messages when accessing bad graph attributes.
* switch findTensorOp to return an optional.
* check that attribute types are correct in findTensorOp
* Check the correct number of outputs for builtins

This also changes emitExpr to return a single SugaredValue

Rather than possibly returning multiple values, emitExpr now
always returns a single value, which _might_ be a tuple. This approach
more closely follows python making the code easier to follow.

Checks for returning the right number of values are now located in
the assignment operator, and occur when unpacking the tuple.

We still pass `n_binders` to function calls so that calls into python
know how many values they should return.
2018-04-12 10:32:49 -07:00
Zachary DeVito
5ab30eedf3
Add __constants__ to Script modules (#6092)
Like `__slots__` the `__constants__` property changes the set/getattr behavior of a script module for the keys listed so they behave as constants.
This enables script methods to use them in way that are otherwise not allowed.

* Python numbers/bools can be inlined as constants in script code.
* List of numbers can be iterated over using for loops
* nn.ModuleLists can be used in for loops as well, unrolling their content.
2018-04-05 11:31:43 -07:00
Edward Z. Yang
acc409396b
Namespaced symbols (#5820)
* Namespaced symbols

- Our interned strings now have structure, "ns::symname" rather than just
  "symname" before.  We support efficient namespace testing for uniques
  by encoding the namespace in one byte in the Symbol internal representation.
  See torch/csrc/jit/interned_strings.h for a more in-depth implementation
  discussion.

- All uses of ksymbol are now attr::symbol (or some appropriate namespace).
  The valid namespaces are prim, attr, onnx and aten.

- Symbol is bound in Python as a qualified string "attr::symbol", EXCEPT for the
  attribute setting/getting API, whose symbols must always be attr
  symbols; they get special cased to assume strings are passed.
  There's a little bit of naughtiness in the implementation, maybe you know
  how to solve it.

- However, the g.op() convenience function assumes that you're generating
  ONNX operators, unless you explicitly qualify.

- All ATen operators and nodes have built-in interned strings generated
  for them, so you should never have to write a string literal ever again.
  The tracing code is adjusted to use it.

- ONNX exporter now properly tests to see that all operators are in
  onnx namespace before accepting the export.  This is way more
  robust than the previous exporter, which would be willing to
  export capitalized operators which were not actually ONNX operators.

- A slight organizational change for symbolic.py; this module now ONLY
  contains aten operators.  In particular, the exporter for Constant
  has moved into utils.py (along with Undefined, from the C++ side),
  since primitive ops get "special treatment."

- The un-inplacing logic in recording is more robust, so that we don't
  delete a trailing underscore from __and__.  This never affected us
  before because we didn't have any tests for it.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2018-03-16 13:36:11 -04:00
Luca Antiga
396637cdd6 Python-free build of autograd + jit (#5356)
This PR adds the possibility to build the C++ parts of autograd and jit, with no dependency on Python.
The goal is to allow taking a PyTorch IR representation (a tree s-expr) and running it with provided inputs.

Prerequisite: build PyTorch so that codegen runs once.
Instructions:

cd tools/cpp_build
bash build_all.sh
This will build libtorchjit and torchjit_test in tools/cpp_build/build/torchjit-build. The latter basically runs the code in test_jit.cpp for now.

While writing the PR, it turned out that a few of Python.h includes were redundant. They were removed here (PyTorch tests still pass on my machine, we'll see CI).

* Introduce Python-free builds of autograd and jit

* Remove NO_PYTHON ifdef in functions/special
2018-03-08 15:13:10 -05:00
James Reed
55c64e5243 Add Python function calls to JIT script (#5445)
* Add Python function calls to script
* Script compiler gains a `Resolver` object that runs when it does not understand a function call. This decouples the python resolution from the conversion to IR.
2018-02-28 19:45:04 -08:00
Zachary DeVito
8904616028
add control flow to interpreter (#5293)
* Use stacks in the interpreter/aten_dispatch

Rather than have separate input/output lists,
the interpreter now works using a single stack.
Operators in the interpreter push/pop from the stack.
This allows ownership of tensors to transfer directly to an operator,
and an operator can drop the reference to a tensors as soon as it is
no longer needed. This is important for the GraphExecutor op,
which recursively runs the interpreter.

Once autograd is updated to pass variables to Function by value,
we will be able to ensure that we release ownership as soon as possible.

This commit also switches the interpreter to use a fake
tensor 'ContainerTensor' rather than at::Retainable to hold non-tensor
data in the interpreter. This allows us to use std::vector<at::Tensor>
for all registers, which is significantly less confusing than the
OwnedRetainables struct it was replacing.

* Add If and Loop to interpreter

* Preprocess loop to calculate where references to tensor should be dropped
* Add control instructions JumpZ/JumpNZ/Jump
* Switch from explicitly having stage structs to having a single list
  of instructions with Store/Load instructions to take values off the
  initial stack
* Make the interpreter tests executable rather than use expect files
* add a flag to interpreter code so that constants are variables
  if the interpreter is running on variables.

* Add tensor_as to its own file
2018-02-22 19:56:15 -08:00
Peter Goldsborough
702a7f3864 Improve Function interface (#5221)
* Improve Function interface

* Undo tracer changes

* Fix bug in VariableType.set_history

* Rename function_counter and sequence_number to sequence_nr

* Clarify Function documentation

* Replace swap_next_edges with next_edges() getter

* Bring back set_gradient_edge

* Simplify special.cpp

* add_gradient_edge -> create_gradient_edge

* Add mutable getters for pre/post hooks

* Use make_variable with Edge

* Remove remove_gradient_edge in favor of detach_

* Fix documentation and remove create_gradient_edge friend method

* Canonicalize some includes
2018-02-21 16:37:52 -05:00
Adam Paszke
8910dd5a81
Fix GraphExecutor and add more AD formulas (#5215) 2018-02-14 16:59:48 +01:00
Peter Goldsborough
2d5fbe6e0d Improve Variable interface (#5127)
* Improve Variable interface

* Address comments from @apaszke and @colesbury

* string ::operator= is not noexcept

* Remove ir.h from tracer_state.h to improve build times

* Make Variable a struct and pack SavedVariable fields

* Implement as_variable_ref

* grad_fn_ptr() -> grad_fn_unsafe()

* Reduce hackiness of set_type hack

* Include variable.h and edge.h in tracer_state.h because it uses them

* class Variable -> struct Variable because Windows cant even

* Make Variable::output_nr uint32_t instead of int

* Add comment about tracing state

* Replaced more static_cast<Variable&> and improve docs

* Remove SavedVariable destructor and construct members in init list

* Clarify docs for Variable

* Variable::set_version -> set_version_counter
2018-02-12 23:26:26 -05:00
Peter Goldsborough
25e946bf78 Replace edge_type with Edge and create Variable::gradient_edge() (#5030) 2018-02-07 10:50:42 -08:00
Zachary DeVito
b044c95129 Use blocks machinery to simplify bookkeeping in autodiff (#5036)
* Remove addValues and use WithInsertPoint

* Use blocks to simplify differentiate

Using @ezyang's suggestion, this change uses a block rather than
staging annotations to represent the reverse pass. This allows us
to reuse the machinery to copy graphs/blocks to extract the
reverse pass concisely.

This also change the input order of Gradients df to:
   [output vjps][temporary vjps][captures]

In addition to being simpler to generate in this order, it also
will allow ExecutionPlan to append the captures onto the already-
existing input list of vjps that are given by the autograd,
rather than have to prepend them, which should be slightly cheaper.

* Enforce that input capture are before outputs

This changes the Gradient struct to enforce that input
captures appear before output captures in the capture list,
which makes it easier to use in ExecutionPlan.
2018-02-05 10:43:50 -05:00
Zachary DeVito
c308e03f3e
Initial GraphExecutor Implementation. (#4982)
This adds the initial implementation of graph executor for the new JIT design. It includes a few python tests ensuring that nograd, backward, and double-backward cases work for simple examples and some corner cases. More work needs to be done to performance optimize as there are many extra copies and places where we hold onto variables longer than we should. These are noted in the comments.
2018-02-02 17:45:59 -08:00