mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[nnc] Add a API to unroll loops by a given factor (#72071)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72071 Reviewed By: ngimel Differential Revision: D33946250 Pulled By: navahgar fbshipit-source-id: 3f3f92054174620025a9d71154d006f1738953e2
This commit is contained in:
parent
2af7cfcf4e
commit
d8b53598e9
|
|
@ -25,7 +25,7 @@ void optimizePointwise(tensorexpr::LoopNest* ln, tensorexpr::Tensor target) {
|
|||
ln->vectorize(inner);
|
||||
ln->splitWithTail(outer, 8, &inner, &tail);
|
||||
StmtPtr unrolled;
|
||||
LoopNest::unroll(inner, &unrolled);
|
||||
LoopNest::fullUnroll(inner, &unrolled);
|
||||
}
|
||||
|
||||
static void relu_nnc(benchmark::State& state) {
|
||||
|
|
|
|||
|
|
@ -230,7 +230,7 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16VecUnroll)(benchmark::State& state) {
|
|||
te::ForPtr ni = loops[4];
|
||||
te::StmtPtr unrolled;
|
||||
loop.vectorize(ni);
|
||||
loop.unroll(mi, &unrolled);
|
||||
loop.fullUnroll(mi, &unrolled);
|
||||
}
|
||||
|
||||
loop.prepareForCodegen();
|
||||
|
|
|
|||
|
|
@ -2929,7 +2929,7 @@ std::string constantUpperBoundLoopIR(int upper_bound_val) {
|
|||
LoopNest l({A});
|
||||
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
|
||||
StmtPtr unrolled = nullptr;
|
||||
LoopNest::unroll(loops[0], &unrolled);
|
||||
LoopNest::fullUnroll(loops[0], &unrolled);
|
||||
std::ostringstream oss;
|
||||
oss << *unrolled;
|
||||
return oss.str();
|
||||
|
|
@ -2958,7 +2958,7 @@ TEST(LoopNest, UnrollOuter) {
|
|||
LoopNest l({A});
|
||||
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
|
||||
StmtPtr unrolled = nullptr;
|
||||
LoopNest::unroll(loops[0], &unrolled);
|
||||
LoopNest::fullUnroll(loops[0], &unrolled);
|
||||
checkIR(unrolled, R"IR(
|
||||
# CHECK: for (int y = 0; y < 4; y++) {
|
||||
# CHECK: A[0, y] = y;
|
||||
|
|
@ -2981,7 +2981,7 @@ TEST(LoopNest, UnrollInner) {
|
|||
LoopNest l({A});
|
||||
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
|
||||
StmtPtr unrolled = nullptr;
|
||||
LoopNest::unroll(
|
||||
LoopNest::fullUnroll(
|
||||
static_to<For>(loops[0]->body()->stmts().front()), &unrolled);
|
||||
checkIR(loops[0], R"IR(
|
||||
# CHECK: for (int x = 0; x < 3; x++) {
|
||||
|
|
@ -3007,7 +3007,7 @@ TEST(LoopNest, UnrollMultipleStatements) {
|
|||
Store::make(b_buf, {x}, Load::make(a_buf, {x}))}));
|
||||
auto parent_block = Block::make({f});
|
||||
StmtPtr unrolled = nullptr;
|
||||
LoopNest::unroll(f, &unrolled);
|
||||
LoopNest::fullUnroll(f, &unrolled);
|
||||
checkIR(unrolled, R"IR(
|
||||
# CHECK: A[0] = 0;
|
||||
# CHECK: B[0] = A[0];
|
||||
|
|
@ -3039,7 +3039,7 @@ TEST(LoopNest, UnrollNonLiteralConstantBounds) {
|
|||
|
||||
std::vector<ForPtr> loops = {outer_for, inner_for};
|
||||
StmtPtr unrolled = nullptr;
|
||||
LoopNest::unroll(loops[0], &unrolled);
|
||||
LoopNest::fullUnroll(loops[0], &unrolled);
|
||||
checkIR(unrolled, R"IR(
|
||||
# CHECK: for (int j = 0; j < 4; j++) {
|
||||
# CHECK: A[1, j] = j;
|
||||
|
|
@ -3052,6 +3052,117 @@ TEST(LoopNest, UnrollNonLiteralConstantBounds) {
|
|||
# CHECK: })IR");
|
||||
}
|
||||
|
||||
TEST(LoopNest, UnrollNonConstantBounds) {
|
||||
// Input IR:
|
||||
// for (int i = 0; i < M; i++) {
|
||||
// for (int j = 0; j < N; j++) {
|
||||
// A[i, j] = i * j;
|
||||
// }
|
||||
// }
|
||||
VarHandle M("M", kInt);
|
||||
VarHandle N("N", kInt);
|
||||
BufHandle a_buf("A", {M, N}, kInt);
|
||||
VarHandle i("i", kInt);
|
||||
VarHandle j("j", kInt);
|
||||
auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
|
||||
auto inner_for = For::make(j, 0, N, for_body);
|
||||
auto outer_for = For::make(i, 0, M, inner_for);
|
||||
auto block = Block::make({outer_for});
|
||||
LoopNest l(block, {a_buf.node()});
|
||||
|
||||
LoopNest::unroll(inner_for, 8);
|
||||
l.simplify();
|
||||
checkIR(l.root_stmt(), R"IR(
|
||||
# CHECK: for (int i = 0; i < M; i++) {
|
||||
# CHECK: for (int j_outer = 0; j_outer < N / 8; j_outer++) {
|
||||
# CHECK: A[i, 8 * j_outer] =
|
||||
# CHECK: A[i, 8 * j_outer + 1] =
|
||||
# CHECK: A[i, 2 * (4 * j_outer + 1)] =
|
||||
# CHECK: A[i, 8 * j_outer + 3] =
|
||||
# CHECK: A[i, 4 * (2 * j_outer + 1)] =
|
||||
# CHECK: A[i, 8 * j_outer + 5] =
|
||||
# CHECK: A[i, 8 * j_outer + 6] =
|
||||
# CHECK: A[i, 8 * j_outer + 7] =
|
||||
# CHECK: }
|
||||
# CHECK: for (int j_tail = 0; j_tail < N % 8; j_tail++) {
|
||||
# CHECK: A[i, 8 * (N / 8) + j_tail] =
|
||||
# CHECK: }
|
||||
# CHECK: }
|
||||
)IR");
|
||||
}
|
||||
|
||||
TEST(LoopNest, UnrollByFactorsLessThan2) {
|
||||
// Input IR:
|
||||
// for (int i = 0; i < M; i++) {
|
||||
// for (int j = 0; j < N; j++) {
|
||||
// A[i, j] = i * j;
|
||||
// }
|
||||
// }
|
||||
VarHandle M("M", kInt);
|
||||
VarHandle N("N", kInt);
|
||||
BufHandle a_buf("A", {M, N}, kInt);
|
||||
VarHandle i("i", kInt);
|
||||
VarHandle j("j", kInt);
|
||||
auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
|
||||
auto inner_for = For::make(j, 0, N, for_body);
|
||||
auto outer_for = For::make(i, 0, M, inner_for);
|
||||
auto block = Block::make({outer_for});
|
||||
LoopNest l(block, {a_buf.node()});
|
||||
|
||||
// Unrolling by factor = 1 should do nothing.
|
||||
LoopNest::unroll(inner_for, 1);
|
||||
checkIR(l.root_stmt(), R"IR(
|
||||
# CHECK: for (int i = 0; i < M; i++) {
|
||||
# CHECK: for (int j = 0; j < N; j++) {
|
||||
# CHECK: A[i, j] =
|
||||
# CHECK: }
|
||||
# CHECK: }
|
||||
)IR");
|
||||
|
||||
// Unrolling by factor = 0 should do nothing.
|
||||
LoopNest::unroll(inner_for, 0);
|
||||
checkIR(l.root_stmt(), R"IR(
|
||||
# CHECK: for (int i = 0; i < M; i++) {
|
||||
# CHECK: for (int j = 0; j < N; j++) {
|
||||
# CHECK: A[i, j] =
|
||||
# CHECK: }
|
||||
# CHECK: }
|
||||
)IR");
|
||||
|
||||
// Unrolling by negative factor should do nothing.
|
||||
LoopNest::unroll(inner_for, -2);
|
||||
checkIR(l.root_stmt(), R"IR(
|
||||
# CHECK: for (int i = 0; i < M; i++) {
|
||||
# CHECK: for (int j = 0; j < N; j++) {
|
||||
# CHECK: A[i, j] =
|
||||
# CHECK: }
|
||||
# CHECK: }
|
||||
)IR");
|
||||
}
|
||||
|
||||
TEST(LoopNest, UnrollByFactorEqualToIters) {
|
||||
// Input IR:
|
||||
// for (int i = 0; i < 5; i++) {
|
||||
// A[i] = i * i;
|
||||
// }
|
||||
BufHandle a_buf("A", {5}, kInt);
|
||||
VarHandle i("i", kInt);
|
||||
auto for_body = Block::make({Store::make(a_buf, {i}, i * i)});
|
||||
auto for_loop = For::make(i, 0, 5, for_body);
|
||||
auto block = Block::make({for_loop});
|
||||
LoopNest l(block, {a_buf.node()});
|
||||
|
||||
LoopNest::unroll(for_loop, 5);
|
||||
checkIR(l.root_stmt(), R"IR(
|
||||
# CHECK: for (int i_outer = 0; i_outer < (5 - 0) / 5; i_outer++)
|
||||
# CHECK: A[5 * i_outer]
|
||||
# CHECK: A[5 * i_outer + 1]
|
||||
# CHECK: A[5 * i_outer + 2]
|
||||
# CHECK: A[5 * i_outer + 3]
|
||||
# CHECK: A[5 * i_outer + 4]
|
||||
)IR");
|
||||
}
|
||||
|
||||
TEST(LoopNest, UnrollEmpty) {
|
||||
const std::string actual = constantUpperBoundLoopIR(0);
|
||||
const std::string& verification_pattern = R"IR(
|
||||
|
|
@ -3069,7 +3180,7 @@ TEST(LoopNest, NoUnroll) {
|
|||
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
|
||||
StmtPtr unrolled = nullptr;
|
||||
ASSERT_THROWS_WITH(
|
||||
LoopNest::unroll(loops[0], &unrolled), "non-constant loop");
|
||||
LoopNest::fullUnroll(loops[0], &unrolled), "non-constant loop");
|
||||
}
|
||||
|
||||
TEST(LoopNest, UnrollWithLet) {
|
||||
|
|
@ -3089,7 +3200,7 @@ TEST(LoopNest, UnrollWithLet) {
|
|||
Store::make(b_buf, {x}, e + 1)}));
|
||||
auto parent_block = Block::make({f});
|
||||
StmtPtr unrolled = nullptr;
|
||||
LoopNest::unroll(f, &unrolled);
|
||||
LoopNest::fullUnroll(f, &unrolled);
|
||||
std::ostringstream oss;
|
||||
oss << *unrolled;
|
||||
const std::string& verification_pattern =
|
||||
|
|
|
|||
|
|
@ -2309,7 +2309,7 @@ bool LoopNest::areLoopsPerfectlyNested(const std::vector<ForPtr>& loops) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void LoopNest::unroll(ForPtr f, StmtPtr* unrolled) {
|
||||
void LoopNest::fullUnroll(ForPtr f, StmtPtr* unrolled) {
|
||||
BlockPtr p = to<Block>(f->get_parent());
|
||||
if (!f) {
|
||||
throw malformed_input("unroll attempted on null loop");
|
||||
|
|
@ -2341,10 +2341,26 @@ void LoopNest::unroll(ForPtr f, StmtPtr* unrolled) {
|
|||
p->replace_stmt(f, *unrolled);
|
||||
}
|
||||
|
||||
void LoopNest::unroll(ForPtr f) {
|
||||
void LoopNest::fullUnroll(ForPtr f) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
StmtPtr unrolled;
|
||||
unroll(f, &unrolled);
|
||||
fullUnroll(f, &unrolled);
|
||||
}
|
||||
|
||||
void LoopNest::unroll(ForPtr f, int factor, ForPtr* tail) {
|
||||
if (factor < 2) {
|
||||
return;
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
ForPtr inner;
|
||||
splitWithTail(f, factor, &inner, tail);
|
||||
fullUnroll(inner);
|
||||
}
|
||||
|
||||
void LoopNest::unroll(ForPtr f, int factor) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
ForPtr tail;
|
||||
unroll(f, factor, &tail);
|
||||
}
|
||||
|
||||
bool LoopNest::isNormalized(ForPtr f) {
|
||||
|
|
|
|||
|
|
@ -418,8 +418,15 @@ class TORCH_API LoopNest {
|
|||
// Returns true if the given loop has a loop-carried dependence.
|
||||
static bool hasLoopCarriedDependence(ForPtr loop);
|
||||
|
||||
static void unroll(ForPtr f, StmtPtr* unrolled);
|
||||
static void unroll(ForPtr f);
|
||||
// Unrolls all the iterations of the given loop.
|
||||
// Requires that the loop bounds are constant.
|
||||
static void fullUnroll(ForPtr f, StmtPtr* unrolled);
|
||||
static void fullUnroll(ForPtr f);
|
||||
|
||||
// Unrolls the given loop for the specified factor.
|
||||
// This does not require constant bounds for the loop being unrolled.
|
||||
static void unroll(ForPtr f, int factor, ForPtr* tail);
|
||||
static void unroll(ForPtr f, int factor);
|
||||
|
||||
static bool normalize(ForPtr f);
|
||||
static bool isNormalized(ForPtr f);
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ void loopnestRandomization(int64_t seed, LoopNest& l) {
|
|||
// static void reorderAxis(ForPtr a, ForPtr b);
|
||||
// static std::vector<ForPtr> reorder(const std::vector<ForPtr>& loops, const std::vector<size_t>& permutation);
|
||||
// ForPtr tile(ForPtr x, ForPtr y, int x_factor, int y_factor);
|
||||
// static void unroll(ForPtr f);
|
||||
// static void fullUnroll(ForPtr f);
|
||||
// static bool normalize(ForPtr f);
|
||||
// static bool flatten(const std::vector<ForPtr>& f, ForPtr* flattened);
|
||||
// static void compressBuffer(BufPtr buf, StmtPtr stmt);
|
||||
|
|
@ -191,7 +191,7 @@ void loopnestRandomization(int64_t seed, LoopNest& l) {
|
|||
REORDER_AXIS,
|
||||
REORDER,
|
||||
TILE,
|
||||
UNROLL,
|
||||
FULL_UNROLL,
|
||||
NORMALIZE,
|
||||
FLATTEN,
|
||||
COMPRESS_BUFFER,
|
||||
|
|
@ -512,7 +512,7 @@ void loopnestRandomization(int64_t seed, LoopNest& l) {
|
|||
break;
|
||||
}
|
||||
|
||||
case UNROLL: {
|
||||
case FULL_UNROLL: {
|
||||
auto loops = NodeFinder<For>::find(l.root_stmt());
|
||||
if (loops.size() == 0) {
|
||||
break;
|
||||
|
|
@ -520,9 +520,9 @@ void loopnestRandomization(int64_t seed, LoopNest& l) {
|
|||
int loop_n = std::rand() % (int)loops.size();
|
||||
auto loop = loops[loop_n];
|
||||
|
||||
message = "unroll(loops[" + std::to_string(loop_n) + "]);\n";
|
||||
message = "fullUnroll(loops[" + std::to_string(loop_n) + "]);\n";
|
||||
randomization_helper::printHistory(n_transform, message);
|
||||
l.unroll(loop);
|
||||
LoopNest::fullUnroll(loop);
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -607,13 +607,20 @@ void initTensorExprBindings(PyObject* module) {
|
|||
},
|
||||
py::return_value_policy::reference)
|
||||
.def(
|
||||
"unroll",
|
||||
[](const LoopNest& self, ForPtr f) {
|
||||
"fullUnroll",
|
||||
[](ForPtr f) {
|
||||
StmtPtr unrolled = nullptr;
|
||||
self.unroll(f, &unrolled);
|
||||
LoopNest::fullUnroll(f, &unrolled);
|
||||
return unrolled;
|
||||
},
|
||||
py::return_value_policy::reference)
|
||||
.def(
|
||||
"unroll",
|
||||
[](ForPtr f, int factor) {
|
||||
LoopNest::unroll(f, factor);
|
||||
return f;
|
||||
},
|
||||
py::return_value_policy::reference)
|
||||
.def(
|
||||
"vectorize",
|
||||
[](ForPtr f) { LoopNest::vectorize(f); },
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user