[StableHLO Optim] Add CompareOp patterns and dont fold large converts.

PiperOrigin-RevId: 825236477
This commit is contained in:
Kevin Gleason 2025-10-28 15:42:38 -07:00 committed by TensorFlower Gardener
parent 1e0c214c2f
commit 03f4c66dd1

View File

@ -248,12 +248,10 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.ml
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
--- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
+++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
@@ -132,6 +132,35 @@
// CHECK-NEXT: return [[R0]], [[R5]]
@@ -134,6 +134,35 @@
return %0, %5 : tensor<1x3x6xi32>, tensor<3x6x1xi32>
+}
+
}
+// CHECK-LABEL: func.func @broadcast_in_dim_prefer_nested_reshape
+// CHECK-SAME: ([[ARG0:%[^ ]+]]: tensor<3x4xi32>)
+func.func @broadcast_in_dim_prefer_nested_reshape(%arg0: tensor<3x4xi32>) -> (tensor<2x3x4x3xi32>, tensor<2x3x4x3xi32>) {
@ -281,10 +279,31 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplific
+
+ // CHECK-DAG: return [[BROADCAST_OF_RESHAPE]], [[MERGED_BROADCAST]]
+ return %1, %3 : tensor<2x3x4x3xi32>, tensor<2x3x4x3xi32>
+}
+
// CHECK-LABEL: func.func @broadcast_in_dim_not_identity_broadcasts
func.func @broadcast_in_dim_not_identity_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> {
// CHECK: stablehlo.broadcast_in_dim
@@ -208,6 +237,18 @@
// CHECK-NEXT: return [[C1]], [[C0]], [[C1]], [[C0]], [[R0]], [[R1]], [[R2]], [[R3]]
return %0, %1, %2, %3, %4, %5, %6, %7 :
tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>
+}
+
+// CHECK-LABEL: func.func @compare_op_bool_simplify
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<i1>)
+func.func @compare_op_bool_simplify(%arg0: tensor<i1>) -> (tensor<i1>, tensor<i1>) {
+ %false = stablehlo.constant dense<false> : tensor<i1>
+ %true = stablehlo.constant dense<true> : tensor<i1>
+ // CHECK-NOT: stablehlo.compare
+ %0 = stablehlo.compare NE, %arg0, %false, UNSIGNED : (tensor<i1>, tensor<i1>) -> tensor<i1>
+ %1 = stablehlo.compare EQ, %arg0, %true, UNSIGNED : (tensor<i1>, tensor<i1>) -> tensor<i1>
+ // CHECK: return [[ARG0]], [[ARG0]]
+ func.return %0, %1 : tensor<i1>, tensor<i1>
}
// CHECK-LABEL: func.func @broadcast_in_dim_not_identity_broadcasts
@@ -1021,6 +1050,18 @@
// -----
@@ -1021,6 +1062,18 @@
// CHECK-NOT: stablehlo.pad
%1 = stablehlo.pad %arg0, %0, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<256x1024xbf16>, tensor<bf16>) -> tensor<256x1024xbf16>
return %1 : tensor<256x1024xbf16>
@ -303,7 +322,7 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplific
}
// -----
@@ -1810,6 +1851,15 @@
@@ -1810,6 +1863,15 @@
return %0 : tensor<2x4x1x5xf32>
}
@ -392,7 +411,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
return failure();
SplatElementsAttr cstAttr;
@@ -1104,7 +1110,7 @@
@@ -825,7 +831,8 @@
RankedTensorType resultType = op.getType();
if (failed(validateStaticShapeResult(rewriter, op, resultType)) ||
- failed(validateShapeFoldDtype(rewriter, op, resultType)))
+ failed(validateShapeFoldDtype(rewriter, op, resultType)) ||
+ failed(validateElementCountForFold(rewriter, op, resultType)))
return failure();
auto operandElemType = getElementTypeOrSelf(operand.getType());
@@ -1104,7 +1111,7 @@
failed(validateShapeFoldDtype(rewriter, op, resultType)))
return failure();
@ -404,7 +433,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
@@ -1309,6 +1309,17 @@
@@ -1309,10 +1309,20 @@
// TransposeOp
/////////////////////////////////
@ -422,6 +451,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimp
// Pattern: transpose(X, [no_mem_layout_change...]) -> reshape(X)
struct TransposeIsReshape final : SimplifyOpRewritePattern<TransposeOp> {
using SimplifyOpRewritePattern::SimplifyOpRewritePattern;
-
LogicalResult matchAndRewrite(TransposeOp op,
PatternRewriter& rewriter) const override {
auto input = op.getOperand();
diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
@ -488,7 +521,44 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimp
// Pattern: broadcast_in_dim(X, [dims...]) -> transpose(X, [dims...])
// [if same numel & rank]
@@ -424,9 +443,9 @@
@@ -197,6 +216,36 @@
: Pat<(StableHLO_BroadcastInDimOp:$op $operand, $dims),
(StableHLO_TransposeOp $operand, (InvertBroadcastDims $dims)),
[(NumberOfElementsEqual $op, $operand), (RankEqual $op, $operand)]>;
+
+////////
+// CompareOp
+
+// The canonical form has the constant operand as the RHS.
+class StableHLO_ComparisonDirectionValue<string enumStr> :
+ ConstantAttr<StableHLO_ComparisonDirectionAttr, "::mlir::stablehlo::ComparisonDirection::" # enumStr>;
+
+// Pattern: compare(NE, X, False) : i1 -> X
+def CompareOp_NeBooleanFalse
+ : Pat<(StableHLO_CompareOp
+ $lhs,
+ (StableHLO_ConstantOp:$cst IntZero:$value),
+ StableHLO_ComparisonDirectionValue<"NE">,
+ $type),
+ (replaceWithValue $lhs),
+ [(HLO_PredTensor $cst)]>;
+
+// Pattern: compare(EQ, X, True) : i1 -> X
+def CompareOp_EqBooleanTrue
+ : Pat<(StableHLO_CompareOp
+ $lhs,
+ (StableHLO_ConstantOp:$cst IntOne:$value),
+ StableHLO_ComparisonDirectionValue<"EQ">,
+ $type),
+ (replaceWithValue $lhs),
+ [(HLO_PredTensor $cst)]>;
+
+// TODO: compare(EQ, X, False) : i1 -> not(X)
+// TODO: compare(NE, X, True) : i1 -> not(X)
////////
// ConvertOp
@@ -424,9 +473,9 @@
: Pat<(StableHLO_PadOp:$pad
$operand,
$padding_value,
@ -501,7 +571,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimp
(replaceWithValue $operand),
[(TypesEqual $pad, $operand)]>;
@@ -539,6 +558,12 @@
@@ -539,6 +588,12 @@
: Pat<(StableHLO_TransposeOp $lhs, IotaDims:$dims),
(replaceWithValue $lhs)>;