mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[nnc] Tests for proposed feature: loop bounds conditional simplification (#54121)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54121 It would be nice to do range analysis to determine if a condition cannot be satisfied. These are some tests that we should be able to turn on once we have this feature. ghstack-source-id: 124116847 Test Plan: Simplify.*LoopBounds Reviewed By: ZolotukhinM Differential Revision: D27107956 fbshipit-source-id: bb27e3d3bc803f0101c416e4a351ba2278684980
This commit is contained in:
parent
a852fdb6b5
commit
7367bca066
|
|
@ -4633,5 +4633,131 @@ TEST(Simplify, SimplifyBroadcastTermExpander) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(Simplify, DISABLED_CompareSelectCondAlwaysInLoopBounds) {
|
||||
// Before:
|
||||
// for (int n = 1; n < N; n++) {
|
||||
// b[n] = n < 1 ? 0.f : 1.f;
|
||||
// }
|
||||
// After:
|
||||
// for (int n = 1; n < N; n++) {
|
||||
// b[n] = 1.f;
|
||||
// }
|
||||
KernelScope kernel_scope;
|
||||
constexpr int N = 8;
|
||||
Placeholder b("b", kFloat, {N});
|
||||
VarHandle n("n", kInt);
|
||||
Stmt* s = For::make(
|
||||
n, 1, N, b.store({n}, CompareSelect::make(n, 1, 0.f, 1.0f, kLT)));
|
||||
s = IRSimplifier::simplify(s);
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
torch::jit::testing::FileCheck().run(
|
||||
R"IR(
|
||||
# CHECK: b[n] = 1.f;
|
||||
)IR",
|
||||
oss.str());
|
||||
}
|
||||
|
||||
TEST(Simplify, DISABLED_IfThenCondAlwaysInLoopBounds) {
|
||||
// Before:
|
||||
// for (int n = 1; n < N; n++) {
|
||||
// b[n] = IfThenElse(n < 1 ? 1 : 0, 0.f, 1.f);
|
||||
// }
|
||||
// After:
|
||||
// for (int n = 1; n < N; n++) {
|
||||
// b[n] = 1.f;
|
||||
// }
|
||||
KernelScope kernel_scope;
|
||||
constexpr int N = 8;
|
||||
Placeholder b("b", kFloat, {N});
|
||||
VarHandle n("n", kInt);
|
||||
Stmt* s =
|
||||
For::make(n, 1, N, b.store({n}, IfThenElse::make(n < 1, 0.f, 1.0f)));
|
||||
s = IRSimplifier::simplify(s);
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
torch::jit::testing::FileCheck().run(
|
||||
R"IR(
|
||||
# CHECK: b[n] = 1.f;
|
||||
)IR",
|
||||
oss.str());
|
||||
}
|
||||
|
||||
TEST(Simplify, DISABLED_MultiClauseCondAlwaysInLoopBounds) {
|
||||
// This test mimics the unpadded region of a conv2d. We want to remove any
|
||||
// conditional that is provably satisfied (or unsatisfied) by the entire loop
|
||||
// range.
|
||||
// Before:
|
||||
// for (int i = 1; i < 7; i++) {
|
||||
// for (int j = 1; j < 7; j++) {
|
||||
// b[i, j] = IfThenElse(
|
||||
// j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, 1.f);
|
||||
// After:
|
||||
// for (int i = 1; i < 7; i++) {
|
||||
// for (int j = 1; j < 7; j++) {
|
||||
// b[i, j] = 1.f;
|
||||
KernelScope kernel_scope;
|
||||
constexpr int N = 8;
|
||||
Placeholder b("b", kFloat, {N, N});
|
||||
VarHandle i("i", kInt);
|
||||
VarHandle j("j", kInt);
|
||||
auto csel = CompareSelect::make(i, 1, kLT);
|
||||
csel = CompareSelect::make(j, 1, 1, csel, kLT);
|
||||
csel = CompareSelect::make(i, N - 1, 1, csel, kGE);
|
||||
csel = CompareSelect::make(j, N - 1, 1, csel, kGE);
|
||||
Stmt* s = b.store({i, j}, IfThenElse::make(csel, 0.f, 1.0f));
|
||||
s = For::make(j, 1, N - 1, s);
|
||||
s = For::make(i, 1, N - 1, s);
|
||||
s = IRSimplifier::simplify(s);
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
torch::jit::testing::FileCheck().run(
|
||||
R"IR(
|
||||
# CHECK: b[n] = 1.f;
|
||||
)IR",
|
||||
oss.str());
|
||||
}
|
||||
|
||||
TEST(Simplify, DISABLED_SimplifyLoopBounds) {
|
||||
// This test mimics the padded region of a conv2d. We want to adjust the
|
||||
// loop bounds such that the condition will be always met. Note that this
|
||||
// could be solved by peeling, and applying the range-based conditional
|
||||
// simplification in the previous tests.
|
||||
// Before:
|
||||
// for (int i = 0; i < 3; i++) {
|
||||
// for (int j = 0; j < 3; j++) {
|
||||
// b[i, j] = (b[i, j]) + (IfThenElse(
|
||||
// j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, a[i, j]));
|
||||
// After:
|
||||
// for (int i = 1; i < 3; i++) {
|
||||
// for (int j = 1; j < 3; j++) {
|
||||
// b[i, j] = (b[i, j]) + 1.f;
|
||||
KernelScope kernel_scope;
|
||||
constexpr int N = 8;
|
||||
constexpr int K = 3;
|
||||
Placeholder a("a", kFloat, {N, N});
|
||||
Placeholder b("b", kFloat, {N, N});
|
||||
VarHandle i("i", kInt);
|
||||
VarHandle j("j", kInt);
|
||||
auto csel = CompareSelect::make(i, 1, kLT);
|
||||
csel = CompareSelect::make(j, 1, 1, csel, kLT);
|
||||
csel = CompareSelect::make(i, N - 1, 1, csel, kGE);
|
||||
csel = CompareSelect::make(j, N - 1, 1, csel, kGE);
|
||||
Stmt* s = b.store(
|
||||
{i, j}, b.load({i, j}) + IfThenElse::make(csel, 0.f, a.load({i, j})));
|
||||
s = For::make(j, 0, K, s);
|
||||
s = For::make(i, 0, K, s);
|
||||
s = IRSimplifier::simplify(s);
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
torch::jit::testing::FileCheck().run(
|
||||
R"IR(
|
||||
# CHECK: for (int i = 1; i < 3; i++) {
|
||||
# CHECK: for (int j = 1; j < 3; j++) {
|
||||
# CHECK-NOT: IfThenElse
|
||||
)IR",
|
||||
oss.str());
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user