Summary:
I found a number of spelling & grammatical mistakes in the repository. Previously I had these fixes submitted individually, but I saw that a single word change was apparently too small for a PR to be merged. Hopefully this new PR has a sufficient number of changes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48592
Reviewed By: ejguan
Differential Revision: D25224216
Pulled By: mrshenli
fbshipit-source-id: 2af3db2aee486563efd0dffc4e8f777306a73e44
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48160
We no longer use the custom c++ test infra anyways, so move to pure
gtest.
Fixes#45703
ghstack-source-id: 116977283
Test Plan: `buck test //caffe2/test/cpp/tensorexpr`
Reviewed By: navahgar, nickgg
Differential Revision: D25046618
fbshipit-source-id: da34183d87465f410379048148c28e1623618553
Summary:
Fixes an internally reported issue in the tensorexpr fuser when using FP16 on Cuda. The HalfChecker analysis to determine if we need to define the Half type searches the IR for expressions that use Half. If one of the parameters is of type Half but it (or any other Half expr) are not used in the IR we'll return a false negative. Fix this by adding the parameter list to the HalfChecker.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48068
Reviewed By: ZolotukhinM
Differential Revision: D25009680
Pulled By: nickgg
fbshipit-source-id: 24fddef06821f130db3d3f45d6d041c7f34a6ab0
Summary:
This is a rewrite of the Registerizer, supporting scalar replacement in *vastly* more situations. As a refresher, the registerizer does this:
Before:
``` A[0] = 0;
for (int x = 0; x < 10; x++) {
A[0] = (A[0]) + x;
}
```
After:
```
int A_ = 0;
for (int x = 0; x < 10; x++) {
A_ = x + A_;
}
A[0] = A_;
```
Which can greatly reduce the number of accesses to main memory in a kernel. There are cases where doing this gets complicated, and the existing implementation bails out whenever encountering multiple partial overlaps of the same buffer, or conditional accesses under any circumstances. This makes it much less useful in the presence of complex (ie. real world not example) kernels. This new version should work optimally in almost all cases (I have a few minor follow ups).
I tested this version extensively, and found quite a few bugs in the original implementation I'd prefer not to back port fixes for - so I'm in favor of landing this even if we don't immediately see a perf win. I believe the killer app for this kind of optimization is fused reductions and we haven't enabled many examples of that yet.
It is safe to move two accesses of the same Tensor element to a local scalar Var if between all usages of the element there are no other Loads or Stores that may refer to it. In the comments I refer to this as overlapping the access, or "cutting" the existing AccessInfo. In the case where a candidate for registerization is cut, it may be possible to finalize the access early by writing it back to the Tensor and then create a new scalar variable after the overlapping access is complete. We will attempt to do this when it saves memory accesses.
There are a few cases that make this more challenging:
- For: Loops change the number of real usages of a buffer by the loop extent, but only if we can pull the definition and finalization of the scalar variable out of the loop block. For loops often create accesses which are conditional on a loop var and will overlap large ranges of elements.
E.g. Before:
```
A[0] = 2;
for (int x1 = 0; x1 < 10; x1++) {
A[0] = (A[0]) + x1;
}
for (int x2 = 1; x2 < 10; x2++) {
A[x2] = A[x2 - 1];
}
for (int x3 = 0; x3 < 10; x3++) {
A[0] = (A[0]) + x3;
}
```
After:
```
int A_1 = 2;
for (int x1 = 0; x1 < 10; x1++) {
A_1 = A_1 + x1;
}
A[0] = A_1;
for (int x2 = 1; x2 < 10; x2++) {
A[x2] = A[x2 - 1];
}
int A_2 = A[0];
for (int x3 = 0; x3 < 10; x3++) {
A_2 = A_2 + x3;
}
A[0] = A_2;
```
- Cond: Conditions complicate lifting scalars out of internal scopes. Generally we cannot lift an access outside of a conditional scope unless there is already a reference to that same access at the higher scope, since we don't know if the condition was guarding an array access not safe at the higher scope. In the comments I refer to this as the condition "hiding" the access, and the outer access "unhiding" it.
E.g. this example:
```
if (x<5 ? 1 : 0) {
A[x] = (A[x]) + 1;
}
A[x] = (A[x]) + 1;
if (x>5 ? 1 : 0) {
A[x] = (A[x]) + 1;
}
```
The A[x] access can be registerized due to the unconditional access between the two conditions:
```
int A_1 = A[x];
if (x<5 ? 1 : 0) {
A_1 = A_1 + 1;
}
A_1 = A_1 + 1;
if (x>5 ? 1 : 0) {
A_1 = A_1 + 1;
}
A[x] = A_1;
```
But this example has no accesses that can be registerized:
```
if (x<5 ? 1 : 0) {
A[x] = (A[x]) + 1;
}
if (x>5 ? 1 : 0) {
A[x] = (A[x]) + 1;
}
```
- IfThenElse: Same situation as Cond, except since IfThenElse is an Expr rather than a Stmt we cannot insert the scalar definition or finalizer within the conditional scope. Accesses inside an IfThenElse can be safely combined with external accesses but cannot exist completely within.
E.g in this example the `B[x]` cannot be registerized as there is no safe place to define it.
```
A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]);
```
But the equivalent kernel using Cond can be registerized:
```
if (x<3 ? 1 : 0) {
float B_1 = B[x];
A[x] = B_1 + B_1;
} else {
A[x] = B[x];
}
```
- Let: Accesses dependent on local variables via Let Stmts, or loop vars, cannot be raised outside of the scope of the dependent var.
E.g. no accesses in this example can be registerized:
```
for (int x = 0; x < 10; x++) {
int y = 30;
A[y] = x + (A[y]);
}
```
But they can in this example:
```
int y = 30;
for (int x = 0; x < 10; x++) {
A[y] = x + (A[y]);
}
```
**Testing**
The majority of this PR is tests, over 3k lines of them, because there are many different rules to consider and they can interact together more or less arbitrarily. I'd greatly appreciate any ideas for situations we could encounter that are not covered by the tests.
**Performance**
Still working on it, will update. In many FastRRNS sub kernels this diff reduces the number of total calls to Store or Load by 4x, but since those kernels use Concat very heavily (meaning a lot of branches) the actual number encountered by any particular thread on GPU is reduced only slightly. Overall perf improved by a very small amount.
Reductions is where this optimization should really shine, and in particular the more complex the kernel gets (with extra fusions, etc) the better this version of the registerizer should do compared the existing version.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45574
Reviewed By: albanD
Differential Revision: D24151517
Pulled By: nickgg
fbshipit-source-id: 9f0b2d98cc213eeea3fda16fee3d144d49fd79ae
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45520
With this change `Load`s and `Store`s no longer accept `Placeholder`s in
their constructor and `::make` functions and can only be built with
`Buf`.
`Placeholder` gets its own `store`, `load`, `storeWithMask`, and
`loadWithMask` method for more convenient construction.
Test Plan: Imported from OSS
Reviewed By: glaringlee
Differential Revision: D23998789
Pulled By: ZolotukhinM
fbshipit-source-id: 3fe018e00c1529a563553b2b215f403b34aea912
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45388
Classes defined in these files are closely related, so it is reasonable
to have them all in one file. The change is purely a code move.
Differential Revision: D23952867
Test Plan: Imported from OSS
Reviewed By: nickgg
Pulled By: ZolotukhinM
fbshipit-source-id: 12cfaa968bdfc4dff00509e34310a497c7b59155
Summary:
The Cuda HalfChecker casts up all loads and stores of Half to Float, so we do math in Float on the device. It didn't cast up HalfImmediate (ie. constants) so they could insert mixed-size ops. Fix is to do that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45213
Reviewed By: ezyang
Differential Revision: D23885287
Pulled By: nickgg
fbshipit-source-id: 912991d85cc06ebb282625cfa5080d7525c8eba9
Summary:
A previous fix for masking Cuda dimensions (https://github.com/pytorch/pytorch/issues/44733) changed the behaviour of inserting thread synchronization barriers in the Cuda CodeGen, causing the CudaSharedMemReduce_1 to be flaky and ultimately disabled.
The issue is working out where these barriers must be inserted - solving this optimally is very hard, and I think not possible without dependency analysis we don't have, so I've changed our logic to be quite pessimistic. We'll insert barriers before and after any blocks that have thread dimensions masked (even between blocks that have no data dependencies). This should be correct, but it's an area we could improve performance. To address this somewhat I've added a simplifier pass that removes obviously unnecessary syncThreads.
To avoid this test being flaky again, I've added a check against the generated code to ensure there is a syncThread in the right place.
Also fixed a couple of non-functional but clarity issues in the generated code: fixed the missing newline after Stores in the CudaPrinter, and prevented the PrioritizeLoad mutator from pulling out loads contained within simple Let statements (such as those produced by the Registerizer).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44909
Reviewed By: agolynski
Differential Revision: D23800565
Pulled By: nickgg
fbshipit-source-id: bddef1f40d8d461da965685f01d00b468d8a2c2f
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44861
We were redefining things like ASSERT_EQ to take a _VA_ARGS_ parameter, so compiling these files with gtest (instead of pytorch's custom python-based cpp test infra) fails.
Test Plan: buck build //caffe2/test/cpp/tensorexpr
Reviewed By: asuhan
Differential Revision: D23711293
fbshipit-source-id: 8af14fa7c1f1e8169d14bb64515771f7bc3089e5
Summary:
Unifies a number of partial solutions to the thread and block dimension extent masking, including the NoThreadIdxWriter and my last fix https://github.com/pytorch/pytorch/issues/44325. The NoThreadIdxWriter is gone in favour of tracking the current loop extents and masking any statements that have a lower rank than the launch parameters in any Block or Thread dimension, which handles both the "no" and "smaller" axis binding cases.
For example it will transform the following:
```
for i in 0..10 // blockIdx.x
for j in 0..10 // threadIdx.x
do thing(i, j);
for k in 0..5 // threadIdx.x
do other thing(i, k);
```
Into:
```
do thing(blockIdx.x, threadIdx.x);
if (threadIdx.x < 5) {
do other thing(blockIdx.x, threadIdx.x);
}
```
And handle the case where statements are not bound by any axis, eg.
```
do outer thing;
for i in 0..10 // blockIdx.x
for j in 0..10 // threadIdx.x
do thing(i, j);
do other thing(i);
```
will become:
```
if (blockIdx.x < 1) {
if (threadIdx.x < 1) {
do outer thing;
}
}
syncthreads();
do thing(blockIdx.x, threadIdx.x);
syncthreads();
if (threadIdx.x < 1) {
do other thing(blockIdx.x);
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44733
Reviewed By: mruberry
Differential Revision: D23736878
Pulled By: nickgg
fbshipit-source-id: 52d08626ae8043d53eb937843466874d479a6768
Summary:
Fix an issue where loops of different sizes are bound to the same Cuda dimension / metavar.
Coming soon more info and tests...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44325
Reviewed By: colesbury
Differential Revision: D23628859
Pulled By: nickgg
fbshipit-source-id: 3621850a4cc38a790b62ad168d32e7a0e2462fad
Summary:
Fixes a bug in the NNC registerizer for Cuda where it would hoist reads out of a conditional context when trying to cache them. As a quick fix, prevent scalar replacement if a usage is within a condition.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44223
Reviewed By: gchanan
Differential Revision: D23551247
Pulled By: nickgg
fbshipit-source-id: 17a7bf2be4c8c3dd8a9ab7997dce9aea200c3685
Summary:
Fixes a bug where FP16 values could be incorrectly cast to a half type that doesn't have a cast operator by inserting the cuda specific cast to float during handling of the Cast node, not as a wrapper around printing Loads and Stores. Two main changes: the HalfChecker now inserts the casts to float explicitly in the IR, and the PrioritizeLoad mutator now consumes both Loads and a Cast which immediately preceded a load.
Tested with test_jit_fuser_te.py and test_tensorexpr.py, plus C++ tests obv.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44209
Reviewed By: izdeby
Differential Revision: D23575577
Pulled By: nickgg
fbshipit-source-id: 808605aeb2af812758f96f9fdc11b07e08053b46
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42567
Before this change we didn't expand arguments, and thus in an expr
`sigmoid(sigmoid(x))` only the outer call was expanded.
Test Plan: Imported from OSS
Reviewed By: gmagogsfm
Differential Revision: D22936177
Pulled By: ZolotukhinM
fbshipit-source-id: 9c05dc96561225bab9a90a407d7bcf9a89b078a1
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36611
Currently Buf represents underlying storage but it didn't have dtype.
That resulted in specifying dtypes in different places and there was no
mechanism to enforce its consistency: e.g. one could've created a kFloat
expression and use a kInt buffer to store its result. Now we're
centralizing where the logic regarding the storage is located and we can
start enforcing semantics rules.
Follow-ups: we can merge Buffer and BufHandle classes as the former is
now a mere wrapper over the latter.
Test Plan: Imported from OSS
Differential Revision: D21027356
Pulled By: ZolotukhinM
fbshipit-source-id: c06aa2c4077fdcde3bb4ca622d324aece79b5a9c
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35800
This PR includes the following changes:
* Introduce a new `Expr` type `Buf`: it plays a similar to `Var` role, but also has dimensions.
* Use the new `Buf` class in `Store` and `Load` instead of `Var` for specifying where to store to or load from. `Buf` contains the dimensions info of the buffer we're loading/storing to and hence we are able to keep N-d indexes without flattening them into a 1-d index ([x,y] vs [x+y*W]).
* Flattening of the indexes is now a separate pass that is executed in `LoopNest::prepareForCodegen` - backends still expect indexes to be flattened, and this PR preserves that.
* `Tensor` now contains a `Buf` instead of `Var`, and thus Tensor now has the dimensions info (previously it was a property of a `Function`, not a `Tensor`). This brings us closer to Tensor being a combination of Buffer + Function, where Buffer specifies iteration domain and the Function defines a computation.
TODOs:
* Consider merging `Buffer` with `Buf` or `BufHandle`. It seems that we don't need all of them.
* Harden the logic of how we create buffers in fuser pass. Currently it seems that sometimes we don't set dimensions.
* Use `Buf` in `Allocate` and `Free`.
* Make it clearer that `Function` doesn't "own" dimensions info and that dimensions are a property of a Tensor, not a Function.
Differential Revision: D20789005
Test Plan: Imported from OSS
Reviewed By: zheng-xq
Pulled By: ZolotukhinM
fbshipit-source-id: e04188d1d297f195f1c46669c614557d6bb6cde4
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34842
This PR (hopefully the last one of such kind) is merging changes from a
side branch where tensor expessions based fuser work has been done so
far. This PR is is a squashed version of changes in the side branch,
which is available here: https://github.com/bertmaher/pytorch
Differential Revision: D20478208
Test Plan: Imported from OSS
Pulled By: ZolotukhinM
fbshipit-source-id: 21556e009f1fd88099944732edba72ac40e9b9c0
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34227
This PR adds a CUDA support to tensor expressions.
Differential Revision: D20251836
Test Plan: Imported from OSS
Pulled By: ZolotukhinM
fbshipit-source-id: ab36a55834cceff30c8371fef6cca1054a32f017