Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44861
We were redefining things like ASSERT_EQ to take a _VA_ARGS_ parameter, so compiling these files with gtest (instead of pytorch's custom python-based cpp test infra) fails.
Test Plan: buck build //caffe2/test/cpp/tensorexpr
Reviewed By: asuhan
Differential Revision: D23711293
fbshipit-source-id: 8af14fa7c1f1e8169d14bb64515771f7bc3089e5
Summary:
Unifies a number of partial solutions to the thread and block dimension extent masking, including the NoThreadIdxWriter and my last fix https://github.com/pytorch/pytorch/issues/44325. The NoThreadIdxWriter is gone in favour of tracking the current loop extents and masking any statements that have a lower rank than the launch parameters in any Block or Thread dimension, which handles both the "no" and "smaller" axis binding cases.
For example it will transform the following:
```
for i in 0..10 // blockIdx.x
for j in 0..10 // threadIdx.x
do thing(i, j);
for k in 0..5 // threadIdx.x
do other thing(i, k);
```
Into:
```
do thing(blockIdx.x, threadIdx.x);
if (threadIdx.x < 5) {
do other thing(blockIdx.x, threadIdx.x);
}
```
And handle the case where statements are not bound by any axis, eg.
```
do outer thing;
for i in 0..10 // blockIdx.x
for j in 0..10 // threadIdx.x
do thing(i, j);
do other thing(i);
```
will become:
```
if (blockIdx.x < 1) {
if (threadIdx.x < 1) {
do outer thing;
}
}
syncthreads();
do thing(blockIdx.x, threadIdx.x);
syncthreads();
if (threadIdx.x < 1) {
do other thing(blockIdx.x);
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44733
Reviewed By: mruberry
Differential Revision: D23736878
Pulled By: nickgg
fbshipit-source-id: 52d08626ae8043d53eb937843466874d479a6768
Summary:
Fix an issue where loops of different sizes are bound to the same Cuda dimension / metavar.
Coming soon more info and tests...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44325
Reviewed By: colesbury
Differential Revision: D23628859
Pulled By: nickgg
fbshipit-source-id: 3621850a4cc38a790b62ad168d32e7a0e2462fad
Summary:
Fixes a bug in the NNC registerizer for Cuda where it would hoist reads out of a conditional context when trying to cache them. As a quick fix, prevent scalar replacement if a usage is within a condition.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44223
Reviewed By: gchanan
Differential Revision: D23551247
Pulled By: nickgg
fbshipit-source-id: 17a7bf2be4c8c3dd8a9ab7997dce9aea200c3685
Summary:
Fixes a bug where FP16 values could be incorrectly cast to a half type that doesn't have a cast operator by inserting the cuda specific cast to float during handling of the Cast node, not as a wrapper around printing Loads and Stores. Two main changes: the HalfChecker now inserts the casts to float explicitly in the IR, and the PrioritizeLoad mutator now consumes both Loads and a Cast which immediately preceded a load.
Tested with test_jit_fuser_te.py and test_tensorexpr.py, plus C++ tests obv.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44209
Reviewed By: izdeby
Differential Revision: D23575577
Pulled By: nickgg
fbshipit-source-id: 808605aeb2af812758f96f9fdc11b07e08053b46
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42567
Before this change we didn't expand arguments, and thus in an expr
`sigmoid(sigmoid(x))` only the outer call was expanded.
Test Plan: Imported from OSS
Reviewed By: gmagogsfm
Differential Revision: D22936177
Pulled By: ZolotukhinM
fbshipit-source-id: 9c05dc96561225bab9a90a407d7bcf9a89b078a1
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36611
Currently Buf represents underlying storage but it didn't have dtype.
That resulted in specifying dtypes in different places and there was no
mechanism to enforce its consistency: e.g. one could've created a kFloat
expression and use a kInt buffer to store its result. Now we're
centralizing where the logic regarding the storage is located and we can
start enforcing semantics rules.
Follow-ups: we can merge Buffer and BufHandle classes as the former is
now a mere wrapper over the latter.
Test Plan: Imported from OSS
Differential Revision: D21027356
Pulled By: ZolotukhinM
fbshipit-source-id: c06aa2c4077fdcde3bb4ca622d324aece79b5a9c
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35800
This PR includes the following changes:
* Introduce a new `Expr` type `Buf`: it plays a similar to `Var` role, but also has dimensions.
* Use the new `Buf` class in `Store` and `Load` instead of `Var` for specifying where to store to or load from. `Buf` contains the dimensions info of the buffer we're loading/storing to and hence we are able to keep N-d indexes without flattening them into a 1-d index ([x,y] vs [x+y*W]).
* Flattening of the indexes is now a separate pass that is executed in `LoopNest::prepareForCodegen` - backends still expect indexes to be flattened, and this PR preserves that.
* `Tensor` now contains a `Buf` instead of `Var`, and thus Tensor now has the dimensions info (previously it was a property of a `Function`, not a `Tensor`). This brings us closer to Tensor being a combination of Buffer + Function, where Buffer specifies iteration domain and the Function defines a computation.
TODOs:
* Consider merging `Buffer` with `Buf` or `BufHandle`. It seems that we don't need all of them.
* Harden the logic of how we create buffers in fuser pass. Currently it seems that sometimes we don't set dimensions.
* Use `Buf` in `Allocate` and `Free`.
* Make it clearer that `Function` doesn't "own" dimensions info and that dimensions are a property of a Tensor, not a Function.
Differential Revision: D20789005
Test Plan: Imported from OSS
Reviewed By: zheng-xq
Pulled By: ZolotukhinM
fbshipit-source-id: e04188d1d297f195f1c46669c614557d6bb6cde4
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34842
This PR (hopefully the last one of such kind) is merging changes from a
side branch where tensor expessions based fuser work has been done so
far. This PR is is a squashed version of changes in the side branch,
which is available here: https://github.com/bertmaher/pytorch
Differential Revision: D20478208
Test Plan: Imported from OSS
Pulled By: ZolotukhinM
fbshipit-source-id: 21556e009f1fd88099944732edba72ac40e9b9c0
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34227
This PR adds a CUDA support to tensor expressions.
Differential Revision: D20251836
Test Plan: Imported from OSS
Pulled By: ZolotukhinM
fbshipit-source-id: ab36a55834cceff30c8371fef6cca1054a32f017