Commit Graph

62 Commits

Author SHA1 Message Date
vishwakftw
86c64440c9 Make PyTorch Python 3.8 compatible (#29302)
Summary:
PEP 590 modifies the `tp_print` offset to `tp_vectorcall_offset` - which requires a Py_ssize_t object.
Passing a nullptr caused compatibility issues for Python 3.8.

Changelog:
- Modify all occurrences of `nullptr  /* tp_print */` to 0  /* tp_vectorcall_offset */
- Minor formatting changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29302

Test Plan:
- Local fresh build with Python 3.8 completed successfully.

Fixes https://github.com/pytorch/pytorch/issues/28060.
Fixes https://github.com/pytorch/pytorch/issues/29162.

Supersedes https://github.com/pytorch/pytorch/pull/28364

Differential Revision: D18372022

Pulled By: ezyang

fbshipit-source-id: 8e9a15b0d0f72101ccc69bd489f5efa216b880bb
2019-11-07 09:20:19 -08:00
Pritam Damania
e8e7d93293 Additional autograd unit tests for Python UDFs. (#29041)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29041

1) Enhanced autograd unit tests to test the
torch.distributed.autograd.backward() API more thoroughly on Python UDFs.
2) Enhanced `python_error` to override `what` such that it returns an
appropriate error string if we call `what()` on this error. This ensures we can
propagate exceptions over the wire during RPCs (since we get the error string
by calling what() on the exception)
ghstack-source-id: 93098679
ghstack-source-id: 93098679

Test Plan: waitforbuildbot

Reviewed By: mrshenli

Differential Revision: D18273041

fbshipit-source-id: 85d3932fed6337668a812367fdfce233c1b3ff8e
2019-11-01 18:30:09 -07:00
Edward Yang
08860721ad Revert D18195584: Additional autograd unit tests for Python UDFs.
Test Plan: revert-hammer

Differential Revision:
D18195584

Original commit changeset: b795daf644ba

fbshipit-source-id: 413dac34f1a28e0a591893f43e116f006fd3f2be
2019-11-01 06:59:54 -07:00
Pritam Damania
3bba751cd6 Additional autograd unit tests for Python UDFs. (#28824)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28824

1) Enhanced autograd unit tests to test the
torch.distributed.autograd.backward() API more thoroughly on Python UDFs.
2) Enhanced `python_error` to override `what` such that it returns an
appropriate error string if we call `what()` on this error. This ensures we can
propagate exceptions over the wire during RPCs (since we get the error string
by calling what() on the exception)
ghstack-source-id: 92972494

Test Plan: waitforbuildbot

Differential Revision: D18195584

fbshipit-source-id: b795daf644ba1816fdec484545192ab55a2f71e7
2019-10-31 14:03:00 -07:00
Pritam Damania
1322daa506 Improve error handling for distributed autograd engine. (#27940)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27940

1) If we receive an error for outstanding rpcs, we enqueue an appropriate error
on the local autograd engine.
2) Add an `exit_on_error` mode for the local autograd engine, where the
computation stops if we see an error.
ghstack-source-id: 92603377

Test Plan: Added unit tests to test failures.

Differential Revision: D17916844

fbshipit-source-id: 199a7832f1033c36a9bbcc1e80d86576c04965d0
2019-10-25 12:07:27 -07:00
Richard Zou
caed485873 Turn on BUILD_NAMEDTENSOR permanently (#26060)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26060

This PR enables BUILD_NAMEDTENSOR by default. This is done via including
a header, `c10/core/EnableNamedTensor`, that sets `BUILD_NAMEDTENSOR`.
In the future, the plan is to get rid of the flag entirely: we can
incrementally delete usages after this PR goes in.

This PR also maintains the namedtensor ci vs regular ci distinction.
`test/test_namedtensor.py` only runs if TEST_NAMEDTENSOR=1 is specified.
TEST_NAMEDTENSOR=1 is set on the namedtensor ci. I'll remove this
distinction later and send out an announcement about it; devs will be
responsible for named tensor failures after that.

The initial reason why we had the BUILD_NAMEDTENSOR flag was so that we
could quickly prototype named tensor features without worrying about
adding overhead to the framework. The overheads can be categorized as
memory overhead and performance overhead.

Memory overhead: named tensors adds 1 additional word per Tensor. This
is because TensorImpl stores a `unique_ptr<NamedTensorMetaInterface>`
field. This is not a lot of overhead.

Performance overhead: At all entry points to name inference, we check
if inputs to an op are named. If inputs are not named, we short-circuit
and don't do name inference. These calls should therefore be as
efficient as error-checking code and not take up a lot of time.

My plan is to benchmark a few functions and then post the results in a
comment to this PR.

Test Plan: - [namedtensor ci]

Differential Revision: D17331635

Pulled By: zou3519

fbshipit-source-id: deed901347448ae2c26066c1fa432e3dc0cadb92
2019-09-17 08:25:00 -07:00
Ralf Gommers
1b4951d3a5 Fix remaining invalid function cast warnings that show up with GCC 8/9 (#26104)
Summary:
Follow-up to gh-25483, more of the same fixes for warnings like:

```
../torch/csrc/autograd/python_variable.cpp:503:31: warning: cast between incompatible function types from ‘PyObject* (*)(THPVariable*)’ {aka ‘_object* (*)(THPVariable*)’} to ‘getter’ {aka ‘_object* (*)(_object*, void*)’} [-Wcast-function-type]
  503 |   {"_backward_hooks", (getter)THPVariable_get_backwards_hooks, (setter)THPVariable_set_backwards_hooks, nullptr, nullptr},
      |                               ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
```

This takes the build log output for a full rebuild with GCC 9.1 from ~10,000 to ~7,000 lines.

`clang-tidy` is going to complain, no way around that - see discussion at the end of gh-25483.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26104

Differential Revision: D17396831

Pulled By: ezyang

fbshipit-source-id: d71696bfe4dbe25519e4bcb7753151c118bd39f7
2019-09-17 07:43:37 -07:00
Richard Zou
47cee2dd22 Implement initial version of autograd with named tensors (#25604)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25604

In this initial version:
- autograd ignores all names.
- tensor.grad is unnamed, unless the user manually assigns to it.
- if a grad tensor has any names, perhaps the user was hoping for some
alignment-checking behavior that named tensor offers for other ops. We
raise a warning in this case.

Future: do some more extensive checking to see if this actually works in
all cases.

Test Plan:
- [namedtensor ci]
- Check a warning is raised if a grad tensor has names.
- Check tensor.grad field is unnamed.
- Check that we can perform backward on an op that doesn't explictly
support names in backward. `sigmoid` is one such op.

Differential Revision: D17171788

Pulled By: zou3519

fbshipit-source-id: 64837fde94d8269610b6d3539ac025516dbe1df4
2019-09-04 06:36:54 -07:00
mal
e7a9b0d62f Rename torch::autograd::Function to torch::autograd::Node
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23269

Test Plan: Imported from OSS

Differential Revision: D16454878

fbshipit-source-id: b1e840fc2d3901955280d141e5ad6efd5e9d66af
2019-07-23 20:52:22 -07:00
Mikhail Zolotukhin
6ca38d9840 Cleanup includes in torch/csrc/autograd/* (#19923)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19923
ghimport-source-id: 54debdd21ca0f4230b1915905673de274807a2e5

Differential Revision: D15125016

Pulled By: ZolotukhinM

fbshipit-source-id: 8d54f436e4508067089a1d05ce192093220aa1bb
2019-05-06 13:48:42 -07:00
Edward Yang
517c7c9861 Canonicalize all includes in PyTorch. (#14849)
Summary:
Anywhere we used #include "foo.h", we now say #include <foo.h>
Paths are adjusted to be rooted out of aten/src, torch/lib, or
the root level directory.

I modified CMakeLists.txt by hand to remove TH and THC from
the include paths.

I used the following script to do the canonicalization:

```
  import subprocess
  import re
  import os.path

  files = subprocess.check_output(['git', 'ls-files']).decode('utf-8').rstrip().split('\n')
  for fn in files:
      if not any(fn.endswith(suff) for suff in ['.cu', '.cpp', '.in', '.h', '.hpp', '.cu', '.cuh', '.cc']):
          continue
      if not any(fn.startswith(pref) for pref in ["aten/", "torch/"]):
          continue
      with open(fn, 'r') as f:
          c = f.read()
      def fmt(p):
          return "#include <{}>".format(p)
      def repl(m):
          p = m.group(1)
          if p in ["dlfcn.h", "unistd.h", "nvrtc.h", "cuda.h", "cuda_runtime.h", "cstdint", "cudnn.h", "Python.h", "cusparse.h", "cuda_runtime_api.h", "cuda_fp16.h", "cublas_v2.h", "stdint.h", "curand_kernel.h"]:
              return fmt(p)
          if any(p.startswith(pref) for pref in ["torch/csrc", "c10/", "ATen/", "caffe2/", "TH/", "THC/", "Eigen/", "gtest/", "zdl/", "gloo/", "onnx/", "miopen/"]):
              return fmt(p)
          for root in ["aten/src", "torch/lib", ""]:
              for bad_root in [os.path.dirname(fn), "aten/src/TH", "aten/src/THC", "torch/csrc"]:
                  new_p = os.path.relpath(os.path.join(bad_root, p), root)
                  if not new_p.startswith("../") and (os.path.exists(os.path.join(root, new_p)) or os.path.exists(os.path.join(root, new_p + ".in"))):
                      return fmt(new_p)
          print("ERROR: ", fn, p)
          return m.group(0)
      new_c = re.sub(r'#include "([^"]+)"', repl, c)
      if new_c != c:
          print(fn)
          with open(fn, 'w') as f:
              f.write(new_c)
```

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14849

Reviewed By: dzhulgakov

Differential Revision: D13363445

Pulled By: ezyang

fbshipit-source-id: 52361f878a672785f9306c9e9ab2513128092b68
2018-12-08 19:38:30 -08:00
Peter Goldsborough
d6c53328f9 Large scale fix of python-related files in torch/csrc/
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14515

Differential Revision: D13247966

Pulled By: goldsborough

fbshipit-source-id: 7a127c508fc576a7a92626dd6b729f660162d628
2018-12-07 13:04:46 -08:00
albanD
78e3259bbe Add autograd automatic anomaly detection (#7677)
* add autograd automatic anomaly detection

* python 3 string support

* Fix non python build

* fix typo in doc

* better test and naming fix

* fix no python build and python object handling

* fix missing checks

* clean NO_PYTHON build

* Remove unwanted changes
2018-06-11 21:26:17 -04:00
Sam Gross
12229afd00
Record shape and type in autograd to validate gradients (#8168)
The check that the gradient is defined is currently disabled because
TestJit.test_ge_optimized will trigger the error.
2018-06-06 18:09:53 -04:00
Zachary DeVito
23dd033b51 Factor python dependency out of interpreter (#7970)
* Factor python dependency out of interpreter

* Remove NO_PYTHON for the autograd engine

If there is no python bindings, then a default Engine is constructed
the first time it is requested.

If the python libraries are loaded, then they override the default
accessor and the default engine becomes a python Engine.

Note: it is possible for two engines to be generated if a non-python
one gets created before the python bindings are loaded. This case
is rare, and just results in additional threads being spawned.

* Fixing AlexNet test which is skipped in CI
2018-06-01 16:07:21 -04:00
Peter Goldsborough
28b1a3852c
Add backward() to Tensor and Variable (#7774)
* Add backward() to Tensor and Variable

* Add at:: in front of Tensor

* Trying to not move optional to appease windows?

* Move implementation into cpp file

* Undo some formatting changes
2018-05-24 17:31:41 -07:00
Will Feng
60745b3380 Revert #7750 and #7762 to fix Windows CI on master (#7772)
* Revert "Add missing brace (#7762)"

This reverts commit ea27c5af50.

* Revert "[C++ API] Add backward() to Tensor and Variable  (#7750)"

This reverts commit 1e2762796f.
2018-05-22 15:42:52 -07:00
Peter Goldsborough
1e2762796f
[C++ API] Add backward() to Tensor and Variable (#7750)
* Add backward() to Tensor and Variable

* Added a couple tests
2018-05-22 10:43:04 -07:00
Tongzhou Wang
1c01eabd3c
Codemod to update our codebase to 0.4 standard (#6641)
* Codemod to update our codebase to 0.4 standard

* Update some of the test scri[ts

* remove Variable in test_clip_grad_value

* fix _symbolic_override_wrapper_maker
2018-04-17 22:06:54 -04:00
Tongzhou Wang
e01569afd7 Restore allow_unused functionality (#6553) 2018-04-12 21:30:42 +02:00
Priya Goyal
e3196e0ea8
[Re-checkpointing] Autograd container for trading compute for memory (#6467)
* Autograd container for trading compute for memory

* add a unit test for checkpoint

* address comments

* address review comments

* adding some docs for the checkpoint api

* more comments

* more comments

* repro bug

* Fix a subtle bug/apply some review comments

* Update checkpoint.py

* Run everything in grad mode

* fix flake and chunk=1

* use imperative backward as per discussion

* remove Variable and also add models and test for models

* Add a simple thread local variable to check for autograd grad mode

* remove models and models test after debugging

* address review comments

* address more comments

* address more comments
2018-04-10 15:26:24 -04:00
Edward Z. Yang
49f2bb7e0b
Extra comment about backward vs. grad in engine. (#6005)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2018-03-27 10:54:06 -04:00
Luca Antiga
396637cdd6 Python-free build of autograd + jit (#5356)
This PR adds the possibility to build the C++ parts of autograd and jit, with no dependency on Python.
The goal is to allow taking a PyTorch IR representation (a tree s-expr) and running it with provided inputs.

Prerequisite: build PyTorch so that codegen runs once.
Instructions:

cd tools/cpp_build
bash build_all.sh
This will build libtorchjit and torchjit_test in tools/cpp_build/build/torchjit-build. The latter basically runs the code in test_jit.cpp for now.

While writing the PR, it turned out that a few of Python.h includes were redundant. They were removed here (PyTorch tests still pass on my machine, we'll see CI).

* Introduce Python-free builds of autograd and jit

* Remove NO_PYTHON ifdef in functions/special
2018-03-08 15:13:10 -05:00
Peter Goldsborough
702a7f3864 Improve Function interface (#5221)
* Improve Function interface

* Undo tracer changes

* Fix bug in VariableType.set_history

* Rename function_counter and sequence_number to sequence_nr

* Clarify Function documentation

* Replace swap_next_edges with next_edges() getter

* Bring back set_gradient_edge

* Simplify special.cpp

* add_gradient_edge -> create_gradient_edge

* Add mutable getters for pre/post hooks

* Use make_variable with Edge

* Remove remove_gradient_edge in favor of detach_

* Fix documentation and remove create_gradient_edge friend method

* Canonicalize some includes
2018-02-21 16:37:52 -05:00
Peter Goldsborough
2d5fbe6e0d Improve Variable interface (#5127)
* Improve Variable interface

* Address comments from @apaszke and @colesbury

* string ::operator= is not noexcept

* Remove ir.h from tracer_state.h to improve build times

* Make Variable a struct and pack SavedVariable fields

* Implement as_variable_ref

* grad_fn_ptr() -> grad_fn_unsafe()

* Reduce hackiness of set_type hack

* Include variable.h and edge.h in tracer_state.h because it uses them

* class Variable -> struct Variable because Windows cant even

* Make Variable::output_nr uint32_t instead of int

* Add comment about tracing state

* Replaced more static_cast<Variable&> and improve docs

* Remove SavedVariable destructor and construct members in init list

* Clarify docs for Variable

* Variable::set_version -> set_version_counter
2018-02-12 23:26:26 -05:00
Peter Goldsborough
f38b6f611e Replace NULL with nullptr in autograd (#5162) 2018-02-12 12:01:52 -08:00
Peter Goldsborough
25e946bf78 Replace edge_type with Edge and create Variable::gradient_edge() (#5030) 2018-02-07 10:50:42 -08:00
Adam Paszke
79d15c52cb
Improve the engine support for functional graph execution (#4690)
Previously the side-effect free grad calculation was performed
using callbacks that could also override the decision to run a
function. However this had a few problems e.g. it forced us to iterate
over pretty much all functions in the graph and drop their buffers.

This patch improves the mechanism, by adding explicit support for this
kind of evaluation in execute(). It's safer, and the algorithm used to
decide which nodes have to be evaluated was replaced with a faster one.
2018-01-18 11:20:30 +01:00
Sam Gross
d605058212
Replace Variable.volatile with torch.no_grad() (#3970)
This removes volatile from Variable. The functionality is mostly
replaced by a global (thread-local) flag, which is controlled by
torch.set_grad_enabled() and the context manager torch.no_grad().

In C++, the flag is exposed through GradMode::is_enabled() and GradMode::set_enabled()

Fixes #3627
2017-12-18 15:46:13 -05:00
Sam Gross
b79d74aa81 Re-initialize autograd engine in child processes (#4158)
* Re-initialize autograd engine in child processes

The autograd engine uses threads for backwards. These don't exist after
forks and they were not being re-initialized because the
Engine::start_threads_flag was already set. This re-initializes the
engine in child processes, which will cause it to re-create threads when
backwards() is called in the child process.

Note that we only attempt to handle the common case where fork() is
called while the backwards threads are idle.

Fixes #3966

* Avoid non-async-signal-safe functions in fork handler
2017-12-18 01:51:27 -05:00
Sam Gross
0a434ff685 Remove Function::is_executable (#3907)
* Remove Function::is_executable

Ensure that grad_fn is null if requires_grad is false.

* Assert that grad_fn implies requires_grad=True
2017-11-28 18:29:27 -08:00
Adam Paszke
3128218397 Allow specifying unused inputs to torch.autograd.grad (#2859) 2017-09-25 14:42:33 -04:00
Sam Gross
f4169260f8 Fix crash when calling backwards on leaf variable which does not require grad (#2788) 2017-09-20 09:43:20 -04:00
Sam Gross
1290e586fb Use at::Tensor based autograd Variable (#2676)
Variable is now a subclass of at::Tensor backed by a VariableImpl* pImpl. The implementation of the ATen functions is defined in the auto-generated VariableType.h/cpp file.

Currently, only functions which fall through to the base type, such as sizes() and isCuda() are implemented. Differentiable ops like add() and mul() will be added in a subsequent PR.
2017-09-12 11:36:01 -04:00
Adam Paszke
ea888c1905 Check input flags in Traceable 2017-09-06 21:35:50 -04:00
Adam Paszke
9f0c4c9f9a Make autograd engine reentrant without creating new threads 2017-09-05 17:48:55 -04:00
Adam Paszke
f83c4fad7b Fix exception propagation from recursive Engine calls 2017-09-05 17:48:55 -04:00
Adam Paszke
d8e2ab632e Add support for Constant nodes in AutogradClosureFactory 2017-09-05 17:48:55 -04:00
Adam Paszke
594f98ce16 Support multi-stage AutogradClosures 2017-09-05 17:48:55 -04:00
Adam Paszke
fa308b3183 Improve backward tracing 2017-09-05 17:48:55 -04:00
Adam Paszke
3e0f1608fe Capture Variables that are not inputs as constants 2017-09-05 17:48:55 -04:00
Adam Paszke
f270973937 Add JIT IR -> Autograd IR converter 2017-09-05 17:48:55 -04:00
Adam Paszke
ea05ac8f41 Move JIT-related files to jit dir. Remove IR interpreter 2017-09-05 17:48:55 -04:00
Zach DeVito
1325fa511c JIT IR including use-def chains and updated comments. 2017-09-05 17:48:55 -04:00
Zach DeVito
f369f8e80d simplify IR 2017-09-05 17:48:55 -04:00
Edward Z. Yang
4979359800 Add graphs, trace them.
It is not an /expression/ we trace, but it is a /graph/: that is,
a closed expression which knows its parameters.  Knowing the list
of parameters is helpful and helps remove a hack when interpreting.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-09-05 17:48:55 -04:00
Edward Z. Yang
a797ab9343 Rewrite AST to a new, more functional representation.
Previously, our AST was a DAG, where shared Nodes indicated a computation
should be reused.  This commit rewrites the IR into a new functional
representation which represents sharing explicitly using variable
bindings.

We offer a few justifications for this new style:

1. The new representation is not all that different from the
old one; it is about as easy to construct, and the lack of an
explicit graph doesn't negatively impact our ability to interpret
the graph, since we've chosen, as a matter of design, to NOT have
the IR participate in the actual execution of a graph.

2. The new let-binding representation has an implicit ordering,
which we can use to conveniently keep track of the original order
the trace showed up as.  This automatically gives us a topsort,
and gives us an easier to read textual representation of our
IR:

  %14 = Embedding %11, %0, -1, None, 2, False, False
  %15 = Dropout %14, 0.2, True, False
  %16 = Index %12, 0
  %17 = Index %12, 1
  %18 = Index %13, 0
  %19 = Index %13, 1
  %20 = Index %15, 0
  %21 = Linear %20, %1, %3
  %22 = Linear %16, %2, %4

3. It moves us closer to a Futhark style language
(http://futhark-lang.org/publications/pldi17.pdf).

Major aspects of the diff

- Node is replaced with Expr and Arg, a pair of mutually recursive
  structures which represent our new language.  In BNF, the language
  looks like this:

    a ::= c | %i
    e ::= %i, ... = e
        | PyOp e, ...
        | Ret %i, ...

  Technically, Ret is not actually a return (no control flow is involved),
  it just tuples up a series of tensors (identified by variables).

  One important invariant is that locals are always tensors; they
  are never constants (this is asymmetric with Args.)

- Arguments support Python constants.  This is an important piece because
  many operators take extra Python literals like integers and tuples in
  order to specify extra parameters about how an operator operates.  Adding
  this was essential to getting word_language_model to work.

- As both Expr and Arg have multiple variants, there is new infrastructure
  for doing case on the variants using ExprVisitor and ArgVisitor.  The
  strategy here is adapted from WebAssembly's visitors, although we have
  generalized to permit arbitrary argument forwarding, which is necessary
  to support tail-recursive visitor calls.  TCO is important because our
  interpreter may recurse arbitrarily deep into a stack of nested lets.
  If users wish, they can also manually case on the type tag.

- Tracing is now turned on and off using _tracer_enter/_tracer_exit in
  torch._C.  _tracer_enter accepts a list of variables which are to be
  treated as arguments; _tracer_exit accepts the list of traced variables
  which should be returned when you reexecute the trace, and returns
  the trace expression which can be reexecuted.  GlobalTracingState
  is a global variable which tracks whether or not we are tracing or not.

- You use run_forward to execute a trace on some set of parameters.

- When under tracing, variables keep track, via trace_local, what the
  name of their variables in the IR are.

Here is a simple runner which leaks memory but can be used to JIT models:

  import torch.autograd.function as F
  import torch._C

  def jit(model):
      import types
      real_forward = model.forward
      def forward(self, *args):
          def flatten(x):
              return tuple(F._iter_variables(x))
          if not hasattr(self, "saved_trace"):
              torch._C._tracer_enter(tuple(self.parameters()) + flatten(args))
              out = real_forward(*args)
              self.saved_trace = torch._C._tracer_exit(flatten(out))
              self.saved_outs = out
              return out
          else:
              flat_out = Variable._execution_engine.run_forward(self.saved_trace, tuple(self.parameters()) + flatten(args))
              return F._unflatten(flat_out, self.saved_outs)

Major problems:

- Sanity checking is spotty at best, especially when users pass in variables.

- The interpreter leaks tensor memory from the store.  When we add back def-use
  we should be able to deallocate tensors as soon as we know they are no longer
  necessary.

- The interpreter needs to reach feature parity with the old execution engine.
  From there, we need to see if backwards can be subsumed as well.

- I still have no confidence in having memory managed everything correctly.
  This requires a close look.

- Rather than return an *open* expression as a trace, we should return a
  *lambda* instead, which knows about how many formal parameters it
  requires.

- The IR is not introspectable from Python at the moment, but this is simply a
  matter of implementing all the binding code.

- The tracer is NOT reentrant (you can't trace while you're inside a trace.)
  Furthermore, no sanity checking is done if you try to incorrectly reuse
  things from one trace in another.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-09-05 17:48:55 -04:00
Edward Z. Yang
1e8bf12b3a Add an inefficient but working evaluator for forward traces.
Simple test:

  import torch
  from torch.autograd import Variable
  import torch._C as _C

  x = Variable(torch.Tensor([4]), requires_grad=True)
  y = Variable(torch.Tensor([7]), requires_grad=True)
  z = x * y
  z.sum().backward()

  print(x.grad)
  print(y.grad)

  x.data[0] = 2
  y.data[0] = 3

  (z,) = z._execution_engine.run_forward((x, y), (z,))
  z.sum().backward()

  print(x.grad)
  print(y.grad)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-09-05 17:48:55 -04:00
Sam Gross
42485d87c2 Set the current device in each engine's thread (#2081)
Fixes #2017
2017-07-13 16:24:38 -04:00
Adam Paszke
86a065e45b Add end callbacks to the engine 2017-06-12 21:58:38 -04:00