mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[StableHLO Optim] Add CompareOp patterns and dont fold large converts.
PiperOrigin-RevId: 825236477
This commit is contained in:
parent
1e0c214c2f
commit
03f4c66dd1
|
|
@ -248,12 +248,10 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.ml
|
|||
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
|
||||
--- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
|
||||
+++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
|
||||
@@ -132,6 +132,35 @@
|
||||
|
||||
// CHECK-NEXT: return [[R0]], [[R5]]
|
||||
@@ -134,6 +134,35 @@
|
||||
return %0, %5 : tensor<1x3x6xi32>, tensor<3x6x1xi32>
|
||||
+}
|
||||
+
|
||||
}
|
||||
|
||||
+// CHECK-LABEL: func.func @broadcast_in_dim_prefer_nested_reshape
|
||||
+// CHECK-SAME: ([[ARG0:%[^ ]+]]: tensor<3x4xi32>)
|
||||
+func.func @broadcast_in_dim_prefer_nested_reshape(%arg0: tensor<3x4xi32>) -> (tensor<2x3x4x3xi32>, tensor<2x3x4x3xi32>) {
|
||||
|
|
@ -281,10 +279,31 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplific
|
|||
+
|
||||
+ // CHECK-DAG: return [[BROADCAST_OF_RESHAPE]], [[MERGED_BROADCAST]]
|
||||
+ return %1, %3 : tensor<2x3x4x3xi32>, tensor<2x3x4x3xi32>
|
||||
+}
|
||||
+
|
||||
// CHECK-LABEL: func.func @broadcast_in_dim_not_identity_broadcasts
|
||||
func.func @broadcast_in_dim_not_identity_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: stablehlo.broadcast_in_dim
|
||||
@@ -208,6 +237,18 @@
|
||||
// CHECK-NEXT: return [[C1]], [[C0]], [[C1]], [[C0]], [[R0]], [[R1]], [[R2]], [[R3]]
|
||||
return %0, %1, %2, %3, %4, %5, %6, %7 :
|
||||
tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>
|
||||
+}
|
||||
+
|
||||
+// CHECK-LABEL: func.func @compare_op_bool_simplify
|
||||
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<i1>)
|
||||
+func.func @compare_op_bool_simplify(%arg0: tensor<i1>) -> (tensor<i1>, tensor<i1>) {
|
||||
+ %false = stablehlo.constant dense<false> : tensor<i1>
|
||||
+ %true = stablehlo.constant dense<true> : tensor<i1>
|
||||
+ // CHECK-NOT: stablehlo.compare
|
||||
+ %0 = stablehlo.compare NE, %arg0, %false, UNSIGNED : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
+ %1 = stablehlo.compare EQ, %arg0, %true, UNSIGNED : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
+ // CHECK: return [[ARG0]], [[ARG0]]
|
||||
+ func.return %0, %1 : tensor<i1>, tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @broadcast_in_dim_not_identity_broadcasts
|
||||
@@ -1021,6 +1050,18 @@
|
||||
// -----
|
||||
@@ -1021,6 +1062,18 @@
|
||||
// CHECK-NOT: stablehlo.pad
|
||||
%1 = stablehlo.pad %arg0, %0, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<256x1024xbf16>, tensor<bf16>) -> tensor<256x1024xbf16>
|
||||
return %1 : tensor<256x1024xbf16>
|
||||
|
|
@ -303,7 +322,7 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplific
|
|||
}
|
||||
|
||||
// -----
|
||||
@@ -1810,6 +1851,15 @@
|
||||
@@ -1810,6 +1863,15 @@
|
||||
return %0 : tensor<2x4x1x5xf32>
|
||||
}
|
||||
|
||||
|
|
@ -392,7 +411,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
|||
return failure();
|
||||
|
||||
SplatElementsAttr cstAttr;
|
||||
@@ -1104,7 +1110,7 @@
|
||||
@@ -825,7 +831,8 @@
|
||||
RankedTensorType resultType = op.getType();
|
||||
|
||||
if (failed(validateStaticShapeResult(rewriter, op, resultType)) ||
|
||||
- failed(validateShapeFoldDtype(rewriter, op, resultType)))
|
||||
+ failed(validateShapeFoldDtype(rewriter, op, resultType)) ||
|
||||
+ failed(validateElementCountForFold(rewriter, op, resultType)))
|
||||
return failure();
|
||||
|
||||
auto operandElemType = getElementTypeOrSelf(operand.getType());
|
||||
@@ -1104,7 +1111,7 @@
|
||||
failed(validateShapeFoldDtype(rewriter, op, resultType)))
|
||||
return failure();
|
||||
|
||||
|
|
@ -404,7 +433,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
|||
diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
|
||||
--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
|
||||
+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
|
||||
@@ -1309,6 +1309,17 @@
|
||||
@@ -1309,10 +1309,20 @@
|
||||
// TransposeOp
|
||||
/////////////////////////////////
|
||||
|
||||
|
|
@ -422,6 +451,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimp
|
|||
// Pattern: transpose(X, [no_mem_layout_change...]) -> reshape(X)
|
||||
struct TransposeIsReshape final : SimplifyOpRewritePattern<TransposeOp> {
|
||||
using SimplifyOpRewritePattern::SimplifyOpRewritePattern;
|
||||
-
|
||||
LogicalResult matchAndRewrite(TransposeOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
auto input = op.getOperand();
|
||||
diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
|
||||
--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
|
||||
+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
|
||||
|
|
@ -488,7 +521,44 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimp
|
|||
|
||||
// Pattern: broadcast_in_dim(X, [dims...]) -> transpose(X, [dims...])
|
||||
// [if same numel & rank]
|
||||
@@ -424,9 +443,9 @@
|
||||
@@ -197,6 +216,36 @@
|
||||
: Pat<(StableHLO_BroadcastInDimOp:$op $operand, $dims),
|
||||
(StableHLO_TransposeOp $operand, (InvertBroadcastDims $dims)),
|
||||
[(NumberOfElementsEqual $op, $operand), (RankEqual $op, $operand)]>;
|
||||
+
|
||||
+////////
|
||||
+// CompareOp
|
||||
+
|
||||
+// The canonical form has the constant operand as the RHS.
|
||||
+class StableHLO_ComparisonDirectionValue<string enumStr> :
|
||||
+ ConstantAttr<StableHLO_ComparisonDirectionAttr, "::mlir::stablehlo::ComparisonDirection::" # enumStr>;
|
||||
+
|
||||
+// Pattern: compare(NE, X, False) : i1 -> X
|
||||
+def CompareOp_NeBooleanFalse
|
||||
+ : Pat<(StableHLO_CompareOp
|
||||
+ $lhs,
|
||||
+ (StableHLO_ConstantOp:$cst IntZero:$value),
|
||||
+ StableHLO_ComparisonDirectionValue<"NE">,
|
||||
+ $type),
|
||||
+ (replaceWithValue $lhs),
|
||||
+ [(HLO_PredTensor $cst)]>;
|
||||
+
|
||||
+// Pattern: compare(EQ, X, True) : i1 -> X
|
||||
+def CompareOp_EqBooleanTrue
|
||||
+ : Pat<(StableHLO_CompareOp
|
||||
+ $lhs,
|
||||
+ (StableHLO_ConstantOp:$cst IntOne:$value),
|
||||
+ StableHLO_ComparisonDirectionValue<"EQ">,
|
||||
+ $type),
|
||||
+ (replaceWithValue $lhs),
|
||||
+ [(HLO_PredTensor $cst)]>;
|
||||
+
|
||||
+// TODO: compare(EQ, X, False) : i1 -> not(X)
|
||||
+// TODO: compare(NE, X, True) : i1 -> not(X)
|
||||
|
||||
////////
|
||||
// ConvertOp
|
||||
@@ -424,9 +473,9 @@
|
||||
: Pat<(StableHLO_PadOp:$pad
|
||||
$operand,
|
||||
$padding_value,
|
||||
|
|
@ -501,7 +571,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimp
|
|||
(replaceWithValue $operand),
|
||||
[(TypesEqual $pad, $operand)]>;
|
||||
|
||||
@@ -539,6 +558,12 @@
|
||||
@@ -539,6 +588,12 @@
|
||||
: Pat<(StableHLO_TransposeOp $lhs, IotaDims:$dims),
|
||||
(replaceWithValue $lhs)>;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user