[nnc] Add a API to unroll loops by a given factor (#72071)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72071

Reviewed By: ngimel

Differential Revision: D33946250

Pulled By: navahgar

fbshipit-source-id: 3f3f92054174620025a9d71154d006f1738953e2
This commit is contained in:
Raghavan Raman 2022-02-03 10:36:27 -08:00 committed by Facebook GitHub Bot
parent 2af7cfcf4e
commit d8b53598e9
7 changed files with 163 additions and 22 deletions

View File

@ -25,7 +25,7 @@ void optimizePointwise(tensorexpr::LoopNest* ln, tensorexpr::Tensor target) {
ln->vectorize(inner);
ln->splitWithTail(outer, 8, &inner, &tail);
StmtPtr unrolled;
LoopNest::unroll(inner, &unrolled);
LoopNest::fullUnroll(inner, &unrolled);
}
static void relu_nnc(benchmark::State& state) {

View File

@ -230,7 +230,7 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16VecUnroll)(benchmark::State& state) {
te::ForPtr ni = loops[4];
te::StmtPtr unrolled;
loop.vectorize(ni);
loop.unroll(mi, &unrolled);
loop.fullUnroll(mi, &unrolled);
}
loop.prepareForCodegen();

View File

@ -2929,7 +2929,7 @@ std::string constantUpperBoundLoopIR(int upper_bound_val) {
LoopNest l({A});
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
StmtPtr unrolled = nullptr;
LoopNest::unroll(loops[0], &unrolled);
LoopNest::fullUnroll(loops[0], &unrolled);
std::ostringstream oss;
oss << *unrolled;
return oss.str();
@ -2958,7 +2958,7 @@ TEST(LoopNest, UnrollOuter) {
LoopNest l({A});
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
StmtPtr unrolled = nullptr;
LoopNest::unroll(loops[0], &unrolled);
LoopNest::fullUnroll(loops[0], &unrolled);
checkIR(unrolled, R"IR(
# CHECK: for (int y = 0; y < 4; y++) {
# CHECK: A[0, y] = y;
@ -2981,7 +2981,7 @@ TEST(LoopNest, UnrollInner) {
LoopNest l({A});
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
StmtPtr unrolled = nullptr;
LoopNest::unroll(
LoopNest::fullUnroll(
static_to<For>(loops[0]->body()->stmts().front()), &unrolled);
checkIR(loops[0], R"IR(
# CHECK: for (int x = 0; x < 3; x++) {
@ -3007,7 +3007,7 @@ TEST(LoopNest, UnrollMultipleStatements) {
Store::make(b_buf, {x}, Load::make(a_buf, {x}))}));
auto parent_block = Block::make({f});
StmtPtr unrolled = nullptr;
LoopNest::unroll(f, &unrolled);
LoopNest::fullUnroll(f, &unrolled);
checkIR(unrolled, R"IR(
# CHECK: A[0] = 0;
# CHECK: B[0] = A[0];
@ -3039,7 +3039,7 @@ TEST(LoopNest, UnrollNonLiteralConstantBounds) {
std::vector<ForPtr> loops = {outer_for, inner_for};
StmtPtr unrolled = nullptr;
LoopNest::unroll(loops[0], &unrolled);
LoopNest::fullUnroll(loops[0], &unrolled);
checkIR(unrolled, R"IR(
# CHECK: for (int j = 0; j < 4; j++) {
# CHECK: A[1, j] = j;
@ -3052,6 +3052,117 @@ TEST(LoopNest, UnrollNonLiteralConstantBounds) {
# CHECK: })IR");
}
TEST(LoopNest, UnrollNonConstantBounds) {
// Input IR:
// for (int i = 0; i < M; i++) {
// for (int j = 0; j < N; j++) {
// A[i, j] = i * j;
// }
// }
VarHandle M("M", kInt);
VarHandle N("N", kInt);
BufHandle a_buf("A", {M, N}, kInt);
VarHandle i("i", kInt);
VarHandle j("j", kInt);
auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
auto inner_for = For::make(j, 0, N, for_body);
auto outer_for = For::make(i, 0, M, inner_for);
auto block = Block::make({outer_for});
LoopNest l(block, {a_buf.node()});
LoopNest::unroll(inner_for, 8);
l.simplify();
checkIR(l.root_stmt(), R"IR(
# CHECK: for (int i = 0; i < M; i++) {
# CHECK: for (int j_outer = 0; j_outer < N / 8; j_outer++) {
# CHECK: A[i, 8 * j_outer] =
# CHECK: A[i, 8 * j_outer + 1] =
# CHECK: A[i, 2 * (4 * j_outer + 1)] =
# CHECK: A[i, 8 * j_outer + 3] =
# CHECK: A[i, 4 * (2 * j_outer + 1)] =
# CHECK: A[i, 8 * j_outer + 5] =
# CHECK: A[i, 8 * j_outer + 6] =
# CHECK: A[i, 8 * j_outer + 7] =
# CHECK: }
# CHECK: for (int j_tail = 0; j_tail < N % 8; j_tail++) {
# CHECK: A[i, 8 * (N / 8) + j_tail] =
# CHECK: }
# CHECK: }
)IR");
}
TEST(LoopNest, UnrollByFactorsLessThan2) {
// Input IR:
// for (int i = 0; i < M; i++) {
// for (int j = 0; j < N; j++) {
// A[i, j] = i * j;
// }
// }
VarHandle M("M", kInt);
VarHandle N("N", kInt);
BufHandle a_buf("A", {M, N}, kInt);
VarHandle i("i", kInt);
VarHandle j("j", kInt);
auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
auto inner_for = For::make(j, 0, N, for_body);
auto outer_for = For::make(i, 0, M, inner_for);
auto block = Block::make({outer_for});
LoopNest l(block, {a_buf.node()});
// Unrolling by factor = 1 should do nothing.
LoopNest::unroll(inner_for, 1);
checkIR(l.root_stmt(), R"IR(
# CHECK: for (int i = 0; i < M; i++) {
# CHECK: for (int j = 0; j < N; j++) {
# CHECK: A[i, j] =
# CHECK: }
# CHECK: }
)IR");
// Unrolling by factor = 0 should do nothing.
LoopNest::unroll(inner_for, 0);
checkIR(l.root_stmt(), R"IR(
# CHECK: for (int i = 0; i < M; i++) {
# CHECK: for (int j = 0; j < N; j++) {
# CHECK: A[i, j] =
# CHECK: }
# CHECK: }
)IR");
// Unrolling by negative factor should do nothing.
LoopNest::unroll(inner_for, -2);
checkIR(l.root_stmt(), R"IR(
# CHECK: for (int i = 0; i < M; i++) {
# CHECK: for (int j = 0; j < N; j++) {
# CHECK: A[i, j] =
# CHECK: }
# CHECK: }
)IR");
}
TEST(LoopNest, UnrollByFactorEqualToIters) {
// Input IR:
// for (int i = 0; i < 5; i++) {
// A[i] = i * i;
// }
BufHandle a_buf("A", {5}, kInt);
VarHandle i("i", kInt);
auto for_body = Block::make({Store::make(a_buf, {i}, i * i)});
auto for_loop = For::make(i, 0, 5, for_body);
auto block = Block::make({for_loop});
LoopNest l(block, {a_buf.node()});
LoopNest::unroll(for_loop, 5);
checkIR(l.root_stmt(), R"IR(
# CHECK: for (int i_outer = 0; i_outer < (5 - 0) / 5; i_outer++)
# CHECK: A[5 * i_outer]
# CHECK: A[5 * i_outer + 1]
# CHECK: A[5 * i_outer + 2]
# CHECK: A[5 * i_outer + 3]
# CHECK: A[5 * i_outer + 4]
)IR");
}
TEST(LoopNest, UnrollEmpty) {
const std::string actual = constantUpperBoundLoopIR(0);
const std::string& verification_pattern = R"IR(
@ -3069,7 +3180,7 @@ TEST(LoopNest, NoUnroll) {
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
StmtPtr unrolled = nullptr;
ASSERT_THROWS_WITH(
LoopNest::unroll(loops[0], &unrolled), "non-constant loop");
LoopNest::fullUnroll(loops[0], &unrolled), "non-constant loop");
}
TEST(LoopNest, UnrollWithLet) {
@ -3089,7 +3200,7 @@ TEST(LoopNest, UnrollWithLet) {
Store::make(b_buf, {x}, e + 1)}));
auto parent_block = Block::make({f});
StmtPtr unrolled = nullptr;
LoopNest::unroll(f, &unrolled);
LoopNest::fullUnroll(f, &unrolled);
std::ostringstream oss;
oss << *unrolled;
const std::string& verification_pattern =

View File

@ -2309,7 +2309,7 @@ bool LoopNest::areLoopsPerfectlyNested(const std::vector<ForPtr>& loops) {
return true;
}
void LoopNest::unroll(ForPtr f, StmtPtr* unrolled) {
void LoopNest::fullUnroll(ForPtr f, StmtPtr* unrolled) {
BlockPtr p = to<Block>(f->get_parent());
if (!f) {
throw malformed_input("unroll attempted on null loop");
@ -2341,10 +2341,26 @@ void LoopNest::unroll(ForPtr f, StmtPtr* unrolled) {
p->replace_stmt(f, *unrolled);
}
void LoopNest::unroll(ForPtr f) {
void LoopNest::fullUnroll(ForPtr f) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
StmtPtr unrolled;
unroll(f, &unrolled);
fullUnroll(f, &unrolled);
}
void LoopNest::unroll(ForPtr f, int factor, ForPtr* tail) {
if (factor < 2) {
return;
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
ForPtr inner;
splitWithTail(f, factor, &inner, tail);
fullUnroll(inner);
}
void LoopNest::unroll(ForPtr f, int factor) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
ForPtr tail;
unroll(f, factor, &tail);
}
bool LoopNest::isNormalized(ForPtr f) {

View File

@ -418,8 +418,15 @@ class TORCH_API LoopNest {
// Returns true if the given loop has a loop-carried dependence.
static bool hasLoopCarriedDependence(ForPtr loop);
static void unroll(ForPtr f, StmtPtr* unrolled);
static void unroll(ForPtr f);
// Unrolls all the iterations of the given loop.
// Requires that the loop bounds are constant.
static void fullUnroll(ForPtr f, StmtPtr* unrolled);
static void fullUnroll(ForPtr f);
// Unrolls the given loop for the specified factor.
// This does not require constant bounds for the loop being unrolled.
static void unroll(ForPtr f, int factor, ForPtr* tail);
static void unroll(ForPtr f, int factor);
static bool normalize(ForPtr f);
static bool isNormalized(ForPtr f);

View File

@ -158,7 +158,7 @@ void loopnestRandomization(int64_t seed, LoopNest& l) {
// static void reorderAxis(ForPtr a, ForPtr b);
// static std::vector<ForPtr> reorder(const std::vector<ForPtr>& loops, const std::vector<size_t>& permutation);
// ForPtr tile(ForPtr x, ForPtr y, int x_factor, int y_factor);
// static void unroll(ForPtr f);
// static void fullUnroll(ForPtr f);
// static bool normalize(ForPtr f);
// static bool flatten(const std::vector<ForPtr>& f, ForPtr* flattened);
// static void compressBuffer(BufPtr buf, StmtPtr stmt);
@ -191,7 +191,7 @@ void loopnestRandomization(int64_t seed, LoopNest& l) {
REORDER_AXIS,
REORDER,
TILE,
UNROLL,
FULL_UNROLL,
NORMALIZE,
FLATTEN,
COMPRESS_BUFFER,
@ -512,7 +512,7 @@ void loopnestRandomization(int64_t seed, LoopNest& l) {
break;
}
case UNROLL: {
case FULL_UNROLL: {
auto loops = NodeFinder<For>::find(l.root_stmt());
if (loops.size() == 0) {
break;
@ -520,9 +520,9 @@ void loopnestRandomization(int64_t seed, LoopNest& l) {
int loop_n = std::rand() % (int)loops.size();
auto loop = loops[loop_n];
message = "unroll(loops[" + std::to_string(loop_n) + "]);\n";
message = "fullUnroll(loops[" + std::to_string(loop_n) + "]);\n";
randomization_helper::printHistory(n_transform, message);
l.unroll(loop);
LoopNest::fullUnroll(loop);
break;
}

View File

@ -607,13 +607,20 @@ void initTensorExprBindings(PyObject* module) {
},
py::return_value_policy::reference)
.def(
"unroll",
[](const LoopNest& self, ForPtr f) {
"fullUnroll",
[](ForPtr f) {
StmtPtr unrolled = nullptr;
self.unroll(f, &unrolled);
LoopNest::fullUnroll(f, &unrolled);
return unrolled;
},
py::return_value_policy::reference)
.def(
"unroll",
[](ForPtr f, int factor) {
LoopNest::unroll(f, factor);
return f;
},
py::return_value_policy::reference)
.def(
"vectorize",
[](ForPtr f) { LoopNest::vectorize(f); },