[XLA:CPU] Fix UnrollExtractLoops to stop false-positive unroll when the vector itself is dependent on for loop.

PiperOrigin-RevId: 825427562
This commit is contained in:
Will Froom 2025-10-29 01:32:42 -07:00 committed by TensorFlower Gardener
parent f19413ab39
commit a8356f63df
2 changed files with 77 additions and 14 deletions

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <cassert>
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include "llvm/ADT/STLExtras.h"
@ -39,6 +40,7 @@ limitations under the License.
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/WalkResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "xla/backends/cpu/codegen/tiled/transforms/passes.h"
@ -50,6 +52,38 @@ namespace xla::cpu {
namespace {
// Check if the given extract operands depends on the given value.
bool ExtractDependsOnValue(mlir::vector::ExtractOp extract_op,
mlir::Value value) {
mlir::BackwardSliceOptions backward_slice_options;
backward_slice_options.omitUsesFromAbove = false;
backward_slice_options.inclusive = true;
for (mlir::Value dynamic_index : extract_op.getDynamicPosition()) {
// We have to explicitly check the index itself as getBackwardSlice only
// starts from the defining operation.
if (dynamic_index == value) {
return true;
}
llvm::SetVector<mlir::Operation*> backwardSlice;
if (mlir::failed(mlir::getBackwardSlice(dynamic_index, &backwardSlice,
backward_slice_options))) {
continue;
}
for (mlir::Operation* op : backwardSlice) {
for (mlir::Value operand : op->getOperands()) {
if (operand == value) {
return true;
}
}
}
}
return false;
}
struct FoldExtractIntoTransferRead
: mlir::OpRewritePattern<mlir::vector::ExtractOp> {
using OpRewritePattern::OpRewritePattern;
@ -197,23 +231,27 @@ struct UnrollExtractLoops : mlir::OpRewritePattern<mlir::scf::ForOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult matchAndRewrite(
mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const override {
if (op.getRegion().getOps<mlir::vector::ExtractOp>().empty()) {
return rewriter.notifyMatchFailure(op,
"loop does not contain an extract");
}
llvm::SetVector<mlir::Operation*> slices;
mlir::getForwardSlice(op.getInductionVar(), &slices);
for (auto slice : slices) {
if (mlir::isa<mlir::vector::ExtractOp>(slice)) {
return mlir::loopUnrollFull(op);
mlir::scf::ForOp for_op, mlir::PatternRewriter& rewriter) const override {
std::optional<mlir::LogicalResult> unroll_result;
// Walk the body of the loop and unroll if we have a dependent extract.
for_op.getBody()->walk([&](mlir::vector::ExtractOp extract) {
if (!ExtractDependsOnValue(extract, for_op.getInductionVar())) {
return mlir::WalkResult::advance();
}
unroll_result = mlir::loopUnrollFull(for_op);
return mlir::WalkResult::interrupt();
});
if (!unroll_result.has_value()) {
return rewriter.notifyMatchFailure(
for_op, "loop does not contain a dependent extract");
}
return rewriter.notifyMatchFailure(
op, "loop does not contain a dependent extract");
if (mlir::failed(*unroll_result)) {
return rewriter.notifyMatchFailure(for_op, "failed to unroll loop");
}
return mlir::success();
}
};

View File

@ -102,3 +102,28 @@ func.func @does_not_unroll_independent_vector_extract(%input: vector<8x2xf32>, %
// CHECK-LABEL: func.func @does_not_unroll_independent_vector_extract(
// CHECK: scf.for
// CHECK-COUNT-1: vector.extract
// -----
// Ensure that we don't unroll loops that have a vector that depends on the for
// loop but the index itself does not.
func.func @does_not_unroll_only_dependent_vector(
%input: memref<8x2xf32>, %arg_index: index) -> vector<2xf32> {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c8 = arith.constant 8 : index
%c0_f32 = arith.constant 0. : f32
%init = vector.broadcast %c0_f32 : f32 to vector<2xf32>
%result = scf.for %index = %c0 to %c8 step %c2 iter_args(%carry = %init) -> vector<2xf32> {
%input_vector = vector.transfer_read %input[%index, %c0], %c0_f32 : memref<8x2xf32>, vector<2x2xf32>
%extract = vector.extract %input_vector[0] : vector<2xf32> from vector<2x2xf32>
%add = arith.addf %carry, %extract : vector<2xf32>
scf.yield %add : vector<2xf32>
}
return %result : vector<2xf32>
}
// CHECK-LABEL: func.func @does_not_unroll_only_dependent_vector(
// CHECK: scf.for
// CHECK-COUNT-1: vector.extract