mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157637 Approved by: https://github.com/yewentao256, https://github.com/albanD ghstack dependencies: #156605
3703 lines
87 KiB
C++
3703 lines
87 KiB
C++
#include <gtest/gtest.h>
|
|
#include "test/cpp/tensorexpr/test_base.h"
|
|
|
|
#include "test/cpp/tensorexpr/test_utils.h"
|
|
#include "torch/csrc/jit/tensorexpr/ir_simplifier.h"
|
|
#include "torch/csrc/jit/tensorexpr/registerizer.h"
|
|
|
|
#include <iostream>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
// Can replace a simple scalar access with a local variable.
|
|
TEST(Registerizer, RegisterizerSimple) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = x + A_1;
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Won't do replacement of a loop access.
|
|
TEST(Registerizer, RegisterizerLoop) {
|
|
BufHandle a("A", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (A[x]) + x;
|
|
* }
|
|
*/
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (A[x]) + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK-NOT: int
|
|
# CHECK: A[0] = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A_
|
|
# CHECK: A[x] =
|
|
# CHECK-NOT: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Won't replace even if the load is a fixed scalar, since the store could
|
|
// invalidate it.
|
|
TEST(Registerizer, RegisterizerLoopFixedLoad) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {x}, Add::make(Load::make(a, {0}), x))}))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK-NOT: int
|
|
# CHECK: A[0] = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A_
|
|
# CHECK: A[x] =
|
|
# CHECK-NOT: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// We can registerize accesses that occur entirely within inner scopes, even if
|
|
// they depend on the loop var.
|
|
TEST(Registerizer, RegisterizerLoopInternal) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {x}, Add::make(Load::make(a, {x}), x)),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (A[x]) + x;
|
|
* A[x] = (A[x]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
// TODO: the order of terms in addition changes and in general depends on
|
|
// some hash value. This results in unpredictable swaps of the operands from
|
|
// random changes, which is not great. Ideally, we should ensure some
|
|
// specific order (ideally, the original one).
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* int A_1 = A[x];
|
|
* A_1 = x + A_1;
|
|
* A_1 = x + A_1;
|
|
* A[x] = A_1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: A[x] = A_1;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// An access can be overlapped by another read in the same Expr. In this case
|
|
// B[z] and B[y] overlap and prevent registerization of both accesses.
|
|
TEST(Registerizer, RegisterizerLoopInternalLoadOverlap) {
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
VarHandle z("z", kInt);
|
|
StmtPtr stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Store::make(a, {x}, Add::make(Load::make(b, {y}), Load::make(b, {z}))))});
|
|
stmt = IRSimplifier::simplify(stmt);
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (B[y]) + (B[z]);
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerLoopInternalRepeated) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {1}), x)),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {1}), x)),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {1}), x))}))
|
|
|
|
});
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = x + (A[1]);
|
|
* A[0] = x + (A[1]);
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = x + (A[1]);
|
|
* A[0] = x + (A[1]);
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[1];
|
|
* int A_2 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = A_1 + x;
|
|
* A_2 = A_1 + x;
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = A_1 + x;
|
|
* A_2 = A_1 + x;
|
|
* }
|
|
* A[0] = A_2;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[1];
|
|
# CHECK: int A_2 = A[0];
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK: A_2 = A_1 + x;
|
|
# CHECK: A_2 = A_1 + x;
|
|
# CHECK: }
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK: A_2 = A_1 + x;
|
|
# CHECK: A_2 = A_1 + x;
|
|
# CHECK: }
|
|
# CHECK-NOT: A[1]
|
|
# CHECK: A[0] = A_2;
|
|
# CHECK-NOT: A[1]
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapLoopVar) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {x}), x)),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {x}), x)),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {x}), x))}))
|
|
|
|
});
|
|
stmt = IRSimplifier::simplify(stmt);
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[x]) + x;
|
|
* A[0] = (A[x]) + x;
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[x]) + x;
|
|
* A[0] = (A[x]) + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = IRSimplifier::simplify(Block::make(
|
|
{For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(x, Load::make(a, {y}))),
|
|
Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(x, Load::make(a, {y}))),
|
|
Store::make(a, {0}, Add::make(x, Load::make(a, {y})))}))
|
|
|
|
}));
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[x]) + x;
|
|
* A[0] = (A[x]) + x;
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[x]) + x;
|
|
* A[0] = (A[x]) + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// Will registerize multiple accesses of different items of the same buffer.
|
|
TEST(Registerizer, RegisterizerMultiVar) {
|
|
BufHandle a("A", {2}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make({
|
|
Store::make(a, {0}, 0),
|
|
Store::make(a, {1}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), x)),
|
|
Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})),
|
|
});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* A[1] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* A[1] = (A[1]) - x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* int A_2 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = x + A_2;
|
|
* A_1 = A_1 - x;
|
|
* }
|
|
* A[1] = A_2;
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: int A_2 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A_2 =
|
|
# CHECK: A[1] = A_2
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Will registerize the valid accesses while skipping invalid replacements.
|
|
TEST(Registerizer, RegisterizerVariableLoad) {
|
|
BufHandle a("A", {1}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle x2("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(x, 0, 10, Store::make(b, {x}, x)),
|
|
For::make(
|
|
x2,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {0}, Add::make(Load::make(a, {0}), Load::make(b, {x2})))}))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = x;
|
|
* }
|
|
* for (int x_1 = 0; x_1 < 10; x_1++) {
|
|
* A[0] = (A[0]) + (B[x_1]);
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = x;
|
|
* }
|
|
* for (int x_1 = 0; x_1 < 10; x_1++) {
|
|
* A_1 = A_1 + (B[x_1]);
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK: B[x] = x
|
|
# CHECK: for (int x_1 = 0; x_1 < 10; x_1++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize variable accesses so long as the variable does not change.
|
|
TEST(Registerizer, RegisterizerSymbolicIndices) {
|
|
VarHandle i("i", kInt);
|
|
VarHandle N("N", kInt);
|
|
BufHandle a("A", {N}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {i}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {i}, Add::make(Load::make(a, {i}), x))}))});
|
|
|
|
/*
|
|
* A[i] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[i] = (A[i]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = x + A_1;
|
|
* }
|
|
* A[i] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[i] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize accesses dependent on multiple loop vars.
|
|
TEST(Registerizer, RegisterizerMultiLoop) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a,
|
|
{0},
|
|
Mul::make(Add::make(Load::make(a, {0}), x), y))})))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A[0] = x * y + (A[0]) * y;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A_1 = x * y + y * A_1;
|
|
* }
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK: for (int y = 0; y < 10; y++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize correctly if scalars already exist in the program.
|
|
TEST(Registerizer, RegisterizerRepeated) {
|
|
BufHandle a("A", {2}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make({
|
|
Store::make(a, {0}, 0),
|
|
Store::make(a, {1}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), x)),
|
|
Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})),
|
|
});
|
|
|
|
// Registerize manually to make sure we only replace a single target.
|
|
{
|
|
registerizer::RegisterizerAnalysis analysis;
|
|
stmt->accept(&analysis);
|
|
auto candidates = analysis.getCandidates();
|
|
ASSERT_EQ(candidates.size(), 2);
|
|
|
|
candidates.pop_back();
|
|
registerizer::RegisterizerReplacer replacer(candidates);
|
|
stmt = stmt->accept_mutator(&replacer);
|
|
}
|
|
|
|
// Re-analyze and replace the second target.
|
|
{
|
|
registerizer::RegisterizerAnalysis analysis;
|
|
stmt->accept(&analysis);
|
|
auto candidates = analysis.getCandidates();
|
|
ASSERT_EQ(candidates.size(), 1);
|
|
|
|
registerizer::RegisterizerReplacer replacer(candidates);
|
|
stmt = stmt->accept_mutator(&replacer);
|
|
}
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: int A_1_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A_1_1 =
|
|
# CHECK: A[1] = A_1_1;
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize the load of A.
|
|
TEST(Registerizer, RegisterizerNoLoads) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x, 0, 10, Block::make({Store::make(a, {0}, Add::make(x, 1))}))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = x + 1;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = x + 1;
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize the load of A but not the store of B.
|
|
TEST(Registerizer, RegisterizerNoRepeatedStores) {
|
|
BufHandle a("A", {1}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(b, {x}, Add::make(Load::make(a, {0}), x))}))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
// TODO: its unnecessary to reorder the initializer of A[0], but it's not
|
|
// actually worse so lets not worry for now.
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = x + A_1;
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A_
|
|
# CHECK: B[x] =
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Won't registerize if there are multiple accesses which may overlap.
|
|
TEST(Registerizer, RegisterizerMultiVarOverlap) {
|
|
BufHandle a("A", {2}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make({
|
|
Store::make(a, {0}, 0),
|
|
Store::make(a, {1}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {x}, Add::make(Load::make(a, {0}), x)),
|
|
Store::make(a, {x + 1}, Sub::make(Load::make(a, {1}), x))})),
|
|
});
|
|
stmt = IRSimplifier::simplify(stmt);
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerAllocs) {
|
|
BufHandle a("A", {2}, kInt);
|
|
BufHandle c("C", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
BufHandle b("B", {Load::make(c, {0})}, kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
{Allocate::make(b),
|
|
Store::make(a, {0}, Load::make(c, {0})),
|
|
Store::make(b, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(b, {0}, Add::make(Load::make(b, {0}), x)),
|
|
Store::make(a, {0}, Load::make(c, {0}))})),
|
|
Free::make(b)});
|
|
|
|
/*
|
|
* Allocate(B, int, {C[0]});
|
|
* A[0] = C[0];
|
|
* B[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[0] = (B[0]) + x;
|
|
* A[0] = C[0];
|
|
* }
|
|
* Free(B);
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int C_1 = C[0];
|
|
* Allocate(B, int, {C_});
|
|
* int A_1 = C_1;
|
|
* int B_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B_1 = B_1 + x;
|
|
* A_1 = C_1;
|
|
* }
|
|
* B[0] = B_1;
|
|
* A[0] = A_1;
|
|
* Free(B);
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int C_1 = C[0];
|
|
# CHECK: Allocate(B
|
|
# CHECK: int A_1 = C_1;
|
|
# CHECK: int B_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK: B_1 =
|
|
# CHECK: A_1 = C_
|
|
# CHECK: B[0] = B_1;
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: Free(B)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerNoInitializer) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = x + A_1;
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerNoInitializerLoopVar) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});
|
|
stmt = IRSimplifier::simplify(stmt);
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (A[x]) + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerLoadThenStore) {
|
|
BufHandle a("A", {1}, kInt);
|
|
BufHandle b("B", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(b, {0}, Add::make(Load::make(a, {0}), x)),
|
|
Store::make(a, {0}, Load::make(b, {0}))}))});
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[0] = (A[0]) + x;
|
|
* A[0] = B[0];
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[0];
|
|
* int B_1 = B[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B_1 = x + A_1;
|
|
* A_1 = B_1;
|
|
* }
|
|
* B[0] = B_1;
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: int B_1 = B[0];
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: B[
|
|
# CHECK: B_1 =
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 = B_
|
|
# CHECK: B[0] = B_
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerParallelized) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
LoopOptions loopOpts;
|
|
loopOpts.set_gpu_block_index(0);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}),
|
|
loopOpts)});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
ASSERT_THROWS_WITH(
|
|
registerize(stmt),
|
|
"Registerization must occur after parallelism flattening");
|
|
}
|
|
|
|
// Should be able to registerize this since the scalar would exist before the
|
|
// branch.
|
|
TEST(Registerizer, RegisterizerConditionAfter) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {x}, Load::make(b, {x})),
|
|
Store::make(c, {x}, Load::make(a, {x})),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
nullptr)});
|
|
|
|
/*
|
|
* A[x] = B[x];
|
|
* C[x] = A[x];
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = B[x];
|
|
* C[x] = A_1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = B[x];
|
|
# CHECK: C[x] = A_1;
|
|
# CHECK: if (
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Should be able to registerize this since the scalar exists in the same form
|
|
// after the branch and there is no overlap.
|
|
TEST(Registerizer, RegisterizerConditionBefore) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
nullptr),
|
|
Store::make(a, {x}, Load::make(b, {x})),
|
|
Store::make(c, {x}, Load::make(a, {x}))});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* A[x] = B[x];
|
|
* C[x] = A[x];
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_ 1 = A[x];
|
|
* if (x<5 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A_1 = B[x];
|
|
* C[x] = A_1;
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: if (
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A_1 = B[x];
|
|
# CHECK: C[x] = A_1;
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Should be able to registerize this as the combination of the two above rules.
|
|
TEST(Registerizer, RegisterizerConditionInside) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {x}, Load::make(b, {x})),
|
|
Store::make(c, {x}, Load::make(a, {x})),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
nullptr),
|
|
Store::make(b, {x}, Load::make(a, {x})),
|
|
Store::make(a, {x}, Load::make(c, {x}))});
|
|
|
|
/*
|
|
* A[x] = B[x];
|
|
* C[x] = A[x];
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* B[x] = A[x];
|
|
* A[x] = C[x];
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = B[x];
|
|
* C[x] = A_1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* B[x] = A_1;
|
|
* A_1 = C[x];
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = B[x];
|
|
# CHECK: C[x] = A_1;
|
|
# CHECK: if (
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: B[x] = A_1;
|
|
# CHECK: A_1 = C[x];
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// An example where an access is cut by an overlapping access inside a
|
|
// condition, and both sides are large enough to be registerized but cannot be
|
|
// because there is no safe place to put the initializer or finalizer.
|
|
TEST(Registerizer, RegisterizerConditionInsideOverlap1) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
|
{Store::make(a, {x}, Load::make(b, {x})),
|
|
Store::make(c, {x}, Load::make(a, {x})),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make({
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
Store::make(a, {0}, 3),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
}),
|
|
nullptr),
|
|
Store::make(b, {x}, Load::make(a, {x})),
|
|
Store::make(a, {x}, Load::make(c, {x}))});
|
|
|
|
/*
|
|
* A[x] = B[x];
|
|
* C[x] = A[x];
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* A[0] = 3;
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* B[x] = A[x];
|
|
* A[x] = C[x];
|
|
*/
|
|
|
|
// The A[0] store overlaps, A[x] cutting the region that can be registerized
|
|
// into two groups.
|
|
// Each group has 2 loads and 2 stores however, so we could registerize it,
|
|
// but the first group would need to be finalized inside the condition block,
|
|
// the second would need to be initialized inside the condition block. There's
|
|
// no safe place to put these that's visible to the other uses in the group
|
|
// and so neither registerization is possible.
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// Same as the above, but the access group before the condition (and after the
|
|
// condition) are large enough to be registerized without needing the access
|
|
// from the loop. Registerization occurs but does not include any accesses in
|
|
// the condition, and the first group must be finalized before the Cond, the
|
|
// second initialized after it.
|
|
TEST(Registerizer, RegisterizerConditionInsideOverlap2) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
|
{Store::make(a, {x}, Load::make(b, {x})),
|
|
Store::make(a, {x}, Load::make(b, {x + 1})),
|
|
Store::make(c, {x}, Load::make(a, {x})),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make({
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
Store::make(a, {0}, 3),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
}),
|
|
nullptr),
|
|
Store::make(b, {x}, Load::make(a, {x})),
|
|
Store::make(b, {x + 1}, Load::make(a, {x})),
|
|
Store::make(a, {x}, Load::make(c, {x}))});
|
|
|
|
/*
|
|
* A[x] = B[x];
|
|
* A[x] = B[x + 1];
|
|
* C[x] = A[x];
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* A[0] = 3;
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* B[x] = A[x];
|
|
* B[x + 1] = A[x];
|
|
* A[x] = C[x];
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = B[x]; // A_1 initializer
|
|
* A_1 = B[x + 1]; //
|
|
* C[x] = A_1; //
|
|
* A[x] = A_1; // A_1 finalizer
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* A[0] = 3;
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* int A_2 = A[x]; // A_2 initializer
|
|
* B[x] = A_2; //
|
|
* B[x + 1] = A_2; //
|
|
* A_2 = C[x]; //
|
|
* A[x] = A_2; // A_2 finalizer
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = B[x];
|
|
# CHECK: A_1 = B[x + 1];
|
|
# CHECK: C[x] = A_1;
|
|
# CHECK: A[x] = A_1;
|
|
# CHECK: if (
|
|
# CHECK-NOT: A_1 = A_1 + 1;
|
|
# CHECK: A[x] = (A[x]
|
|
# CHECK: A[0] =
|
|
# CHECK: A[x] = (A[x]
|
|
# CHECK: }
|
|
# CHECK: int A_2 = A[x];
|
|
# CHECK: B[x] = A_2;
|
|
# CHECK: B[x + 1] = A_2;
|
|
# CHECK: A_2 = C[x];
|
|
# CHECK: A[x] = A_2;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// When accesses are within conditional blocks they are not visible to the wider
|
|
// program, because we don't know if the branch would be taken and if it isn't
|
|
// the accesses in it don't need to be valid (think size checks on the index).
|
|
// In this case the accesses cannot be registerized.
|
|
TEST(Registerizer, RegisterizerConditionHidden) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
nullptr),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kGT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* if (x>5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// But... if the same access is found in a non conditional scope, that means
|
|
// that that access is valid in the higher scope (or at least if its not it's
|
|
// the user's fault). It "unhides" the conditional accesses, allowing
|
|
// registerization to occur.
|
|
TEST(Registerizer, RegisterizerConditionUnhidden) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
nullptr),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kGT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* A[x] = (A[x]) + 1; <-- this is doing the unhiding.
|
|
* if (x>5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[x];
|
|
* if (x<5 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A_1 = A_1 + 1;
|
|
* if (x>5 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: if (x<5
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: if (x>5
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize a load that occurs in the condition of a Cond.
|
|
TEST(Registerizer, RegisterizerCondCondition) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {x}, Load::make(b, {x})),
|
|
Store::make(c, {x}, Load::make(a, {x})),
|
|
Cond::make(
|
|
CompareSelect::make(
|
|
Load::make(a, {x}), 5, CompareSelectOperation::kLT),
|
|
Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)),
|
|
nullptr)});
|
|
|
|
/*
|
|
* A[x] = B[x];
|
|
* C[x] = A[x];
|
|
* if ((A[x])<5 ? 1 : 0) {
|
|
* C[x] = (C[x]) + 1;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = B[x];
|
|
* int C_1 = A_1;
|
|
* if (A_1<5 ? 1 : 0) {
|
|
* C_1 = C_1 + 1;
|
|
* }
|
|
* C[x] = C_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = B[x];
|
|
# CHECK: int C_1 = A_1;
|
|
# CHECK: if (A_1<5
|
|
# CHECK: C_1 = C_1 + 1;
|
|
# CHECK: C[x] = C_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Appearing in the condition of a Cond makes it visible to the enclosing scope,
|
|
// and so we can registerize internal usages.
|
|
TEST(Registerizer, RegisterizerCondConditionUnhidden) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make({Cond::make(
|
|
CompareSelect::make(Load::make(a, {x}), 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 10)))});
|
|
|
|
/*
|
|
* if ((A[x])<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* } else {
|
|
* A[x] = (A[x]) + 10;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[x];
|
|
* if (A_1<5 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* } else {
|
|
* A_1 = A_1 + 10;
|
|
* }
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: if (A_1<5
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: } else {
|
|
# CHECK: A_1 = A_1 + 10;
|
|
# CHECK: }
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Conditional hiding also works for IfThenElse exprs.
|
|
TEST(Registerizer, RegisterizerIfThenElseHidden) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(
|
|
b,
|
|
{y},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}), 1),
|
|
Add::make(Load::make(a, {x + 1}), 2))),
|
|
Store::make(
|
|
b,
|
|
{y + 1},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}), 1),
|
|
Add::make(Load::make(a, {x + 1}), 2)))});
|
|
|
|
/*
|
|
* B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
|
|
* B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// Conditional unhiding also works for IfThenElse exprs.
|
|
TEST(Registerizer, RegisterizerIfThenElseUnhidden) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
|
|
StmtPtr stmt = Block::make({
|
|
Store::make(a, {x}, 0),
|
|
Store::make(
|
|
b,
|
|
{y},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}), 1),
|
|
Add::make(Load::make(a, {x + 1}), 2))),
|
|
Store::make(
|
|
b,
|
|
{y + 1},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}), 1),
|
|
Add::make(Load::make(a, {x + 1}), 2))),
|
|
});
|
|
|
|
/*
|
|
* A[x] = 0;
|
|
* B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
|
|
* B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
|
|
* B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
|
|
# CHECK: B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Nested IfThenElse exprs can't promote to higher level scopes.
|
|
TEST(Registerizer, RegisterizerIfThenElseNested) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
BufHandle d("D", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make({Store::make(
|
|
a,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kLT),
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Load::make(d, {x}),
|
|
Load::make(b, {x})),
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kEQ),
|
|
Load::make(c, {x}),
|
|
Load::make(d, {x}))))});
|
|
|
|
/*
|
|
* A[x] = IfThenElse(x<3 ? 1 : 0,
|
|
* IfThenElse(x==2 ? 1 : 0, D[x], B[x]),
|
|
* IfThenElse(x==5 ? 1 : 0, C[x], D[x]));
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// Cannot registerize an access completely contained within an IfThenElse
|
|
// branch, since it is not a Stmt and cannot hold variable definitions. We need
|
|
// to check that we don't promote the initializer/finalizer to the enclosing
|
|
// Block.
|
|
TEST(Registerizer, RegisterizerIfThenElseInternal) {
|
|
// Making these floats so they don't get simplified to a single access.
|
|
BufHandle a("A", {5}, kFloat);
|
|
BufHandle b("B", {5}, kFloat);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make({Store::make(
|
|
a,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(b, {x}), Load::make(b, {x})),
|
|
Load::make(b, {x})))});
|
|
|
|
/*
|
|
* A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]);
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
|
|
// If this was a Cond instead of an IfThenElse then we could registerize the
|
|
// two accesses to B[x] in the True branch.
|
|
|
|
// Actually lets verify that.
|
|
|
|
stmt = Block::make({Cond::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(b, {x}), Load::make(b, {x}))),
|
|
Store::make(a, {x}, Load::make(b, {x})))});
|
|
|
|
/*
|
|
* if (x<3 ? 1 : 0) {
|
|
* A[x] = (B[x]) + (B[x]);
|
|
* } else {
|
|
* A[x] = B[x];
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x<3 ? 1 : 0) {
|
|
* float B_1 = B[x];
|
|
* A[x] = B_1 + B_1;
|
|
* } else {
|
|
* A[x] = B[x];
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK-NOT: int
|
|
# CHECK-NOT: float
|
|
# CHECK: if (x<3
|
|
# CHECK: float B_1 =
|
|
# CHECK: A[x] = B_1 + B_1
|
|
# CHECK: } else {
|
|
# CHECK: A[x] = B[x]
|
|
# CHECK: }
|
|
# CHECK-NOT: A[x]
|
|
# CHECK-NOT: B[x])IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize a load that occurs in the condition of an IfThenElse;
|
|
TEST(Registerizer, RegisterizerIfThenElseCondition) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {x}, Load::make(a, {x})),
|
|
Store::make(
|
|
a,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(
|
|
Load::make(a, {x}), 5, CompareSelectOperation::kLT),
|
|
Load::make(b, {0}),
|
|
Load::make(c, {0})))});
|
|
|
|
/*
|
|
* A[x] = A[x]; <---- just here so there are enough accesses to combine.
|
|
* A[x] = IfThenElse((A[x])<5 ? 1 : 0, B[0], C[0]);
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[x];
|
|
* A_1 = A_1;
|
|
* A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]);
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]);
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Appearing in the condition of a Cond makes it visible to the enclosing scope,
|
|
// and so we can registerize internal usages.
|
|
TEST(Registerizer, RegisterizerIfThenElseConditionUnhidden) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make({Store::make(
|
|
b,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(
|
|
Load::make(a, {x}), 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}), 1),
|
|
Add::make(Load::make(a, {x}), 10)))});
|
|
|
|
/*
|
|
* B[x] = IfThenElse((A[x])<5 ? 1 : 0, (A[x]) + 1, (A[x]) + 10);
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[x];
|
|
* B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Cannot promote accesses internal to IfThenElse branches even if the enclosing
|
|
// scope if conditional.
|
|
TEST(Registerizer, RegisterizerConditionBranchOnly) {
|
|
BufHandle a("A", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(
|
|
a,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}), x),
|
|
Add::make(Load::make(a, {x - 5}), x))),
|
|
Store::make(
|
|
a,
|
|
{x - 5},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}), x),
|
|
Add::make(Load::make(a, {x - 5}), x)))),
|
|
}))});
|
|
stmt = IRSimplifier::simplify(stmt);
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
/* for (int x = 0; x < 10; x++) {
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x);
|
|
* } else {
|
|
* A[x - 5] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x);
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// We can registerize an IfThenElse that appears in the condition branch of a
|
|
// Cond. This is a weird but valid thing to do.
|
|
TEST(Registerizer, RegisterizerCondIfThenElse) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make({Cond::make(
|
|
CompareSelect::make(
|
|
IfThenElse::make(
|
|
CompareSelect::make(
|
|
Load::make(a, {x}), 5, CompareSelectOperation::kLT),
|
|
Load::make(a, {x}),
|
|
Load::make(b, {x})),
|
|
x,
|
|
CompareSelectOperation::kEQ),
|
|
Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if ((IfThenElse((A[x])<5 ? 1 : 0, A[x], B[x]))==x ? 1 : 0) {
|
|
* C[x] = (C[x]) + 1;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
// access to A can be registerized, but not B or C
|
|
|
|
/*
|
|
* int A_1 = A[x];
|
|
* if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]))==x ? 1 : 0) {
|
|
* C[x] = (C[x]) + 1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]
|
|
# CHECK: C[x] = (C[x]) + 1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize a conditional access in the RHS of a store unhidden by it's
|
|
// LHS, and hoist it out of a loop.
|
|
TEST(Registerizer, RegisterizerIfThenElseLoop) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
|
|
StmtPtr stmt = For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
Store::make(
|
|
a,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kLT),
|
|
Load::make(a, {x}),
|
|
Load::make(b, {y}))));
|
|
|
|
/*
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A[x] = IfThenElse(x<3 ? 1 : 0, A[x], B[y]);
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[x];
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]);
|
|
* }
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: for (
|
|
# CHECK: A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]);
|
|
# CHECK: }
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Cannot registerize if the RHS overlaps the access creating visibility.
|
|
TEST(Registerizer, RegisterizerIfThenElseLoopCut) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
|
|
StmtPtr stmt = Block::make({For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
Store::make(
|
|
a,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kLT),
|
|
Load::make(a, {x}),
|
|
Load::make(a, {y}))))});
|
|
|
|
/*
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A[x] = IfThenElse(x<3 ? 1 : 0, A[x], A[y]);
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// Simple case where an access is cut by an overlapping access later in the
|
|
// program, we can registerize up until the overlap.
|
|
TEST(Registerizer, RegisterizerPartialAfter) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})),
|
|
For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
* for (int x = 1; x < 10; x++) {
|
|
* A[x] = A[x - 1];
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + x;
|
|
* }
|
|
* A[0] = A_1;
|
|
* for (int x = 1; x < 10; x++) {
|
|
* A[x] = A[x - 1];
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: for (
|
|
# CHECK: A[x] = A[x - 1];
|
|
# CHECK: }
|
|
# CHECK-NOT: A)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// We can registerize an access which overlaps a previous access, the
|
|
// initializer must be inserted after the previous access.
|
|
TEST(Registerizer, RegisterizerPartialBefore) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))),
|
|
Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});
|
|
|
|
/*
|
|
* for (int x = 1; x < 10; x++) {
|
|
* A[x] = A[x - 1];
|
|
* }
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* for (int x = 1; x < 10; x++) {
|
|
* A[x] = A[x - 1];
|
|
* }
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + x;
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK-NOT: int
|
|
# CHECK: for (
|
|
# CHECK: A[x] = A[x - 1];
|
|
# CHECK: }
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// The combination of the previous two tests, an access is cut by an overlapping
|
|
// access in both directions.
|
|
TEST(Registerizer, RegisterizerPartialInside) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x1("x1", kInt);
|
|
VarHandle x2("x2", kInt);
|
|
VarHandle x3("x3", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 2),
|
|
For::make(
|
|
x1, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x1))),
|
|
For::make(x2, 1, 10, Store::make(a, {x2}, Load::make(a, {x2 - 1}))),
|
|
For::make(
|
|
x3, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x3)))});
|
|
|
|
/*
|
|
* A[0] = 2;
|
|
* for (int x1 = 0; x1 < 10; x1++) {
|
|
* A[0] = (A[0]) + x1;
|
|
* }
|
|
* for (int x2 = 1; x2 < 10; x2++) {
|
|
* A[x2] = A[x2 - 1];
|
|
* }
|
|
* for (int x3 = 0; x3 < 10; x3++) {
|
|
* A[0] = (A[0]) + x3;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 2;
|
|
* for (int x1 = 0; x1 < 10; x1++) {
|
|
* A_1 = A_1 + x1;
|
|
* }
|
|
* A[0] = A_1;
|
|
* for (int x2 = 1; x2 < 10; x2++) {
|
|
* A[x2] = A[x2 - 1];
|
|
* }
|
|
* int A_2 = A[0];
|
|
* for (int x3 = 0; x3 < 10; x3++) {
|
|
* A_2 = A_2 + x3;
|
|
* }
|
|
* A[0] = A_2;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 2;
|
|
# CHECK: for (
|
|
# CHECK: A_1 = A_1 + x1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: for (
|
|
# CHECK: A[x2] =
|
|
# CHECK: }
|
|
# CHECK: int A_2 = A[0];
|
|
# CHECK: for (
|
|
# CHECK: A_2 = A_2 + x3;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_2;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// An element could be registerized program wide but is cut by a conditional
|
|
// access, we should break this into two scalars and write back to the buffer
|
|
// before the condition.
|
|
TEST(Registerizer, RegisterizerPartialCondition) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 2),
|
|
For::make(
|
|
x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x))),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Load::make(a, {x - 1})),
|
|
nullptr),
|
|
For::make(
|
|
x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x)))});
|
|
|
|
/*
|
|
* A[0] = 2;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = A[x - 1];
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 2;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + x;
|
|
* }
|
|
* A[0] = A_1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = A[x - 1];
|
|
* }
|
|
* int A_2 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = A_2 + x;
|
|
* }
|
|
* A[0] = A_2;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 2;
|
|
# CHECK: for (
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: if (
|
|
# CHECK: A[x] =
|
|
# CHECK: }
|
|
# CHECK: int A_2 = A[0];
|
|
# CHECK: for (
|
|
# CHECK: A_2 = A_2 + x;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_2;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Tests case where an access is cut by an internal conditional access which
|
|
// itself is registerized.
|
|
TEST(Registerizer, RegisterizerPartialConditionInternalCut) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 1),
|
|
Store::make(a, {0}, 3),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}),
|
|
nullptr),
|
|
Store::make(a, {0}, 4),
|
|
Store::make(a, {0}, 6)});
|
|
|
|
/*
|
|
* A[0] = 1;
|
|
* A[0] = 3;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = 1;
|
|
* A[x] = 3;
|
|
* }
|
|
* A[0] = 4;
|
|
* A[0] = 6;
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 1;
|
|
* A_1 = 3;
|
|
* A[0] = A_1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* int A_2 = 1;
|
|
* A_2 = 3;
|
|
* A[x] = A_2;
|
|
* }
|
|
* int A_3 = 4;
|
|
* A_3 = 6;
|
|
* A[0] = A_3;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 1;
|
|
# CHECK: A_1 = 3
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: if (
|
|
# CHECK: int A_2 = 1;
|
|
# CHECK: A_2 = 3;
|
|
# CHECK: A[x] = A_2;
|
|
# CHECK: }
|
|
# CHECK: int A_3 = 4;
|
|
# CHECK: A_3 = 6;
|
|
# CHECK: A[0] = A_3;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// First statement in condition closes outer access, but can be registerized
|
|
// with later statements.
|
|
TEST(Registerizer, RegisterizerPartialConditionInternalStart) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 1),
|
|
Store::make(a, {0}, 3),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}),
|
|
nullptr),
|
|
Store::make(a, {x}, 4),
|
|
Store::make(a, {x}, 6)});
|
|
|
|
/*
|
|
* A[0] = 1;
|
|
* A[0] = 3;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = 1;
|
|
* A[x] = 3;
|
|
* }
|
|
* A[x] = 4;
|
|
* A[x] = 6;
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 1;
|
|
* A_1 = 3;
|
|
* A[0] = A_1;
|
|
* int A_2 = A[x]; <--- must read from the input here.
|
|
* if (x<5 ? 1 : 0) {
|
|
* A_2 = 1;
|
|
* A_2 = 3;
|
|
* }
|
|
* A_2 = 4;
|
|
* A_2 = 6;
|
|
* A[x] = A_2;
|
|
*/
|
|
|
|
// TODO: I suppose we could refactor with a conditional initializer?
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 1;
|
|
# CHECK: A_1 = 3
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: int A_2 = A[x];
|
|
# CHECK: if (
|
|
# CHECK: A_2 = 1;
|
|
# CHECK: A_2 = 3;
|
|
# CHECK: }
|
|
# CHECK: A_2 = 4;
|
|
# CHECK: A_2 = 6;
|
|
# CHECK: A[x] = A_2;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// An access cuts two open overlaps and creates four scalar variables.
|
|
TEST(Registerizer, RegisterizerPartialOverlapsTwo) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {1}, Load::make(a, {0})),
|
|
Store::make(a, {0}, Load::make(a, {1})),
|
|
Store::make(a, {0}, Load::make(a, {1})),
|
|
For::make(x, 1, 10, Store::make(a, {x}, x)),
|
|
Store::make(a, {1}, Load::make(a, {0})),
|
|
Store::make(a, {0}, Load::make(a, {1})),
|
|
Store::make(a, {0}, Load::make(a, {1}))});
|
|
|
|
/*
|
|
* A[1] = A[0];
|
|
* A[0] = A[1];
|
|
* A[0] = A[1];
|
|
* for (int x = 1; x < 10; x++) {
|
|
* A[x] = x;
|
|
* }
|
|
* A[1] = A[0];
|
|
* A[0] = A[1];
|
|
* A[0] = A[1];
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[0];
|
|
* int A_2 = A_1;
|
|
* A_1 = A_2;
|
|
* A_1 = A_2;
|
|
* A[1] = A_2;
|
|
* A[0] = A_1;
|
|
* for (int x = 1; x < 10; x++) {
|
|
* A[x] = x;
|
|
* }
|
|
* int A_3 = A[0];
|
|
* int A_4 = A_3;
|
|
* A_3 = A_4;
|
|
* A_3 = A_4;
|
|
* A[1] = A_4;
|
|
* A[0] = A_3;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: int A_2 = A_1;
|
|
# CHECK: A_1 = A_2;
|
|
# CHECK: A_1 = A_2;
|
|
# CHECK: A[1] = A_2;
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: for (
|
|
# CHECK: A[x] = x;
|
|
# CHECK: }
|
|
# CHECK: int A_3 = A[0];
|
|
# CHECK: int A_4 = A_3;
|
|
# CHECK: A_3 = A_4;
|
|
# CHECK: A_3 = A_4;
|
|
# CHECK: A[1] = A_4;
|
|
# CHECK: A[0] = A_3;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Nested blocks will automatically be flattened and do not provent
|
|
// registerization of enclosed accesses.
|
|
TEST(Registerizer, RegisterizerNestedBlocks) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), 2))}),
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), 3)),
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), 4))})})});
|
|
|
|
/*
|
|
* A[0] = (A[0]) + 1;
|
|
* {
|
|
* A[0] = (A[0]) + 2;
|
|
* }
|
|
* {
|
|
* A[0] = (A[0]) + 3;
|
|
* {
|
|
* A[0] = (A[0]) + 4;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[0];
|
|
* A_1 = A_1 + 1;
|
|
* A_1 = A_1 + 2;
|
|
* A_1 = A_1 + 3;
|
|
* A_1 = A_1 + 4;
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: A_1 = A_1 + 2;
|
|
# CHECK: A_1 = A_1 + 3;
|
|
# CHECK: A_1 = A_1 + 4;
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// The access can be registerized internally to a condition, but must ensure
|
|
// that both initializer and finalizer are within the same condition.
|
|
TEST(Registerizer, RegisterizerNestedConditions) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make({Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
nullptr)}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* if (x==2 ? 1 : 0) {
|
|
*
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* int A_1 = A[0];
|
|
* A_1 = A_1 + 1;
|
|
* if (x==2 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A[0] = A_1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: if (x<5
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: if (x==2
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// If an access exists outside the scope of the condition then we can lift
|
|
// nested conditional usages into the same scalar.
|
|
TEST(Registerizer, RegisterizerNestedConditionsUnhidden) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make(
|
|
{Store::make(a, {1}, 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
nullptr)}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* A[0] = (A[0]) + 1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[1] = 1;
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[0];
|
|
* A_1 = A_1 + 1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[1] = 1;
|
|
* if (x==2 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: if (x<5
|
|
# CHECK: A[1] = 1;
|
|
# CHECK: if (x==2
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerNestedConditionsHiddenFirst) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
nullptr),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make({Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
nullptr)}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* if (x<5 ? 1 : 0) {
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
stmt = registerize(stmt);
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerNestedConditionsHiddenSecond) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make({Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
nullptr)}),
|
|
nullptr),
|
|
Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
stmt = registerize(stmt);
|
|
}
|
|
|
|
// If an access is cut by another access internal to a condition block, it still
|
|
// cuts the access.
|
|
TEST(Registerizer, RegisterizerNestedConditionsCut) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make(
|
|
{Store::make(a, {x}, 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
nullptr)}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* A[0] = (A[0]) + 1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = 1;
|
|
* if (x==2 ? 1 : 0) {
|
|
*
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerNestedConditionLoopHidden) {
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
nullptr),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(b, {x}, 0),
|
|
Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
nullptr)}))});
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = 0; <-- this is only here to prevent Loop/Cond reordering.
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// Three loops and four element regions, three of which should be registerized
|
|
// at different levels of the IR.
|
|
TEST(Registerizer, RegisterizerNestedConditionThreeDeep) {
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {4}, 0),
|
|
Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kGT),
|
|
Cond::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kGT),
|
|
Block::make({
|
|
Cond::make(
|
|
CompareSelect::make(x, 4, CompareSelectOperation::kGT),
|
|
Block::make({
|
|
Store::make(
|
|
a, {1}, Add::make(Load::make(a, {1}), 1)),
|
|
Store::make(
|
|
a, {2}, Add::make(Load::make(a, {2}), 1)),
|
|
Store::make(
|
|
a, {3}, Add::make(Load::make(a, {3}), 1)),
|
|
Store::make(
|
|
a, {4}, Add::make(Load::make(a, {4}), 1)),
|
|
Store::make(
|
|
a, {1}, Add::make(Load::make(a, {1}), 1)),
|
|
}),
|
|
nullptr),
|
|
Store::make(a, {2}, Add::make(Load::make(a, {2}), 1)),
|
|
}),
|
|
nullptr),
|
|
nullptr)});
|
|
|
|
/*
|
|
* A[4] = 0;
|
|
* if (x>2 ? 1 : 0) {
|
|
* if (x>3 ? 1 : 0) {
|
|
* if (x>4 ? 1 : 0) {
|
|
* A[1] = (A[1]) + 1;
|
|
* A[2] = (A[2]) + 1;
|
|
* A[3] = (A[3]) + 1;
|
|
* A[4] = (A[4]) + 1;
|
|
* A[1] = (A[1]) + 1;
|
|
* }
|
|
* A[2] = (A[2]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* if (x>2 ? 1 : 0) {
|
|
* if (x>3 ? 1 : 0) {
|
|
* int A_3 = A[2];
|
|
* if (x>4 ? 1 : 0) {
|
|
* int A_2 = A[1];
|
|
* A_2 = A_2 + 1;
|
|
* A_3 = A_3 + 1;
|
|
* A[3] = (A[3]) + 1;
|
|
* A_1 = A_1 + 1;
|
|
* A_2 = A_2 + 1;
|
|
* A[1] = A_2;
|
|
* }
|
|
* A_3 = A_3 + 1;
|
|
* A[2] = A_3;
|
|
* }
|
|
* }
|
|
* A[4] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: if (x>2 ? 1 : 0) {
|
|
# CHECK: if (x>3 ? 1 : 0) {
|
|
# CHECK: int A_3 = A[2];
|
|
# CHECK: if (x>4 ? 1 : 0) {
|
|
# CHECK: int A_2 = A[1];
|
|
# CHECK: A_2 = A_2 + 1;
|
|
# CHECK: A_3 = A_3 + 1;
|
|
# CHECK: A[3] = (A[3]) + 1;
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: A_2 = A_2 + 1;
|
|
# CHECK: A[1] = A_2;
|
|
# CHECK: }
|
|
# CHECK: A_3 = A_3 + 1;
|
|
# CHECK: A[2] = A_3;
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: A[4] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can replace a simple scalar access with a local variable even when that
|
|
// variable is an outer loop var.
|
|
TEST(Registerizer, RegisterizerNestedLoopSimple) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = Block::make({For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})))});
|
|
|
|
/*
|
|
* for (int y = 0; y < 10; y++) {
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[y] = (A[y]) + x;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* for (int y = 0; y < 10; y++) {
|
|
* int A_1 = A[y];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + x;
|
|
* }
|
|
* A[y] = A_1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int y
|
|
# CHECK: int A_1 = A[y];
|
|
# CHECK: for (int x
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: }
|
|
# CHECK: A[y] = A_1;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Test the positive case of the hiddenAccess split, where an internal
|
|
// conditional access can be hoisted up through a loop to match an existing
|
|
// access in a higher scope and the two can be registerized.
|
|
TEST(Registerizer, RegisterizerHiddenAccessYes) {
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = Block::make({Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(b, {x}, 0),
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
|
Cond::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kEQ),
|
|
For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
Store::make(
|
|
a, {0}, Add::make(Load::make(a, {0}), 1))),
|
|
nullptr)}))}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = 0;
|
|
* if (x==3 ? 1 : 0) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = 0;
|
|
* if (x==3 ? 1 : 0) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* }
|
|
* }
|
|
* A[0] = A_1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: if (x==2
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x
|
|
# CHECK: B[x] = 0;
|
|
# CHECK: if (x==3
|
|
# CHECK: for (int y
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Test the negative case of the hiddenAccess split, where the hoisted access is
|
|
// never unhidden at a higher scope and registerization occurs at the lower
|
|
// scope.
|
|
TEST(Registerizer, RegisterizerHiddenAccessNo) {
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = Block::make({Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(b, {x}, 0),
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
|
Cond::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kEQ),
|
|
For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
|
|
nullptr)}))}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = 0;
|
|
* if (x==3 ? 1 : 0) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = 0;
|
|
* if (x==3 ? 1 : 0) {
|
|
* int A_1 = A[0];
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A[0] = A_1;
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: if (x==2
|
|
# CHECK: for (int x
|
|
# CHECK: B[x] = 0;
|
|
# CHECK: if (x==3
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: for (int y
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// In this case the conditional access must be hoisted by two loops, there are
|
|
// two accesses here one is unhidden and the other isn't. A[0] can be
|
|
// registerized but B[0] cannot.
|
|
TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) {
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = Block::make({Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
Block::make({Cond::make(
|
|
CompareSelect::make(y, 3, CompareSelectOperation::kEQ),
|
|
Block::make(
|
|
{Store::make(
|
|
a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
Store::make(
|
|
b, {0}, Add::make(Load::make(b, {0}), 1))}),
|
|
nullptr)})))}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* if (y==3 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* B[0] = (B[0]) + 1;
|
|
* }
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* if (y==3 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* B[0] = (B[0]) + 1;
|
|
* }
|
|
* }
|
|
* }
|
|
* A[0] = A_1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: if (x==2
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x
|
|
# CHECK: for (int y
|
|
# CHECK: if (y==3
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: B[0] = (B[0]) + 1;
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Accesses are registerized inside two conditions, but the immediate parent is
|
|
// not a condition.
|
|
TEST(Registerizer, RegisterizerTwoConditionalLoops) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
For::make(
|
|
x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
|
|
nullptr),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kGT),
|
|
For::make(
|
|
x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
* if (x>5 ? 1 : 0) {
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* int A_1 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A[0] = A_1;
|
|
* }
|
|
* if (x>5 ? 1 : 0) {
|
|
* int A_2 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = A_2 + 1;
|
|
* }
|
|
* A[0] = A_2;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: if (x<5
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: for (int x
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: }
|
|
# CHECK: if (x>5
|
|
# CHECK: int A_2 = A[0];
|
|
# CHECK: for (int x
|
|
# CHECK: A_2 = A_2 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_2;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Accesses are registerized inside two conditions, cut in the middle.
|
|
TEST(Registerizer, RegisterizerTwoConditionalLoopsCut) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
For::make(
|
|
x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
|
|
nullptr),
|
|
For::make(x, 0, 10, Store::make(a, {x}, 1)),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kGT),
|
|
For::make(
|
|
x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = 1;
|
|
* }
|
|
* if (x>5 ? 1 : 0) {
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* int A_1 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A[0] = A_1;
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = 1;
|
|
* }
|
|
* if (x>5 ? 1 : 0) {
|
|
* int A_2 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = A_2 + 1;
|
|
* }
|
|
* A[0] = A_2;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: if (x<5
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: for (int x
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: }
|
|
# CHECK: for (int x
|
|
# CHECK: A[x] = 1;
|
|
# CHECK: if (x>5
|
|
# CHECK: int A_2 = A[0];
|
|
# CHECK: for (int x
|
|
# CHECK: A_2 = A_2 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_2;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// references a Let var in a local scope which cannot be hoisted out of the
|
|
// loop.
|
|
TEST(Registerizer, RegisterizerLoopLetVar) {
|
|
BufHandle a("A", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = IRSimplifier::simplify(Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Let::make(y, 30),
|
|
Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))}));
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* int y = 30;
|
|
* A[y] = x + (A[y]);
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// references a Let var in an outer scope that does not prevent hoisting the
|
|
// initializer.
|
|
TEST(Registerizer, RegisterizerLoopLetVarOuter) {
|
|
BufHandle a("A", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Let::make(y, 30),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))});
|
|
|
|
/*
|
|
* int y = 30;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[y] = x + (A[y]);
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int y = 30;
|
|
* int A_1 = A[y];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + x;
|
|
* }
|
|
* A[y] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int y = 30;
|
|
# CHECK: int A_1 = A[y];
|
|
# CHECK: for (int x
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: A[y] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Okay so the registerizer generally goes after index flattening, but just in
|
|
// case. Test multi index registerization.
|
|
TEST(Registerizer, RegisterizerMultiDim) {
|
|
BufHandle a("A", {3, 4, 5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0, 1, 2}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {0, 1, 2}, Add::make(Load::make(a, {0, 1, 2}), x))}))});
|
|
|
|
/*
|
|
* A[0, 1, 2] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0, 1, 2] = (A[0, 1, 2]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = x + A_1;
|
|
* }
|
|
* A[0, 1, 2] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[0, 1, 2] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Won't registerize if only some dims match, but will still registerize
|
|
// distinct elements.
|
|
TEST(Registerizer, RegisterizerMultiDimPartial) {
|
|
BufHandle a("A", {3, 4, 5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0, 1, 2}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {0, 2, 2}, Add::make(Load::make(a, {0, 1, 4}), x))}))});
|
|
|
|
/*
|
|
* A[0, 1, 2] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0, 2, 2] = (A[0, 1, 4]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* A[0, 1, 2] = 0;
|
|
* int A_1 = A[0, 1, 4];
|
|
* int A_2 = A[0, 2, 2];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = A_1 + x;
|
|
* }
|
|
* A[0, 2, 2] = A_2;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: A[0, 1, 2] = 0;
|
|
# CHECK: int A_1 = A[0, 1, 4];
|
|
# CHECK: int A_2 = A[0, 2, 2];
|
|
# CHECK: for (
|
|
# CHECK: A_2 = A_1 + x;
|
|
# CHECK: A[0, 2, 2] = A_2;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// If they could overlap across all dimensions we cannot registerize.
|
|
TEST(Registerizer, RegisterizerMultiDimOverlap) {
|
|
BufHandle a("A", {3, 4, 5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0, 1, 2}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 2}), x))}))});
|
|
stmt = IRSimplifier::simplify(stmt);
|
|
|
|
/*
|
|
* A[0, 1, 2] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0, x, 2] = (A[y, 2, 2]) + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// But, if one dimension is known to be distinct they do not overlap.
|
|
TEST(Registerizer, RegisterizerMultiDimPartialOverlap) {
|
|
BufHandle a("A", {3, 4, 5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0, 1, 2}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 4}), x))}))});
|
|
|
|
/*
|
|
* A[0, 1, 2] = 0; <---- 2nd dim overlaps with store.
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0, x, 2] = (A[y, 2, 4]) + x; <---- 3rd dim has constant diff.
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* A[0, 1, 2] = 0;
|
|
* int A_1 = A[y, 2, 4];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0, x, 2] = A_1 + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: A[0, 1, 2] = 0;
|
|
# CHECK: int A_1 = A[y, 2, 4];
|
|
# CHECK: for (
|
|
# CHECK: A[0, x, 2] = A_1 + x;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// A 3D reduction with different input dimensionality.
|
|
TEST(Registerizer, RegisterizerMultiDim3DReduction1) {
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10, 10}, kInt);
|
|
BufHandle c("C", {10, 10, 10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
VarHandle z("z", kInt);
|
|
StmtPtr stmt = For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
For::make(
|
|
z,
|
|
0,
|
|
10,
|
|
Store::make(
|
|
c,
|
|
{x, y, z},
|
|
Add::make(
|
|
Load::make(c, {x, y, z}),
|
|
Mul::make(Load::make(b, {x, y}), Load::make(a, {x})))))));
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* for (int z = 0; z < 10; z++) {
|
|
* C[x, y, z] = (C[x, y, z]) + (B[x, y]) * (A[x]);
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
// We can registerize the A and B access since they can be hoisted before
|
|
// hitting a dependent loop var.
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* int A_1 = A[x];
|
|
* for (int y = 0; y < 10; y++) {
|
|
* int B_1 = B[x, y];
|
|
* for (int z = 0; z < 10; z++) {
|
|
* C[x, y, z] = A_1 * B_1 + (C[x, y, z]);
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int x
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: for (int y
|
|
# CHECK: int B_1 = B[x, y];
|
|
# CHECK: for (int z
|
|
# CHECK: C[x, y, z] = A_1 * B_1 + (C[x, y, z]);
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// A 3D reduction with the same smaller dimensionality using different loop
|
|
// vars.
|
|
TEST(Registerizer, RegisterizerMultiDim3DReduction2) {
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
BufHandle c("C", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
VarHandle z("z", kInt);
|
|
StmtPtr stmt = For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
|
For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
For::make(
|
|
z,
|
|
0,
|
|
10,
|
|
Store::make(
|
|
c,
|
|
{x},
|
|
Add::make(
|
|
Load::make(c, {x}),
|
|
Mul::make(Load::make(b, {y}), Load::make(a, {x})))))));
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* for (int z = 0; z < 10; z++) {
|
|
* C[x] = (C[x]) + (B[y]) * (A[x]);
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
// We can registerize all accesses, the A and C access can be hoisted to the
|
|
// outer loop since they depend only on it's loop var while the B can only be
|
|
// raised to the loop of y.
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* int A_1 = A[x];
|
|
* int C_1 = C[x];
|
|
* for (int y = 0; y < 10; y++) {
|
|
* int B_1 = B[y];
|
|
* for (int z = 0; z < 10; z++) {
|
|
* C_1 = A_1 * B_1 + C_1;
|
|
* }
|
|
* }
|
|
* C[x] = C_1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int x
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: int C_1 = C[x];
|
|
# CHECK: for (int y
|
|
# CHECK: int B_1 = B[y];
|
|
# CHECK: for (int z
|
|
# CHECK: C_1 = A_1 * B_1 + C_1;
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: C[x] = C_1;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|