mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59136 Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D28768598 Pulled By: huiguoo fbshipit-source-id: 99ab8430bc0ba395e2a041b03a7761de335ddda5
This commit is contained in:
parent
d9d7d5e24a
commit
7c4ac9e3ee
|
|
@ -1308,7 +1308,7 @@ TEST(Reductions, ReduceInlineReducerInternal) {
|
|||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(Reductions, ReductionCacheAccessesOuter) {
|
||||
TEST(Reductions, ReductionCacheAccessesOperatorAxis) {
|
||||
KernelScope kernel_scope;
|
||||
|
||||
int L = 4;
|
||||
|
|
@ -1331,12 +1331,94 @@ TEST(Reductions, ReductionCacheAccessesOuter) {
|
|||
});
|
||||
|
||||
LoopNest l({e}, {c, d, e});
|
||||
LoopNest l_before(l);
|
||||
l_before.prepareForCodegen();
|
||||
SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
|
||||
|
||||
Stmt* d_loop = l.getLoopStmtsFor(d)[0];
|
||||
l.cacheAccesses(d->buf(), "d_local", d_loop);
|
||||
l.prepareForCodegen();
|
||||
|
||||
Stmt* result = IRSimplifier::simplify(l.root_stmt());
|
||||
SimpleIREvaluator cg_after(result, {a, b, e});
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *result;
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: Allocate(d_local); // dtype=float, dims=[4]
|
||||
#CHECK: for (int l1
|
||||
#CHECK: d_local[l1] = 0.f
|
||||
#CHECK: for (int n1
|
||||
#CHECK: for (int m1
|
||||
#CHECK: d_local[l1] = (d_local[l1]) + (scale[
|
||||
#CHECK: }
|
||||
#CHECK: }
|
||||
#CHECK: }
|
||||
#CHECK: for (int i
|
||||
#CHECK: sum[i] = d_local[i]
|
||||
#CHECK: Free(d_local);
|
||||
#CHECK-NOT: d_local
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
|
||||
PaddedBuffer<float> a_v(L, M, N, "a");
|
||||
PaddedBuffer<float> b_v(L, M, N, "b");
|
||||
PaddedBuffer<float> c_v(L, M, N, "c");
|
||||
PaddedBuffer<float> d_v(L, "d");
|
||||
PaddedBuffer<float> e_before(L, "e_before");
|
||||
PaddedBuffer<float> e_after(L, "e_after");
|
||||
|
||||
for (int l = 0; l < L; l++) {
|
||||
for (int m = 0; m < M; m++) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
a_v(l, m, n) = at::randn({1}).item().to<float>();
|
||||
b_v(l, m, n) = at::randn({1}).item().to<float>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cg_before.call({a_v, b_v, e_before});
|
||||
cg_after.call({a_v, b_v, e_after});
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
ExpectAllNear(e_before, e_after, 1e-5);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) {
|
||||
KernelScope kernel_scope;
|
||||
|
||||
int L = 4;
|
||||
int N = 3;
|
||||
int M = 2;
|
||||
|
||||
Placeholder a(BufHandle("a", {L, N, M}, kFloat));
|
||||
Placeholder b(BufHandle("b", {L, N, M}, kFloat));
|
||||
|
||||
Tensor* c = Compute(
|
||||
"scale",
|
||||
{{L, "l2"}, {N, "n1"}, {M, "m1"}},
|
||||
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
|
||||
return b.load(l, n, m) * a.load(l, n, m);
|
||||
});
|
||||
Tensor* d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}});
|
||||
|
||||
Tensor* e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) {
|
||||
return b.load(0, 0, l) * d->load(l);
|
||||
});
|
||||
|
||||
LoopNest l({e}, {c, d, e});
|
||||
LoopNest l_before(l);
|
||||
l_before.prepareForCodegen();
|
||||
SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
|
||||
|
||||
Stmt* d_loop = l.getLoopStmtsFor(d)[1];
|
||||
l.cacheAccesses(d->buf(), "d_local", d_loop);
|
||||
l.prepareForCodegen();
|
||||
|
||||
Stmt* result = IRSimplifier::simplify(l.root_stmt());
|
||||
SimpleIREvaluator cg_after(result, {a, b, e});
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *result;
|
||||
|
|
@ -1344,21 +1426,43 @@ TEST(Reductions, ReductionCacheAccessesOuter) {
|
|||
R"IR(
|
||||
#CHECK: Allocate(d_local); // dtype=float, dims=[1]
|
||||
#CHECK: sum[l1] = 0
|
||||
#CHECK: d_local[0] = 0
|
||||
#CHECK: d_local[0] = sum[l1]
|
||||
#CHECK: for (int n1
|
||||
#CHECK: for (int m1
|
||||
#CHECK: d_local[0] = (d_local[0]) + (scale[
|
||||
#CHECK: }
|
||||
#CHECK: }
|
||||
#CHECK: sum[l1] = (sum[l1]) + (d_local[0])
|
||||
#CHECK: sum[l1] = d_local[0]
|
||||
#CHECK: Free(d_local);
|
||||
#CHECK-NOT: d_local
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
|
||||
PaddedBuffer<float> a_v(L, M, N, "a");
|
||||
PaddedBuffer<float> b_v(L, M, N, "b");
|
||||
PaddedBuffer<float> c_v(L, M, N, "c");
|
||||
PaddedBuffer<float> d_v(L, "d");
|
||||
PaddedBuffer<float> e_before(L, "e_before");
|
||||
PaddedBuffer<float> e_after(L, "e_after");
|
||||
|
||||
for (int l = 0; l < L; l++) {
|
||||
for (int m = 0; m < M; m++) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
a_v(l, m, n) = at::randn({1}).item().to<float>();
|
||||
b_v(l, m, n) = at::randn({1}).item().to<float>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cg_before.call({a_v, b_v, e_before});
|
||||
cg_after.call({a_v, b_v, e_after});
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
ExpectAllNear(e_before, e_after, 1e-5);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(Reductions, ReductionCacheAccessesInner) {
|
||||
TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) {
|
||||
KernelScope kernel_scope;
|
||||
|
||||
int L = 4;
|
||||
|
|
@ -1381,12 +1485,16 @@ TEST(Reductions, ReductionCacheAccessesInner) {
|
|||
});
|
||||
|
||||
LoopNest l({e}, {c, d, e});
|
||||
LoopNest l_before(l);
|
||||
l_before.prepareForCodegen();
|
||||
SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
|
||||
|
||||
Stmt* d_loop = l.getLoopStmtsFor(d)[2];
|
||||
l.cacheAccesses(d->buf(), "d_local", d_loop);
|
||||
l.prepareForCodegen();
|
||||
|
||||
Stmt* result = IRSimplifier::simplify(l.root_stmt());
|
||||
SimpleIREvaluator cg_after(result, {a, b, e});
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *result;
|
||||
|
|
@ -1405,6 +1513,28 @@ TEST(Reductions, ReductionCacheAccessesInner) {
|
|||
#CHECK-NOT: d_local
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
|
||||
PaddedBuffer<float> a_v(L, M, N, "a");
|
||||
PaddedBuffer<float> b_v(L, M, N, "b");
|
||||
PaddedBuffer<float> c_v(L, M, N, "c");
|
||||
PaddedBuffer<float> d_v(L, "d");
|
||||
PaddedBuffer<float> e_before(L, "e_before");
|
||||
PaddedBuffer<float> e_after(L, "e_after");
|
||||
|
||||
for (int l = 0; l < L; l++) {
|
||||
for (int m = 0; m < M; m++) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
a_v(l, m, n) = at::randn({1}).item().to<float>();
|
||||
b_v(l, m, n) = at::randn({1}).item().to<float>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cg_before.call({a_v, b_v, e_before});
|
||||
cg_after.call({a_v, b_v, e_after});
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
ExpectAllNear(e_before, e_after, 1e-5);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
|
|
|||
|
|
@ -2520,10 +2520,25 @@ LoopNest::AccessResult LoopNest::cacheAccesses(
|
|||
consumer_block->replace_stmt(consumer, new_consumer);
|
||||
}
|
||||
|
||||
// If there's a reduction we can't just write the result straight back to the
|
||||
// original buffer, since after parallelism the writes will race. Instead we
|
||||
// need to create a new ReduceOp.
|
||||
// If there's a reduction and we are operating on the reduce axis, we need to
|
||||
// initialize the cache with 0s. Also, we can't just write the result straight
|
||||
// back to the original buffer, since after parallelism the writes will race.
|
||||
// Instead we need to create a new ReduceOp.
|
||||
bool on_reduce_axis = false;
|
||||
if (reduceOp) {
|
||||
std::set<const Var*> reduce_args(
|
||||
reduceOp->reduce_args().begin(), reduceOp->reduce_args().end());
|
||||
std::set<const Var*> enclosing_vars;
|
||||
for (auto enclosing_for_stmt : NodeFinder<For>::find(consumer)) {
|
||||
enclosing_vars.insert(enclosing_for_stmt->var());
|
||||
}
|
||||
for (auto reduce_arg : reduce_args) {
|
||||
if (enclosing_vars.find(reduce_arg) == enclosing_vars.end()) {
|
||||
on_reduce_axis = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (reduceOp && on_reduce_axis) {
|
||||
// reduceOp means we had both loads and stores.
|
||||
|
||||
// Init cache to 0.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user