mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59136 Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D28768598 Pulled By: huiguoo fbshipit-source-id: 99ab8430bc0ba395e2a041b03a7761de335ddda5
This commit is contained in:
parent
d9d7d5e24a
commit
7c4ac9e3ee
|
|
@ -1308,7 +1308,7 @@ TEST(Reductions, ReduceInlineReducerInternal) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
TEST(Reductions, ReductionCacheAccessesOuter) {
|
TEST(Reductions, ReductionCacheAccessesOperatorAxis) {
|
||||||
KernelScope kernel_scope;
|
KernelScope kernel_scope;
|
||||||
|
|
||||||
int L = 4;
|
int L = 4;
|
||||||
|
|
@ -1331,12 +1331,94 @@ TEST(Reductions, ReductionCacheAccessesOuter) {
|
||||||
});
|
});
|
||||||
|
|
||||||
LoopNest l({e}, {c, d, e});
|
LoopNest l({e}, {c, d, e});
|
||||||
|
LoopNest l_before(l);
|
||||||
|
l_before.prepareForCodegen();
|
||||||
|
SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
|
||||||
|
|
||||||
|
Stmt* d_loop = l.getLoopStmtsFor(d)[0];
|
||||||
|
l.cacheAccesses(d->buf(), "d_local", d_loop);
|
||||||
|
l.prepareForCodegen();
|
||||||
|
|
||||||
|
Stmt* result = IRSimplifier::simplify(l.root_stmt());
|
||||||
|
SimpleIREvaluator cg_after(result, {a, b, e});
|
||||||
|
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << *result;
|
||||||
|
const std::string& expected_ir =
|
||||||
|
R"IR(
|
||||||
|
#CHECK: Allocate(d_local); // dtype=float, dims=[4]
|
||||||
|
#CHECK: for (int l1
|
||||||
|
#CHECK: d_local[l1] = 0.f
|
||||||
|
#CHECK: for (int n1
|
||||||
|
#CHECK: for (int m1
|
||||||
|
#CHECK: d_local[l1] = (d_local[l1]) + (scale[
|
||||||
|
#CHECK: }
|
||||||
|
#CHECK: }
|
||||||
|
#CHECK: }
|
||||||
|
#CHECK: for (int i
|
||||||
|
#CHECK: sum[i] = d_local[i]
|
||||||
|
#CHECK: Free(d_local);
|
||||||
|
#CHECK-NOT: d_local
|
||||||
|
)IR";
|
||||||
|
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||||
|
|
||||||
|
PaddedBuffer<float> a_v(L, M, N, "a");
|
||||||
|
PaddedBuffer<float> b_v(L, M, N, "b");
|
||||||
|
PaddedBuffer<float> c_v(L, M, N, "c");
|
||||||
|
PaddedBuffer<float> d_v(L, "d");
|
||||||
|
PaddedBuffer<float> e_before(L, "e_before");
|
||||||
|
PaddedBuffer<float> e_after(L, "e_after");
|
||||||
|
|
||||||
|
for (int l = 0; l < L; l++) {
|
||||||
|
for (int m = 0; m < M; m++) {
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
a_v(l, m, n) = at::randn({1}).item().to<float>();
|
||||||
|
b_v(l, m, n) = at::randn({1}).item().to<float>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cg_before.call({a_v, b_v, e_before});
|
||||||
|
cg_after.call({a_v, b_v, e_after});
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||||
|
ExpectAllNear(e_before, e_after, 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) {
|
||||||
|
KernelScope kernel_scope;
|
||||||
|
|
||||||
|
int L = 4;
|
||||||
|
int N = 3;
|
||||||
|
int M = 2;
|
||||||
|
|
||||||
|
Placeholder a(BufHandle("a", {L, N, M}, kFloat));
|
||||||
|
Placeholder b(BufHandle("b", {L, N, M}, kFloat));
|
||||||
|
|
||||||
|
Tensor* c = Compute(
|
||||||
|
"scale",
|
||||||
|
{{L, "l2"}, {N, "n1"}, {M, "m1"}},
|
||||||
|
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
|
||||||
|
return b.load(l, n, m) * a.load(l, n, m);
|
||||||
|
});
|
||||||
|
Tensor* d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}});
|
||||||
|
|
||||||
|
Tensor* e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) {
|
||||||
|
return b.load(0, 0, l) * d->load(l);
|
||||||
|
});
|
||||||
|
|
||||||
|
LoopNest l({e}, {c, d, e});
|
||||||
|
LoopNest l_before(l);
|
||||||
|
l_before.prepareForCodegen();
|
||||||
|
SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
|
||||||
|
|
||||||
Stmt* d_loop = l.getLoopStmtsFor(d)[1];
|
Stmt* d_loop = l.getLoopStmtsFor(d)[1];
|
||||||
l.cacheAccesses(d->buf(), "d_local", d_loop);
|
l.cacheAccesses(d->buf(), "d_local", d_loop);
|
||||||
l.prepareForCodegen();
|
l.prepareForCodegen();
|
||||||
|
|
||||||
Stmt* result = IRSimplifier::simplify(l.root_stmt());
|
Stmt* result = IRSimplifier::simplify(l.root_stmt());
|
||||||
|
SimpleIREvaluator cg_after(result, {a, b, e});
|
||||||
|
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
oss << *result;
|
oss << *result;
|
||||||
|
|
@ -1344,21 +1426,43 @@ TEST(Reductions, ReductionCacheAccessesOuter) {
|
||||||
R"IR(
|
R"IR(
|
||||||
#CHECK: Allocate(d_local); // dtype=float, dims=[1]
|
#CHECK: Allocate(d_local); // dtype=float, dims=[1]
|
||||||
#CHECK: sum[l1] = 0
|
#CHECK: sum[l1] = 0
|
||||||
#CHECK: d_local[0] = 0
|
#CHECK: d_local[0] = sum[l1]
|
||||||
#CHECK: for (int n1
|
#CHECK: for (int n1
|
||||||
#CHECK: for (int m1
|
#CHECK: for (int m1
|
||||||
#CHECK: d_local[0] = (d_local[0]) + (scale[
|
#CHECK: d_local[0] = (d_local[0]) + (scale[
|
||||||
#CHECK: }
|
#CHECK: }
|
||||||
#CHECK: }
|
#CHECK: }
|
||||||
#CHECK: sum[l1] = (sum[l1]) + (d_local[0])
|
#CHECK: sum[l1] = d_local[0]
|
||||||
#CHECK: Free(d_local);
|
#CHECK: Free(d_local);
|
||||||
#CHECK-NOT: d_local
|
#CHECK-NOT: d_local
|
||||||
)IR";
|
)IR";
|
||||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||||
|
|
||||||
|
PaddedBuffer<float> a_v(L, M, N, "a");
|
||||||
|
PaddedBuffer<float> b_v(L, M, N, "b");
|
||||||
|
PaddedBuffer<float> c_v(L, M, N, "c");
|
||||||
|
PaddedBuffer<float> d_v(L, "d");
|
||||||
|
PaddedBuffer<float> e_before(L, "e_before");
|
||||||
|
PaddedBuffer<float> e_after(L, "e_after");
|
||||||
|
|
||||||
|
for (int l = 0; l < L; l++) {
|
||||||
|
for (int m = 0; m < M; m++) {
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
a_v(l, m, n) = at::randn({1}).item().to<float>();
|
||||||
|
b_v(l, m, n) = at::randn({1}).item().to<float>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cg_before.call({a_v, b_v, e_before});
|
||||||
|
cg_after.call({a_v, b_v, e_after});
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||||
|
ExpectAllNear(e_before, e_after, 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
TEST(Reductions, ReductionCacheAccessesInner) {
|
TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) {
|
||||||
KernelScope kernel_scope;
|
KernelScope kernel_scope;
|
||||||
|
|
||||||
int L = 4;
|
int L = 4;
|
||||||
|
|
@ -1381,12 +1485,16 @@ TEST(Reductions, ReductionCacheAccessesInner) {
|
||||||
});
|
});
|
||||||
|
|
||||||
LoopNest l({e}, {c, d, e});
|
LoopNest l({e}, {c, d, e});
|
||||||
|
LoopNest l_before(l);
|
||||||
|
l_before.prepareForCodegen();
|
||||||
|
SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
|
||||||
|
|
||||||
Stmt* d_loop = l.getLoopStmtsFor(d)[2];
|
Stmt* d_loop = l.getLoopStmtsFor(d)[2];
|
||||||
l.cacheAccesses(d->buf(), "d_local", d_loop);
|
l.cacheAccesses(d->buf(), "d_local", d_loop);
|
||||||
l.prepareForCodegen();
|
l.prepareForCodegen();
|
||||||
|
|
||||||
Stmt* result = IRSimplifier::simplify(l.root_stmt());
|
Stmt* result = IRSimplifier::simplify(l.root_stmt());
|
||||||
|
SimpleIREvaluator cg_after(result, {a, b, e});
|
||||||
|
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
oss << *result;
|
oss << *result;
|
||||||
|
|
@ -1405,6 +1513,28 @@ TEST(Reductions, ReductionCacheAccessesInner) {
|
||||||
#CHECK-NOT: d_local
|
#CHECK-NOT: d_local
|
||||||
)IR";
|
)IR";
|
||||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||||
|
|
||||||
|
PaddedBuffer<float> a_v(L, M, N, "a");
|
||||||
|
PaddedBuffer<float> b_v(L, M, N, "b");
|
||||||
|
PaddedBuffer<float> c_v(L, M, N, "c");
|
||||||
|
PaddedBuffer<float> d_v(L, "d");
|
||||||
|
PaddedBuffer<float> e_before(L, "e_before");
|
||||||
|
PaddedBuffer<float> e_after(L, "e_after");
|
||||||
|
|
||||||
|
for (int l = 0; l < L; l++) {
|
||||||
|
for (int m = 0; m < M; m++) {
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
a_v(l, m, n) = at::randn({1}).item().to<float>();
|
||||||
|
b_v(l, m, n) = at::randn({1}).item().to<float>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cg_before.call({a_v, b_v, e_before});
|
||||||
|
cg_after.call({a_v, b_v, e_after});
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||||
|
ExpectAllNear(e_before, e_after, 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
|
|
||||||
|
|
@ -2520,10 +2520,25 @@ LoopNest::AccessResult LoopNest::cacheAccesses(
|
||||||
consumer_block->replace_stmt(consumer, new_consumer);
|
consumer_block->replace_stmt(consumer, new_consumer);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there's a reduction we can't just write the result straight back to the
|
// If there's a reduction and we are operating on the reduce axis, we need to
|
||||||
// original buffer, since after parallelism the writes will race. Instead we
|
// initialize the cache with 0s. Also, we can't just write the result straight
|
||||||
// need to create a new ReduceOp.
|
// back to the original buffer, since after parallelism the writes will race.
|
||||||
|
// Instead we need to create a new ReduceOp.
|
||||||
|
bool on_reduce_axis = false;
|
||||||
if (reduceOp) {
|
if (reduceOp) {
|
||||||
|
std::set<const Var*> reduce_args(
|
||||||
|
reduceOp->reduce_args().begin(), reduceOp->reduce_args().end());
|
||||||
|
std::set<const Var*> enclosing_vars;
|
||||||
|
for (auto enclosing_for_stmt : NodeFinder<For>::find(consumer)) {
|
||||||
|
enclosing_vars.insert(enclosing_for_stmt->var());
|
||||||
|
}
|
||||||
|
for (auto reduce_arg : reduce_args) {
|
||||||
|
if (enclosing_vars.find(reduce_arg) == enclosing_vars.end()) {
|
||||||
|
on_reduce_axis = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (reduceOp && on_reduce_axis) {
|
||||||
// reduceOp means we had both loads and stores.
|
// reduceOp means we had both loads and stores.
|
||||||
|
|
||||||
// Init cache to 0.
|
// Init cache to 0.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user