[NNC] Fix loopnest.cache_accesses for reduce ops (fixed #59002) (#59136)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59136

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D28768598

Pulled By: huiguoo

fbshipit-source-id: 99ab8430bc0ba395e2a041b03a7761de335ddda5
This commit is contained in:
Hui Guo 2021-06-03 21:03:02 -07:00 committed by Facebook GitHub Bot
parent d9d7d5e24a
commit 7c4ac9e3ee
2 changed files with 152 additions and 7 deletions

View File

@ -1308,7 +1308,7 @@ TEST(Reductions, ReduceInlineReducerInternal) {
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(Reductions, ReductionCacheAccessesOuter) {
TEST(Reductions, ReductionCacheAccessesOperatorAxis) {
KernelScope kernel_scope;
int L = 4;
@ -1331,12 +1331,94 @@ TEST(Reductions, ReductionCacheAccessesOuter) {
});
LoopNest l({e}, {c, d, e});
LoopNest l_before(l);
l_before.prepareForCodegen();
SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
Stmt* d_loop = l.getLoopStmtsFor(d)[0];
l.cacheAccesses(d->buf(), "d_local", d_loop);
l.prepareForCodegen();
Stmt* result = IRSimplifier::simplify(l.root_stmt());
SimpleIREvaluator cg_after(result, {a, b, e});
std::ostringstream oss;
oss << *result;
const std::string& expected_ir =
R"IR(
#CHECK: Allocate(d_local); // dtype=float, dims=[4]
#CHECK: for (int l1
#CHECK: d_local[l1] = 0.f
#CHECK: for (int n1
#CHECK: for (int m1
#CHECK: d_local[l1] = (d_local[l1]) + (scale[
#CHECK: }
#CHECK: }
#CHECK: }
#CHECK: for (int i
#CHECK: sum[i] = d_local[i]
#CHECK: Free(d_local);
#CHECK-NOT: d_local
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
PaddedBuffer<float> a_v(L, M, N, "a");
PaddedBuffer<float> b_v(L, M, N, "b");
PaddedBuffer<float> c_v(L, M, N, "c");
PaddedBuffer<float> d_v(L, "d");
PaddedBuffer<float> e_before(L, "e_before");
PaddedBuffer<float> e_after(L, "e_after");
for (int l = 0; l < L; l++) {
for (int m = 0; m < M; m++) {
for (int n = 0; n < N; n++) {
a_v(l, m, n) = at::randn({1}).item().to<float>();
b_v(l, m, n) = at::randn({1}).item().to<float>();
}
}
}
cg_before.call({a_v, b_v, e_before});
cg_after.call({a_v, b_v, e_after});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
ExpectAllNear(e_before, e_after, 1e-5);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) {
KernelScope kernel_scope;
int L = 4;
int N = 3;
int M = 2;
Placeholder a(BufHandle("a", {L, N, M}, kFloat));
Placeholder b(BufHandle("b", {L, N, M}, kFloat));
Tensor* c = Compute(
"scale",
{{L, "l2"}, {N, "n1"}, {M, "m1"}},
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
return b.load(l, n, m) * a.load(l, n, m);
});
Tensor* d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}});
Tensor* e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) {
return b.load(0, 0, l) * d->load(l);
});
LoopNest l({e}, {c, d, e});
LoopNest l_before(l);
l_before.prepareForCodegen();
SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
Stmt* d_loop = l.getLoopStmtsFor(d)[1];
l.cacheAccesses(d->buf(), "d_local", d_loop);
l.prepareForCodegen();
Stmt* result = IRSimplifier::simplify(l.root_stmt());
SimpleIREvaluator cg_after(result, {a, b, e});
std::ostringstream oss;
oss << *result;
@ -1344,21 +1426,43 @@ TEST(Reductions, ReductionCacheAccessesOuter) {
R"IR(
#CHECK: Allocate(d_local); // dtype=float, dims=[1]
#CHECK: sum[l1] = 0
#CHECK: d_local[0] = 0
#CHECK: d_local[0] = sum[l1]
#CHECK: for (int n1
#CHECK: for (int m1
#CHECK: d_local[0] = (d_local[0]) + (scale[
#CHECK: }
#CHECK: }
#CHECK: sum[l1] = (sum[l1]) + (d_local[0])
#CHECK: sum[l1] = d_local[0]
#CHECK: Free(d_local);
#CHECK-NOT: d_local
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
PaddedBuffer<float> a_v(L, M, N, "a");
PaddedBuffer<float> b_v(L, M, N, "b");
PaddedBuffer<float> c_v(L, M, N, "c");
PaddedBuffer<float> d_v(L, "d");
PaddedBuffer<float> e_before(L, "e_before");
PaddedBuffer<float> e_after(L, "e_after");
for (int l = 0; l < L; l++) {
for (int m = 0; m < M; m++) {
for (int n = 0; n < N; n++) {
a_v(l, m, n) = at::randn({1}).item().to<float>();
b_v(l, m, n) = at::randn({1}).item().to<float>();
}
}
}
cg_before.call({a_v, b_v, e_before});
cg_after.call({a_v, b_v, e_after});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
ExpectAllNear(e_before, e_after, 1e-5);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(Reductions, ReductionCacheAccessesInner) {
TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) {
KernelScope kernel_scope;
int L = 4;
@ -1381,12 +1485,16 @@ TEST(Reductions, ReductionCacheAccessesInner) {
});
LoopNest l({e}, {c, d, e});
LoopNest l_before(l);
l_before.prepareForCodegen();
SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
Stmt* d_loop = l.getLoopStmtsFor(d)[2];
l.cacheAccesses(d->buf(), "d_local", d_loop);
l.prepareForCodegen();
Stmt* result = IRSimplifier::simplify(l.root_stmt());
SimpleIREvaluator cg_after(result, {a, b, e});
std::ostringstream oss;
oss << *result;
@ -1405,6 +1513,28 @@ TEST(Reductions, ReductionCacheAccessesInner) {
#CHECK-NOT: d_local
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
PaddedBuffer<float> a_v(L, M, N, "a");
PaddedBuffer<float> b_v(L, M, N, "b");
PaddedBuffer<float> c_v(L, M, N, "c");
PaddedBuffer<float> d_v(L, "d");
PaddedBuffer<float> e_before(L, "e_before");
PaddedBuffer<float> e_after(L, "e_after");
for (int l = 0; l < L; l++) {
for (int m = 0; m < M; m++) {
for (int n = 0; n < N; n++) {
a_v(l, m, n) = at::randn({1}).item().to<float>();
b_v(l, m, n) = at::randn({1}).item().to<float>();
}
}
}
cg_before.call({a_v, b_v, e_before});
cg_after.call({a_v, b_v, e_after});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
ExpectAllNear(e_before, e_after, 1e-5);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)

View File

@ -2520,10 +2520,25 @@ LoopNest::AccessResult LoopNest::cacheAccesses(
consumer_block->replace_stmt(consumer, new_consumer);
}
// If there's a reduction we can't just write the result straight back to the
// original buffer, since after parallelism the writes will race. Instead we
// need to create a new ReduceOp.
// If there's a reduction and we are operating on the reduce axis, we need to
// initialize the cache with 0s. Also, we can't just write the result straight
// back to the original buffer, since after parallelism the writes will race.
// Instead we need to create a new ReduceOp.
bool on_reduce_axis = false;
if (reduceOp) {
std::set<const Var*> reduce_args(
reduceOp->reduce_args().begin(), reduceOp->reduce_args().end());
std::set<const Var*> enclosing_vars;
for (auto enclosing_for_stmt : NodeFinder<For>::find(consumer)) {
enclosing_vars.insert(enclosing_for_stmt->var());
}
for (auto reduce_arg : reduce_args) {
if (enclosing_vars.find(reduce_arg) == enclosing_vars.end()) {
on_reduce_axis = true;
}
}
}
if (reduceOp && on_reduce_axis) {
// reduceOp means we had both loads and stores.
// Init cache to 0.