mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
This is a rewrite of the Registerizer, supporting scalar replacement in *vastly* more situations. As a refresher, the registerizer does this:
Before:
``` A[0] = 0;
for (int x = 0; x < 10; x++) {
A[0] = (A[0]) + x;
}
```
After:
```
int A_ = 0;
for (int x = 0; x < 10; x++) {
A_ = x + A_;
}
A[0] = A_;
```
Which can greatly reduce the number of accesses to main memory in a kernel. There are cases where doing this gets complicated, and the existing implementation bails out whenever encountering multiple partial overlaps of the same buffer, or conditional accesses under any circumstances. This makes it much less useful in the presence of complex (ie. real world not example) kernels. This new version should work optimally in almost all cases (I have a few minor follow ups).
I tested this version extensively, and found quite a few bugs in the original implementation I'd prefer not to back port fixes for - so I'm in favor of landing this even if we don't immediately see a perf win. I believe the killer app for this kind of optimization is fused reductions and we haven't enabled many examples of that yet.
It is safe to move two accesses of the same Tensor element to a local scalar Var if between all usages of the element there are no other Loads or Stores that may refer to it. In the comments I refer to this as overlapping the access, or "cutting" the existing AccessInfo. In the case where a candidate for registerization is cut, it may be possible to finalize the access early by writing it back to the Tensor and then create a new scalar variable after the overlapping access is complete. We will attempt to do this when it saves memory accesses.
There are a few cases that make this more challenging:
- For: Loops change the number of real usages of a buffer by the loop extent, but only if we can pull the definition and finalization of the scalar variable out of the loop block. For loops often create accesses which are conditional on a loop var and will overlap large ranges of elements.
E.g. Before:
```
A[0] = 2;
for (int x1 = 0; x1 < 10; x1++) {
A[0] = (A[0]) + x1;
}
for (int x2 = 1; x2 < 10; x2++) {
A[x2] = A[x2 - 1];
}
for (int x3 = 0; x3 < 10; x3++) {
A[0] = (A[0]) + x3;
}
```
After:
```
int A_1 = 2;
for (int x1 = 0; x1 < 10; x1++) {
A_1 = A_1 + x1;
}
A[0] = A_1;
for (int x2 = 1; x2 < 10; x2++) {
A[x2] = A[x2 - 1];
}
int A_2 = A[0];
for (int x3 = 0; x3 < 10; x3++) {
A_2 = A_2 + x3;
}
A[0] = A_2;
```
- Cond: Conditions complicate lifting scalars out of internal scopes. Generally we cannot lift an access outside of a conditional scope unless there is already a reference to that same access at the higher scope, since we don't know if the condition was guarding an array access not safe at the higher scope. In the comments I refer to this as the condition "hiding" the access, and the outer access "unhiding" it.
E.g. this example:
```
if (x<5 ? 1 : 0) {
A[x] = (A[x]) + 1;
}
A[x] = (A[x]) + 1;
if (x>5 ? 1 : 0) {
A[x] = (A[x]) + 1;
}
```
The A[x] access can be registerized due to the unconditional access between the two conditions:
```
int A_1 = A[x];
if (x<5 ? 1 : 0) {
A_1 = A_1 + 1;
}
A_1 = A_1 + 1;
if (x>5 ? 1 : 0) {
A_1 = A_1 + 1;
}
A[x] = A_1;
```
But this example has no accesses that can be registerized:
```
if (x<5 ? 1 : 0) {
A[x] = (A[x]) + 1;
}
if (x>5 ? 1 : 0) {
A[x] = (A[x]) + 1;
}
```
- IfThenElse: Same situation as Cond, except since IfThenElse is an Expr rather than a Stmt we cannot insert the scalar definition or finalizer within the conditional scope. Accesses inside an IfThenElse can be safely combined with external accesses but cannot exist completely within.
E.g in this example the `B[x]` cannot be registerized as there is no safe place to define it.
```
A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]);
```
But the equivalent kernel using Cond can be registerized:
```
if (x<3 ? 1 : 0) {
float B_1 = B[x];
A[x] = B_1 + B_1;
} else {
A[x] = B[x];
}
```
- Let: Accesses dependent on local variables via Let Stmts, or loop vars, cannot be raised outside of the scope of the dependent var.
E.g. no accesses in this example can be registerized:
```
for (int x = 0; x < 10; x++) {
int y = 30;
A[y] = x + (A[y]);
}
```
But they can in this example:
```
int y = 30;
for (int x = 0; x < 10; x++) {
A[y] = x + (A[y]);
}
```
**Testing**
The majority of this PR is tests, over 3k lines of them, because there are many different rules to consider and they can interact together more or less arbitrarily. I'd greatly appreciate any ideas for situations we could encounter that are not covered by the tests.
**Performance**
Still working on it, will update. In many FastRRNS sub kernels this diff reduces the number of total calls to Store or Load by 4x, but since those kernels use Concat very heavily (meaning a lot of branches) the actual number encountered by any particular thread on GPU is reduced only slightly. Overall perf improved by a very small amount.
Reductions is where this optimization should really shine, and in particular the more complex the kernel gets (with extra fusions, etc) the better this version of the registerizer should do compared the existing version.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45574
Reviewed By: albanD
Differential Revision: D24151517
Pulled By: nickgg
fbshipit-source-id: 9f0b2d98cc213eeea3fda16fee3d144d49fd79ae
3814 lines
90 KiB
C++
3814 lines
90 KiB
C++
#include "test/cpp/tensorexpr/test_base.h"
|
|
|
|
#include "test/cpp/tensorexpr/test_utils.h"
|
|
#include "torch/csrc/jit/tensorexpr/ir_simplifier.h"
|
|
#include "torch/csrc/jit/tensorexpr/registerizer.h"
|
|
|
|
#include <iostream>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
// Can replace a simple scalar access with a local variable.
|
|
void testRegisterizerSimple() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt =
|
|
Block::make({Store::make(a, {0}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)}))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = x + A_1;
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Won't do replacement of a loop access.
|
|
void testRegisterizerLoop() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt =
|
|
Block::make({Store::make(a, {0}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {x}, Add::make(Load::make(a, {x}, 1), x), 1)}))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (A[x]) + x;
|
|
* }
|
|
*/
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (A[x]) + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK-NOT: int
|
|
# CHECK: A[0] = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A_
|
|
# CHECK: A[x] =
|
|
# CHECK-NOT: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Won't replace even if the load is a fixed scalar, since the store could
|
|
// invalidate it.
|
|
void testRegisterizerLoopFixedLoad() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt =
|
|
Block::make({Store::make(a, {0}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {x}, Add::make(Load::make(a, {0}, 1), x), 1)}))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK-NOT: int
|
|
# CHECK: A[0] = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A_
|
|
# CHECK: A[x] =
|
|
# CHECK-NOT: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// We can registerize accesses that occur entirely within inner scopes, even if
|
|
// they depend on the loop var.
|
|
void testRegisterizerLoopInternal() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), x), 1),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), x), 1)}))});
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (A[x]) + x;
|
|
* A[x] = (A[x]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* int A_1 = A[x];
|
|
* A_1 = A_1 + x;
|
|
* A_1 = A_1 + x;
|
|
* A[x] = A_1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: A[x] = A_1;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// An access can be overlapped by another read in the same Expr. In this case
|
|
// B[z] and B[y] overlap and prevent registerization of both accesses.
|
|
void testRegisterizerLoopInternalLoadOverlap() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
VarHandle z("z", kInt);
|
|
Stmt* stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Store::make(
|
|
a,
|
|
{x},
|
|
Add::make(Load::make(b, {y}, 1), Load::make(b, {z}, 1)),
|
|
1))});
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (B[y]) + (B[z]);
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
void testRegisterizerLoopInternalRepeated() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {1}, 1), x), 1),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {1}, 1), x), 1)})),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {1}, 1), x), 1),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {1}, 1), x), 1)}))
|
|
|
|
});
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = x + (A[1]);
|
|
* A[0] = x + (A[1]);
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = x + (A[1]);
|
|
* A[0] = x + (A[1]);
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[1];
|
|
* int A_2 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = x + A_1;
|
|
* A_2 = x + A_1;
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = x + A_1;
|
|
* A_2 = x + A_1;
|
|
* }
|
|
* A[0] = A_2;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[1];
|
|
# CHECK: int A_2 = A[0];
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK: A_2 = x + A_1;
|
|
# CHECK: A_2 = x + A_1;
|
|
# CHECK: }
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK: A_2 = x + A_1;
|
|
# CHECK: A_2 = x + A_1;
|
|
# CHECK: }
|
|
# CHECK-NOT: A[1]
|
|
# CHECK: A[0] = A_2;
|
|
# CHECK-NOT: A[1]
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
void testRegisterizerLoopInternalRepeatedOverlapLoopVar() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {x}, 1), x), 1),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {x}, 1), x), 1)})),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {x}, 1), x), 1),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {x}, 1), x), 1)}))
|
|
|
|
});
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[x]) + x;
|
|
* A[0] = (A[x]) + x;
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[x]) + x;
|
|
* A[0] = (A[x]) + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
void testRegisterizerLoopInternalRepeatedOverlapOther() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(x, Load::make(a, {y}, 1)), 1),
|
|
Store::make(a, {0}, Add::make(x, Load::make(a, {y}, 1)), 1)})),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(x, Load::make(a, {y}, 1)), 1),
|
|
Store::make(a, {0}, Add::make(x, Load::make(a, {y}, 1)), 1)}))
|
|
|
|
});
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[x]) + x;
|
|
* A[0] = (A[x]) + x;
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[x]) + x;
|
|
* A[0] = (A[x]) + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// Will registerize multiple accesses of different items of the same buffer.
|
|
void testRegisterizerMultiVar() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {2}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make({
|
|
Store::make(a, {0}, 0, 1),
|
|
Store::make(a, {1}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1),
|
|
Store::make(a, {1}, Sub::make(Load::make(a, {1}, 1), x), 1)})),
|
|
});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* A[1] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* A[1] = (A[1]) - x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* int A_2 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = x + A_2;
|
|
* A_1 = A_1 - x;
|
|
* }
|
|
* A[1] = A_2;
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: int A_2 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A_2 =
|
|
# CHECK: A[1] = A_2
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Will registerize the valid accesses while skipping invalid replacements.
|
|
void testRegisterizerVariableLoad() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle x2("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {0}, 0, 1),
|
|
For::make(x, 0, 10, Store::make(b, {x}, x, 1)),
|
|
For::make(
|
|
x2,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a,
|
|
{0},
|
|
Add::make(Load::make(a, {0}, 1), Load::make(b, {x2}, 1)),
|
|
1)}))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = x;
|
|
* }
|
|
* for (int x_1 = 0; x_1 < 10; x_1++) {
|
|
* A[0] = (A[0]) + (B[x_1]);
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = x;
|
|
* }
|
|
* for (int x_1 = 0; x_1 < 10; x_1++) {
|
|
* A_1 = A_1 + (B[x_1]);
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK: B[x] = x
|
|
# CHECK: for (int x_1 = 0; x_1 < 10; x_1++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize variable accesses so long as the variable does not change.
|
|
void testRegisterizerSymbolicIndices() {
|
|
KernelScope kernel_scope;
|
|
VarHandle i("i", kInt);
|
|
VarHandle N("N", kInt);
|
|
BufHandle a("A", {N}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt =
|
|
Block::make({Store::make(a, {i}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {i}, Add::make(Load::make(a, {i}, 1), x), 1)}))});
|
|
|
|
/*
|
|
* A[i] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[i] = (A[i]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = x + A_1;
|
|
* }
|
|
* A[i] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[i] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize accesses dependent on multiple loop vars.
|
|
void testRegisterizerMultiLoop() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {0}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a,
|
|
{0},
|
|
Mul::make(Add::make(Load::make(a, {0}, 1), x), y),
|
|
1)})))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A[0] = x * y + (A[0]) * y;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A_1 = x * y + y * A_1;
|
|
* }
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK: for (int y = 0; y < 10; y++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize correctly if scalars already exist in the program.
|
|
void testRegisterizerRepeated() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {2}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make({
|
|
Store::make(a, {0}, 0, 1),
|
|
Store::make(a, {1}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1),
|
|
Store::make(a, {1}, Sub::make(Load::make(a, {1}, 1), x), 1)})),
|
|
});
|
|
|
|
// Registerize manually to make sure we only replace a single target.
|
|
{
|
|
registerizer::RegisterizerAnalysis analysis;
|
|
stmt->accept(&analysis);
|
|
auto candidates = analysis.getCandidates();
|
|
ASSERT_EQ(candidates.size(), 2);
|
|
|
|
candidates.pop_back();
|
|
registerizer::RegisterizerReplacer replacer(candidates);
|
|
stmt = stmt->accept_mutator(&replacer);
|
|
}
|
|
|
|
// Re-analyze and replace the second target.
|
|
{
|
|
registerizer::RegisterizerAnalysis analysis;
|
|
stmt->accept(&analysis);
|
|
auto candidates = analysis.getCandidates();
|
|
ASSERT_EQ(candidates.size(), 1);
|
|
|
|
registerizer::RegisterizerReplacer replacer(candidates);
|
|
stmt = stmt->accept_mutator(&replacer);
|
|
}
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: int A_1_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A_1_1 =
|
|
# CHECK: A[1] = A_1_1;
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize the load of A.
|
|
void testRegisterizerNoLoads() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {0}, 0, 1),
|
|
For::make(
|
|
x, 0, 10, Block::make({Store::make(a, {0}, Add::make(x, 1), 1)}))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = x + 1;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = x + 1;
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize the load of A but not the store of B.
|
|
void testRegisterizerNoRepeatedStores() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt =
|
|
Block::make({Store::make(a, {0}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
b, {x}, Add::make(Load::make(a, {0}, 1), x), 1)}))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
// TODO: its unnecessary to reorder the initializer of A[0], but it's not
|
|
// actually worse so lets not worry for now.
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = x + A_1;
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A_
|
|
# CHECK: B[x] =
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Won't registerize if there are multiple accesses which may overlap.
|
|
void testRegisterizerMultiVarOverlap() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {2}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make({
|
|
Store::make(a, {0}, 0, 1),
|
|
Store::make(a, {1}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {x}, Add::make(Load::make(a, {0}, 1), x), 1),
|
|
Store::make(
|
|
a, {x + 1}, Sub::make(Load::make(a, {1}, 1), x), 1)})),
|
|
});
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
void testRegisterizerAllocs() {
|
|
KernelScope kernel_scope;
|
|
|
|
BufHandle a("A", {2}, kInt);
|
|
BufHandle b("B", {1}, kInt);
|
|
BufHandle c("C", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
VarHandle b_(b.node()->base_handle());
|
|
|
|
Stmt* stmt = Block::make(
|
|
{Allocate::make(b_, kInt, {Load::make(c, {0}, 1)}),
|
|
Store::make(a, {0}, Load::make(c, {0}, 1), 1),
|
|
Store::make(b, {0}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(b, {0}, Add::make(Load::make(b, {0}, 1), x), 1),
|
|
Store::make(a, {0}, Load::make(c, {0}, 1), 1)})),
|
|
Free::make(b_)});
|
|
|
|
/*
|
|
* Allocate(B, int, {C[0]});
|
|
* A[0] = C[0];
|
|
* B[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[0] = (B[0]) + x;
|
|
* A[0] = C[0];
|
|
* }
|
|
* Free(B);
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int C_1 = C[0];
|
|
* Allocate(B, int, {C_});
|
|
* int A_1 = C_1;
|
|
* int B_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B_1 = B_1 + x;
|
|
* A_1 = C_1;
|
|
* }
|
|
* B[0] = B_1;
|
|
* A[0] = A_1;
|
|
* Free(B);
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int C_1 = C[0];
|
|
# CHECK: Allocate(B
|
|
# CHECK: int A_1 = C_1;
|
|
# CHECK: int B_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK: B_1 =
|
|
# CHECK: A_1 = C_
|
|
# CHECK: B[0] = B_1;
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: Free(B)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
void testRegisterizerNoInitializer() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)}))});
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = x + A_1;
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
void testRegisterizerNoInitializerLoopVar() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), x), 1)}))});
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (A[x]) + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
void testRegisterizerLoadThenStore() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
BufHandle b("B", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(b, {0}, Add::make(Load::make(a, {0}, 1), x), 1),
|
|
Store::make(a, {0}, Load::make(b, {0}, 1), 1)}))});
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[0] = (A[0]) + x;
|
|
* A[0] = B[0];
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[0];
|
|
* int B_1 = B[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B_1 = x + A_1;
|
|
* A_1 = B_1;
|
|
* }
|
|
* B[0] = B_1;
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: int B_1 = B[0];
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: B[
|
|
# CHECK: B_1 =
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 = B_
|
|
# CHECK: B[0] = B_
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
void testRegisterizerParallelized() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
LoopOptions loopOpts;
|
|
loopOpts.set_gpu_block_index(0);
|
|
Stmt* stmt =
|
|
Block::make({Store::make(a, {0}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)}),
|
|
loopOpts)});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
ASSERT_THROWS_WITH(
|
|
registerize(stmt),
|
|
"Registerization must occur after parallelism flattening");
|
|
}
|
|
|
|
// Should be able to registerize this since the scalar would exist before the
|
|
// branch.
|
|
void testRegisterizerConditionAfter() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {x}, Load::make(b, {x}, 1), 1),
|
|
Store::make(c, {x}, Load::make(a, {x}, 1), 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1),
|
|
nullptr)});
|
|
|
|
/*
|
|
* A[x] = B[x];
|
|
* C[x] = A[x];
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = B[x];
|
|
* C[x] = A_1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = B[x];
|
|
# CHECK: C[x] = A_1;
|
|
# CHECK: if (
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Should be able to registerize this since the scalar exists in the same form
|
|
// after the branch and there is no overlap.
|
|
void testRegisterizerConditionBefore() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
Stmt* stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1),
|
|
nullptr),
|
|
Store::make(a, {x}, Load::make(b, {x}, 1), 1),
|
|
Store::make(c, {x}, Load::make(a, {x}, 1), 1)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* A[x] = B[x];
|
|
* C[x] = A[x];
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_ 1 = A[x];
|
|
* if (x<5 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A_1 = B[x];
|
|
* C[x] = A_1;
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: if (
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A_1 = B[x];
|
|
# CHECK: C[x] = A_1;
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Should be able to registerize this as the combination of the two above rules.
|
|
void testRegisterizerConditionInside() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {x}, Load::make(b, {x}, 1), 1),
|
|
Store::make(c, {x}, Load::make(a, {x}, 1), 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1),
|
|
nullptr),
|
|
Store::make(b, {x}, Load::make(a, {x}, 1), 1),
|
|
Store::make(a, {x}, Load::make(c, {x}, 1), 1)});
|
|
|
|
/*
|
|
* A[x] = B[x];
|
|
* C[x] = A[x];
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* B[x] = A[x];
|
|
* A[x] = C[x];
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = B[x];
|
|
* C[x] = A_1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* B[x] = A_1;
|
|
* A_1 = C[x];
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = B[x];
|
|
# CHECK: C[x] = A_1;
|
|
# CHECK: if (
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: B[x] = A_1;
|
|
# CHECK: A_1 = C[x];
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// An example where an access is cut by an overlapping access inside a
|
|
// condition, and both sides are large enough to be registerized but cannot be
|
|
// because there is no safe place to put the initializer or finalizer.
|
|
void testRegisterizerConditionInsideOverlap1() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {x}, Load::make(b, {x}, 1), 1),
|
|
Store::make(c, {x}, Load::make(a, {x}, 1), 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make({
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1),
|
|
Store::make(a, {0}, 3, 1),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1),
|
|
}),
|
|
nullptr),
|
|
Store::make(b, {x}, Load::make(a, {x}, 1), 1),
|
|
Store::make(a, {x}, Load::make(c, {x}, 1), 1)});
|
|
|
|
/*
|
|
* A[x] = B[x];
|
|
* C[x] = A[x];
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* A[0] = 3;
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* B[x] = A[x];
|
|
* A[x] = C[x];
|
|
*/
|
|
|
|
// The A[0] store overlaps, A[x] cutting the region that can be registerized
|
|
// into two groups.
|
|
// Each group has 2 loads and 2 stores however, so we could registerize it,
|
|
// but the first group would need to be finalized inside the condition block,
|
|
// the second would need to be initialized inside the condition block. There's
|
|
// no safe place to put these that's visible to the other uses in the group
|
|
// and so neither registerization is possible.
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// Same as the above, but the access group before the condition (and after the
|
|
// condition) are large enough to be registerized without needing the access
|
|
// from the loop. Registerization occurs but does not include any accesses in
|
|
// the condition, and the first group must be finalized before the Cond, the
|
|
// second initialized after it.
|
|
void testRegisterizerConditionInsideOverlap2() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {x}, Load::make(b, {x}, 1), 1),
|
|
Store::make(a, {x}, Load::make(b, {x + 1}, 1), 1),
|
|
Store::make(c, {x}, Load::make(a, {x}, 1), 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make({
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1),
|
|
Store::make(a, {0}, 3, 1),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1),
|
|
}),
|
|
nullptr),
|
|
Store::make(b, {x}, Load::make(a, {x}, 1), 1),
|
|
Store::make(b, {x + 1}, Load::make(a, {x}, 1), 1),
|
|
Store::make(a, {x}, Load::make(c, {x}, 1), 1)});
|
|
|
|
/*
|
|
* A[x] = B[x];
|
|
* A[x] = B[x + 1];
|
|
* C[x] = A[x];
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* A[0] = 3;
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* B[x] = A[x];
|
|
* B[x + 1] = A[x];
|
|
* A[x] = C[x];
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = B[x]; // A_1 initializer
|
|
* A_1 = B[x + 1]; //
|
|
* C[x] = A_1; //
|
|
* A[x] = A_1; // A_1 finalizer
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* A[0] = 3;
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* int A_2 = A[x]; // A_2 initialier
|
|
* B[x] = A_2; //
|
|
* B[x + 1] = A_2; //
|
|
* A_2 = C[x]; //
|
|
* A[x] = A_2; // A_2 finalizer
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = B[x];
|
|
# CHECK: A_1 = B[x + 1];
|
|
# CHECK: C[x] = A_1;
|
|
# CHECK: A[x] = A_1;
|
|
# CHECK: if (
|
|
# CHECK-NOT: A_1 = A_1 + 1;
|
|
# CHECK: A[x] = (A[x]
|
|
# CHECK: A[0] =
|
|
# CHECK: A[x] = (A[x]
|
|
# CHECK: }
|
|
# CHECK: int A_2 = A[x];
|
|
# CHECK: B[x] = A_2;
|
|
# CHECK: B[x + 1] = A_2;
|
|
# CHECK: A_2 = C[x];
|
|
# CHECK: A[x] = A_2;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// When accesses are within conditional blocks they are not visible to the wider
|
|
// program, because we don't know if the branch would be taken and if it isn't
|
|
// the accesses in it don't need to be valid (think size checks on the index).
|
|
// In this case the accesses cannot be registerized.
|
|
void testRegisterizerConditionHidden() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
Stmt* stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1),
|
|
nullptr),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kGT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* if (x>5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// But... if the same access is found in a non conditional scope, that means
|
|
// that that access is valid in the higher scope (or at least if its not it's
|
|
// the user's fault). It "unhides" the conditional accesses, allowing
|
|
// registerization to occur.
|
|
void testRegisterizerConditionUnhidden() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
Stmt* stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1),
|
|
nullptr),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kGT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* A[x] = (A[x]) + 1; <-- this is doing the unhiding.
|
|
* if (x>5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[x];
|
|
* if (x<5 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A_1 = A_1 + 1;
|
|
* if (x>5 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: if (x<5
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: if (x>5
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize a load that occurs in the condition of a Cond.
|
|
void testRegisterizerCondCondition() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {x}, Load::make(b, {x}, 1), 1),
|
|
Store::make(c, {x}, Load::make(a, {x}, 1), 1),
|
|
Cond::make(
|
|
CompareSelect::make(
|
|
Load::make(a, {x}, 1), 5, CompareSelectOperation::kLT),
|
|
Store::make(c, {x}, Add::make(Load::make(c, {x}, 1), 1), 1),
|
|
nullptr)});
|
|
|
|
/*
|
|
* A[x] = B[x];
|
|
* C[x] = A[x];
|
|
* if ((A[x])<5 ? 1 : 0) {
|
|
* C[x] = (C[x]) + 1;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = B[x];
|
|
* int C_1 = A_1;
|
|
* if (A_1<5 ? 1 : 0) {
|
|
* C_1 = C_1 + 1;
|
|
* }
|
|
* C[x] = C_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = B[x];
|
|
# CHECK: int C_1 = A_1;
|
|
# CHECK: if (A_1<5
|
|
# CHECK: C_1 = C_1 + 1;
|
|
# CHECK: C[x] = C_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Appearing in the condition of a Cond makes it visible to the enclosing scope,
|
|
// and so we can registerize internal usages.
|
|
void testRegisterizerCondConditionUnhidden() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
Stmt* stmt = Block::make({Cond::make(
|
|
CompareSelect::make(
|
|
Load::make(a, {x}, 1), 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 10), 1))});
|
|
|
|
/*
|
|
* if ((A[x])<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* } else {
|
|
* A[x] = (A[x]) + 10;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[x];
|
|
* if (A_1<5 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* } else {
|
|
* A_1 = A_1 + 10;
|
|
* }
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: if (A_1<5
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: } else {
|
|
# CHECK: A_1 = A_1 + 10;
|
|
# CHECK: }
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Conditional hiding also works for IfThenElse exprs.
|
|
void testRegisterizerIfThenElseHidden() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(
|
|
b,
|
|
{y},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}, 1), 1),
|
|
Add::make(Load::make(a, {x + 1}, 1), 2)),
|
|
1),
|
|
Store::make(
|
|
b,
|
|
{y + 1},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}, 1), 1),
|
|
Add::make(Load::make(a, {x + 1}, 1), 2)),
|
|
1)});
|
|
|
|
/*
|
|
* B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
|
|
* B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// Conditional unhiding also works for IfThenElse exprs.
|
|
void testRegisterizerIfThenElseUnhidden() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
|
|
Stmt* stmt = Block::make({
|
|
Store::make(a, {x}, 0, 1),
|
|
Store::make(
|
|
b,
|
|
{y},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}, 1), 1),
|
|
Add::make(Load::make(a, {x + 1}, 1), 2)),
|
|
1),
|
|
Store::make(
|
|
b,
|
|
{y + 1},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}, 1), 1),
|
|
Add::make(Load::make(a, {x + 1}, 1), 2)),
|
|
1),
|
|
});
|
|
|
|
/*
|
|
* A[x] = 0;
|
|
* B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
|
|
* B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
|
|
* B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
|
|
# CHECK: B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Nested IfThenElse exprs can't promote to higher level scopes.
|
|
void testRegisterizerIfThenElseNested() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
BufHandle d("D", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
Stmt* stmt = Block::make({Store::make(
|
|
a,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kLT),
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Load::make(d, {x}, 1),
|
|
Load::make(b, {x}, 1)),
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kEQ),
|
|
Load::make(c, {x}, 1),
|
|
Load::make(d, {x}, 1))),
|
|
1)});
|
|
|
|
/*
|
|
* A[x] = IfThenElse(x<3 ? 1 : 0,
|
|
* IfThenElse(x==2 ? 1 : 0, D[x], B[x]),
|
|
* IfThenElse(x==5 ? 1 : 0, C[x], D[x]));
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// Cannot registerize an access completely contained within an IfThenElse
|
|
// branch, since it is not a Stmt and cannot hold variable definitions. We need
|
|
// to check that we don't promote the initializer/finalizer to the enclosing
|
|
// Block.
|
|
void testRegisterizerIfThenElseInternal() {
|
|
KernelScope kernel_scope;
|
|
// Making these floats so they don't get simplified to a single access.
|
|
BufHandle a("A", {5}, kFloat);
|
|
BufHandle b("B", {5}, kFloat);
|
|
VarHandle x("x", kInt);
|
|
|
|
Stmt* stmt = Block::make({Store::make(
|
|
a,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(b, {x}, 1), Load::make(b, {x}, 1)),
|
|
Load::make(b, {x}, 1)),
|
|
1)});
|
|
|
|
/*
|
|
* A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]);
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
|
|
// If this was a Cond instead of an IfThenElse then we could registerize the
|
|
// two accesses to B[x] in the True branch.
|
|
|
|
// Actually lets verify that.
|
|
|
|
stmt = Block::make({Cond::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kLT),
|
|
Store::make(
|
|
a, {x}, Add::make(Load::make(b, {x}, 1), Load::make(b, {x}, 1)), 1),
|
|
Store::make(a, {x}, Load::make(b, {x}, 1), 1))});
|
|
|
|
/*
|
|
* if (x<3 ? 1 : 0) {
|
|
* A[x] = (B[x]) + (B[x]);
|
|
* } else {
|
|
* A[x] = B[x];
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x<3 ? 1 : 0) {
|
|
* float B_1 = B[x];
|
|
* A[x] = B_1 + B_1;
|
|
* } else {
|
|
* A[x] = B[x];
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK-NOT: int
|
|
# CHECK-NOT: float
|
|
# CHECK: if (x<3
|
|
# CHECK: float B_1 =
|
|
# CHECK: A[x] = B_1 + B_1
|
|
# CHECK: } else {
|
|
# CHECK: A[x] = B[x]
|
|
# CHECK: }
|
|
# CHECK-NOT: A[x]
|
|
# CHECK-NOT: B[x])IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize a load that occurs in the condition of an IfThenElse;
|
|
void testRegisterizerIfThenElseCondition() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {x}, Load::make(a, {x}, 1), 1),
|
|
Store::make(
|
|
a,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(
|
|
Load::make(a, {x}, 1), 5, CompareSelectOperation::kLT),
|
|
Load::make(b, {0}, 1),
|
|
Load::make(c, {0}, 1)),
|
|
1)});
|
|
|
|
/*
|
|
* A[x] = A[x]; <---- just here so there are enough accesses to combine.
|
|
* A[x] = IfThenElse((A[x])<5 ? 1 : 0, B[0], C[0]);
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[x];
|
|
* A_1 = A_1;
|
|
* A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]);
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]);
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Appearing in the condition of a Cond makes it visible to the enclosing scope,
|
|
// and so we can registerize internal usages.
|
|
void testRegisterizerIfThenElseConditionUnhidden() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
Stmt* stmt = Block::make({Store::make(
|
|
b,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(
|
|
Load::make(a, {x}, 1), 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}, 1), 1),
|
|
Add::make(Load::make(a, {x}, 1), 10)),
|
|
1)});
|
|
|
|
/*
|
|
* B[x] = IfThenElse((A[x])<5 ? 1 : 0, (A[x]) + 1, (A[x]) + 10);
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[x];
|
|
* B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Cannot promote accesses internal to IfThenElse branches even if the enclosing
|
|
// scope if conditional.
|
|
void testRegisterizerConditionBranchOnly() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(
|
|
a,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}, 1), x),
|
|
Add::make(Load::make(a, {x - 5}, 1), x)),
|
|
1),
|
|
Store::make(
|
|
a,
|
|
{x - 5},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}, 1), x),
|
|
Add::make(Load::make(a, {x - 5}, 1), x)),
|
|
1)),
|
|
}))});
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
/* for (int x = 0; x < 10; x++) {
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x);
|
|
* } else {
|
|
* A[x - 5] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x);
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// We can registerize an IfThenElse that appears in the condition branch of a
|
|
// Cond. This is a weird but valid thing to do.
|
|
void testRegisterizerCondIfThenElse() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
Stmt* stmt = Block::make({Cond::make(
|
|
CompareSelect::make(
|
|
IfThenElse::make(
|
|
CompareSelect::make(
|
|
Load::make(a, {x}, 1), 5, CompareSelectOperation::kLT),
|
|
Load::make(a, {x}, 1),
|
|
Load::make(b, {x}, 1)),
|
|
x,
|
|
CompareSelectOperation::kEQ),
|
|
Store::make(c, {x}, Add::make(Load::make(c, {x}, 1), 1), 1),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if ((IfThenElse((A[x])<5 ? 1 : 0, A[x], B[x]))==x ? 1 : 0) {
|
|
* C[x] = (C[x]) + 1;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
// access to A can be registerized, but not B or C
|
|
|
|
/*
|
|
* int A_1 = A[x];
|
|
* if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]))==x ? 1 : 0) {
|
|
* C[x] = (C[x]) + 1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]
|
|
# CHECK: C[x] = (C[x]) + 1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize a conditional access in the RHS of a store unhidden by it's
|
|
// LHS, and hoist it out of a loop.
|
|
void testRegisterizerIfThenElseLoop() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
|
|
Stmt* stmt = For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
Store::make(
|
|
a,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kLT),
|
|
Load::make(a, {x}, 1),
|
|
Load::make(b, {y}, 1)),
|
|
1));
|
|
|
|
/*
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A[x] = IfThenElse(x<3 ? 1 : 0, A[x], B[y]);
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[x];
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]);
|
|
* }
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: for (
|
|
# CHECK: A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]);
|
|
# CHECK: }
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Cannot registerize if the RHS overlaps the access creating visibility.
|
|
void testRegisterizerIfThenElseLoopCut() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
|
|
Stmt* stmt = Block::make({For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
Store::make(
|
|
a,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kLT),
|
|
Load::make(a, {x}, 1),
|
|
Load::make(a, {y}, 1)),
|
|
1))});
|
|
|
|
/*
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A[x] = IfThenElse(x<3 ? 1 : 0, A[x], A[y]);
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// Simple case where an access is cut by an overlapping access later in the
|
|
// program, we can registerize up until the overlap.
|
|
void testRegisterizerPartialAfter() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {0}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)})),
|
|
For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}, 1), 1))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
* for (int x = 1; x < 10; x++) {
|
|
* A[x] = A[x - 1];
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + x;
|
|
* }
|
|
* A[0] = A_1;
|
|
* for (int x = 1; x < 10; x++) {
|
|
* A[x] = A[x - 1];
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: for (
|
|
# CHECK: A[x] = A[x - 1];
|
|
# CHECK: }
|
|
# CHECK-NOT: A)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// We can registerize an access which overlaps a previous access, the
|
|
// initializer must be inserted after the previous access.
|
|
void testRegisterizerPartialBefore() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}, 1), 1)),
|
|
Store::make(a, {0}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)}))});
|
|
|
|
/*
|
|
* for (int x = 1; x < 10; x++) {
|
|
* A[x] = A[x - 1];
|
|
* }
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* for (int x = 1; x < 10; x++) {
|
|
* A[x] = A[x - 1];
|
|
* }
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + x;
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK-NOT: int
|
|
# CHECK: for (
|
|
# CHECK: A[x] = A[x - 1];
|
|
# CHECK: }
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// The combination of the previous two tests, an access is cut by an overlapping
|
|
// access in both directions.
|
|
void testRegisterizerPartialInside() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x1("x1", kInt);
|
|
VarHandle x2("x2", kInt);
|
|
VarHandle x3("x3", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {0}, 2, 1),
|
|
For::make(
|
|
x1,
|
|
0,
|
|
10,
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x1), 1)),
|
|
For::make(
|
|
x2, 1, 10, Store::make(a, {x2}, Load::make(a, {x2 - 1}, 1), 1)),
|
|
For::make(
|
|
x3,
|
|
0,
|
|
10,
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x3), 1))});
|
|
|
|
/*
|
|
* A[0] = 2;
|
|
* for (int x1 = 0; x1 < 10; x1++) {
|
|
* A[0] = (A[0]) + x1;
|
|
* }
|
|
* for (int x2 = 1; x2 < 10; x2++) {
|
|
* A[x2] = A[x2 - 1];
|
|
* }
|
|
* for (int x3 = 0; x3 < 10; x3++) {
|
|
* A[0] = (A[0]) + x3;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 2;
|
|
* for (int x1 = 0; x1 < 10; x1++) {
|
|
* A_1 = A_1 + x1;
|
|
* }
|
|
* A[0] = A_1;
|
|
* for (int x2 = 1; x2 < 10; x2++) {
|
|
* A[x2] = A[x2 - 1];
|
|
* }
|
|
* int A_2 = A[0];
|
|
* for (int x3 = 0; x3 < 10; x3++) {
|
|
* A_2 = A_2 + x3;
|
|
* }
|
|
* A[0] = A_2;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 2;
|
|
# CHECK: for (
|
|
# CHECK: A_1 = A_1 + x1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: for (
|
|
# CHECK: A[x2] =
|
|
# CHECK: }
|
|
# CHECK: int A_2 = A[0];
|
|
# CHECK: for (
|
|
# CHECK: A_2 = A_2 + x3;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_2;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// An element could be registerized program wide but is cut by a conditional
|
|
// access, we should break this into two scalars and write back to the buffer
|
|
// before the condition.
|
|
void testRegisterizerPartialCondition() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {0}, 2, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Load::make(a, {x - 1}, 1), 1),
|
|
nullptr),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1))});
|
|
|
|
/*
|
|
* A[0] = 2;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = A[x - 1];
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 2;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + x;
|
|
* }
|
|
* A[0] = A_1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = A[x - 1];
|
|
* }
|
|
* int A_2 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = A_2 + x;
|
|
* }
|
|
* A[0] = A_2;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 2;
|
|
# CHECK: for (
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: if (
|
|
# CHECK: A[x] =
|
|
# CHECK: }
|
|
# CHECK: int A_2 = A[0];
|
|
# CHECK: for (
|
|
# CHECK: A_2 = A_2 + x;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_2;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Tests case where an access is cut by an internal conditional access which
|
|
// itself is registerized.
|
|
void testRegisterizerPartialConditionInternalCut() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {0}, 1, 1),
|
|
Store::make(a, {0}, 3, 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make({Store::make(a, {x}, 1, 1), Store::make(a, {x}, 3, 1)}),
|
|
nullptr),
|
|
Store::make(a, {0}, 4, 1),
|
|
Store::make(a, {0}, 6, 1)});
|
|
|
|
/*
|
|
* A[0] = 1;
|
|
* A[0] = 3;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = 1;
|
|
* A[x] = 3;
|
|
* }
|
|
* A[0] = 4;
|
|
* A[0] = 6;
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 1;
|
|
* A_1 = 3;
|
|
* A[0] = A_1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* int A_2 = 1;
|
|
* A_2 = 3;
|
|
* A[x] = A_2;
|
|
* }
|
|
* int A_3 = 4;
|
|
* A_3 = 6;
|
|
* A[0] = A_3;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 1;
|
|
# CHECK: A_1 = 3
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: if (
|
|
# CHECK: int A_2 = 1;
|
|
# CHECK: A_2 = 3;
|
|
# CHECK: A[x] = A_2;
|
|
# CHECK: }
|
|
# CHECK: int A_3 = 4;
|
|
# CHECK: A_3 = 6;
|
|
# CHECK: A[0] = A_3;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// First statment in condition closes outer access, but can be registerized with
|
|
// later statements.
|
|
void testRegisterizerPartialConditionInternalStart() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {0}, 1, 1),
|
|
Store::make(a, {0}, 3, 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make({Store::make(a, {x}, 1, 1), Store::make(a, {x}, 3, 1)}),
|
|
nullptr),
|
|
Store::make(a, {x}, 4, 1),
|
|
Store::make(a, {x}, 6, 1)});
|
|
|
|
/*
|
|
* A[0] = 1;
|
|
* A[0] = 3;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = 1;
|
|
* A[x] = 3;
|
|
* }
|
|
* A[x] = 4;
|
|
* A[x] = 6;
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 1;
|
|
* A_1 = 3;
|
|
* A[0] = A_1;
|
|
* int A_2 = A[x]; <--- must read from the input here.
|
|
* if (x<5 ? 1 : 0) {
|
|
* A_2 = 1;
|
|
* A_2 = 3;
|
|
* }
|
|
* A_2 = 4;
|
|
* A_2 = 6;
|
|
* A[x] = A_2;
|
|
*/
|
|
|
|
// TODO: I suppose we could refactor with a conditional initializier?
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 1;
|
|
# CHECK: A_1 = 3
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: int A_2 = A[x];
|
|
# CHECK: if (
|
|
# CHECK: A_2 = 1;
|
|
# CHECK: A_2 = 3;
|
|
# CHECK: }
|
|
# CHECK: A_2 = 4;
|
|
# CHECK: A_2 = 6;
|
|
# CHECK: A[x] = A_2;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// An access cuts two open overlaps and creates four scalar variables.
|
|
void testRegisterizerPartialOverlapsTwo() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make({Store::make(a, {1}, Load::make(a, {0}, 1), 1),
|
|
Store::make(a, {0}, Load::make(a, {1}, 1), 1),
|
|
Store::make(a, {0}, Load::make(a, {1}, 1), 1),
|
|
For::make(x, 1, 10, Store::make(a, {x}, x, 1)),
|
|
Store::make(a, {1}, Load::make(a, {0}, 1), 1),
|
|
Store::make(a, {0}, Load::make(a, {1}, 1), 1),
|
|
Store::make(a, {0}, Load::make(a, {1}, 1), 1)});
|
|
|
|
/*
|
|
* A[1] = A[0];
|
|
* A[0] = A[1];
|
|
* A[0] = A[1];
|
|
* for (int x = 1; x < 10; x++) {
|
|
* A[x] = x;
|
|
* }
|
|
* A[1] = A[0];
|
|
* A[0] = A[1];
|
|
* A[0] = A[1];
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[0];
|
|
* int A_2 = A_1;
|
|
* A_1 = A_2;
|
|
* A_1 = A_2;
|
|
* A[1] = A_2;
|
|
* A[0] = A_1;
|
|
* for (int x = 1; x < 10; x++) {
|
|
* A[x] = x;
|
|
* }
|
|
* int A_3 = A[0];
|
|
* int A_4 = A_3;
|
|
* A_3 = A_4;
|
|
* A_3 = A_4;
|
|
* A[1] = A_4;
|
|
* A[0] = A_3;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: int A_2 = A_1;
|
|
# CHECK: A_1 = A_2;
|
|
# CHECK: A_1 = A_2;
|
|
# CHECK: A[1] = A_2;
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: for (
|
|
# CHECK: A[x] = x;
|
|
# CHECK: }
|
|
# CHECK: int A_3 = A[0];
|
|
# CHECK: int A_4 = A_3;
|
|
# CHECK: A_3 = A_4;
|
|
# CHECK: A_3 = A_4;
|
|
# CHECK: A[1] = A_4;
|
|
# CHECK: A[0] = A_3;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Nested blocks will automatically be flattened and do not provent
|
|
// registerization of enclosed accesses.
|
|
void testRegisterizerNestedBlocks() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1),
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 2), 1)}),
|
|
Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 3), 1),
|
|
Block::make({Store::make(
|
|
a, {0}, Add::make(Load::make(a, {0}, 1), 4), 1)})})});
|
|
|
|
/*
|
|
* A[0] = (A[0]) + 1;
|
|
* {
|
|
* A[0] = (A[0]) + 2;
|
|
* }
|
|
* {
|
|
* A[0] = (A[0]) + 3;
|
|
* {
|
|
* A[0] = (A[0]) + 4;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[0];
|
|
* A_1 = A_1 + 1;
|
|
* A_1 = A_1 + 2;
|
|
* A_1 = A_1 + 3;
|
|
* A_1 = A_1 + 4;
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: A_1 = A_1 + 2;
|
|
# CHECK: A_1 = A_1 + 3;
|
|
# CHECK: A_1 = A_1 + 4;
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// The access can be registerized internally to a condition, but must ensure
|
|
// that both initializer and finalizer are within the same condition.
|
|
void testRegisterizerNestedConditions() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make({Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1),
|
|
nullptr)}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* if (x==2 ? 1 : 0) {
|
|
*
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* int A_1 = A[0];
|
|
* A_1 = A_1 + 1;
|
|
* if (x==2 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A[0] = A_1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: if (x<5
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: if (x==2
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// If an access exists outside the scope of the condition then we can lift
|
|
// nested conditional usages into the same scalar.
|
|
void testRegisterizerNestedConditionsUnhidden() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make(
|
|
{Store::make(a, {1}, 1, 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1),
|
|
nullptr)}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* A[0] = (A[0]) + 1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[1] = 1;
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[0];
|
|
* A_1 = A_1 + 1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[1] = 1;
|
|
* if (x==2 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: if (x<5
|
|
# CHECK: A[1] = 1;
|
|
# CHECK: if (x==2
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
void testRegisterizerNestedConditionsHiddenFirst() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1),
|
|
nullptr),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make({Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1),
|
|
nullptr)}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* if (x<5 ? 1 : 0) {
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
|
|
stmt = registerize(stmt);
|
|
}
|
|
|
|
void testRegisterizerNestedConditionsHiddenSecond() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make({Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1),
|
|
nullptr)}),
|
|
nullptr),
|
|
Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
|
|
stmt = registerize(stmt);
|
|
}
|
|
|
|
// If an access is cut by another access internal to a condition block, it still
|
|
// cuts the access.
|
|
void testRegisterizerNestedConditionsCut() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make(
|
|
{Store::make(a, {x}, 1, 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1),
|
|
nullptr)}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* A[0] = (A[0]) + 1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = 1;
|
|
* if (x==2 ? 1 : 0) {
|
|
*
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
void testRegisterizerNestedConditionLoopHidden() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1),
|
|
nullptr),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(b, {x}, 0, 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1),
|
|
nullptr)}))});
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = 0; <-- this is only here to prevent Loop/Cond reordering.
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// Three loops and four element regions, three of which should be registerized
|
|
// at different levels of the IR.
|
|
void testRegisterizerNestedConditionThreeDeep() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {4}, 0, 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kGT),
|
|
Cond::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kGT),
|
|
Block::make({
|
|
Cond::make(
|
|
CompareSelect::make(x, 4, CompareSelectOperation::kGT),
|
|
Block::make({
|
|
Store::make(
|
|
a, {1}, Add::make(Load::make(a, {1}, 1), 1), 1),
|
|
Store::make(
|
|
a, {2}, Add::make(Load::make(a, {2}, 1), 1), 1),
|
|
Store::make(
|
|
a, {3}, Add::make(Load::make(a, {3}, 1), 1), 1),
|
|
Store::make(
|
|
a, {4}, Add::make(Load::make(a, {4}, 1), 1), 1),
|
|
Store::make(
|
|
a, {1}, Add::make(Load::make(a, {1}, 1), 1), 1),
|
|
}),
|
|
nullptr),
|
|
Store::make(a, {2}, Add::make(Load::make(a, {2}, 1), 1), 1),
|
|
}),
|
|
nullptr),
|
|
nullptr)});
|
|
|
|
/*
|
|
* A[4] = 0;
|
|
* if (x>2 ? 1 : 0) {
|
|
* if (x>3 ? 1 : 0) {
|
|
* if (x>4 ? 1 : 0) {
|
|
* A[1] = (A[1]) + 1;
|
|
* A[2] = (A[2]) + 1;
|
|
* A[3] = (A[3]) + 1;
|
|
* A[4] = (A[4]) + 1;
|
|
* A[1] = (A[1]) + 1;
|
|
* }
|
|
* A[2] = (A[2]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* if (x>2 ? 1 : 0) {
|
|
* if (x>3 ? 1 : 0) {
|
|
* int A_3 = A[2];
|
|
* if (x>4 ? 1 : 0) {
|
|
* int A_2 = A[1];
|
|
* A_2 = A_2 + 1;
|
|
* A_3 = A_3 + 1;
|
|
* A[3] = (A[3]) + 1;
|
|
* A_1 = A_1 + 1;
|
|
* A_2 = A_2 + 1;
|
|
* A[1] = A_2;
|
|
* }
|
|
* A_3 = A_3 + 1;
|
|
* A[2] = A_3;
|
|
* }
|
|
* }
|
|
* A[4] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: if (x>2 ? 1 : 0) {
|
|
# CHECK: if (x>3 ? 1 : 0) {
|
|
# CHECK: int A_3 = A[2];
|
|
# CHECK: if (x>4 ? 1 : 0) {
|
|
# CHECK: int A_2 = A[1];
|
|
# CHECK: A_2 = A_2 + 1;
|
|
# CHECK: A_3 = A_3 + 1;
|
|
# CHECK: A[3] = (A[3]) + 1;
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: A_2 = A_2 + 1;
|
|
# CHECK: A[1] = A_2;
|
|
# CHECK: }
|
|
# CHECK: A_3 = A_3 + 1;
|
|
# CHECK: A[2] = A_3;
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: A[4] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can replace a simple scalar access with a local variable even when that
|
|
// variable is an outer loop var.
|
|
void testRegisterizerNestedLoopSimple() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
Stmt* stmt = Block::make({For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {y}, Add::make(Load::make(a, {y}, 1), x), 1)})))});
|
|
|
|
/*
|
|
* for (int y = 0; y < 10; y++) {
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[y] = (A[y]) + x;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* for (int y = 0; y < 10; y++) {
|
|
* int A_1 = A[y];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = x + A_1;
|
|
* }
|
|
* A[y] = A_1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int y
|
|
# CHECK: int A_1 = A[y];
|
|
# CHECK: for (int x
|
|
# CHECK: A_1 = x + A_1;
|
|
# CHECK: }
|
|
# CHECK: A[y] = A_1;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Test the positive case of the hiddenAccess split, where an internal
|
|
// conditional access can be hoisted up through a loop to match an existing
|
|
// access in a higher scope and the two can be registerized.
|
|
void testRegisterizerHiddenAccessYes() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
Stmt* stmt = Block::make({Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Block::make(
|
|
{Store::make(a, {0}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(b, {x}, 0, 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kEQ),
|
|
For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
Store::make(
|
|
a,
|
|
{0},
|
|
Add::make(Load::make(a, {0}, 1), 1),
|
|
1)),
|
|
nullptr)}))}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = 0;
|
|
* if (x==3 ? 1 : 0) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = 0;
|
|
* if (x==3 ? 1 : 0) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* }
|
|
* }
|
|
* A[0] = A_1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: if (x==2
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x
|
|
# CHECK: B[x] = 0;
|
|
# CHECK: if (x==3
|
|
# CHECK: for (int y
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Test the negative case of the hiddenAccess split, where the hoisted access is
|
|
// never unhidden at a higher scope and registerization occurs at the lower
|
|
// scope.
|
|
void testRegisterizerHiddenAccessNo() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
Stmt* stmt = Block::make({Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(b, {x}, 0, 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kEQ),
|
|
For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
Store::make(
|
|
a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1)),
|
|
nullptr)}))}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = 0;
|
|
* if (x==3 ? 1 : 0) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = 0;
|
|
* if (x==3 ? 1 : 0) {
|
|
* int A_1 = A[0];
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A[0] = A_1;
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: if (x==2
|
|
# CHECK: for (int x
|
|
# CHECK: B[x] = 0;
|
|
# CHECK: if (x==3
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: for (int y
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// In this case the conditional access must be hoisted by two loops, there are
|
|
// two accesses here one is unhidden and the other isnt. A[0] can be
|
|
// registerized but B[0] cannot.
|
|
void testRegisterizerHiddenAccessMultiLoop() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
Stmt* stmt = Block::make({Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Block::make(
|
|
{Store::make(a, {0}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
Block::make({Cond::make(
|
|
CompareSelect::make(y, 3, CompareSelectOperation::kEQ),
|
|
Block::make(
|
|
{Store::make(
|
|
a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1),
|
|
Store::make(
|
|
b,
|
|
{0},
|
|
Add::make(Load::make(b, {0}, 1), 1),
|
|
1)}),
|
|
nullptr)})))}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* if (y==3 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* B[0] = (B[0]) + 1;
|
|
* }
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* if (y==3 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* B[0] = (B[0]) + 1;
|
|
* }
|
|
* }
|
|
* }
|
|
* A[0] = A_1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: if (x==2
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x
|
|
# CHECK: for (int y
|
|
# CHECK: if (y==3
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: B[0] = (B[0]) + 1;
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Accesses are registerized inside two conditions, but the immeidate parent is
|
|
// not a condition.
|
|
void testRegisterizerTwoConditionalLoops() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1)),
|
|
nullptr),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kGT),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1)),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
* if (x>5 ? 1 : 0) {
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* int A_1 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A[0] = A_1;
|
|
* }
|
|
* if (x>5 ? 1 : 0) {
|
|
* int A_2 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = A_2 + 1;
|
|
* }
|
|
* A[0] = A_2;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: if (x<5
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: for (int x
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: }
|
|
# CHECK: if (x>5
|
|
# CHECK: int A_2 = A[0];
|
|
# CHECK: for (int x
|
|
# CHECK: A_2 = A_2 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_2;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Accesses are registerized inside two conditions, cut in the middle.
|
|
void testRegisterizerTwoConditionalLoopsCut() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1)),
|
|
nullptr),
|
|
For::make(x, 0, 10, Store::make(a, {x}, 1, 1)),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kGT),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1)),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = 1;
|
|
* }
|
|
* if (x>5 ? 1 : 0) {
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* int A_1 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A[0] = A_1;
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = 1;
|
|
* }
|
|
* if (x>5 ? 1 : 0) {
|
|
* int A_2 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = A_2 + 1;
|
|
* }
|
|
* A[0] = A_2;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: if (x<5
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: for (int x
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: }
|
|
# CHECK: for (int x
|
|
# CHECK: A[x] = 1;
|
|
# CHECK: if (x>5
|
|
# CHECK: int A_2 = A[0];
|
|
# CHECK: for (int x
|
|
# CHECK: A_2 = A_2 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_2;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// references a Let var in a local scope which cannot be hoisted out of the
|
|
// loop.
|
|
void testRegisterizerLoopLetVar() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
Stmt* stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Let::make(y, 30),
|
|
Store::make(a, {y}, Add::make(x, Load::make(a, {y}, 1)), 1)}))});
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* int y = 30;
|
|
* A[y] = x + (A[y]);
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// references a Let var in an outer scope that does not prevent hoisting the
|
|
// initializer.
|
|
void testRegisterizerLoopLetVarOuter() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
Stmt* stmt =
|
|
Block::make({Let::make(y, 30),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {y}, Add::make(x, Load::make(a, {y}, 1)), 1)}))});
|
|
|
|
/*
|
|
* int y = 30;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[y] = x + (A[y]);
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int y = 30;
|
|
* int A_1 = A[y];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = x + A_1;
|
|
* }
|
|
* A[y] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int y = 30;
|
|
# CHECK: int A_1 = A[y];
|
|
# CHECK: for (int x
|
|
# CHECK: A_1 = x + A_1;
|
|
# CHECK: A[y] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Okay so the registerizer generally goes after index flattening, but just in
|
|
// case. Test multi index registerization.
|
|
void testRegisterizerMultiDim() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {3, 4, 5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {0, 1, 2}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {0, 1, 2}, Add::make(Load::make(a, {0, 1, 2}, 1), x), 1)}))});
|
|
|
|
/*
|
|
* A[0, 1, 2] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0, 1, 2] = (A[0, 1, 2]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = x + A_1;
|
|
* }
|
|
* A[0, 1, 2] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[0, 1, 2] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Wont registerize if only some dims match, but will still registerize distinct
|
|
// elements.
|
|
void testRegisterizerMultiDimPartial() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {3, 4, 5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {0, 1, 2}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {0, 2, 2}, Add::make(Load::make(a, {0, 1, 4}, 1), x), 1)}))});
|
|
|
|
/*
|
|
* A[0, 1, 2] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0, 2, 2] = (A[0, 1, 4]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* A[0, 1, 2] = 0;
|
|
* int A_1 = A[0, 1, 4];
|
|
* int A_2 = A[0, 2, 2];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = x + A_1;
|
|
* }
|
|
* A[0, 2, 2] = A_2;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: A[0, 1, 2] = 0;
|
|
# CHECK: int A_1 = A[0, 1, 4];
|
|
# CHECK: int A_2 = A[0, 2, 2];
|
|
# CHECK: for (
|
|
# CHECK: A_2 = x + A_1;
|
|
# CHECK: A[0, 2, 2] = A_2;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// If they could overlap across all dimensions we cannot registerize.
|
|
void testRegisterizerMultiDimOverlap() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {3, 4, 5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {0, 1, 2}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 2}, 1), x), 1)}))});
|
|
|
|
/*
|
|
* A[0, 1, 2] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0, x, 2] = (A[y, 2, 2]) + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// But, if one dimension is known to be distinct they do not overlap.
|
|
void testRegisterizerMultiDimPartialOverlap() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {3, 4, 5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
Stmt* stmt = Block::make(
|
|
{Store::make(a, {0, 1, 2}, 0, 1),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 4}, 1), x), 1)}))});
|
|
|
|
/*
|
|
* A[0, 1, 2] = 0; <---- 2nd dim overlaps with store.
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0, x, 2] = (A[y, 2, 4]) + x; <---- 3rd dim has constant diff.
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* A[0, 1, 2] = 0;
|
|
* int A_1 = A[y, 2, 4];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0, x, 2] = A_1 + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: A[0, 1, 2] = 0;
|
|
# CHECK: int A_1 = A[y, 2, 4];
|
|
# CHECK: for (
|
|
# CHECK: A[0, x, 2] = A_1 + x;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// A 3D reduction with different input dimensionality.
|
|
void testRegisterizerMultiDim3DReduction1() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10, 10}, kInt);
|
|
BufHandle c("C", {10, 10, 10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
VarHandle z("z", kInt);
|
|
Stmt* stmt = For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
For::make(
|
|
z,
|
|
0,
|
|
10,
|
|
Store::make(
|
|
c,
|
|
{x, y, z},
|
|
Add::make(
|
|
Load::make(c, {x, y, z}, 1),
|
|
Mul::make(
|
|
Load::make(b, {x, y}, 1), Load::make(a, {x}, 1))),
|
|
1))));
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* for (int z = 0; z < 10; z++) {
|
|
* C[x, y, z] = (C[x, y, z]) + (B[x, y]) * (A[x]);
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
// We can registerize the A and B access since they can be hoisted before
|
|
// hitting a dependent loop var.
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* int A_1 = A[x];
|
|
* for (int y = 0; y < 10; y++) {
|
|
* int B_1 = B[x, y];
|
|
* for (int z = 0; z < 10; z++) {
|
|
* C[x, y, z] = A_1 * B_1 + (C[x, y, z]);
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int x
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: for (int y
|
|
# CHECK: int B_1 = B[x, y];
|
|
# CHECK: for (int z
|
|
# CHECK: C[x, y, z] = A_1 * B_1 + (C[x, y, z]);
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// A 3D reduction with the same smaller dimensionality using different loop
|
|
// vars.
|
|
void testRegisterizerMultiDim3DReduction2() {
|
|
KernelScope kernel_scope;
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
BufHandle c("C", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
VarHandle z("z", kInt);
|
|
Stmt* stmt = For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
For::make(
|
|
z,
|
|
0,
|
|
10,
|
|
Store::make(
|
|
c,
|
|
{x},
|
|
Add::make(
|
|
Load::make(c, {x}, 1),
|
|
Mul::make(Load::make(b, {y}, 1), Load::make(a, {x}, 1))),
|
|
1))));
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* for (int z = 0; z < 10; z++) {
|
|
* C[x] = (C[x]) + (B[y]) * (A[x]);
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
// We can registerize all accesses, the A and C access can be hoisted to the
|
|
// outer loop since they depend only on it's loop var while the B can only be
|
|
// raised to the loop of y.
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* int A_1 = A[x];
|
|
* int C_1 = C[x];
|
|
* for (int y = 0; y < 10; y++) {
|
|
* int B_1 = B[y];
|
|
* for (int z = 0; z < 10; z++) {
|
|
* C_1 = B_1 * A_1 + C_1;
|
|
* }
|
|
* }
|
|
* C[x] = C_1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int x
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: int C_1 = C[x];
|
|
# CHECK: for (int y
|
|
# CHECK: int B_1 = B[y];
|
|
# CHECK: for (int z
|
|
# CHECK: C_1 = B_1 * A_1 + C_1;
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: C[x] = C_1;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|