mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[TensorExpr] IRSimplifier: sort terms in polynomials, terms, minterms, maxterms. (#63197)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63197 This solves non-determinism from using hash values in sort methods. Changes in tests are mostly mechanical. Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D30292776 Pulled By: ZolotukhinM fbshipit-source-id: 74f57b53c3afc9d4be45715fd74781271373e055
This commit is contained in:
parent
8bdd542417
commit
7fdba4564a
|
|
@ -1575,10 +1575,10 @@ TEST(Cuda, MaskMultiDim_CUDA) {
|
|||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
# CHECK-NOT: if (
|
||||
# CHECK: C[100 * blockIdx.x + threadIdx.x] =
|
||||
# CHECK: C[threadIdx.x + 100 * blockIdx.x] =
|
||||
# CHECK: __syncthreads();
|
||||
# CHECK: if (threadIdx.x<50
|
||||
# CHECK: D[50 * blockIdx.x + threadIdx.x] =)IR";
|
||||
# CHECK: D[threadIdx.x + 50 * blockIdx.x] =)IR";
|
||||
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
|
||||
|
|
@ -1705,10 +1705,10 @@ TEST(Cuda, MaskMultiDimSymbolic_CUDA) {
|
|||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
# CHECK: if (threadIdx.x<A_SIZE
|
||||
# CHECK: C[threadIdx.x + A_SIZE * blockIdx.x] =
|
||||
# CHECK: C[A_SIZE * blockIdx.x + threadIdx.x] =
|
||||
# CHECK: __syncthreads();
|
||||
# CHECK: if (threadIdx.x<B_SIZE
|
||||
# CHECK: D[threadIdx.x + B_SIZE * blockIdx.x] =)IR";
|
||||
# CHECK: D[B_SIZE * blockIdx.x + threadIdx.x] =)IR";
|
||||
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
|
||||
|
|
@ -1852,10 +1852,10 @@ TEST(Cuda, MaskCompoundInnerLoop_CUDA) {
|
|||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
# CHECK-NOT: if (
|
||||
# CHECK: c[100 * blockIdx.x + threadIdx.x] =
|
||||
# CHECK: c[threadIdx.x + 100 * blockIdx.x] =
|
||||
# CHECK: __syncthreads();
|
||||
# CHECK: if (threadIdx.x<50
|
||||
# CHECK: d[50 * blockIdx.x + threadIdx.x] =)IR";
|
||||
# CHECK: d[threadIdx.x + 50 * blockIdx.x] =)IR";
|
||||
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
|
||||
|
|
@ -1991,10 +1991,10 @@ TEST(Cuda, MaskInnerLoopOneBlock_CUDA) {
|
|||
R"IR(
|
||||
# CHECK: for (int i = 0; i < 10
|
||||
# CHECK-NOT: if (
|
||||
# CHECK: c[100 * i + threadIdx.x] =
|
||||
# CHECK: c[threadIdx.x + 100 * i] =
|
||||
# CHECK: __syncthreads();
|
||||
# CHECK: if (threadIdx.x<50
|
||||
# CHECK: d[50 * i + threadIdx.x] =)IR";
|
||||
# CHECK: d[threadIdx.x + 50 * i] =)IR";
|
||||
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
|
||||
|
|
@ -2119,7 +2119,7 @@ TEST(Cuda, MaskMultiDimMultiAxis_CUDA) {
|
|||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
# CHECK: if (threadIdx.y<1
|
||||
# CHECK: C[30 * blockIdx.x + threadIdx.x] =
|
||||
# CHECK: C[threadIdx.x + 30 * blockIdx.x] =
|
||||
# CHECK: __syncthreads();
|
||||
# CHECK: if (threadIdx.x<1
|
||||
# CHECK: D[threadIdx.y + 15 * blockIdx.x] =)IR";
|
||||
|
|
@ -2250,7 +2250,7 @@ TEST(Cuda, MaskMultiDimMultiLevel_CUDA) {
|
|||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
# CHECK-NOT: if (
|
||||
# CHECK: C[30 * blockIdx.x + threadIdx.x] =
|
||||
# CHECK: C[threadIdx.x + 30 * blockIdx.x] =
|
||||
# CHECK: __syncthreads();
|
||||
# CHECK: if (blockIdx.x<5
|
||||
# CHECK: if (threadIdx.x<15
|
||||
|
|
|
|||
|
|
@ -29,6 +29,17 @@ void checkIR(StmtPtr s, const std::string& pattern) {
|
|||
torch::jit::testing::FileCheck().run(pattern, oss.str());
|
||||
}
|
||||
|
||||
void checkExprIR(ExprPtr e, const std::string& pattern) {
|
||||
std::string prefixed_pattern = "# CHECK: " + pattern + "\n";
|
||||
std::ostringstream oss;
|
||||
oss << *e << "\n";
|
||||
torch::jit::testing::FileCheck().run(prefixed_pattern, oss.str());
|
||||
}
|
||||
|
||||
void checkExprIR(const ExprHandle& e, const std::string& pattern) {
|
||||
checkExprIR(e.node(), pattern);
|
||||
}
|
||||
|
||||
TEST(LoopNest, ExprSimple01) {
|
||||
KernelScope kernel_scope;
|
||||
Tensor* tensor = Compute(
|
||||
|
|
@ -1305,7 +1316,7 @@ TEST(LoopNest, ScheduleInlineRandomUnrelated) {
|
|||
# CHECK: for (int m2 = 0; m2 < 4; m2++)
|
||||
# CHECK: for (int n2 = 0; n2 < 5; n2++)
|
||||
# CHECK: for (int k2 = 0; k2 < 6; k2++)
|
||||
# CHECK: y[m2, n2, k2] = ((n2 * m2) * k2 + (rand())) + (rand());)IR");
|
||||
# CHECK: y[m2, n2, k2] = ((k2 * m2) * n2 + (rand())) + (rand());)IR");
|
||||
}
|
||||
|
||||
// Make sure we generate the right number of random values == the dimensionality
|
||||
|
|
@ -1710,11 +1721,11 @@ TEST(LoopNest, ScheduleInlineOutputTensors) {
|
|||
# CHECK: for (int m1 = 0; m1 < 4; m1++)
|
||||
# CHECK: for (int n1 = 0; n1 < 5; n1++)
|
||||
# CHECK: for (int k1 = 0; k1 < 6; k1++)
|
||||
# CHECK: x[m1, n1, k1] = (n1 * m1) * k1;
|
||||
# CHECK: x[m1, n1, k1] = (k1 * m1) * n1;
|
||||
# CHECK: for (int m2 = 0; m2 < 4; m2++)
|
||||
# CHECK: for (int n2 = 0; n2 < 5; n2++)
|
||||
# CHECK: for (int k2 = 0; k2 < 6; k2++)
|
||||
# CHECK: y[m2, n2, k2] = (n2 * m2) * k2 + m2;)IR");
|
||||
# CHECK: y[m2, n2, k2] = (k2 * m2) * n2 + m2;)IR");
|
||||
}
|
||||
|
||||
TEST(LoopNest, ScheduleFuserStyle) {
|
||||
|
|
@ -2130,7 +2141,7 @@ TEST(LoopNest, Reduce2dComputeAt) {
|
|||
# CHECK: cons[(0 + cy * (1 * W)) + cx * 1] = int(0);
|
||||
# CHECK: for (int r = 0; r < 2; r++) {
|
||||
# CHECK: for (int s = 0; s < 2; s++) {
|
||||
# CHECK: cons[(0 + cy * (1 * W)) + cx * 1] = (cons[(0 + cy * (1 * W)) + cx * 1]) + (temp[(0 + r * (1 * (W + 1))) + (s + cx) * 1]);
|
||||
# CHECK: cons[(0 + cy * (1 * W)) + cx * 1] = (cons[(0 + cy * (1 * W)) + cx * 1]) + (temp[(0 + r * (1 * (W + 1))) + (cx + s) * 1]);
|
||||
# CHECK: }
|
||||
# CHECK: }
|
||||
# CHECK: }
|
||||
|
|
@ -3225,7 +3236,7 @@ TEST(LoopNest, NormalizeStartVariable) {
|
|||
{Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})),
|
||||
Store::make(b_buf, {x}, x * 2)});
|
||||
auto for_stmt = For::make(x, y, 100, for_body);
|
||||
Block::make({for_stmt});
|
||||
auto parent_block = Block::make({for_stmt});
|
||||
|
||||
LoopNest::normalize(for_stmt);
|
||||
|
||||
|
|
@ -3235,8 +3246,8 @@ TEST(LoopNest, NormalizeStartVariable) {
|
|||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
# CHECK: for (int x = 0; x < 100 - y; x++) {
|
||||
# CHECK: A[y + x] = B[y + x];
|
||||
# CHECK: B[y + x] = 2 * (y + x);
|
||||
# CHECK: A[x + y] = B[x + y];
|
||||
# CHECK: B[x + y] = 2 * (x + y);
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
}
|
||||
|
|
@ -3304,7 +3315,7 @@ TEST(LoopNest, NormalizeOnNestedInnerLoop) {
|
|||
R"IR(
|
||||
# CHECK: for (int x = 50; x < 100; x++) {
|
||||
# CHECK: for (int y = 0; y < 90; y++) {
|
||||
# CHECK: A[x] = (((B[y + 10]) + 2 * y) + (A[x])) + 20;
|
||||
# CHECK: A[x] = (((A[x]) + (B[y + 10])) + 2 * y) + 20;
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
}
|
||||
|
|
@ -3327,7 +3338,7 @@ TEST(LoopNest, NormalizeAndSplitWithTail) {
|
|||
BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
|
||||
VarHandle x("x", kInt);
|
||||
auto for_stmt = For::make(x, 5, 10, Store::make(a_buf, {x}, x * 2));
|
||||
Block::make({for_stmt});
|
||||
auto parent_block = Block::make({for_stmt});
|
||||
|
||||
LoopNest::normalize(for_stmt);
|
||||
|
||||
|
|
@ -3373,7 +3384,7 @@ TEST(LoopNest, FlattenSimpleLoopNest2D) {
|
|||
auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
|
||||
auto inner_for = For::make(j, 0, 5, for_body);
|
||||
auto outer_for = For::make(i, 0, 10, inner_for);
|
||||
Block::make({outer_for});
|
||||
auto parent_block = Block::make({outer_for});
|
||||
|
||||
std::vector<ForPtr> loops = {outer_for, inner_for};
|
||||
ForPtr flattened = nullptr;
|
||||
|
|
@ -3420,7 +3431,7 @@ TEST(LoopNest, FlattenSimpleLoopNest3D) {
|
|||
auto for1 = For::make(k, 0, 7, for_body);
|
||||
auto for2 = For::make(j, 0, 5, for1);
|
||||
auto for3 = For::make(i, 0, 10, for2);
|
||||
Block::make({for3});
|
||||
auto parent_block = Block::make({for3});
|
||||
|
||||
std::vector<ForPtr> loops = {for3, for2, for1};
|
||||
ForPtr flattened = nullptr;
|
||||
|
|
@ -3463,7 +3474,7 @@ TEST(LoopNest, FlattenLoopNestAfterNormalize) {
|
|||
auto for_body = Block::make({Store::make(a_buf, {i - 2, j - 3}, i * j)});
|
||||
auto inner_for = For::make(j, 3, 15, for_body);
|
||||
auto outer_for = For::make(i, 2, 10, inner_for);
|
||||
Block::make({outer_for});
|
||||
auto parent_block = Block::make({outer_for});
|
||||
|
||||
std::vector<ForPtr> loops = {outer_for, inner_for};
|
||||
ForPtr flattened = nullptr;
|
||||
|
|
@ -3712,7 +3723,7 @@ TEST(LoopNest, CacheReadsSimple) {
|
|||
#CHECK: A_local[j_1] = A[
|
||||
#CHECK: }
|
||||
#CHECK: for (int j_2
|
||||
#CHECK: B[10 * i_1 + j_2] = A_local[j_2];
|
||||
#CHECK: B[j_2 + 10 * i_1] = A_local[j_2];
|
||||
#CHECK: }
|
||||
#CHECK: }
|
||||
#CHECK: for (int i_2
|
||||
|
|
@ -3769,7 +3780,7 @@ TEST(LoopNest, CacheReadsOuter) {
|
|||
checkIR(result, R"IR(
|
||||
#CHECK: Allocate(A_local); // dtype=int, dims=[21, 11]
|
||||
#CHECK: A_local[j_1 + 11 * i_1] =
|
||||
#CHECK: B[10 * i_2 + j_2] = (A_local[(j_2 + 11 * i_2) + 12]) + (A_local[j_2 + 11 * i_2]);
|
||||
#CHECK: B[j_2 + 10 * i_2] = (A_local[j_2 + 11 * i_2]) + (A_local[(j_2 + 11 * i_2) + 12]);
|
||||
)IR");
|
||||
|
||||
std::vector<int> b_data(200, 0);
|
||||
|
|
@ -3816,7 +3827,7 @@ TEST(LoopNest, CacheReadsInternal) {
|
|||
checkIR(result, R"IR(
|
||||
#CHECK: Allocate(A_local); // dtype=int, dims=[2, 11]
|
||||
#CHECK: A_local[j_1 + 11 * i_2] =
|
||||
#CHECK: B[10 * i_1 + j_2] = (A_local[j_2 + 12]) + (A_local[j_2]);
|
||||
#CHECK: B[j_2 + 10 * i_1] = (A_local[j_2 + 12]) + (A_local[j_2]);
|
||||
)IR");
|
||||
|
||||
std::vector<int> b_data(200, 0);
|
||||
|
|
@ -3863,8 +3874,8 @@ TEST(LoopNest, CacheReadsInner) {
|
|||
|
||||
checkIR(result, R"IR(
|
||||
#CHECK: Allocate(A_local); // dtype=int, dims=[5, 2]
|
||||
#CHECK: A_local[2 * i_2 + j_2] =
|
||||
#CHECK: B[10 * i_1 + j_1] = (A_local[1]) + (A_local[8]);
|
||||
#CHECK: A_local[j_2 + 2 * i_2] =
|
||||
#CHECK: B[j_1 + 10 * i_1] = (A_local[1]) + (A_local[8]);
|
||||
)IR");
|
||||
|
||||
std::vector<int> b_data(200, 0);
|
||||
|
|
@ -3914,7 +3925,7 @@ TEST(LoopNest, CacheWritesSimple) {
|
|||
#CHECK: for (int j = 0; j < 64
|
||||
#CHECK: A_local[j] = i * j;
|
||||
#CHECK: for (int j_1 = 0; j_1 < 64
|
||||
#CHECK: A[64 * i + j_1] = A_local[
|
||||
#CHECK: A[j_1 + 64 * i] = A_local[
|
||||
#CHECK: Free(A_local);
|
||||
#CHECK-NOT: A_local
|
||||
)IR");
|
||||
|
|
|
|||
|
|
@ -1578,8 +1578,8 @@ TEST(Reductions, ReductionCacheBodyAccess) {
|
|||
#CHECK: Allocate(scale_local); // dtype=float, dims=[1, 32, 12]
|
||||
#CHECK: for (int j = 0; j < 32; j++) {
|
||||
#CHECK: for (int k = 0; k < 12; k++) {
|
||||
#CHECK: scale_local[k + 12 * j] = scale[(k + 384 * l1) + 12 * j];
|
||||
#CHECK: sum[l1] = (sum[l1]) + (scale_local[12 * n1_1 + m1_1]);
|
||||
#CHECK: scale_local[k + 12 * j] = scale[(k + 12 * j) + 384 * l1];
|
||||
#CHECK: sum[l1] = (sum[l1]) + (scale_local[m1_1 + 12 * n1_1]);
|
||||
#CHECK: scale_1[l] = (b[l]) * (sum[l]);
|
||||
#CHECK: Free(scale_local);
|
||||
)IR";
|
||||
|
|
@ -1667,7 +1667,7 @@ TEST(Reductions, ReductionSplitCacheConsumerAccess) {
|
|||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: Allocate(sum_local); // dtype=float, dims=[4]
|
||||
#CHECK: sum[l1_inner + 4 * l1_outer] = (sum[l1_inner + 4 * l1_outer]) + (scale[((12 * n1_1 + 384 * l1_inner) + m1_1) + 1536 * l1_outer]);
|
||||
#CHECK: sum[l1_inner + 4 * l1_outer] = (sum[l1_inner + 4 * l1_outer]) + (scale[((m1_1 + 12 * n1_1) + 1536 * l1_outer) + 384 * l1_inner]);
|
||||
#CHECK: for (int i = 0; i < 4
|
||||
#CHECK: sum_local[i] = sum[i + 4 * l_outer];
|
||||
#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]);
|
||||
|
|
@ -1716,7 +1716,7 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) {
|
|||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: Allocate(sum_local); // dtype=float, dims=[4]
|
||||
#CHECK: sum[l1] = (sum[l1]) + (scale[(12 * n1_1 + m1_1) + 384 * l1]);
|
||||
#CHECK: sum[l1] = (sum[l1]) + (scale[(m1_1 + 12 * n1_1) + 384 * l1]);
|
||||
#CHECK: for (int i = 0; i < 4
|
||||
#CHECK: sum_local[i] = sum[i + 4 * l_outer];
|
||||
#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]);
|
||||
|
|
|
|||
|
|
@ -192,8 +192,8 @@ TEST(Registerizer, RegisterizerLoopInternal) {
|
|||
R"IR(
|
||||
# CHECK: for (int x = 0; x < 10; x++)
|
||||
# CHECK: int A_1 = A[x];
|
||||
# CHECK: A_1 = x + A_1;
|
||||
# CHECK: A_1 = x + A_1;
|
||||
# CHECK: A_1 = A_1 + x;
|
||||
# CHECK: A_1 = A_1 + x;
|
||||
# CHECK: A[x] = A_1;
|
||||
# CHECK: })IR";
|
||||
|
||||
|
|
@ -273,12 +273,12 @@ TEST(Registerizer, RegisterizerLoopInternalRepeated) {
|
|||
* int A_1 = A[1];
|
||||
* int A_2 = A[0];
|
||||
* for (int x = 0; x < 10; x++) {
|
||||
* A_2 = x + A_1;
|
||||
* A_2 = x + A_1;
|
||||
* A_2 = A_1 + x;
|
||||
* A_2 = A_1 + x;
|
||||
* }
|
||||
* for (int x = 0; x < 10; x++) {
|
||||
* A_2 = x + A_1;
|
||||
* A_2 = x + A_1;
|
||||
* A_2 = A_1 + x;
|
||||
* A_2 = A_1 + x;
|
||||
* }
|
||||
* A[0] = A_2;
|
||||
*/
|
||||
|
|
@ -291,12 +291,12 @@ TEST(Registerizer, RegisterizerLoopInternalRepeated) {
|
|||
# CHECK: int A_1 = A[1];
|
||||
# CHECK: int A_2 = A[0];
|
||||
# CHECK: for (int x = 0; x < 10; x++)
|
||||
# CHECK: A_2 = x + A_1;
|
||||
# CHECK: A_2 = x + A_1;
|
||||
# CHECK: A_2 = A_1 + x;
|
||||
# CHECK: A_2 = A_1 + x;
|
||||
# CHECK: }
|
||||
# CHECK: for (int x = 0; x < 10; x++)
|
||||
# CHECK: A_2 = x + A_1;
|
||||
# CHECK: A_2 = x + A_1;
|
||||
# CHECK: A_2 = A_1 + x;
|
||||
# CHECK: A_2 = A_1 + x;
|
||||
# CHECK: }
|
||||
# CHECK-NOT: A[1]
|
||||
# CHECK: A[0] = A_2;
|
||||
|
|
@ -357,7 +357,7 @@ TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) {
|
|||
BufHandle a("A", {1}, kInt);
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
StmtPtr stmt = Block::make(
|
||||
StmtPtr stmt = IRSimplifier::simplify(Block::make(
|
||||
{For::make(
|
||||
x,
|
||||
0,
|
||||
|
|
@ -373,7 +373,7 @@ TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) {
|
|||
{Store::make(a, {0}, Add::make(x, Load::make(a, {y}))),
|
||||
Store::make(a, {0}, Add::make(x, Load::make(a, {y})))}))
|
||||
|
||||
});
|
||||
}));
|
||||
|
||||
/*
|
||||
* for (int x = 0; x < 10; x++) {
|
||||
|
|
@ -2044,7 +2044,7 @@ TEST(Registerizer, RegisterizerPartialAfter) {
|
|||
/*
|
||||
* int A_1 = 0;
|
||||
* for (int x = 0; x < 10; x++) {
|
||||
* A_1 = x + A_1;
|
||||
* A_1 = A_1 + x;
|
||||
* }
|
||||
* A[0] = A_1;
|
||||
* for (int x = 1; x < 10; x++) {
|
||||
|
|
@ -2059,7 +2059,7 @@ TEST(Registerizer, RegisterizerPartialAfter) {
|
|||
R"IR(
|
||||
# CHECK: int A_1 = 0;
|
||||
# CHECK: for (
|
||||
# CHECK: A_1 = x + A_1;
|
||||
# CHECK: A_1 = A_1 + x;
|
||||
# CHECK: }
|
||||
# CHECK: A[0] = A_1;
|
||||
# CHECK: for (
|
||||
|
|
@ -2104,7 +2104,7 @@ TEST(Registerizer, RegisterizerPartialBefore) {
|
|||
* }
|
||||
* int A_1 = 0;
|
||||
* for (int x = 0; x < 10; x++) {
|
||||
* A_1 = x + A_1;
|
||||
* A_1 = A_1 + x;
|
||||
* }
|
||||
* A[0] = A_1;
|
||||
*/
|
||||
|
|
@ -2120,7 +2120,7 @@ TEST(Registerizer, RegisterizerPartialBefore) {
|
|||
# CHECK: }
|
||||
# CHECK: int A_1 = 0;
|
||||
# CHECK: for (
|
||||
# CHECK: A_1 = x + A_1;
|
||||
# CHECK: A_1 = A_1 + x;
|
||||
# CHECK: }
|
||||
# CHECK: A[0] = A_1;)IR";
|
||||
|
||||
|
|
@ -2161,7 +2161,7 @@ TEST(Registerizer, RegisterizerPartialInside) {
|
|||
/*
|
||||
* int A_1 = 2;
|
||||
* for (int x1 = 0; x1 < 10; x1++) {
|
||||
* A_1 = x1 + A_1;
|
||||
* A_1 = A_1 + x1;
|
||||
* }
|
||||
* A[0] = A_1;
|
||||
* for (int x2 = 1; x2 < 10; x2++) {
|
||||
|
|
@ -2169,7 +2169,7 @@ TEST(Registerizer, RegisterizerPartialInside) {
|
|||
* }
|
||||
* int A_2 = A[0];
|
||||
* for (int x3 = 0; x3 < 10; x3++) {
|
||||
* A_2 = x3 + A_2;
|
||||
* A_2 = A_2 + x3;
|
||||
* }
|
||||
* A[0] = A_2;
|
||||
*/
|
||||
|
|
@ -2181,7 +2181,7 @@ TEST(Registerizer, RegisterizerPartialInside) {
|
|||
R"IR(
|
||||
# CHECK: int A_1 = 2;
|
||||
# CHECK: for (
|
||||
# CHECK: A_1 = x1 + A_1;
|
||||
# CHECK: A_1 = A_1 + x1;
|
||||
# CHECK: }
|
||||
# CHECK: A[0] = A_1;
|
||||
# CHECK: for (
|
||||
|
|
@ -2189,7 +2189,7 @@ TEST(Registerizer, RegisterizerPartialInside) {
|
|||
# CHECK: }
|
||||
# CHECK: int A_2 = A[0];
|
||||
# CHECK: for (
|
||||
# CHECK: A_2 = x3 + A_2;
|
||||
# CHECK: A_2 = A_2 + x3;
|
||||
# CHECK: }
|
||||
# CHECK: A[0] = A_2;)IR";
|
||||
|
||||
|
|
@ -2232,7 +2232,7 @@ TEST(Registerizer, RegisterizerPartialCondition) {
|
|||
/*
|
||||
* int A_1 = 2;
|
||||
* for (int x = 0; x < 10; x++) {
|
||||
* A_1 = x + A_1;
|
||||
* A_1 = A_1 + x;
|
||||
* }
|
||||
* A[0] = A_1;
|
||||
* if (x<5 ? 1 : 0) {
|
||||
|
|
@ -2240,7 +2240,7 @@ TEST(Registerizer, RegisterizerPartialCondition) {
|
|||
* }
|
||||
* int A_2 = A[0];
|
||||
* for (int x = 0; x < 10; x++) {
|
||||
* A_2 = x + A_2;
|
||||
* A_2 = A_2 + x;
|
||||
* }
|
||||
* A[0] = A_2;
|
||||
*/
|
||||
|
|
@ -2252,7 +2252,7 @@ TEST(Registerizer, RegisterizerPartialCondition) {
|
|||
R"IR(
|
||||
# CHECK: int A_1 = 2;
|
||||
# CHECK: for (
|
||||
# CHECK: A_1 = x + A_1;
|
||||
# CHECK: A_1 = A_1 + x;
|
||||
# CHECK: }
|
||||
# CHECK: A[0] = A_1;
|
||||
# CHECK: if (
|
||||
|
|
@ -2260,7 +2260,7 @@ TEST(Registerizer, RegisterizerPartialCondition) {
|
|||
# CHECK: }
|
||||
# CHECK: int A_2 = A[0];
|
||||
# CHECK: for (
|
||||
# CHECK: A_2 = x + A_2;
|
||||
# CHECK: A_2 = A_2 + x;
|
||||
# CHECK: }
|
||||
# CHECK: A[0] = A_2;)IR";
|
||||
|
||||
|
|
@ -2937,7 +2937,7 @@ TEST(Registerizer, RegisterizerNestedLoopSimple) {
|
|||
* for (int y = 0; y < 10; y++) {
|
||||
* int A_1 = A[y];
|
||||
* for (int x = 0; x < 10; x++) {
|
||||
* A_1 = x + A_1;
|
||||
* A_1 = A_1 + x;
|
||||
* }
|
||||
* A[y] = A_1;
|
||||
* }
|
||||
|
|
@ -2951,7 +2951,7 @@ TEST(Registerizer, RegisterizerNestedLoopSimple) {
|
|||
# CHECK: for (int y
|
||||
# CHECK: int A_1 = A[y];
|
||||
# CHECK: for (int x
|
||||
# CHECK: A_1 = x + A_1;
|
||||
# CHECK: A_1 = A_1 + x;
|
||||
# CHECK: }
|
||||
# CHECK: A[y] = A_1;
|
||||
# CHECK: })IR";
|
||||
|
|
@ -3366,13 +3366,13 @@ TEST(Registerizer, RegisterizerLoopLetVar) {
|
|||
BufHandle a("A", {10}, kInt);
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
StmtPtr stmt = Block::make({For::make(
|
||||
StmtPtr stmt = IRSimplifier::simplify(Block::make({For::make(
|
||||
x,
|
||||
0,
|
||||
10,
|
||||
Block::make(
|
||||
{Let::make(y, 30),
|
||||
Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))});
|
||||
Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))}));
|
||||
|
||||
/*
|
||||
* for (int x = 0; x < 10; x++) {
|
||||
|
|
@ -3422,7 +3422,7 @@ TEST(Registerizer, RegisterizerLoopLetVarOuter) {
|
|||
* int y = 30;
|
||||
* int A_1 = A[y];
|
||||
* for (int x = 0; x < 10; x++) {
|
||||
* A_1 = x + A_1;
|
||||
* A_1 = A_1 + x;
|
||||
* }
|
||||
* A[y] = A_1;
|
||||
*/
|
||||
|
|
@ -3435,7 +3435,7 @@ TEST(Registerizer, RegisterizerLoopLetVarOuter) {
|
|||
# CHECK: int y = 30;
|
||||
# CHECK: int A_1 = A[y];
|
||||
# CHECK: for (int x
|
||||
# CHECK: A_1 = x + A_1;
|
||||
# CHECK: A_1 = A_1 + x;
|
||||
# CHECK: A[y] = A_1;)IR";
|
||||
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
|
|
@ -3516,7 +3516,7 @@ TEST(Registerizer, RegisterizerMultiDimPartial) {
|
|||
* int A_1 = A[0, 1, 4];
|
||||
* int A_2 = A[0, 2, 2];
|
||||
* for (int x = 0; x < 10; x++) {
|
||||
* A_2 = x + A_1;
|
||||
* A_2 = A_1 + x;
|
||||
* }
|
||||
* A[0, 2, 2] = A_2;
|
||||
*/
|
||||
|
|
@ -3530,7 +3530,7 @@ TEST(Registerizer, RegisterizerMultiDimPartial) {
|
|||
# CHECK: int A_1 = A[0, 1, 4];
|
||||
# CHECK: int A_2 = A[0, 2, 2];
|
||||
# CHECK: for (
|
||||
# CHECK: A_2 = x + A_1;
|
||||
# CHECK: A_2 = A_1 + x;
|
||||
# CHECK: A[0, 2, 2] = A_2;)IR";
|
||||
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
|
|
@ -3599,7 +3599,7 @@ TEST(Registerizer, RegisterizerMultiDimPartialOverlap) {
|
|||
* A[0, 1, 2] = 0;
|
||||
* int A_1 = A[y, 2, 4];
|
||||
* for (int x = 0; x < 10; x++) {
|
||||
* A[0, x, 2] = x + A_1;
|
||||
* A[0, x, 2] = A_1 + x;
|
||||
* }
|
||||
*/
|
||||
|
||||
|
|
@ -3611,7 +3611,7 @@ TEST(Registerizer, RegisterizerMultiDimPartialOverlap) {
|
|||
# CHECK: A[0, 1, 2] = 0;
|
||||
# CHECK: int A_1 = A[y, 2, 4];
|
||||
# CHECK: for (
|
||||
# CHECK: A[0, x, 2] = x + A_1;
|
||||
# CHECK: A[0, x, 2] = A_1 + x;
|
||||
# CHECK: })IR";
|
||||
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
|
|
@ -3736,12 +3736,12 @@ TEST(Registerizer, RegisterizerMultiDim3DReduction2) {
|
|||
|
||||
/*
|
||||
* for (int x = 0; x < 10; x++) {
|
||||
* int C_1 = C[x];
|
||||
* int A_1 = A[x];
|
||||
* int C_1 = C[x];
|
||||
* for (int y = 0; y < 10; y++) {
|
||||
* int B_1 = B[y];
|
||||
* for (int z = 0; z < 10; z++) {
|
||||
* C_1 = C_1 + A_1 * B_1;
|
||||
* C_1 = A_1 * B_1 + C_1;
|
||||
* }
|
||||
* }
|
||||
* C[x] = C_1;
|
||||
|
|
@ -3754,12 +3754,12 @@ TEST(Registerizer, RegisterizerMultiDim3DReduction2) {
|
|||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
# CHECK: for (int x
|
||||
# CHECK: int C_1 = C[x];
|
||||
# CHECK: int A_1 = A[x];
|
||||
# CHECK: int C_1 = C[x];
|
||||
# CHECK: for (int y
|
||||
# CHECK: int B_1 = B[y];
|
||||
# CHECK: for (int z
|
||||
# CHECK: C_1 = C_1 + A_1 * B_1;
|
||||
# CHECK: C_1 = A_1 * B_1 + C_1;
|
||||
# CHECK: }
|
||||
# CHECK: }
|
||||
# CHECK: C[x] = C_1;
|
||||
|
|
|
|||
|
|
@ -649,12 +649,12 @@ TEST(Simplify, SimplifyMultiVar) {
|
|||
ASSERT_NE(lhs, nullptr);
|
||||
VarPtr varX = to<Var>(lhs->rhs());
|
||||
ASSERT_NE(varX, nullptr);
|
||||
ASSERT_EQ(varX->name_hint(), "y");
|
||||
ASSERT_EQ(varX->name_hint(), "x");
|
||||
MulPtr rhs = to<Mul>(root->rhs());
|
||||
ASSERT_NE(rhs, nullptr);
|
||||
VarPtr varY = to<Var>(rhs->rhs());
|
||||
ASSERT_NE(varY, nullptr);
|
||||
ASSERT_EQ(varY->name_hint(), "x");
|
||||
ASSERT_EQ(varY->name_hint(), "y");
|
||||
}
|
||||
|
||||
// x + 2 + y => x + y + 2
|
||||
|
|
@ -698,8 +698,8 @@ TEST(Simplify, SimplifyAdds) {
|
|||
IS_NODE_WITH_NAME(Mul, simplified.node(), root);
|
||||
IS_IMM_WITH_VAL(Int, root->lhs(), 2);
|
||||
IS_NODE_WITH_NAME(Add, root->rhs(), add);
|
||||
IS_VAR_WITH_NAME(add->lhs(), "y");
|
||||
IS_VAR_WITH_NAME(add->rhs(), "x");
|
||||
IS_VAR_WITH_NAME(add->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(add->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -770,11 +770,11 @@ TEST(Simplify, SimplifyMuls) {
|
|||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
|
||||
IS_NODE_WITH_NAME(Add, mul->lhs(), lhs);
|
||||
IS_VAR_WITH_NAME(lhs->lhs(), "y");
|
||||
IS_VAR_WITH_NAME(lhs->rhs(), "x");
|
||||
IS_VAR_WITH_NAME(lhs->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(lhs->rhs(), "y");
|
||||
IS_NODE_WITH_NAME(Add, mul->rhs(), rhs);
|
||||
IS_VAR_WITH_NAME(rhs->lhs(), "y");
|
||||
IS_VAR_WITH_NAME(rhs->rhs(), "x");
|
||||
IS_VAR_WITH_NAME(rhs->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(rhs->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -867,8 +867,8 @@ TEST(Simplify, SimplifyMuls) {
|
|||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
|
||||
IS_NODE_WITH_NAME(Add, mul->lhs(), lhs);
|
||||
IS_VAR_WITH_NAME(lhs->lhs(), "y");
|
||||
IS_VAR_WITH_NAME(lhs->rhs(), "x");
|
||||
IS_VAR_WITH_NAME(lhs->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(lhs->rhs(), "y");
|
||||
IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
|
||||
IS_VAR_WITH_NAME(rhs->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(rhs->rhs(), "y");
|
||||
|
|
@ -1654,14 +1654,14 @@ TEST(Simplify, SimplifyMultiOp) {
|
|||
}
|
||||
|
||||
{
|
||||
// (x + y) - (x * y) => x + y - (x * y)
|
||||
ExprHandle body = (x + y) - (x * y);
|
||||
// (x + y) - x * y => (x + y) - x * y
|
||||
ExprHandle body = (x + y) - x * y;
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
|
||||
IS_NODE_WITH_NAME(Add, sub->lhs(), add);
|
||||
IS_NODE_WITH_NAME(Mul, sub->rhs(), mul);
|
||||
IS_VAR_WITH_NAME(add->lhs(), "y");
|
||||
IS_VAR_WITH_NAME(add->rhs(), "x");
|
||||
IS_VAR_WITH_NAME(add->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(add->rhs(), "y");
|
||||
IS_VAR_WITH_NAME(mul->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(mul->rhs(), "y");
|
||||
}
|
||||
|
|
@ -1709,19 +1709,19 @@ TEST(Simplify, SimplifyManyOps) {
|
|||
VarHandle y("y", kInt);
|
||||
|
||||
{
|
||||
// x + y + x + x + y + y + x + y + x = 5 * x + 4 * y
|
||||
// x + y + x + x + y + y + x + y + x = 4 * y + 5 * x
|
||||
ExprHandle body = x + y + x + x + y + y + x + y + x;
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Add, simplified.node(), add);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
|
||||
IS_IMM_WITH_VAL(Int, lhs->lhs(), 5);
|
||||
IS_VAR_WITH_NAME(lhs->rhs(), "x");
|
||||
IS_IMM_WITH_VAL(Int, lhs->lhs(), 4);
|
||||
IS_VAR_WITH_NAME(lhs->rhs(), "y");
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
|
||||
IS_IMM_WITH_VAL(Int, rhs->lhs(), 4);
|
||||
IS_VAR_WITH_NAME(rhs->rhs(), "y");
|
||||
IS_IMM_WITH_VAL(Int, rhs->lhs(), 5);
|
||||
IS_VAR_WITH_NAME(rhs->rhs(), "x");
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -1765,8 +1765,8 @@ TEST(Simplify, SimplifyFactorization) {
|
|||
IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
|
||||
|
||||
IS_NODE_WITH_NAME(Add, mul->rhs(), add);
|
||||
IS_VAR_WITH_NAME(add->lhs(), "y");
|
||||
IS_VAR_WITH_NAME(add->rhs(), "x");
|
||||
IS_VAR_WITH_NAME(add->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(add->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -1794,12 +1794,12 @@ TEST(Simplify, SimplifyFactorization) {
|
|||
IS_NODE_WITH_NAME(Add, simplified.node(), add);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
|
||||
IS_IMM_WITH_VAL(Int, lhs->lhs(), 5);
|
||||
IS_VAR_WITH_NAME(lhs->rhs(), "y");
|
||||
IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
|
||||
IS_VAR_WITH_NAME(lhs->rhs(), "x");
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
|
||||
IS_IMM_WITH_VAL(Int, rhs->lhs(), 2);
|
||||
IS_VAR_WITH_NAME(rhs->rhs(), "x");
|
||||
IS_IMM_WITH_VAL(Int, rhs->lhs(), 5);
|
||||
IS_VAR_WITH_NAME(rhs->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -1813,8 +1813,8 @@ TEST(Simplify, SimplifyFactorization) {
|
|||
IS_IMM_WITH_VAL(Int, mul->lhs(), 10);
|
||||
|
||||
IS_NODE_WITH_NAME(Add, mul->rhs(), add);
|
||||
IS_VAR_WITH_NAME(add->lhs(), "y");
|
||||
IS_VAR_WITH_NAME(add->rhs(), "x");
|
||||
IS_VAR_WITH_NAME(add->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(add->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -1863,18 +1863,12 @@ TEST(Simplify, SimplifyFactorization) {
|
|||
VarHandle g("g", kInt);
|
||||
VarHandle h("h", kInt);
|
||||
|
||||
ExprHandle body = ExprHandle(0) + (ExprHandle(1024) * a) +
|
||||
(ExprHandle(-1) * b) + (ExprHandle(-1) * c) + (ExprHandle(1) * d) +
|
||||
(ExprHandle(1) * e) + (ExprHandle(32) * f) + (ExprHandle(-1024) * g) +
|
||||
(ExprHandle(-32) * h);
|
||||
ExprHandle body = a * 1024 + 0 + b * (-1) + c * (-1) + d * 1 + e * 1 +
|
||||
f * 32 + g * (-1024) + h * (-32);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
// We only check for the top level nodes here, since the main purpose
|
||||
// here is ensure that this simplification completes.
|
||||
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
|
||||
IS_NODE_WITH_NAME(Mul, sub->rhs(), mul);
|
||||
IS_IMM_WITH_VAL(Int, mul->lhs(), 1024);
|
||||
IS_VAR_WITH_NAME(mul->rhs(), "g");
|
||||
checkExprIR(
|
||||
simplified,
|
||||
"((((((d + e) + 1024 * a) + 32 * f) - b) - c) - 1024 * g) - 32 * h");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1904,7 +1898,7 @@ TEST(Simplify, SimplifyFactorizeUneven) {
|
|||
IS_VAR_WITH_NAME(zmul->rhs(), "z");
|
||||
}
|
||||
|
||||
// (x * y) + (2 * x) * (x + y) => 3 * (x * y) + 2 * (x * x)
|
||||
// (x * y) + (2 * x) * (x + y) => 2 * (x * x) + 3 * (x * y)
|
||||
// This is kind of a placeholder test for variable factorization.
|
||||
TEST(Simplify, SimplifyDeeperTerms) {
|
||||
KernelScope kernel_scope;
|
||||
|
|
@ -1916,16 +1910,16 @@ TEST(Simplify, SimplifyDeeperTerms) {
|
|||
IS_NODE_WITH_NAME(Add, simplified.node(), add);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
|
||||
IS_IMM_WITH_VAL(Int, lhs->lhs(), 3);
|
||||
IS_NODE_WITH_NAME(Mul, lhs->rhs(), xyTerm);
|
||||
IS_VAR_WITH_NAME(xyTerm->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(xyTerm->rhs(), "y");
|
||||
IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
|
||||
IS_NODE_WITH_NAME(Mul, lhs->rhs(), xxTerm);
|
||||
IS_VAR_WITH_NAME(xxTerm->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(xxTerm->rhs(), "x");
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
|
||||
IS_IMM_WITH_VAL(Int, rhs->lhs(), 2);
|
||||
IS_NODE_WITH_NAME(Mul, rhs->rhs(), xxTerm);
|
||||
IS_VAR_WITH_NAME(xxTerm->rhs(), "x");
|
||||
IS_VAR_WITH_NAME(xxTerm->rhs(), "x");
|
||||
IS_IMM_WITH_VAL(Int, rhs->lhs(), 3);
|
||||
IS_NODE_WITH_NAME(Mul, rhs->rhs(), xyTerm);
|
||||
IS_VAR_WITH_NAME(xyTerm->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(xyTerm->rhs(), "y");
|
||||
}
|
||||
|
||||
// Tests the difference between two less trivial expressions.
|
||||
|
|
@ -1987,15 +1981,15 @@ TEST(Simplify, SimplifyOpaqueTerms) {
|
|||
VarHandle y("y", kInt);
|
||||
|
||||
{
|
||||
// 2 * x/y * x - x/y * y => y * x/y
|
||||
// 2 * x/y * y - x/y * y => x/y * y
|
||||
ExprHandle body = ((ExprHandle(2)) * (x / y) * y) - ((x / y) * y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
|
||||
IS_VAR_WITH_NAME(mul->lhs(), "y");
|
||||
IS_NODE_WITH_NAME(Div, mul->rhs(), div);
|
||||
IS_NODE_WITH_NAME(Div, mul->lhs(), div);
|
||||
IS_VAR_WITH_NAME(div->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(div->rhs(), "y");
|
||||
IS_VAR_WITH_NAME(mul->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -2055,46 +2049,46 @@ TEST(Simplify, SimplifyNestedMax) {
|
|||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
IS_BINOP_W_VARS(Add, simplified.node(), add, "y", "x");
|
||||
IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y");
|
||||
}
|
||||
|
||||
{
|
||||
// Max(x + y, Max(x + y, z)) => Max(y + x, z)
|
||||
// Max(x + y, Max(x + y, z)) => Max(x + y, z)
|
||||
ExprHandle body = Max::make(x + y, Max::make(x + y, z, true), true);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Max, simplified.node(), max);
|
||||
IS_BINOP_W_VARS(Add, max->lhs(), add, "y", "x");
|
||||
IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y");
|
||||
IS_VAR_WITH_NAME(max->rhs(), "z");
|
||||
}
|
||||
|
||||
{
|
||||
// Max(x + y, Max(z, x + y)) => Max(y + x, z)
|
||||
// Max(x + y, Max(z, x + y)) => Max(x + y, z)
|
||||
ExprHandle body = Max::make(x + y, Max::make(z, x + y, true), true);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Max, simplified.node(), max);
|
||||
IS_BINOP_W_VARS(Add, max->lhs(), add, "y", "x");
|
||||
IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y");
|
||||
IS_VAR_WITH_NAME(max->rhs(), "z");
|
||||
}
|
||||
|
||||
{
|
||||
// Max(Max(x + y, z), x + y) => Max(y + x, z)
|
||||
// Max(Max(x + y, z), x + y) => Max(x + y, z)
|
||||
ExprHandle body = Max::make(Max::make(x + y, z, true), x + y, true);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Max, simplified.node(), max);
|
||||
IS_BINOP_W_VARS(Add, max->lhs(), add, "y", "x");
|
||||
IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y");
|
||||
IS_VAR_WITH_NAME(max->rhs(), "z");
|
||||
}
|
||||
|
||||
{
|
||||
// Max(Max(z, x + y), x + y) => Max(y + x, z)
|
||||
// Max(Max(z, x + y), x + y) => Max(x + y, z)
|
||||
ExprHandle body = Max::make(Max::make(z, x + y, true), x + y, true);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Max, simplified.node(), max);
|
||||
IS_BINOP_W_VARS(Add, max->lhs(), add, "y", "x");
|
||||
IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y");
|
||||
IS_VAR_WITH_NAME(max->rhs(), "z");
|
||||
}
|
||||
|
||||
|
|
@ -2112,55 +2106,39 @@ TEST(Simplify, SimplifyNestedMax) {
|
|||
}
|
||||
|
||||
{
|
||||
// Max(Min(x, y), Min(x, z)) => Min(x, Max(y, z))
|
||||
// Max(Min(x, y), Min(x, z)) => Min(Max(y, z), x)
|
||||
ExprHandle body =
|
||||
Max::make(Min::make(x, y, true), Min::make(x, z, true), true);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Min, simplified.node(), min);
|
||||
IS_VAR_WITH_NAME(min->lhs(), "x");
|
||||
IS_BINOP_W_VARS(Max, min->rhs(), max, "y", "z");
|
||||
ASSERT_TRUE(max->propagate_nans());
|
||||
checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)");
|
||||
}
|
||||
|
||||
{
|
||||
// Max(Min(x, y), Min(z, x)) => Min(x, Max(y, z))
|
||||
// Max(Min(x, y), Min(z, x)) => Min(Max(y, z), x)
|
||||
ExprHandle body =
|
||||
Max::make(Min::make(x, y, true), Min::make(z, x, true), true);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Min, simplified.node(), min);
|
||||
IS_VAR_WITH_NAME(min->lhs(), "x");
|
||||
IS_BINOP_W_VARS(Max, min->rhs(), max, "y", "z");
|
||||
ASSERT_TRUE(max->propagate_nans());
|
||||
checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)");
|
||||
}
|
||||
|
||||
{
|
||||
// Max(Min(y, x), Min(x, z)) => Min(x, Max(y, z))
|
||||
// Max(Min(y, x), Min(x, z)) => Min(Max(y, z), x)
|
||||
ExprHandle body =
|
||||
Max::make(Min::make(y, x, true), Min::make(x, z, true), true);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Min, simplified.node(), min);
|
||||
IS_VAR_WITH_NAME(min->lhs(), "x");
|
||||
IS_BINOP_W_VARS(Max, min->rhs(), max, "y", "z");
|
||||
ASSERT_TRUE(max->propagate_nans());
|
||||
checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)");
|
||||
}
|
||||
|
||||
{
|
||||
// Max(Min(y, x), Min(z, x)) => Min(x, Max(y, z))
|
||||
// Max(Min(y, x), Min(z, x)) => Min(Max(y, z), x)
|
||||
ExprHandle body =
|
||||
Max::make(Min::make(y, x, true), Min::make(z, x, true), true);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Min, simplified.node(), min);
|
||||
IS_VAR_WITH_NAME(min->lhs(), "x");
|
||||
IS_BINOP_W_VARS(Max, min->rhs(), max, "y", "z");
|
||||
ASSERT_TRUE(max->propagate_nans());
|
||||
checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)");
|
||||
}
|
||||
|
||||
{
|
||||
// Max(Min(y, x), Min(z, x)) => Max(Min(x, z), Min(x, y))
|
||||
// Max(Min(y, x), Min(z, x)) => Max(Min(x, y), Min(x, z))
|
||||
// When all the ops in the pattern do not have the same propagate_nans,
|
||||
// it should not be simplified.
|
||||
ExprHandle body =
|
||||
|
|
@ -2168,10 +2146,10 @@ TEST(Simplify, SimplifyNestedMax) {
|
|||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Max, simplified.node(), max);
|
||||
IS_BINOP_W_VARS(Min, max->lhs(), min1, "x", "z");
|
||||
ASSERT_FALSE(min1->propagate_nans());
|
||||
IS_BINOP_W_VARS(Min, max->rhs(), min2, "x", "y");
|
||||
ASSERT_TRUE(min2->propagate_nans());
|
||||
IS_BINOP_W_VARS(Min, max->lhs(), min1, "x", "y");
|
||||
ASSERT_TRUE(min1->propagate_nans());
|
||||
IS_BINOP_W_VARS(Min, max->rhs(), min2, "x", "z");
|
||||
ASSERT_FALSE(min2->propagate_nans());
|
||||
ASSERT_TRUE(max->propagate_nans());
|
||||
}
|
||||
|
||||
|
|
@ -2304,18 +2282,7 @@ TEST(Simplify, SimplifyNestedMax) {
|
|||
8,
|
||||
false);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Max, simplified.node(), max1);
|
||||
IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
|
||||
IS_VAR_WITH_NAME(max2->lhs(), "x");
|
||||
IS_NODE_WITH_NAME(Max, max2->rhs(), max3);
|
||||
IS_BINOP_W_CONST(Max, max3->lhs(), max4, "z", 5);
|
||||
ASSERT_TRUE(max4->propagate_nans());
|
||||
IS_VAR_WITH_NAME(max3->rhs(), "y");
|
||||
ASSERT_FALSE(max3->propagate_nans());
|
||||
ASSERT_TRUE(max2->propagate_nans());
|
||||
IS_IMM_WITH_VAL(Int, max1->rhs(), 8);
|
||||
ASSERT_FALSE(max1->propagate_nans());
|
||||
checkExprIR(simplified, "Max(Max(Max(Max(z, 5, 1), y, 0), x, 1), 8, 0)");
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -2359,46 +2326,46 @@ TEST(Simplify, SimplifyNestedMin) {
|
|||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
IS_BINOP_W_VARS(Add, simplified.node(), add, "y", "x");
|
||||
IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y");
|
||||
}
|
||||
|
||||
{
|
||||
// Min(x + y, Min(x + y, z)) => Min(y + x, z)
|
||||
// Min(x + y, Min(x + y, z)) => Min(x + y, z)
|
||||
ExprHandle body = Min::make(x + y, Min::make(x + y, z, true), true);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Min, simplified.node(), min);
|
||||
IS_BINOP_W_VARS(Add, min->lhs(), add, "y", "x");
|
||||
IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y");
|
||||
IS_VAR_WITH_NAME(min->rhs(), "z");
|
||||
}
|
||||
|
||||
{
|
||||
// Min(x + y, Min(z, x + y)) => Min(y + x, z)
|
||||
// Min(x + y, Min(z, x + y)) => Min(x + y, z)
|
||||
ExprHandle body = Min::make(x + y, Min::make(z, x + y, true), true);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Min, simplified.node(), min);
|
||||
IS_BINOP_W_VARS(Add, min->lhs(), add, "y", "x");
|
||||
IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y");
|
||||
IS_VAR_WITH_NAME(min->rhs(), "z");
|
||||
}
|
||||
|
||||
{
|
||||
// Min(Min(x + y, z), x + y) => Min(y + x, z)
|
||||
// Min(Min(x + y, z), x + y) => Min(x + y, z)
|
||||
ExprHandle body = Min::make(Min::make(x + y, z, true), x + y, true);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Min, simplified.node(), min);
|
||||
IS_BINOP_W_VARS(Add, min->lhs(), add, "y", "x");
|
||||
IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y");
|
||||
IS_VAR_WITH_NAME(min->rhs(), "z");
|
||||
}
|
||||
|
||||
{
|
||||
// Min(Min(z, x + y), x + y) => Min(y + x, z)
|
||||
// Min(Min(z, x + y), x + y) => Min(x + y, z)
|
||||
ExprHandle body = Min::make(Min::make(z, x + y, true), x + y, true);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Min, simplified.node(), min);
|
||||
IS_BINOP_W_VARS(Add, min->lhs(), add, "y", "x");
|
||||
IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y");
|
||||
IS_VAR_WITH_NAME(min->rhs(), "z");
|
||||
}
|
||||
|
||||
|
|
@ -2416,55 +2383,39 @@ TEST(Simplify, SimplifyNestedMin) {
|
|||
}
|
||||
|
||||
{
|
||||
// Min(Max(x, y), Max(x, z)) => Max(x, Min(y, z))
|
||||
// Min(Max(x, y), Max(x, z)) => Max(Min(y, z), x)
|
||||
ExprHandle body =
|
||||
Min::make(Max::make(x, y, true), Max::make(x, z, true), true);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Max, simplified.node(), max);
|
||||
IS_VAR_WITH_NAME(max->lhs(), "x");
|
||||
IS_BINOP_W_VARS(Min, max->rhs(), min, "y", "z");
|
||||
ASSERT_TRUE(min->propagate_nans());
|
||||
checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)");
|
||||
}
|
||||
|
||||
{
|
||||
// Min(Max(x, y), Max(z, x)) => Max(x, Min(y, z))
|
||||
// Min(Max(x, y), Max(z, x)) => Max(Min(y, z), x)
|
||||
ExprHandle body =
|
||||
Min::make(Max::make(x, y, true), Max::make(z, x, true), true);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Max, simplified.node(), max);
|
||||
IS_VAR_WITH_NAME(max->lhs(), "x");
|
||||
IS_BINOP_W_VARS(Min, max->rhs(), min, "y", "z");
|
||||
ASSERT_TRUE(min->propagate_nans());
|
||||
checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)");
|
||||
}
|
||||
|
||||
{
|
||||
// Min(Max(y, x), Max(x, z)) => Max(x, Min(y, z))
|
||||
// Min(Max(y, x), Max(x, z)) => Max(Min(y, z), x)
|
||||
ExprHandle body =
|
||||
Min::make(Max::make(y, x, true), Max::make(x, z, true), true);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Max, simplified.node(), max);
|
||||
IS_VAR_WITH_NAME(max->lhs(), "x");
|
||||
IS_BINOP_W_VARS(Min, max->rhs(), min, "y", "z");
|
||||
ASSERT_TRUE(min->propagate_nans());
|
||||
checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)");
|
||||
}
|
||||
|
||||
{
|
||||
// Min(Max(y, x), Max(z, x)) => Max(x, Min(y, z))
|
||||
// Min(Max(y, x), Max(z, x)) => Max(Min(y, z), x)
|
||||
ExprHandle body =
|
||||
Min::make(Max::make(y, x, true), Max::make(z, x, true), true);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Max, simplified.node(), max);
|
||||
IS_VAR_WITH_NAME(max->lhs(), "x");
|
||||
IS_BINOP_W_VARS(Min, max->rhs(), min, "y", "z");
|
||||
ASSERT_TRUE(min->propagate_nans());
|
||||
checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)");
|
||||
}
|
||||
|
||||
{
|
||||
// Min(Max(y, x), Max(z, x)) => Min(Max(x, z), Max(x, y))
|
||||
// Min(Max(y, x), Max(z, x)) => Min(Max(x, y), Max(x, z))
|
||||
// When all the ops in the pattern do not have the same propagate_nans,
|
||||
// it should not be simplified.
|
||||
ExprHandle body =
|
||||
|
|
@ -2472,10 +2423,10 @@ TEST(Simplify, SimplifyNestedMin) {
|
|||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Min, simplified.node(), min);
|
||||
IS_BINOP_W_VARS(Max, min->lhs(), max1, "x", "z");
|
||||
ASSERT_FALSE(max1->propagate_nans());
|
||||
IS_BINOP_W_VARS(Max, min->rhs(), max2, "x", "y");
|
||||
ASSERT_TRUE(max2->propagate_nans());
|
||||
IS_BINOP_W_VARS(Max, min->lhs(), max1, "x", "y");
|
||||
ASSERT_TRUE(max1->propagate_nans());
|
||||
IS_BINOP_W_VARS(Max, min->rhs(), max2, "x", "z");
|
||||
ASSERT_FALSE(max2->propagate_nans());
|
||||
ASSERT_TRUE(min->propagate_nans());
|
||||
}
|
||||
|
||||
|
|
@ -2600,7 +2551,7 @@ TEST(Simplify, SimplifyNestedMin) {
|
|||
}
|
||||
|
||||
{
|
||||
// Min(Min(Min(Min(z, 5), y), x), 8) => Min(Min(x, Min(Min(z, 5), y)), 8)
|
||||
// Min(Min(Min(Min(z, 5), y), x), 8) => Min(Min(Min(Min(z, 5), y), x), 8)
|
||||
// Do not simplify when all the Min ops do not have the same
|
||||
// propagate_nans.
|
||||
ExprHandle body = Min::make(
|
||||
|
|
@ -2608,18 +2559,7 @@ TEST(Simplify, SimplifyNestedMin) {
|
|||
8,
|
||||
false);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Min, simplified.node(), min1);
|
||||
IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
|
||||
IS_VAR_WITH_NAME(min2->lhs(), "x");
|
||||
IS_NODE_WITH_NAME(Min, min2->rhs(), min3);
|
||||
IS_BINOP_W_CONST(Min, min3->lhs(), min4, "z", 5);
|
||||
ASSERT_TRUE(min4->propagate_nans());
|
||||
IS_VAR_WITH_NAME(min3->rhs(), "y");
|
||||
ASSERT_FALSE(min3->propagate_nans());
|
||||
ASSERT_TRUE(min2->propagate_nans());
|
||||
IS_IMM_WITH_VAL(Int, min1->rhs(), 8);
|
||||
ASSERT_FALSE(min1->propagate_nans());
|
||||
checkExprIR(simplified, "Min(Min(Min(Min(z, 5, 1), y, 0), x, 1), 8, 0)");
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -2922,16 +2862,7 @@ TEST(Simplify, SimplifyRoundModPattern) {
|
|||
VarHandle z("z", kInt);
|
||||
ExprHandle body = ((x / y) * y) + (x % z);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Add, simplified.node(), add);
|
||||
IS_NODE_WITH_NAME(Mul, add->lhs(), roundMul);
|
||||
IS_VAR_WITH_NAME(roundMul->lhs(), "y");
|
||||
IS_NODE_WITH_NAME(Div, roundMul->rhs(), roundDiv);
|
||||
IS_VAR_WITH_NAME(roundDiv->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(roundDiv->rhs(), "y");
|
||||
IS_NODE_WITH_NAME(Mod, add->rhs(), mod);
|
||||
IS_VAR_WITH_NAME(mod->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(mod->rhs(), "z");
|
||||
checkExprIR(simplified, "(x / y) * y + x % z");
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -2941,15 +2872,7 @@ TEST(Simplify, SimplifyRoundModPattern) {
|
|||
VarHandle z("z", kInt);
|
||||
ExprHandle body = (y * (x / z)) + (x % y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_NODE_WITH_NAME(Add, simplified.node(), add);
|
||||
IS_NODE_WITH_NAME(Mul, add->lhs(), roundMul);
|
||||
IS_VAR_WITH_NAME(roundMul->lhs(), "y");
|
||||
IS_NODE_WITH_NAME(Div, roundMul->rhs(), roundDiv);
|
||||
IS_VAR_WITH_NAME(roundDiv->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(roundDiv->rhs(), "z");
|
||||
IS_NODE_WITH_NAME(Mod, add->rhs(), mod);
|
||||
IS_VAR_WITH_NAME(mod->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(mod->rhs(), "y");
|
||||
checkExprIR(simplified, "x % y + (x / z) * y");
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -2959,15 +2882,7 @@ TEST(Simplify, SimplifyRoundModPattern) {
|
|||
VarHandle z("z", kInt);
|
||||
ExprHandle body = ((x / y) * z) + (x % y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_NODE_WITH_NAME(Add, simplified.node(), add);
|
||||
IS_NODE_WITH_NAME(Mul, add->lhs(), roundMul);
|
||||
IS_VAR_WITH_NAME(roundMul->lhs(), "z");
|
||||
IS_NODE_WITH_NAME(Div, roundMul->rhs(), roundDiv);
|
||||
IS_VAR_WITH_NAME(roundDiv->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(roundDiv->rhs(), "y");
|
||||
IS_NODE_WITH_NAME(Mod, add->rhs(), mod);
|
||||
IS_VAR_WITH_NAME(mod->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(mod->rhs(), "y");
|
||||
checkExprIR(simplified, "x % y + (x / y) * z");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3036,20 +2951,20 @@ TEST(Simplify, SimplifyRoundModPatternMultivar) {
|
|||
|
||||
{
|
||||
// Multivar.
|
||||
// (x/8) * 8 + (y/5)*5 + x%8 + y%5 => y + x.
|
||||
// (x/8) * 8 + (y/5)*5 + x%8 + y%5 => x + y.
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
ExprHandle body = (x / ExprHandle(8) * ExprHandle(8)) +
|
||||
(y / ExprHandle(5) * ExprHandle(5)) + (x % 8) + (y % 5);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_NODE_WITH_NAME(Add, simplified.node(), add);
|
||||
IS_VAR_WITH_NAME(add->lhs(), "y");
|
||||
IS_VAR_WITH_NAME(add->rhs(), "x");
|
||||
IS_VAR_WITH_NAME(add->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(add->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
// Find the right var.
|
||||
// (y/8) * 8 x%8 + y%8 + z%8 => z%8 + x%8 + y
|
||||
// (y/8) * 8 x%8 + y%8 + z%8 => x%8 + y + z%8
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
VarHandle z("z", kInt);
|
||||
|
|
@ -3075,16 +2990,9 @@ TEST(Simplify, SimplifyRoundModPatternMultivar) {
|
|||
VarHandle y("y", kInt);
|
||||
VarHandle z("z", kInt);
|
||||
|
||||
ExprHandle body = x + (z + ExprHandle(512) * y) % ExprHandle(16) +
|
||||
ExprHandle(16) * ((z + ExprHandle(512) * y) / ExprHandle(16));
|
||||
ExprHandle body = x + (z + y * 512) % 16 + ((z + y * 512) / 16 * 16);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_NODE_WITH_NAME(Add, simplified.node(), add);
|
||||
IS_VAR_WITH_NAME(add->rhs(), "x");
|
||||
IS_NODE_WITH_NAME(Add, add->lhs(), add2);
|
||||
IS_VAR_WITH_NAME(add2->lhs(), "z");
|
||||
IS_NODE_WITH_NAME(Mul, add2->rhs(), mul);
|
||||
IS_IMM_WITH_VAL(Int, mul->lhs(), 512);
|
||||
IS_VAR_WITH_NAME(mul->rhs(), "y");
|
||||
checkExprIR(simplified, "x + (z + 512 * y)");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3135,13 +3043,7 @@ TEST(Simplify, SimplifyModRoundModPattern) {
|
|||
VarHandle k("k", kInt);
|
||||
ExprHandle body = (k * t / x % y) * x + k * t % x;
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
|
||||
IS_NODE_WITH_NAME(Mul, mod->lhs(), mul1);
|
||||
IS_VAR_WITH_NAME(mul1->lhs(), "t");
|
||||
IS_VAR_WITH_NAME(mul1->rhs(), "k");
|
||||
IS_NODE_WITH_NAME(Mul, mod->rhs(), mul2);
|
||||
IS_VAR_WITH_NAME(mul2->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(mul2->rhs(), "y");
|
||||
checkExprIR(simplified, "(k * t) % (x * y)");
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -3259,11 +3161,7 @@ TEST(Simplify, SimplifyModRoundModPatternMultivar) {
|
|||
VarHandle t("t", kInt);
|
||||
ExprHandle body = (t / 7 % 9) * 7 + t % 7 + t;
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_NODE_WITH_NAME(Add, simplified.node(), add);
|
||||
IS_NODE_WITH_NAME(Mod, add->rhs(), mod);
|
||||
IS_VAR_WITH_NAME(mod->lhs(), "t");
|
||||
IS_IMM_WITH_VAL(Int, mod->rhs(), 63);
|
||||
IS_VAR_WITH_NAME(add->lhs(), "t");
|
||||
checkExprIR(simplified, "t % 63 + t");
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -3306,19 +3204,7 @@ TEST(Simplify, SimplifyModRoundModPatternMultivar) {
|
|||
VarHandle k("k", kInt);
|
||||
ExprHandle body = (t / x % y) * x + t % x + (t / k / x % y) * x + t / k % x;
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_NODE_WITH_NAME(Add, simplified.node(), add);
|
||||
IS_NODE_WITH_NAME(Mod, add->lhs(), mod);
|
||||
IS_VAR_WITH_NAME(mod->lhs(), "t");
|
||||
IS_NODE_WITH_NAME(Mul, mod->rhs(), mul);
|
||||
IS_VAR_WITH_NAME(mul->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(mul->rhs(), "y");
|
||||
IS_NODE_WITH_NAME(Mod, add->rhs(), mod2);
|
||||
IS_NODE_WITH_NAME(Div, mod2->lhs(), div);
|
||||
IS_VAR_WITH_NAME(div->lhs(), "t");
|
||||
IS_VAR_WITH_NAME(div->rhs(), "k");
|
||||
IS_NODE_WITH_NAME(Mul, mod2->rhs(), mul2);
|
||||
IS_VAR_WITH_NAME(mul2->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(mul2->rhs(), "y");
|
||||
checkExprIR(simplified, "(t / k) % (x * y) + t % (x * y)");
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -3971,7 +3857,7 @@ TEST(Simplify, SimplifyForWontLoseLoopOptions) {
|
|||
BufHandle c("C", {4}, kInt);
|
||||
VarHandle i("i", kInt);
|
||||
LoopOptions options;
|
||||
options.set_gpu_block_index(12);
|
||||
options.set_gpu_block_index(LoopOptions::IDX_W);
|
||||
auto body =
|
||||
For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})), options);
|
||||
StmtPtr simplified = IRSimplifier::simplify(body);
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include <test/cpp/tensorexpr/test_base.h>
|
||||
#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
|
||||
namespace torch {
|
||||
|
|
@ -69,5 +70,9 @@ using namespace torch::jit::tensorexpr;
|
|||
ASSERT_EQ(node_->op_type(), kRand); \
|
||||
}
|
||||
|
||||
void checkIR(StmtPtr s, const std::string& pattern);
|
||||
void checkExprIR(ExprPtr e, const std::string& pattern);
|
||||
void checkExprIR(const ExprHandle& e, const std::string& pattern);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -35,8 +35,15 @@ void Term::sort() {
|
|||
if (dtype().is_floating_point()) {
|
||||
throw std::logic_error("reordering FP ops");
|
||||
}
|
||||
std::unordered_map<ExprPtr, std::string> str_repr_cache;
|
||||
std::sort(variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) {
|
||||
return hasher_.hash(a) < hasher_.hash(b);
|
||||
if (!str_repr_cache.count(a)) {
|
||||
str_repr_cache[a] = std::to_string(a);
|
||||
}
|
||||
if (!str_repr_cache.count(b)) {
|
||||
str_repr_cache[b] = std::to_string(b);
|
||||
}
|
||||
return str_repr_cache.at(a) < str_repr_cache.at(b);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -52,8 +59,15 @@ void Polynomial::sort() {
|
|||
if (dtype().is_floating_point()) {
|
||||
throw std::logic_error("reordering FP ops");
|
||||
}
|
||||
std::unordered_map<ExprPtr, std::string> str_repr_cache;
|
||||
std::sort(variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) {
|
||||
return hasher_.hash(a) < hasher_.hash(b);
|
||||
if (!str_repr_cache.count(a)) {
|
||||
str_repr_cache[a] = std::to_string(a);
|
||||
}
|
||||
if (!str_repr_cache.count(b)) {
|
||||
str_repr_cache[b] = std::to_string(b);
|
||||
}
|
||||
return str_repr_cache.at(a) < str_repr_cache.at(b);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -66,6 +80,18 @@ void MaxTerm::uniquefy() {
|
|||
return hasher_.hash(a) == hasher_.hash(b);
|
||||
});
|
||||
variables_.resize(std::distance(variables_.begin(), it));
|
||||
|
||||
// Once we removed duplicates, sort terms alphabetically for stability.
|
||||
std::unordered_map<ExprPtr, std::string> str_repr_cache;
|
||||
std::sort(variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) {
|
||||
if (!str_repr_cache.count(a)) {
|
||||
str_repr_cache[a] = std::to_string(a);
|
||||
}
|
||||
if (!str_repr_cache.count(b)) {
|
||||
str_repr_cache[b] = std::to_string(b);
|
||||
}
|
||||
return str_repr_cache.at(a) < str_repr_cache.at(b);
|
||||
});
|
||||
}
|
||||
|
||||
void MinTerm::uniquefy() {
|
||||
|
|
@ -77,6 +103,18 @@ void MinTerm::uniquefy() {
|
|||
return hasher_.hash(a) == hasher_.hash(b);
|
||||
});
|
||||
variables_.resize(std::distance(variables_.begin(), it));
|
||||
|
||||
// Once we removed duplicates, sort terms alphabetically for stability.
|
||||
std::unordered_map<ExprPtr, std::string> str_repr_cache;
|
||||
std::sort(variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) {
|
||||
if (!str_repr_cache.count(a)) {
|
||||
str_repr_cache[a] = std::to_string(a);
|
||||
}
|
||||
if (!str_repr_cache.count(b)) {
|
||||
str_repr_cache[b] = std::to_string(b);
|
||||
}
|
||||
return str_repr_cache.at(a) < str_repr_cache.at(b);
|
||||
});
|
||||
}
|
||||
|
||||
// Handles optimization cases for Broadcast/Ramp +/- Broadcast/Ramp
|
||||
|
|
@ -2076,8 +2114,20 @@ ExprPtr TermExpander::mutate(PolynomialPtr v) {
|
|||
std::vector<TermPtr> addTerms;
|
||||
std::vector<TermPtr> subTerms;
|
||||
|
||||
auto vars = v->variables();
|
||||
std::unordered_map<ExprPtr, std::string> str_repr_cache;
|
||||
std::sort(vars.begin(), vars.end(), [&](ExprPtr a, ExprPtr b) {
|
||||
if (!str_repr_cache.count(a)) {
|
||||
str_repr_cache[a] = std::to_string(a);
|
||||
}
|
||||
if (!str_repr_cache.count(b)) {
|
||||
str_repr_cache[b] = std::to_string(b);
|
||||
}
|
||||
return str_repr_cache.at(a) < str_repr_cache.at(b);
|
||||
});
|
||||
|
||||
// partition the terms into a list to add and list to subtract.
|
||||
for (auto node : v->variables()) {
|
||||
for (auto node : vars) {
|
||||
if (immediateIsNegative(node->scalar())) {
|
||||
subTerms.push_back(node);
|
||||
} else if (!immediateEquals(node->scalar(), 0)) {
|
||||
|
|
@ -2822,6 +2872,49 @@ bool exprEquals(ExprPtr A, ExprPtr B) {
|
|||
}
|
||||
}
|
||||
|
||||
ExprPtr IRSimplifier::simplify(ExprPtr e) {
|
||||
GRAPH_DEBUG("(Simplifier) Original: ", std::to_string(e));
|
||||
SimplifierUnderContext ctxsimplifier;
|
||||
e = e->accept_mutator(&ctxsimplifier);
|
||||
|
||||
PolynomialTransformer simplifier;
|
||||
e = e->accept_mutator(&simplifier);
|
||||
|
||||
// There may be terms left in the IR, expand them.
|
||||
TermExpander expander(&simplifier);
|
||||
e = e->accept_mutator(&expander);
|
||||
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
if (!expander.check_safe()) {
|
||||
throw malformed_input("eliminated null Allocation without free");
|
||||
}
|
||||
|
||||
GRAPH_DEBUG("(Simplifier) Simplified: ", std::to_string(e));
|
||||
return e;
|
||||
}
|
||||
|
||||
StmtPtr IRSimplifier::simplify(StmtPtr s) {
|
||||
GRAPH_DEBUG("(Simplifier) Original: ", std::to_string(s));
|
||||
SimplifierUnderContext ctxsimplifier;
|
||||
s = s->accept_mutator(&ctxsimplifier);
|
||||
|
||||
PolynomialTransformer simplifier;
|
||||
s = s->accept_mutator(&simplifier);
|
||||
if (s == nullptr) {
|
||||
GRAPH_DEBUG("(Simplifier) Simplified: NULL");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// There may be terms left in the IR, expand them.
|
||||
TermExpander expander(&simplifier);
|
||||
s = s->accept_mutator(&expander);
|
||||
if (!expander.check_safe()) {
|
||||
throw malformed_input("eliminated null Allocation without free");
|
||||
}
|
||||
|
||||
GRAPH_DEBUG("(Simplifier) Simplified: ", std::to_string(s));
|
||||
return s;
|
||||
}
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -596,47 +596,11 @@ class TORCH_API TermExpander : public PolynomialBase {
|
|||
|
||||
class TORCH_API IRSimplifier {
|
||||
public:
|
||||
static ExprPtr simplify(ExprPtr e) {
|
||||
SimplifierUnderContext ctxsimplifier;
|
||||
e = e->accept_mutator(&ctxsimplifier);
|
||||
|
||||
PolynomialTransformer simplifier;
|
||||
e = e->accept_mutator(&simplifier);
|
||||
|
||||
// There may be terms left in the IR, expand them.
|
||||
TermExpander expander(&simplifier);
|
||||
e = e->accept_mutator(&expander);
|
||||
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
if (!expander.check_safe()) {
|
||||
throw malformed_input("eliminated null Allocation without free");
|
||||
}
|
||||
|
||||
return e;
|
||||
}
|
||||
|
||||
static StmtPtr simplify(StmtPtr s);
|
||||
static ExprPtr simplify(ExprPtr e);
|
||||
static ExprHandle simplify(const ExprHandle& e) {
|
||||
return ExprHandle(simplify(e.node()));
|
||||
}
|
||||
|
||||
static StmtPtr simplify(StmtPtr s) {
|
||||
SimplifierUnderContext ctxsimplifier;
|
||||
s = s->accept_mutator(&ctxsimplifier);
|
||||
|
||||
PolynomialTransformer simplifier;
|
||||
s = s->accept_mutator(&simplifier);
|
||||
if (s == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// There may be terms left in the IR, expand them.
|
||||
TermExpander expander(&simplifier);
|
||||
s = s->accept_mutator(&expander);
|
||||
if (!expander.check_safe()) {
|
||||
throw malformed_input("eliminated null Allocation without free");
|
||||
}
|
||||
|
||||
return s;
|
||||
}
|
||||
};
|
||||
|
||||
// Flattens the buf and performs the simplifier on the flattened dims.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user