[nnc] Tests for proposed feature: loop bounds conditional simplification (#54121)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54121

It would be nice to do range analysis to determine if a condition
cannot be satisfied.  These are some tests that we should be able to turn on
once we have this feature.
ghstack-source-id: 124116847

Test Plan: Simplify.*LoopBounds

Reviewed By: ZolotukhinM

Differential Revision: D27107956

fbshipit-source-id: bb27e3d3bc803f0101c416e4a351ba2278684980
This commit is contained in:
Bert Maher 2021-03-17 10:56:54 -07:00 committed by Facebook GitHub Bot
parent a852fdb6b5
commit 7367bca066

View File

@ -4633,5 +4633,131 @@ TEST(Simplify, SimplifyBroadcastTermExpander) {
}
}
TEST(Simplify, DISABLED_CompareSelectCondAlwaysInLoopBounds) {
// Before:
// for (int n = 1; n < N; n++) {
// b[n] = n < 1 ? 0.f : 1.f;
// }
// After:
// for (int n = 1; n < N; n++) {
// b[n] = 1.f;
// }
KernelScope kernel_scope;
constexpr int N = 8;
Placeholder b("b", kFloat, {N});
VarHandle n("n", kInt);
Stmt* s = For::make(
n, 1, N, b.store({n}, CompareSelect::make(n, 1, 0.f, 1.0f, kLT)));
s = IRSimplifier::simplify(s);
std::ostringstream oss;
oss << *s;
torch::jit::testing::FileCheck().run(
R"IR(
# CHECK: b[n] = 1.f;
)IR",
oss.str());
}
TEST(Simplify, DISABLED_IfThenCondAlwaysInLoopBounds) {
// Before:
// for (int n = 1; n < N; n++) {
// b[n] = IfThenElse(n < 1 ? 1 : 0, 0.f, 1.f);
// }
// After:
// for (int n = 1; n < N; n++) {
// b[n] = 1.f;
// }
KernelScope kernel_scope;
constexpr int N = 8;
Placeholder b("b", kFloat, {N});
VarHandle n("n", kInt);
Stmt* s =
For::make(n, 1, N, b.store({n}, IfThenElse::make(n < 1, 0.f, 1.0f)));
s = IRSimplifier::simplify(s);
std::ostringstream oss;
oss << *s;
torch::jit::testing::FileCheck().run(
R"IR(
# CHECK: b[n] = 1.f;
)IR",
oss.str());
}
TEST(Simplify, DISABLED_MultiClauseCondAlwaysInLoopBounds) {
// This test mimics the unpadded region of a conv2d. We want to remove any
// conditional that is provably satisfied (or unsatisfied) by the entire loop
// range.
// Before:
// for (int i = 1; i < 7; i++) {
// for (int j = 1; j < 7; j++) {
// b[i, j] = IfThenElse(
// j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, 1.f);
// After:
// for (int i = 1; i < 7; i++) {
// for (int j = 1; j < 7; j++) {
// b[i, j] = 1.f;
KernelScope kernel_scope;
constexpr int N = 8;
Placeholder b("b", kFloat, {N, N});
VarHandle i("i", kInt);
VarHandle j("j", kInt);
auto csel = CompareSelect::make(i, 1, kLT);
csel = CompareSelect::make(j, 1, 1, csel, kLT);
csel = CompareSelect::make(i, N - 1, 1, csel, kGE);
csel = CompareSelect::make(j, N - 1, 1, csel, kGE);
Stmt* s = b.store({i, j}, IfThenElse::make(csel, 0.f, 1.0f));
s = For::make(j, 1, N - 1, s);
s = For::make(i, 1, N - 1, s);
s = IRSimplifier::simplify(s);
std::ostringstream oss;
oss << *s;
torch::jit::testing::FileCheck().run(
R"IR(
# CHECK: b[n] = 1.f;
)IR",
oss.str());
}
TEST(Simplify, DISABLED_SimplifyLoopBounds) {
// This test mimics the padded region of a conv2d. We want to adjust the
// loop bounds such that the condition will be always met. Note that this
// could be solved by peeling, and applying the range-based conditional
// simplification in the previous tests.
// Before:
// for (int i = 0; i < 3; i++) {
// for (int j = 0; j < 3; j++) {
// b[i, j] = (b[i, j]) + (IfThenElse(
// j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, a[i, j]));
// After:
// for (int i = 1; i < 3; i++) {
// for (int j = 1; j < 3; j++) {
// b[i, j] = (b[i, j]) + 1.f;
KernelScope kernel_scope;
constexpr int N = 8;
constexpr int K = 3;
Placeholder a("a", kFloat, {N, N});
Placeholder b("b", kFloat, {N, N});
VarHandle i("i", kInt);
VarHandle j("j", kInt);
auto csel = CompareSelect::make(i, 1, kLT);
csel = CompareSelect::make(j, 1, 1, csel, kLT);
csel = CompareSelect::make(i, N - 1, 1, csel, kGE);
csel = CompareSelect::make(j, N - 1, 1, csel, kGE);
Stmt* s = b.store(
{i, j}, b.load({i, j}) + IfThenElse::make(csel, 0.f, a.load({i, j})));
s = For::make(j, 0, K, s);
s = For::make(i, 0, K, s);
s = IRSimplifier::simplify(s);
std::ostringstream oss;
oss << *s;
torch::jit::testing::FileCheck().run(
R"IR(
# CHECK: for (int i = 1; i < 3; i++) {
# CHECK: for (int j = 1; j < 3; j++) {
# CHECK-NOT: IfThenElse
)IR",
oss.str());
}
} // namespace jit
} // namespace torch