From 7c4ac9e3ee1e58c4902a30c89e1ca15b05319dba Mon Sep 17 00:00:00 2001 From: Hui Guo Date: Thu, 3 Jun 2021 21:03:02 -0700 Subject: [PATCH] [NNC] Fix loopnest.cache_accesses for reduce ops (fixed #59002) (#59136) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59136 Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D28768598 Pulled By: huiguoo fbshipit-source-id: 99ab8430bc0ba395e2a041b03a7761de335ddda5 --- test/cpp/tensorexpr/test_reductions.cpp | 138 +++++++++++++++++++++++- torch/csrc/jit/tensorexpr/loopnest.cpp | 21 +++- 2 files changed, 152 insertions(+), 7 deletions(-) diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp index de28871bd0a..18f455e5430 100644 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ b/test/cpp/tensorexpr/test_reductions.cpp @@ -1308,7 +1308,7 @@ TEST(Reductions, ReduceInlineReducerInternal) { } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -TEST(Reductions, ReductionCacheAccessesOuter) { +TEST(Reductions, ReductionCacheAccessesOperatorAxis) { KernelScope kernel_scope; int L = 4; @@ -1331,12 +1331,94 @@ TEST(Reductions, ReductionCacheAccessesOuter) { }); LoopNest l({e}, {c, d, e}); + LoopNest l_before(l); + l_before.prepareForCodegen(); + SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); + + Stmt* d_loop = l.getLoopStmtsFor(d)[0]; + l.cacheAccesses(d->buf(), "d_local", d_loop); + l.prepareForCodegen(); + + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + SimpleIREvaluator cg_after(result, {a, b, e}); + + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(d_local); // dtype=float, dims=[4] +#CHECK: for (int l1 +#CHECK: d_local[l1] = 0.f +#CHECK: for (int n1 +#CHECK: for (int m1 +#CHECK: d_local[l1] = (d_local[l1]) + (scale[ +#CHECK: } +#CHECK: } +#CHECK: } +#CHECK: for (int i +#CHECK: sum[i] = d_local[i] +#CHECK: Free(d_local); +#CHECK-NOT: d_local + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + PaddedBuffer a_v(L, M, N, "a"); + PaddedBuffer b_v(L, M, N, "b"); + PaddedBuffer c_v(L, M, N, "c"); + PaddedBuffer d_v(L, "d"); + PaddedBuffer e_before(L, "e_before"); + PaddedBuffer e_after(L, "e_after"); + + for (int l = 0; l < L; l++) { + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + a_v(l, m, n) = at::randn({1}).item().to(); + b_v(l, m, n) = at::randn({1}).item().to(); + } + } + } + + cg_before.call({a_v, b_v, e_before}); + cg_after.call({a_v, b_v, e_after}); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(e_before, e_after, 1e-5); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) { + KernelScope kernel_scope; + + int L = 4; + int N = 3; + int M = 2; + + Placeholder a(BufHandle("a", {L, N, M}, kFloat)); + Placeholder b(BufHandle("b", {L, N, M}, kFloat)); + + Tensor* c = Compute( + "scale", + {{L, "l2"}, {N, "n1"}, {M, "m1"}}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor* d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}}); + + Tensor* e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d->load(l); + }); + + LoopNest l({e}, {c, d, e}); + LoopNest l_before(l); + l_before.prepareForCodegen(); + SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); Stmt* d_loop = l.getLoopStmtsFor(d)[1]; l.cacheAccesses(d->buf(), "d_local", d_loop); l.prepareForCodegen(); Stmt* result = IRSimplifier::simplify(l.root_stmt()); + SimpleIREvaluator cg_after(result, {a, b, e}); std::ostringstream oss; oss << *result; @@ -1344,21 +1426,43 @@ TEST(Reductions, ReductionCacheAccessesOuter) { R"IR( #CHECK: Allocate(d_local); // dtype=float, dims=[1] #CHECK: sum[l1] = 0 -#CHECK: d_local[0] = 0 +#CHECK: d_local[0] = sum[l1] #CHECK: for (int n1 #CHECK: for (int m1 #CHECK: d_local[0] = (d_local[0]) + (scale[ #CHECK: } #CHECK: } -#CHECK: sum[l1] = (sum[l1]) + (d_local[0]) +#CHECK: sum[l1] = d_local[0] #CHECK: Free(d_local); #CHECK-NOT: d_local )IR"; torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + PaddedBuffer a_v(L, M, N, "a"); + PaddedBuffer b_v(L, M, N, "b"); + PaddedBuffer c_v(L, M, N, "c"); + PaddedBuffer d_v(L, "d"); + PaddedBuffer e_before(L, "e_before"); + PaddedBuffer e_after(L, "e_after"); + + for (int l = 0; l < L; l++) { + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + a_v(l, m, n) = at::randn({1}).item().to(); + b_v(l, m, n) = at::randn({1}).item().to(); + } + } + } + + cg_before.call({a_v, b_v, e_before}); + cg_after.call({a_v, b_v, e_after}); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(e_before, e_after, 1e-5); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -TEST(Reductions, ReductionCacheAccessesInner) { +TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) { KernelScope kernel_scope; int L = 4; @@ -1381,12 +1485,16 @@ TEST(Reductions, ReductionCacheAccessesInner) { }); LoopNest l({e}, {c, d, e}); + LoopNest l_before(l); + l_before.prepareForCodegen(); + SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); Stmt* d_loop = l.getLoopStmtsFor(d)[2]; l.cacheAccesses(d->buf(), "d_local", d_loop); l.prepareForCodegen(); Stmt* result = IRSimplifier::simplify(l.root_stmt()); + SimpleIREvaluator cg_after(result, {a, b, e}); std::ostringstream oss; oss << *result; @@ -1405,6 +1513,28 @@ TEST(Reductions, ReductionCacheAccessesInner) { #CHECK-NOT: d_local )IR"; torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + PaddedBuffer a_v(L, M, N, "a"); + PaddedBuffer b_v(L, M, N, "b"); + PaddedBuffer c_v(L, M, N, "c"); + PaddedBuffer d_v(L, "d"); + PaddedBuffer e_before(L, "e_before"); + PaddedBuffer e_after(L, "e_after"); + + for (int l = 0; l < L; l++) { + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + a_v(l, m, n) = at::randn({1}).item().to(); + b_v(l, m, n) = at::randn({1}).item().to(); + } + } + } + + cg_before.call({a_v, b_v, e_before}); + cg_after.call({a_v, b_v, e_after}); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(e_before, e_after, 1e-5); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index c6a6cb3e43b..bfa11d374dd 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -2520,10 +2520,25 @@ LoopNest::AccessResult LoopNest::cacheAccesses( consumer_block->replace_stmt(consumer, new_consumer); } - // If there's a reduction we can't just write the result straight back to the - // original buffer, since after parallelism the writes will race. Instead we - // need to create a new ReduceOp. + // If there's a reduction and we are operating on the reduce axis, we need to + // initialize the cache with 0s. Also, we can't just write the result straight + // back to the original buffer, since after parallelism the writes will race. + // Instead we need to create a new ReduceOp. + bool on_reduce_axis = false; if (reduceOp) { + std::set reduce_args( + reduceOp->reduce_args().begin(), reduceOp->reduce_args().end()); + std::set enclosing_vars; + for (auto enclosing_for_stmt : NodeFinder::find(consumer)) { + enclosing_vars.insert(enclosing_for_stmt->var()); + } + for (auto reduce_arg : reduce_args) { + if (enclosing_vars.find(reduce_arg) == enclosing_vars.end()) { + on_reduce_axis = true; + } + } + } + if (reduceOp && on_reduce_axis) { // reduceOp means we had both loads and stores. // Init cache to 0.