Commit Graph

71 Commits

Author SHA1 Message Date
Yangqing Jia
713e706618 Move exception to C10 (#12354)
Summary:
There are still a few work to be done:

- Move logging and unify AT_WARN with LOG(ERROR).
- A few header files are still being plumbed through, need cleaning.
- caffe2::EnforceNotMet aliasing is not done yet.
- need to unify the macros. See c10/util/Exception.h

This is mainly a codemod and not causing functional changes. If you find your job failing and trace back to this diff, usually it can be fixed by the following approaches:

(1) add //caffe2/c10:c10 to your dependency (or transitive dependency).
(2) change objects such as at::Error, at::Optional to the c10 namespace.
(3) change functions to the c10 namespace. Especially, caffe2::MakeString is not overridden by the unified c10::str function. Nothing else changes.

Please kindly consider not reverting this diff - it involves multiple rounds of rebasing and the fix is usually simple. Contact jiayq@ or AI Platform Dev for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/12354

Reviewed By: orionr

Differential Revision: D10238910

Pulled By: Yangqing

fbshipit-source-id: 7794d5bf2797ab0ca6ebaccaa2f7ebbd50ff8f32
2018-10-15 13:33:18 -07:00
Peter Goldsborough
033e95765c Diff against master and enable bugprone-* checks (#12378)
Summary:
This PR:

1. Makes clang-tidy diff against `master` instead of `HEAD~1` in CI, which makes much more sense
2. Enables all checks in the `bugprone-*` category (see https://clang.llvm.org/extra/clang-tidy/checks/list.html) except one about parantheses in macros, because it doesn't always apply too well for us.

Fixed some nice code smells.

ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12378

Differential Revision: D10247972

Pulled By: goldsborough

fbshipit-source-id: 97dc9e262effa6874d2854584bf41a86684eb8bd
2018-10-10 07:23:57 -07:00
Zachary DeVito
b937cbb776 Fix a bug that would resize tensor storage on export
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12377

Differential Revision: D10219213

Pulled By: zdevito

fbshipit-source-id: 85cfa4467c672ff5a718e58cfae7e8c8b1cfc532
2018-10-05 16:24:54 -07:00
David Riazati
d1ac1eba3b Add bool type to IR (#11834)
Summary:
This PR adds a bool type to `IValue` and puts it into place.

* changes conds for `prim::If` and `prim::Loop` to use `bool` type
* changes operators that take `bool`s to match their native ops
* fixes ambiguous `aten` ops `aten::std` and `aten::var`
	* fixes tests in `test_jit.py TestJitGenerated`
		```
		'test_std_dim',
		'test_std_dim_1d',
		'test_std_dim_1d_neg0',
		'test_std_dim_neg0',
		'test_var_dim',
		'test_var_dim_1d',
		'test_var_dim_1d_neg0',
		'test_var_dim_neg0'
		```
* adds `prim::BoolToTensor` and `prim::TensorToBool`

apaszke zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11834

Differential Revision: D9928570

Pulled By: driazati

fbshipit-source-id: 373c53df2f1a8ffa9e33d9a517002fbeef25f3eb
2018-10-03 12:40:03 -07:00
Luca Antiga
5be0baefa2 Use streams in JIT serialization, allow JIT serialization to/from buffer (#11932)
Summary:
This PR replaces the use of `std::FILE` with `istream`/`ostream` for JIT serialization.
It uses this mechanism to add the possibility to serialize to/from binary buffers, in addition to files, both in `libtorch` and from Python.

`getExportImportCopy` in `test_jit.py` has been updated so that both file and buffer codepaths are exercised during tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11932

Differential Revision: D10084303

Pulled By: apaszke

fbshipit-source-id: b850801b3932922fa1dbac6fdaed5063d58bc20d
2018-09-28 07:54:27 -07:00
Michael Suo
7f35e92af2 mutable lists (#10700)
Summary:
This PR implements the design that we discussed. Changes:
- Added a World token IValue and type. The IValue is basically a dummy struct for now, in the future we may extend it (say, add thread-local state).
- Effectful ops explicitly declare they are mutable by having World tokens as inputs and outputs in their schema.
- Purely functional ops that use mutable values will get "fenced" and the world token will be threaded through the fences
- AnnotateEffects pass which wires up all the world tokens together.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10700

Reviewed By: eellison

Differential Revision: D9547881

Pulled By: michaelsuo

fbshipit-source-id: ebbd786c31f15bf45e2ddb0c188438ff2f5f3c88
2018-09-27 19:25:13 -07:00
Zachary DeVito
478803a75f Introduce type variables to implement generic list operators (#12040)
Summary:
We generate specialized list operations for int, float, and Tensor lists so that small lists of integers like the arguments to conv do not involve tons of boxing code.

This PR adds a fallback GenericList for List types that contain any other type. It does so by adding type variables to `jit::Type`, and machinery for matching/replacing the type variables during `tryMatchSchema` and operator lookup.

It also modifies the builtin list ops to include a fallback that works on a GenericList object that simply holds IValues. This is distinguished from IValue's tuple type so that conversion to/from Python still happens losslessly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12040

Differential Revision: D10037098

Pulled By: zdevito

fbshipit-source-id: 0c5f2864d12e7d33554bf34cc29e5fb700dde150
2018-09-26 17:02:51 -07:00
Gregory Chanan
0947712e5d Move Factory functions from Type to TypeExtendedInterface. (#12025)
Summary:
This makes a few changes wrt Type, with the ultimate goal of removing Type from the public Methods/Functions.  In particular:
1) Removes factory functions from Type, into TypeExtendedInterface.
2) sparse_coo_tensor is now a first class at:: namespace function, with TensorOptions overloads.
3) We move from Type-based sparse_coo_tensor dispatch to function-based.

Note we still require a number of changes to get rid of tType in the public interface, in particular TensorOptions needs to support CUDA vs non-CUDA dispatch.  That is coming in a future patch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12025

Reviewed By: ezyang

Differential Revision: D10017205

Pulled By: gchanan

fbshipit-source-id: 00807a37b09ed33f0656aaa165bb925abb026320
2018-09-25 09:40:17 -07:00
James Reed
7d25fa3c72 Emit Undefined type for value when it is Dynamic type (#11810)
Summary:
For example, outputs of control blocks often have Dynamic type, and when we try to export them to ONNX we get an invalid proto, since `elem_type` is not populated on the TypeInfoProto. This makes it so at least we can get past the checker, since having a dynamic typed output from a control block should still be semantically valid
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11810

Differential Revision: D9922754

Pulled By: jamesr66a

fbshipit-source-id: 5c66113cc302a9d9b8b9f5a8605473d3c6ad5af1
2018-09-18 13:55:36 -07:00
David Riazati
7671f4ab1c Add math to scope when using inf in tests (#11302)
Summary:
This fixes #8515 which was mostly issues in the test themselves. As long
as `math` is imported in the scope in which the script runs it resolves
to a `prim::Constant` with value `inf` correctly. This PR adds this to
the `test_jit.py` tests involving `inf` and adds a test to demonstrate
`inf` in a non-generated test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11302

Differential Revision: D9684336

Pulled By: driazati

fbshipit-source-id: 73df2848dfdb45ab50690a7c88df8fda269a64eb
2018-09-17 14:08:32 -07:00
Roy Li
a861573e36 fix tensor export bug in IR export (#11613)
Differential Revision: D9811094

Pulled By: li-roy

fbshipit-source-id: 012792dbedc70bd3fa242fdf2e39da0b21ce158d
2018-09-13 11:10:35 -07:00
Orion Reblitz-Richardson
0a6931cfee Only reference ONNX through onnx_pb.h (#11609)
Summary:
I think this is needed to land https://github.com/onnx/onnx/pull/1407 without CI errors.

cc mingzhe09088 houseroad
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11609

Reviewed By: houseroad

Differential Revision: D9803490

Pulled By: orionr

fbshipit-source-id: 26193f38ab0a2eef9ad7d0da9a0310dc40ef0f2d
2018-09-12 18:25:58 -07:00
Elias Ellison
2158f4a9c8 add export import test to TestJitGenerated (#10982)
Summary:
Checking assertExportImport for all of the generated test jit tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10982

Differential Revision: D9636935

Pulled By: eellison

fbshipit-source-id: f3f1ce77d454848098f2ac7e0fa18bf8564890be
2018-09-10 11:37:05 -07:00
James Reed
4ae16c9ad9 Recursive descent for validation + convert expands in ATen fal… (#11356)
Summary:
…lback
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11356

Differential Revision: D9721002

Pulled By: jamesr66a

fbshipit-source-id: eeb50b56f8a72e929860c5e459a5ab50ac624814
2018-09-07 16:39:36 -07:00
Roy Li
9fc22cb772 Add import export step to end to end tests
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10717

Differential Revision: D9562888

Pulled By: li-roy

fbshipit-source-id: 8f5d62fd0a44aca0a41dc10438e7bb91cc2a972a
2018-09-05 09:39:47 -07:00
Peter Goldsborough
fe15aedacc Store schema in serialized modules and check arguments in function call (#10872)
Summary:
This PR adds argument checking for script method invocation from C++. For this I had to:
1. The schema of a method is currently not serialized in script modules, so we now store the function schema in the `doc_string` field of the ONNX proto. Upon loading of a serialized script module, we parse the schema into the structured C++ form and assign it to the loaded method,
2. Inside `Method::operator()`, we now verify the number and types of arguments.

CC The controller you requested could not be found.

zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10872

Differential Revision: D9521219

Pulled By: goldsborough

fbshipit-source-id: 5cb3d710af6f500e7579dad176652c9b11a0487d
2018-08-28 20:11:39 -07:00
Adam Paszke
c8b246abf3 Prevent JIT from overspecializing to every single size configuration (#10844)
Summary:
Please review the expects carefully to make sure there are no regressions. I tried to go over them one by one when they changed, but it's sometimes easy to miss finer details.

Summary of changes:

- Renamed `TensorType` to `CompleteTensorType`. Added a new `TensorType` which records only the scalar type, number of dimensions, and device of a value. The argument behind the rename is to encourage people to use `CompleteTensorType` less, as most passes will only have limited information available. To make transition easier `complete_type->cast<TensorType>()` works, and makes our passes work with both kinds of specialization if they don't need extra the extra detail.
- Renamed `ArgumentSpec` to `CompleteArgumentSpec`. Added a new `ArgumentSpec`, which matches argument only at the level of the new `TensorType`.
- Shape analysis can process graphs with both `CompleteTensorType` and `TensorType`.
- Fuser was a part that heavily relied on full shape information being available. Now, we simply try to fuse the largest possible graphs, and have to do run-time checks to make sure they match the code we generate. If they don't, we fall back to regular interpretation. The shape checks are implementing using an optimized method exploiting algebraic properties of shapes with broadcasting, and the relations of broadcasting with pointwise ops. A full written proof of correctness of the shape checking algorithm is included in a comment in `graph_fuser.cpp`.

zdevito ezyang mruberry ngimel csarofeen
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10844

Differential Revision: D9498705

Pulled By: apaszke

fbshipit-source-id: 0c53c2fcebd871cc2a29c260f8d012276479cc61
2018-08-26 09:54:48 -07:00
Edward Yang
19031c68dc Use intrusive_ptr in Storage; replace unique_ptr<Storage> with Storage (#10488)
Summary:
```
Use intrusive_ptr in Storage; replace unique_ptr<Storage> with Storage

This patch does two major changes:

- It replaces the use of Retainable in Storage with a new implementation
  based on intrusive_ptr.  This will be necessary because Caffe2 will
  be using this class to implement intrusive_ptrs, and we need to
  line these up for the merge.  One good thing about the new implementation is
  that the default copy/move constructors/assignment operators and destructor
  work automatically, instead of needing to be hardcoded into Storage/Tensor.

- It replaces all places where we returned std::unique_ptr<Storage> with
  Storage, collapsing an unnecessary double indirection that is no longer
  necessary now that we have correctly working copy/move constructors.

I didn't initially want to do step (2), but it was very important to
eliminate all bare uses of new Storage and new StorageImpl, and this making
the API change was the most straightforward way to do this.

HOW TO FIX YOUR CODE IN THE NEW API

- You no longer need to dereference the result of tensor.storage() to pass
  it to set.  So, instead of:

      x.set_(*y.storage());

  just write:

      x.set_(y.storage());

- If you were accessing methods on StorageImpl via the pImpl() method, you
  must use the dot operator to run pImpl().  Even better; just drop pImpl,
  we now have method forwarding.  So, instead of:

      storage->pImpl()->data();

  just do:

      storage->data();
      // storage.pImpl()->data() works too but is not as recommended

- storage->getDevice() is no more; instead use storage->device().index()

MISC CODE UPDATES

- retain, release, weak_retain, weak_release and weak_lock are now
  reimplemented using the "blessed API", and renamed to make it
  clearer that their use is discouraged.

- nvcc OS X and general OS X portability improvements to intrusive_ptr

- A new comment in intrusive_ptr describing how stack allocated
  intrusive_ptr_targets work differently than heap allocated ones
  from c10::make_intrusive

CAVEAT EMPTOR

- THStorage_weakRetain used to work on strong pointers, but it NO LONGER
  works with intrusive_ptr.  You must reclaim the strong pointer into a
  real strong pointer, construct a weak pointer from it, and then release
  the strong and weak pointers.  See StorageSharing.cpp for an example.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10488

Reviewed By: gchanan

Differential Revision: D9306134

Pulled By: ezyang

fbshipit-source-id: 02d58ef62dab8e4da6131e1a24834a65c21048e2
2018-08-21 21:39:55 -07:00
Edward Yang
6bdbad93b9 Refactor Device to not depend on Backend. (#10478)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10478

- Removed Backend constructor from Device, and fixed all
  use-sites to use DeviceType::CPU instead of kCPU, or
  use a new function backendToDeviceType to perform
  the conversion.
- New method device_type() on Type; it gives you the
  underlying device type, e.g., CPU for SparseCPU.
- We add backward compatibility for kCPU/kCUDA uses,
  by introducing a new special type which is implicitly
  convertible to both DeviceType and Backend.  As long as
  you don't define a function that's overloaded on both
  DeviceType and Backend (but not on BackendOrDeviceType),
  the implicit conversions will ensure that uses
  of at::Device(at::kCPU) keep working. We fixed use-sites in
  the library, but did NOT fix sites in the test code, so that
  we can exercise this BC code.

Reviewed By: Yangqing

Differential Revision: D9301861

fbshipit-source-id: 9a9d88620500715c7b37e655b4fd761f6dd72716
2018-08-18 17:39:14 -07:00
Lu Fang
bdb11e716a Split the dependence of ONNX from test_operators.py (#10151)
Summary:
Now, run `python test/onnx/test_operators.py --no-onnx`, we won't introduce any onnx python dependence. (No onnx/protobuf python packages needs to be installed)

The major changes:
- output pbtxt from C++ exporter directly, so the floating format may be slightly different. (This should be fine, since it's just to guard ONNX exporting.)
- ONNX python packages are only imported if we run the ONNX related checks. Those checks are disabled when using `--no-onnx` flag.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10151

Reviewed By: jamesr66a

Differential Revision: D9130706

Pulled By: houseroad

fbshipit-source-id: ea28cf5db8399929179698ee535137f209e9ce6f
2018-08-14 12:54:44 -07:00
Roy Li
e9ad74357e Use serialization container in ir import export (#10394)
Summary:
Copy of #10191 because these changes didn't land with the diff.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10394

Differential Revision: D9260816

Pulled By: li-roy

fbshipit-source-id: 7dc16919cfab6221fda1d44e98c5b900cfb40558
2018-08-10 00:09:30 -07:00
Sebastian Messmer
f51f15bb27 Update include paths for ATen/core (#10130)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10130

Update some include paths to make them internally consistent

Reviewed By: ezyang

Differential Revision: D9119906

fbshipit-source-id: b44e5cab8e8e795ee18afe9ffc6caf1f2b413467
2018-08-03 11:57:02 -07:00
Roy Li
0e9c6898cb Export modules in ir with google protobuf
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9746

Differential Revision: D9110006

Pulled By: li-roy

fbshipit-source-id: 8b9744c042f822fdfe959a7a7fef3d0baff4f639
2018-08-02 15:54:51 -07:00
Roy Li
f908b2b919 Use google protobuf in pytorch onnx import/export
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/8469

Reviewed By: houseroad

Differential Revision: D9102041

Pulled By: li-roy

fbshipit-source-id: 805c473745d181b71c7deebf0b9afd0f0849ba4f
2018-08-01 12:54:41 -07:00
Wanchao Liang
b7b61a8eb4 Change expect, cast on Type to return shared pointers, make isSubtypeOf accept TypePtr (#9786)
Summary:
Follow up task of #9584.

Commit 1:

- change expect/cast to return shared pointers instead of raw pointer
- isSubtypeOf accept TypePtr instead. Use `x->isSubtypeOf(NumberType::get())` rather than `x->isSubtypeOf(*NumberType::get())`

Commit 2:

- to address enable_shared_from_this pitfalls, we make the constructor private and expose the factory method to make sure user can only create it using our factory method.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9786

Reviewed By: zdevito

Differential Revision: D8980441

Pulled By: wanchaol

fbshipit-source-id: e5c923fc57a701014310e77cf29985b43bb25364
2018-07-26 18:09:45 -07:00
Peter Goldsborough
f62bc01dfe Remove TORCH_ASSERT (#9575)
Summary:
I got some tensor->variable conversion exceptions from `torch/csrc/autograd/variable.h`, which used the `TORCH_ASSERTM` macros instead of `AT_CHECK`, so they didn't have backtraces. This was such a substantial loss for debugability that I decided to update the whole codebase to use the backtrace-enabled ATen macros instead of `TORCH_ASSERT` and `JIT_ASSERT`, the latter having been an alias of the former.

ezyang apaszke zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9575

Differential Revision: D8924566

Pulled By: goldsborough

fbshipit-source-id: 7a4013b13eec9dbf024cef94cf49fca72f61d441
2018-07-24 18:10:06 -07:00
Adam Paszke
b9f575fc33 Remove legacy code from the JIT (#9323)
Summary:
In particular, get rid of backward tracing and CppOp.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9323

Reviewed By: ezyang

Differential Revision: D8795935

Pulled By: apaszke

fbshipit-source-id: fb7a7eeee41902da35f2a8efd77262ca60fd6bbe
2018-07-11 10:25:38 -07:00
James Reed
04503962ff
[ONNX] Add an ATen fallback pathway for ONNX export (#8273)
* ATen fallback for ONNX export

* Move to enum

* Fix model test

* Add comment

* Address comments

BC interface
2018-06-12 22:59:45 -07:00
zrphercule
afa75fa6b2 Remove NO_PYTHON macros from Exceptions.h/cpp (#8007)
Removes cases where NO_PYTHON was unnecessary in Exception.h/cpp
2018-06-01 22:37:18 -07:00
Luca Antiga
5d3c3c53aa
Add raw IR serialization/deserialization (#6392) 2018-05-01 20:21:29 +02:00
James Reed
4667983f0f
Fixes for interpreter and ONNX export for translation (#7044)
Fixes for interpreter and ONNX export for translation

Address comments
2018-04-27 22:23:57 -07:00
Zachary DeVito
b2581c0289 Workaround in onnx to get transposes into init_nets (#6924)
* Workaround in onnx to get transposes into init_nets

This adds a pass to ONNX so that it can speculate Transpose
operators so that ONNX's split pass can put them into an init_net

Also fixes a potential bug in onnx peephole where an optimization
across blocks might move a Value and violate scoping.

* Perform shape propagation when embedding a program into a trace.

This ensures the trace still has type information specific to that trace, which will help onnx export succeed in more cases.
2018-04-26 11:04:17 -04:00
Zachary DeVito
d985cf46f1
Add workaround to fix include warnings in Python 2 builds. (#6716) 2018-04-24 12:30:19 -07:00
James Reed
ef76e24f60
[JIT][script][ONNX] ScriptModule ONNX export + ONNX export for control flow nodes (#6608)
* ScriptModule ONNX export

* ScriptModule ONNX export

* Export for control flow nodes

* Add pretty-print capability for ONNX export testing

* Update tests and handling of mutliple GraphProto names

* Maybe bugfix?

* factor out code from export and pretty print
2018-04-19 23:45:03 -07:00
anderspapitto
12bfa47ddd
Onnx RNN export: remove Constant default hidden state (#6199)
when no explicit hidden state is provided, a default is created by
constructing a new Variable filled with zeros. This gets traced as a
Constant operator, which hardcodes in the batch size.

To fix this, we remove such constant operators in an 'optimization'
pass. We could have also fixed it by causing the code to not generate
a Constant in the first place, but this is the least invasive fix from
the perspective of the pure pytorch codebase.
2018-04-04 19:22:38 -07:00
James Reed
5fe3c406f2 Experimental support for different ONNX export types (#6016)
Allows you to export an ONNX model as:

Protobuf file (this is what we have now)
Uncompressed zip archive
Compressed zip archive
Directory

* Experimental support for different ONNX export types

* Remove a copy

* Add comment

* Add test cases

* lint

* fix bug

* address comments
2018-03-30 15:30:38 -04:00
James Reed
641fb21bdd Update no_unions flag for nanopb gen and update ONNX proto files (#5972) 2018-03-23 18:52:33 -04:00
Edward Z. Yang
acc409396b
Namespaced symbols (#5820)
* Namespaced symbols

- Our interned strings now have structure, "ns::symname" rather than just
  "symname" before.  We support efficient namespace testing for uniques
  by encoding the namespace in one byte in the Symbol internal representation.
  See torch/csrc/jit/interned_strings.h for a more in-depth implementation
  discussion.

- All uses of ksymbol are now attr::symbol (or some appropriate namespace).
  The valid namespaces are prim, attr, onnx and aten.

- Symbol is bound in Python as a qualified string "attr::symbol", EXCEPT for the
  attribute setting/getting API, whose symbols must always be attr
  symbols; they get special cased to assume strings are passed.
  There's a little bit of naughtiness in the implementation, maybe you know
  how to solve it.

- However, the g.op() convenience function assumes that you're generating
  ONNX operators, unless you explicitly qualify.

- All ATen operators and nodes have built-in interned strings generated
  for them, so you should never have to write a string literal ever again.
  The tracing code is adjusted to use it.

- ONNX exporter now properly tests to see that all operators are in
  onnx namespace before accepting the export.  This is way more
  robust than the previous exporter, which would be willing to
  export capitalized operators which were not actually ONNX operators.

- A slight organizational change for symbolic.py; this module now ONLY
  contains aten operators.  In particular, the exporter for Constant
  has moved into utils.py (along with Undefined, from the C++ side),
  since primitive ops get "special treatment."

- The un-inplacing logic in recording is more robust, so that we don't
  delete a trailing underscore from __and__.  This never affected us
  before because we didn't have any tests for it.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2018-03-16 13:36:11 -04:00
James Reed
59d1d17775 Print source location when ONNX export fails for a node (#5652) 2018-03-09 15:31:28 -08:00
Peter Goldsborough
7391dae709 Fix Variable conversion on the way to/from Python (#5581)
* PyObject* <--> at::Tensor no longer unwraps variables, instead we expect end uses to always work with variable types, and we will only unwrap the variables when we optimize.
* Add torch::CPU, torch::CUDA and torch::getType
* at::CPU -> torch::CPU in extensions
2018-03-09 14:31:05 -08:00
Zachary DeVito
39608b0180
Add source information to IR nodes (#5449)
* Add source information to IR nodes

SourceRange information from the script is not propagated to IR nodes.
This information is only used in two places now: the interpreter
wraps errors that occur when an instruction executions and shape
propagation now reports errors on the line where it fails:

    Traceback (most recent call last):
      File "test/test_jit.py", line 1655, in test_script_error
        bar(Variable(torch.rand(10), requires_grad=True), Variable(torch.rand(9), requires_grad=True))
    RuntimeError:
    The size of tensor a (10) must match the size of tensor b (9) at non-singleton dimension 0:
    @torch.jit.script
    def bar(c, b):
        return c / b
               ~~~~~ <--- HERE

In the future, shape propagation should really not report any size
errors and instead just not propagate shapes and let the actual
execution fail. However, this is hard to accomplish while we still
depend on running the op to do shape propagation.
2018-02-28 17:06:18 -08:00
Zachary DeVito
05269b582b
[JIT] Support shape propagation with control-flow (#5391)
Support shape propagation with control-flow

* This allows us to enable optimization in the GraphExecutor for most
  script tests.
* Changes Type to always be present (non-null) on a Value, removing `hasType()`
  and `typeOption()`. A new type kind 'DynamicType' now represents when
  a specific type has not been determined.
* If/Loop nodes propagate shapes/types in the simple cases where types of
  outputs do not change depending on where control flows. In other
  cases, we propagate DynamicType to indicate we do not know what
  the shape will be.
* Remove the `cond` input to the body of Loop to simplify handling in
  interpreter and shape propagation.
* Bugfix for zero-dim contiguousStridesOf
2018-02-26 15:24:05 -08:00
anderspapitto
b2cfd961d3 Handle sequence lengths correctly when exporting RNNs to ONNX (#4695)
* PackedSequence: store batch_sizes as tensor

rather than converting to a list of python integers. This maintains
the invariant that module's inputs/outputs are collections of
Variables.

In particular, this causes the JIT to no longer choke when flattening
and unflattening arguments.

* Handle sequence lengths correctly when exporting RNNs to ONNX

- when uniform sequence lengths are provided, correctly omit the
  argument when constructing the ONNX graph, so as to not fix the
  graph to the batch size.

- handle PackedSequences by floating them through the graph and
  eliminating them in an optimization pass. ONNX does not have packed
  sequences, but operates on a representation equivalent to
  PaddedSequence, so we hide the representation-switching from ONNX

- as a preliminary step towards handling PackedSequences, not directly
  tied to ONNX export, change batch_sizes from being an argument to
  the RNN operators into being an argument to the forward() function
  of those RNN operators. This more closely models the reality that
  batch_sizes are effectively part of the input sequences.
2018-02-06 21:40:27 -05:00
Zach DeVito
0d748fac96 Add nested Blocks in IR
This commit is getting the IR ready for representing ONNX control flow.
It adds nested blocks to the IR.

* Each node now has blocks(), addBlock(), and eraseBlock() similar to a node's
  output list.
* Blocks are a property of every node rather than an attribute because
  to make it easier to manage the lifetime of the containing nodes and because
  the behavior of cloning Blocks will likely be different from the way we clone other
  attributes.
* A block itself has a list of nodes, as well as inputs and outputs.
  The meaning of the nested input/output nodes are specific to the particular
  node kind containing the block. It is safe to assume inputs to a block will be
  in scope in the block.
* Each Block has an owningNode() and each node has an owningBlock().
  The owningNode of the top-most block is null.
* Values are lexically scoped: nested blocks can use values from outer blocks
  that have been defined in previous nodes. Lint has been updated with these
  new scoping rules.
* This change preserves almost all of the pre-Block API. No attempt has been made
  to make optimizations aware of Blocks. This will need to be done on a case-by-case
  basis as we make optimizations capable of handling Blocks.
2018-02-02 23:24:49 -08:00
Zachary DeVito
2da43bf6f1 Make Symbol a true struct (#4717)
Previous Symbol was just a uint32_t and we converts symbolToString and
stringToSymbol. Now Symbol is a struct with a toString method, and
constructors from either BuiltinSymbols enums (e.g. kParam) or strings.

Symbol is convertible to a uint32_t to ensure it can still be used in
switch statement BuiltinSymbol case branches.
2018-01-17 21:49:28 -08:00
Edward Z. Yang
a88a8ec827
Convolution derivatives in ATen (#4116)
* Convolution derivatives in ATen

This PR introduces ATen implementation of convolution, which dispatches to
THNN/CuDNN/nnpack based on input parameters. The general strategy is to compose
this function out of the various forward-backward pairs of specific
implementations, rather than write a monolithic function with backwards (which
is what we did before because the boilerplate of doing it otherwise would have
been very high.) The new API provides the following functions:

  - _convolution, which is a fully generic, native convolution implementation
    that dispatches to various other convolution implementations depending on
    input characteristics. This is prefixed with an underscore because it
    explicitly takes benchmark, deterministic and cudnn_enabled which are
    implementation details for CuDNN. The intent is to eventually provide a
    convolution that reads these parameters out of the context using #4104.
  - _convolution_nogroup is a convolution implementation for non-CuDNN
    algorithms which don't support group convolution natively.
  - _convolution_double_backward is the generic double-backwards implementation
    for convolution.

In more detail:

- Most functionality from torch/csrc/autograd/functions/convolution.cpp has been
  moved into aten/src/ATen/native/Convolution.cpp
- We continue to make use of ConvParams, but we now construct the parameters
  upon entry to a function from the function signature (which does not use
  ConvParams; having convolution take ConvParams directly would require teaching
  the code generator how to accept these as parameters, complicating ATen's API
  model) and destruct them when making subprocedure calls.
- I introduce a new idiom, input_r, which represents a const Tensor& reference,
  which will subsequently be assigned to a local Tensor input. This is helpful
  because a lot of the existing algorithms relied on being able to assign to
  locals, which is not permitted with a const reference.
- The native argument parser now supports std::array<bool,2> inputs (NB: there
  MUST NOT be a space; this is the same hack as is applied to derivatives.yaml)
- Native parser now supports Tensor? arguments, which indicates a nullable
  tensor. Previously this function was only used by NN methods.
- Documentation updates on THNN library
- I added an extra fgradInput argument to VolumetricConvolutionMM_updateOutput
  and VolumetricConvolutionMM_accGradParameters so that its buffer list lines up
  with the backward argument list. This makes it possible to write derivative
  for conv3d which previously was not supported (commented out in
  derivatives.yaml)
- Extra double_backward declarations for all convolution backwards functions was
  added.
- You can now use the syntax Tensor? in native_functions.yaml to indicate that a
  tensor argument is nullable.  There are adjustments to propagate this to the
  Python argument parser.
- NNPACK was ported to ATen, and ATen now builds and links against ATen if
  possible. New AT_NNPACK_ENABLED macro.  The nnpack functions are
  nnpack_spatial_convolution.
- Some modest CuDNN convolution refactoring to remove _forward from names.
- There's a new cudnn_convolution_backward function to deal with the fact that
  CuDNN convolution double backward requires you to have computed all gradients
  in one go.
- Variable set_flags now checks if the tensor is undefined, fixing a silent memory
  corruption.
- checkSameType updated to not raise an exception if called with Variable arguments
- "no ATen declaration found for" error message is improved to say what available declarations are
- make_variable now accepts undefined tensors, and returns an undefined tensor in this case.
2017-12-20 14:19:27 -05:00
Edward Z. Yang
de00aab720 PyTorch now uses operator versioning.
Also move some of the exporter info out of the ModelProto constructor.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-11-30 23:09:45 -05:00
Dmytro Dzhulgakov
4726d0cb7a Fix exporting HalfTensor 2017-11-28 08:57:17 +01:00
Zach DeVito
ef4b19f767 Refactor ir.h to distinguish Nodes and Values
This commit adds a Value type similar to the one @ezyang suggested a while
ago for handling multi-return nodes.

Previously if we had a graph like:

  a = op1(b)
  c, d = op2(a)

Then its in-memory format would look like:

  %0 = op1(b)
  %1 = op2(%0)
  %2 = select(%1, 0)
  %2 = select(%1, 1)

Select nodes were used only to handle the multi-output case. In the
single-output case ops referred directly to their uses.

This required special handling for the single- and multi- output cases,
and was confusing when used with ONNX which distinguishes values (the
inputs/outputs of a node) from the nodes themselves (e.g. a Conv).

This commit adds the Node/Value distinction to the IR. In the example
above, `a`, `b`, `c`, and `d` are now Value objects, while `op1` and
`op2` are now Node objects. Inputs/Outputs to the graph are values.

* Nodes now always have multiple outputs, accessible through their `output()`
  method.
* Methods exist for adding/removing outputs from a node.
* Nodes own their output Values, destroying a node destroys its outputs and it
is only valid to destroy a node when no uses of its outputs remain.
* Unlike select, Values do not appear in the nodes list.
* The method `node()` on `Value` retrieves its defining node. Calling it
is always valid. For inputs, its kind is "Param". Like "Return" there is a single Param
node representing all inputs.
* For single-output Nodes, the method `output()` retrieves the single
output Value, asserting that the node is in-fact single output.
* Functions are the same, but some functions like `type()` have moved to
Value.
* `replaceAllUsesWith` is now sanely defined for both Values and Nodes.
In the case of Nodes, it replaces all outputs of the node with the outputs
of the replacement node.
* stage is defined both on Node/Value. This is because Inputs require a stage.
* Apart from changing data types from Node->Value most passes remain the same.
  Things that previously assumed single-output nodes now have to call output()
  to get the node.
* This removes the uses = [...] field in the outputs because it was
getting confusing even before this commit when uses would refer to nodes,
but we print the names of Values. The lint pass validates the use list,
so printing it out seems less necessary.
2017-11-15 11:47:18 -08:00
James Reed
c7cb6a795e Record stack traces during JIT tracing (#3607)
* Update comments and size logic

* Record stack traces during JIT tracing

* Use string helper functions and AutoGIL

* Use SourceLocation object instead of storing in debugName

* Address zdevito comments

* Address comments
2017-11-13 10:18:55 -08:00