Commit Graph

259 Commits

Author SHA1 Message Date
Richard Zou
03f2ad9029 Add check for python build deps to setup.py (#5618)
* Add check for python build deps to setup.py

* Address comments

* Remove install_requires line
2018-03-09 23:49:18 -05:00
Peter Goldsborough
7391dae709 Fix Variable conversion on the way to/from Python (#5581)
* PyObject* <--> at::Tensor no longer unwraps variables, instead we expect end uses to always work with variable types, and we will only unwrap the variables when we optimize.
* Add torch::CPU, torch::CUDA and torch::getType
* at::CPU -> torch::CPU in extensions
2018-03-09 14:31:05 -08:00
Sam Gross
5dedc648bb Compile DataLoader.cpp separately (#5507)
Don't #include DataLoader.cpp in Module.cpp
2018-03-02 05:54:33 -05:00
Peter Goldsborough
b10fcca5f0 Install cuda headers in ATen build (#5474) 2018-02-28 19:36:41 -08:00
peterjc123
377d896969 better solution for the linking error related to lazy_init for MSVC (#5375)
* Revert "Fix wrong argument name (#5366)"

This reverts commit cc9d3b265d.

* Fix wrong argument naming

* Revert "Wrap torch::cuda::lazy_init with WITH_CUDA flag"

This reverts commit a8fa37f8fac5aef09eb7fe54d84de6126618c262.

* Revert "Solves the linking error related to lazy_init for MSVC"

This reverts commit 63913a102f274865a76e7c40ffdf6b40c277d5ff.

* better solution for the linking error related to lazy_init for MSVC

* Naming changes

* Namespace changes and further comment

* Rebasing onto current master

* Remove code that is useless

* Fix linting

* Remove rebasing bugs
2018-02-28 17:34:34 -05:00
Sam Gross
48a3349c29
Delete dead Tensor code paths (#5417)
This deletes most of the dead Tensor code paths, including the TensorMethods cwrap and generic/Tensor.cpp.

This also moves the THNN.cwrap/.cpp generation to generate_code which can use ninja if installed.
2018-02-27 17:58:09 -05:00
gchanan
d5038309a1
Remove WITH_SCALARS, as it's enabled by default now. (#5437) 2018-02-27 14:51:11 -05:00
Soumith Chintala
d2f71cbdeb
make CuDNN finders respect library major version (#5399) 2018-02-24 19:37:00 -05:00
Sam Gross
30ec06c140
Merge Variable and Tensor classes (#5225)
This replaces the torch.Tensor constructors with factories that produce
Variables. Similarly, functions on the torch module (e.g. torch.randn)
now return Variables.

To keep the PR to a reasonable size, I've left most of the unused tensor
code. Subsequent PRs will remove the dead code, clean-up calls to
torch.autograd.Variable, and rename Variable to Tensor everywhere.

There are some breaking changes because Variable and Tensors had
slightly different semantics. There's a list of those changes here:

 https://github.com/pytorch/pytorch/wiki/Breaking-Changes-from-Variable-and-Tensor-merge
2018-02-23 18:03:31 -05:00
peterjc123
6c587e9e67 Solves the linking error related to lazy_init for MSVC (#5368)
* Revert "Fix wrong argument name (#5366)"

This reverts commit cc9d3b265d.

* Solves the linking error related to lazy_init for MSVC

* Fix wrong argument naming

* Wrap torch::cuda::lazy_init with WITH_CUDA flag
2018-02-23 11:08:20 -05:00
Peter Goldsborough
008ba18c5b Improve CUDA extension support (#5324)
* Also pass torch includes to nvcc build

* Export ATen/cuda headers with install

* Refactor flags common to C++ and CUDA

* Improve tests for C++/CUDA extensions

* Export .cuh files under THC

* Refactor and clean cpp_extension.py slightly

* Include ATen in cuda extension test

* Clarifying comment in cuda_extension.cu

* Replace cuda_extension.cu with cuda_extension_kernel.cu in setup.py

* Copy compile args in C++ extension and add second kernel

* Conditionally add -std=c++11 to cuda_flags

* Also export cuDNN headers

* Add comment about deepcopy
2018-02-23 10:15:30 -05:00
peterjc123
cc9d3b265d Fix wrong argument name (#5366) 2018-02-23 00:37:02 -05:00
peterjc123
013ed5b88f Add lazy_init.h into build for Windows and refactor code (#5365)
* Add lazy_init.h into build for Windows and refactor code

* Remove minor bugs
2018-02-23 00:05:43 -05:00
Soumith Chintala
9388d35293
prioritize cudnn library dir in library_dirs order (#5345) 2018-02-21 22:51:04 -05:00
gchanan
0878c6d4d7
Various dtype improvements. (#5321)
* Various dtype improvements.

1) Add dtypes to the new data-based constructors: Variable.new_tensor and torch.autograd.variable.
2) In the python signatures, use Type instead of Dtype to match	the C++ signatures; the error messages still print as dtype.
3) Handle / add a better error message when a dtype is used when ATen was not compiled with that type (e.g. cuda types).
4) Move cuda_lazy_init to its own file.

A later commit will add support to the legacy constructors as well.

* Move implementation of lazy_init to cpp.

* Fix parsed_arg size.
2018-02-21 17:37:59 -05:00
Edward Z. Yang
031412a14b
setup.py and cmake improvements (#5269)
* Document env vars and properly propagate MAX_JOBS down.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Apply CFLAGS and LDFLAGS environment variables to cmake builds.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Test that running built program works; fixes #5151.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* CMake CR.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2018-02-20 16:55:57 -05:00
gchanan
5edf6b2037
Add numpy-style dtypes to Variable factories. (#5245)
* Add numpy-style dtypes to Variable factories.

1) Add numpy-style dtypes corresponding to torch tensor types.  These are:
torch.float16, torch.float32, torch.float64, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64
as well as torch.cuda, torch.sparse, and torch.cuda.sparse equivalents.

2) Adds "legacy" names for the above dtypes that correspond more closely to existing tensor names.  These are:
torch.half, torch.float, torch.double, torch.short, torch.int, torch.long.
torch.byte and torch.char don't exist because they either don't match numpy semantics or differ on different architectures.

3) Adds a "dtype" parameter to Variable factories (e.g. zeros, ones) that allows the user to specify the type without changing the default tensor type.

4) Adds a "dtype" getter to Variables that return the canonical dtype from 1)

This PR is missing the following useful features that should be added in the future:
A) We only add the "dtype" parameter to auto-generated factories; hand-written factories like in tensor_new.cpp don't support this yet.

B) We don't allow type conversions to use dtypes; that should be added to type(param) or a new function.

C) We don't yet have a "device" parameter for these factories; right now, they will only create Variables on the default device.

* backend_to_string can be private.

* Define python binding argument indexes in a more simple way.

* add all_declared_types, still need to hook it up to THPDType.

* Fix all_declared_types for missing types (it's Sparse + Half).

* Ensure cuda dtypes are created even if compiled with NO_CUDA=1.

* Fix case where dtype is provided but dispatch is via namespace.

This happens in ones_like, empty_like, randn_like.

There is some question if we should do:
1) at::ones_like(tensor).toType(dtype)
2) at::ones_like(tensor.toType(dtype))

I did the former because this matches with the numpy documentation, i.e.:
"Overrides the data type of the result." and it's easier to implement.

Note that the above causes an extra copy, either of the input or output.
Here's a better implementation:
1) Make zeros_like, ones_like native functions that take an optional type (named dtype?).
2) Match the type argument with the dtype, so we don't have two different parameters.
3) Call at::zeros_like(input, type) -> at::native::zeros_like(input, type) -> type.zeros(input.sizes())

* Don't return from maybe_initialize_cuda.

* Don't leak DType name.

* Address cpp review comments.

* Share code between sparse and non-sparse test_dtypes.

* Rewrite _like functions as native function with explicit type parameter.

* Use type 'Type' instead of 'dtype' for consistency.

* Address review comments.

* Handle arg_idx when there is requires_grad but no dtype in python_binding_arguments.
2018-02-20 11:04:14 -05:00
Adam Paszke
cb2fd39fdd
Add Python frontend to the JIT (#5190) 2018-02-15 22:53:19 +01:00
Peter Goldsborough
2d5fbe6e0d Improve Variable interface (#5127)
* Improve Variable interface

* Address comments from @apaszke and @colesbury

* string ::operator= is not noexcept

* Remove ir.h from tracer_state.h to improve build times

* Make Variable a struct and pack SavedVariable fields

* Implement as_variable_ref

* grad_fn_ptr() -> grad_fn_unsafe()

* Reduce hackiness of set_type hack

* Include variable.h and edge.h in tracer_state.h because it uses them

* class Variable -> struct Variable because Windows cant even

* Make Variable::output_nr uint32_t instead of int

* Add comment about tracing state

* Replaced more static_cast<Variable&> and improve docs

* Remove SavedVariable destructor and construct members in init list

* Clarify docs for Variable

* Variable::set_version -> set_version_counter
2018-02-12 23:26:26 -05:00
gchanan
4b8bf73729
Enable scalars. (#5158)
* Enable scalars.

* Avoid variable name shadowing in list comprehension, because it rebinds in python2, but not python3.
2018-02-09 15:45:41 -05:00
bddppq
3e85613751 Experimental jit script (#5074) 2018-02-07 20:43:45 +01:00
Zachary DeVito
c308e03f3e
Initial GraphExecutor Implementation. (#4982)
This adds the initial implementation of graph executor for the new JIT design. It includes a few python tests ensuring that nograd, backward, and double-backward cases work for simple examples and some corner cases. More work needs to be done to performance optimize as there are many extra copies and places where we hold onto variables longer than we should. These are noted in the comments.
2018-02-02 17:45:59 -08:00
Peter Goldsborough
1475895c1d Use distutils.copy_tree/copy_file instead of shutil 2018-02-01 16:19:03 -08:00
Peter Goldsborough
1262fba8e7 [cpp extensions] Create torch.h and update setup.py 2018-02-01 16:19:03 -08:00
Zach DeVito
2d829d15af [JIT] Add simple shape analysis
This quick and dirty shape analysis just makes up fake tensors,
and runs them through ATen to do shape propagation.
2018-01-28 22:55:36 -08:00
Edward Z. Yang
b8ab7bee26
Use variadic templates instead of initializer lists and overloads. (#4772)
Suppose you are given a list of arguments, each of which may be Tensor or
TensorList.  How can you write a function that can treat these arguments
uniformly as a list of tensors?  This patch solves the problem using
variadic templates.

Why variadic templates?  Use of variadic templates means anyone working
with this code has to understand universal references, perfect
forwarding, parameter packs and some idioms of C++ template design.
However, I argue that variadic templates are the *right* tool for
supporting the implementation of functions which must take an
arbitrarily heterogenous set of inputs.  We were able to limp by
in old code because, for the most part, tensor inputs were homogenous,
but this is no longer the case for some non-primitively differentiable
functions; and with the upcoming cuDNN RNN in ATen PR, will no longer be
the case for primitively differentiable functions too.

There are two parts to the PR.

First, we add torch/csrc/utils/variadic.h, which defines a mix-in
IterArgs that takes any class which supports operator(), and augments
with a new variadic function apply() which calls operator() on each
argument passed to it.  In an original draft of the patch, I wrote the
recursion for each parameter pack from scratch for each function;
however, it turns out there are no fewer than seven instances where we
need this idiom, and the mix-in reduces the lines of code, and also
helps centralize the most important (and easy to forget) boilerplate
for perfect forwarding.

To verify that IterArgs is compiled away into an unrolled form per
call site, I inspected the assembly on some synthetic examples.

Next, we modify the following functions to make use of IterArgs:

  - compute_requires_grad
  - Function::flags (Variable and Tensor variants)
  - flatten
  - isTracing
  - count_tensors / count_variables

Finally, the tuple packer is rewritten to be variadic, although we
cannot make use of IterArgs (since we are given a tuple).  It might
make sense to refactor the code into a generic piece which invokes
a function with the arguments specified by a tuple, and then an
appropriate IterArgs, but we leave this for future work.

One thing to note: we cannot write a function with overloads for both
Tensor and Variable, because both ArrayRef<Variable> and Tensor have
implicit conversions from Variable, making such an overload ambiguous.
It may be interesting to remove the implicit conversion from ArrayRef.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2018-01-26 15:56:39 -05:00
Soumith Chintala
bb3bc969ca
fix binary version scheme to be PEP compliant (#4847) 2018-01-25 11:16:02 -05:00
Teng Li
1b3d6ab864 Enabling Infiniband support for Gloo data channel with auto IB detection (#4795) 2018-01-24 23:18:24 +01:00
Zachary DeVito
0ae5498079 [JIT] add create_autodiff_subgraphs (#4822)
This pass splits differentiable subgraphs into their own Node,
similar to a fusion group.

This initial implementation does not create optimal subgraphs, but
it works well in the case where most things are differentiable,
and has the building blocks (`mergeNodes`) to extend to the
better implementation.
2018-01-23 23:46:54 -05:00
gchanan
9bb6d33d35
Enable scalars if compiled with WITH_SCALAR environment variable. (#4806)
* Enable scalars if compiled with WITH_SCALAR environment variable.

We are pretty close to enabling scalars (0-dimensional arrays); this allows turning them on
for development purposes and to be able to write code that works both with and without scalars enabled.

WITH_SCALARS is currently broken with distributions, but should work for test_torch, test_autograd, test_nn.

* Fix unsqueeze.

* Fix wrap dim, wrapping with Scalar.
2018-01-23 15:44:11 -05:00
Adam Paszke
ad2edd8613 Check submodules only in build_deps (#4770) 2018-01-21 20:24:05 -08:00
Adam Paszke
816d5d8ff7 Scaffolding for source-to-source AD in the JIT 2018-01-20 17:34:08 +01:00
Adam Paszke
1061d7970d Move broadcast and broadcast_coalesced to C++ 2018-01-18 11:16:45 +01:00
Adam Paszke
de5f7b725e Base for pure C++ NCCL interface 2018-01-18 11:16:45 +01:00
Sam Gross
57549b7e44
Bind functions with out= arguments in VariableType (#4565)
This adds overrides in VariableType for the xxx_out ATen functions and
implements Python bindings. There is no support for automatic
differentiation. If any of the inputs (or outputs) requires grad, then the
function will throw an exception unless it's running in "no-grad" mode.

The bindings for calling torch.xxx functions on Variables are moved to a
different object. Previously, they were static method on VariableBase.
This change prevents users from accidentally calling static methods as if
they were instance methods.
2018-01-17 18:27:42 -05:00
Adam Paszke
1a02d3ae86
Implement MM fusion (MM with add reduction tree) (#4615)
Implement MM fusion (MM with add reduction tree)

A tree where leaves are matrix multiplies and inner
vertices are adds can be computed as a single mm.
Such subgraph often appear in backward if a single weight
is reused multiple times (e.g. in RNNs).

NOTE: this seems to be slightly slower on the GPU than the
naive implementation, but it's a huge win on the CPU
(think 100x lower overhead)
2018-01-17 21:36:21 +01:00
Jon Crall
94f439c07c Fixed setup.py to handle CUDNN_LIBRARY envvar with aten (#4597)
* Fixed setup.py to handle CUDNN_LIBRARY envvar with aten

* undo changes

* Added CUDNN_LIBRARY to bat file
2018-01-11 07:24:17 -05:00
Edward Z. Yang
dc76db349e Delete a pile of dead code (#4295)
* Delete obsolete basic ops.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* More deletion.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Delete some unused utilities.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Delete dead apply_fn

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Delete CppFunction symbolic support.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Delete ForwardFunction

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Batchnorm is 'working'

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2018-01-04 09:21:54 -05:00
peterjc123
b78a37a058 Enable ninja during python build process for MSVC (#3993) 2017-12-30 12:58:32 +01:00
Edward Z. Yang
8c9a22a88e Support NO_NNPACK environment variable (#4401)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-12-29 16:33:01 +09:00
Edward Z. Yang
5b8fe5cbb5
Batchnorm in ATen (#4285)
* Batchnorm in ATen

This commit moves BatchNorm derivatives into ATen, eliminating
torch/csrc/autograd/functions/batch_normalization.cpp

Some refactoring along the way:

- Functions got renamed to remove _forward from their names
- CuDNN batchnorm forward was modified to return save_mean/save_std instead of
  take it as parameters. To avoid returning undefined Variables, these return
  (small) uninitialized tensors when they are not used.
- THNN batch normalization takes care of resizing save_mean and save_std on
  forward.
- There are some shenanigans re batchnorm backwards in eval mode. I'm tracking
  that in #4284
- I decided not to introduce buffers as a proper concept in ATen, which means
  that tensors like running_mean/running_var are variables in ATen.  This meant
  there needed to be some adjustments to how we *trace* such variables; the
  new strategy is if we can't find a Value for a variable, we look and see
  if we have a Value for the buffer pointed to by the variable, before
  finally falling back on constant.
- This PR finally reliably triggered OOM on Travis builds; I fixed this by reducing
  the number of parallel jobs.
- Stop using std::string when it's not necessary.
- Remove training parameter from cudnn_batch_norm_backward, because it
  doesn't make sense; cuDNN doesn't implement the math for evaluation mode
  batchnorm backwards.
- batchnorm_double_backward is now in an anonymous namespace, as it
  no longer needs to be called from torch/csrc

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-12-21 11:38:31 -05:00
Edward Z. Yang
a88a8ec827
Convolution derivatives in ATen (#4116)
* Convolution derivatives in ATen

This PR introduces ATen implementation of convolution, which dispatches to
THNN/CuDNN/nnpack based on input parameters. The general strategy is to compose
this function out of the various forward-backward pairs of specific
implementations, rather than write a monolithic function with backwards (which
is what we did before because the boilerplate of doing it otherwise would have
been very high.) The new API provides the following functions:

  - _convolution, which is a fully generic, native convolution implementation
    that dispatches to various other convolution implementations depending on
    input characteristics. This is prefixed with an underscore because it
    explicitly takes benchmark, deterministic and cudnn_enabled which are
    implementation details for CuDNN. The intent is to eventually provide a
    convolution that reads these parameters out of the context using #4104.
  - _convolution_nogroup is a convolution implementation for non-CuDNN
    algorithms which don't support group convolution natively.
  - _convolution_double_backward is the generic double-backwards implementation
    for convolution.

In more detail:

- Most functionality from torch/csrc/autograd/functions/convolution.cpp has been
  moved into aten/src/ATen/native/Convolution.cpp
- We continue to make use of ConvParams, but we now construct the parameters
  upon entry to a function from the function signature (which does not use
  ConvParams; having convolution take ConvParams directly would require teaching
  the code generator how to accept these as parameters, complicating ATen's API
  model) and destruct them when making subprocedure calls.
- I introduce a new idiom, input_r, which represents a const Tensor& reference,
  which will subsequently be assigned to a local Tensor input. This is helpful
  because a lot of the existing algorithms relied on being able to assign to
  locals, which is not permitted with a const reference.
- The native argument parser now supports std::array<bool,2> inputs (NB: there
  MUST NOT be a space; this is the same hack as is applied to derivatives.yaml)
- Native parser now supports Tensor? arguments, which indicates a nullable
  tensor. Previously this function was only used by NN methods.
- Documentation updates on THNN library
- I added an extra fgradInput argument to VolumetricConvolutionMM_updateOutput
  and VolumetricConvolutionMM_accGradParameters so that its buffer list lines up
  with the backward argument list. This makes it possible to write derivative
  for conv3d which previously was not supported (commented out in
  derivatives.yaml)
- Extra double_backward declarations for all convolution backwards functions was
  added.
- You can now use the syntax Tensor? in native_functions.yaml to indicate that a
  tensor argument is nullable.  There are adjustments to propagate this to the
  Python argument parser.
- NNPACK was ported to ATen, and ATen now builds and links against ATen if
  possible. New AT_NNPACK_ENABLED macro.  The nnpack functions are
  nnpack_spatial_convolution.
- Some modest CuDNN convolution refactoring to remove _forward from names.
- There's a new cudnn_convolution_backward function to deal with the fact that
  CuDNN convolution double backward requires you to have computed all gradients
  in one go.
- Variable set_flags now checks if the tensor is undefined, fixing a silent memory
  corruption.
- checkSameType updated to not raise an exception if called with Variable arguments
- "no ATen declaration found for" error message is improved to say what available declarations are
- make_variable now accepts undefined tensors, and returns an undefined tensor in this case.
2017-12-20 14:19:27 -05:00
peterjc123
77ea2f26d8 Add build support for Python 2.7 using MSVC (#4226) 2017-12-20 15:07:25 +01:00
Sam Gross
d605058212
Replace Variable.volatile with torch.no_grad() (#3970)
This removes volatile from Variable. The functionality is mostly
replaced by a global (thread-local) flag, which is controlled by
torch.set_grad_enabled() and the context manager torch.no_grad().

In C++, the flag is exposed through GradMode::is_enabled() and GradMode::set_enabled()

Fixes #3627
2017-12-18 15:46:13 -05:00
peterjc123
02317d9336 Enable ext build for Windows (#3935)
* Enable ext build for Windows

* Include the static libs to make the compiling of the extension easier
2017-12-18 02:23:34 -05:00
Sam Gross
bec0349280 Implement Variable.cuda and Variable.type using ATen (#4139)
* Implement Variable.cuda using ATen

This adds an optional async flag to Tensor::copy_, which attempts to do
a non-blocking copy if the one of the tensors is in pinned memory and
the other is a CUDA tensor.

* Perform cross-device copy in CopyBackwards

Also call torch.cuda._lazy_init() from Variable.cuda()

* Implement Variable.type via ATen

* Changes from review:

 - remove copy_out
 - remove unnecessary include
 - fix default device for .cuda()

* Combine if statements in dispatch_type
2017-12-18 01:54:35 -05:00
Edward Z. Yang
6d72c82985
Trace ATen native functions as themselves, not their implementations. (#4127)
* Trace ATen non-primitive functions as themselves, not their implementations.

Previously, if I invoked an ATen non-primitive function foo, which in turn
called subfoo, I would always see 'subfoo' in the trace (e.g., tracing
'inlines' all of these operations.)  Such inlining is bad for ONNX
(and can be bad for optimization) as it prevents high-level
optimizations from taking advantage of the structure.  It might
be right to inline, but give the optimizer a chance to work before
inlining happens!

The implementation here is surprisingly simple, because it uses
the "DCE trick".  Essentially, it doesn't matter if the constituent
calls perform tracing, because you can always trace it again, and
override the trace nodes associated with the returned variables.
The original trace becomes dead and can be DCE'd.

While implementing this, I also refactored how 'isTracing' and
'trace_outputs' works:

- isTracing was previously a single function with overloads for
  both Tensor and Variable arguments.  Unfortunately, such overloads
  are not safe, because of how C++ implicit conversions work.  You
  would think that C++ should never confuse an overload for
  Variable with ArrayRef<Tensor>, but this is exactly what can
  happen: Tensor is convertible to both Variable and ArrayRef<Tensor>,
  thus it's ambiguous and C++ doesn't like it.  The last time I ran
  into this problem, I applied initializer lists to everything and
  called it a day.  A more robust fix is to separate out the
  Variable and Tensor overloads, which I have done in this patch.

- trace_outputs was fed as an initializer list, which doesn't work
  when you have heterogenous inputs.  So instead we first feed
  everything through 'flatten', which has overloads for each of the
  argument patterns in ATen, which then goes on to the recordTrace
  (which takes an ArrayRef).  This is *no less efficient*, because
  we were allocating a vector anyway (to do the conversion from
  vector of Tensor to vector of Variable).

This fixes mean that 'index' can properly be traced... although the
JIT still does not support it.  A failing test case has been added to
this effect.

Some knock-on effects:

- The fuser now knows about chunk as well as split.  They're pretty
  similar so there is no problem.

- There is a new 'canonicalize' pass in the JIT which renumbers a graph
  so that all structurally equivalent graphs render the same.

- We run DCE before the fuser tests, to make sure dead nodes don't
  block fusion.

- There are new ONNX exports for the newly introduced higher level ATen
  operations.  This includes type_as (no-op case only), chunk, select.

Zach didn't like the extra use of 'native' in the new codegen, so
we've introduced a new concept, 'abstract'.  An abstract function
is one that is implemented in derived types (e.g., CPUDoubleType),
where as a concrete one is implemented in the base type (Type).

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-12-15 13:50:32 -05:00
Will Feng
db446d69ca Fix issues with Windows 7 & 10 CPU build (#4065) 2017-12-15 10:14:43 +01:00
Sam Gross
aeb7a3668d
Implement Variable.new (#4080) 2017-12-11 15:45:43 -05:00
Sam Gross
60c03bc09c
Implement apply_, map_, and map2_ in Variable (#4057) 2017-12-07 14:48:56 -05:00