[XLA:CPU] Only add reassoc flag to reductions with a single floating point op.

PiperOrigin-RevId: 822598746
This commit is contained in:
Will Froom 2025-10-22 08:23:34 -07:00 committed by TensorFlower Gardener
parent bbea04967a
commit 3353eeeab7
2 changed files with 35 additions and 1 deletions

View File

@ -55,6 +55,13 @@ struct RewriteCallPattern
return rewriter.notifyMatchFailure(call_op, "Could not resolve callee.");
}
// Adding reassoc flags to reductions with more than one fast math op
// can result in unexpected behaviour as they can reassociate between
// themselves.
if (FastMathOpCount(callee) > 1) {
return rewriter.notifyMatchFailure(call_op, "Too many fast math ops.");
}
callee->walk([&rewriter](mlir::Operation* op) {
if (auto fm_op =
mlir::dyn_cast_or_null<mlir::arith::ArithFastMathInterface>(op)) {
@ -74,6 +81,13 @@ struct RewriteCallPattern
return mlir::success();
}
private:
static int FastMathOpCount(mlir::func::FuncOp callee) {
int count = 0;
callee.walk([&](mlir::arith::ArithFastMathInterface op) { count++; });
return count;
}
};
class AddReductionFastMathFlagsPass

View File

@ -14,4 +14,24 @@ func.func @reducer(%x: f32, %y: f32) -> f32
// CHECK-LABEL: func.func @caller
// CHECK-LABEL: func.func @reducer
// CHECK arith.addf {{.*}} fastmath<reassoc> : f32
// CHECK: arith.addf {{.*}} fastmath<reassoc> : f32
// -----
func.func @caller(%x: f32, %y: f32) -> f32
{
%z = func.call @reducer(%x, %y) { xla.is_reduction }: (f32, f32) -> f32
func.return %z : f32
}
func.func @reducer(%x: f32, %y: f32) -> f32
{
%w = arith.addf %x, %y : f32
%z = arith.mulf %w, %y : f32
func.return %z : f32
}
// CHECK-LABEL: func.func @caller
// CHECK-LABEL: func.func @reducer
// CHECK-NOT: fastmath