[XLA:CPU][XTile] Implement pass to rewrite dynamic vector extracts to static.

PiperOrigin-RevId: 824427163
This commit is contained in:
Will Froom 2025-10-27 02:46:04 -07:00 committed by TensorFlower Gardener
parent 9add8b7e61
commit dfeccf211b
5 changed files with 382 additions and 0 deletions

View File

@ -48,6 +48,7 @@ cc_library(
srcs = [
"elemental_tensor_to_vector.cc",
"lower_xtile_entry.cc",
"rewrite_dynamic_vector_extract.cc",
"shlo_to_vector.cc",
"tensor_ops_to_vector.cc",
"xtile_to_vector.cc",
@ -63,6 +64,7 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ArithOpsIncGen",
"@llvm-project//mlir:DataLayoutInterfaces",
@ -75,6 +77,7 @@ cc_library(
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFUtils",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",

View File

@ -38,6 +38,7 @@ std::unique_ptr<mlir::Pass> CreateLowerXTileEntryPass();
std::unique_ptr<mlir::Pass> CreateShloToVectorPass();
std::unique_ptr<mlir::Pass> CreateXTileToVectorPass();
std::unique_ptr<mlir::Pass> CreateTensorOpsToVectorPass();
std::unique_ptr<mlir::Pass> CreateRewriteDynamicVectorExtractPass();
#define GEN_PASS_REGISTRATION
#include "xla/backends/cpu/codegen/tiled/transforms/passes.h.inc"

View File

@ -76,3 +76,15 @@ def TensorOpsToVectorPass : Pass<"xtile-cpu-tensor-ops-to-vector",
"mlir::vector::VectorDialect",
];
}
def RewriteDynamicVectorExtractPass : Pass<"xtile-cpu-rewrite-dynamic-vector-extract",
"mlir::ModuleOp"> {
let summary = "Rewrite vector.extracts with dynamic indices.";
let constructor = "CreateRewriteDynamicVectorExtractPass()";
let dependentDialects = [
"::mlir::vector::VectorDialect",
"::mlir::memref::MemRefDialect",
];
}

View File

@ -0,0 +1,262 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cassert>
#include <cstdint>
#include <memory>
#include <utility>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "xla/backends/cpu/codegen/tiled/transforms/passes.h"
namespace xla::cpu {
#define GEN_PASS_DECL_REWRITEDYNAMICVECTOREXTRACTPASS
#define GEN_PASS_DEF_REWRITEDYNAMICVECTOREXTRACTPASS
#include "xla/backends/cpu/codegen/tiled/transforms/passes.h.inc"
namespace {
struct FoldExtractIntoTransferRead
: mlir::OpRewritePattern<mlir::vector::ExtractOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult matchAndRewrite(
mlir::vector::ExtractOp op,
mlir::PatternRewriter& rewriter) const override {
if (!op.hasDynamicPosition()) {
return rewriter.notifyMatchFailure(
op, "extract does not have dynamic position");
}
auto transfer_read_op =
op.getSource().getDefiningOp<mlir::vector::TransferReadOp>();
if (!transfer_read_op) {
return rewriter.notifyMatchFailure(op,
"source is not a transfer_read op");
}
auto vector_type = mlir::dyn_cast<mlir::VectorType>(op.getType());
if (!vector_type) {
// TODO(willfroom): Support scalars types.
return rewriter.notifyMatchFailure(op, "Output is not a vector type");
}
mlir::ValueRange transfer_read_indices = transfer_read_op.getIndices();
llvm::SmallVector<mlir::OpFoldResult> extended_positions(
op.getMixedPosition());
for (int64_t idx = extended_positions.size();
idx < transfer_read_indices.size(); ++idx) {
extended_positions.push_back(rewriter.getIndexAttr(0));
}
llvm::SmallVector<mlir::Value> new_offsets;
new_offsets.reserve(transfer_read_indices.size());
for (auto [tile_offset, extract_offset] :
llvm::zip(transfer_read_indices, extended_positions)) {
if (auto static_position =
mlir::dyn_cast<mlir::Attribute>(extract_offset)) {
new_offsets.push_back(mlir::arith::AddIOp::create(
rewriter, op.getLoc(), rewriter.getIndexType(), tile_offset,
mlir::arith::ConstantIndexOp::create(
rewriter, op.getLoc(),
mlir::cast<mlir::IntegerAttr>(static_position).getInt())));
} else {
auto dynamic_position = mlir::dyn_cast<mlir::Value>(extract_offset);
new_offsets.push_back(mlir::arith::AddIOp::create(
rewriter, op.getLoc(), rewriter.getIndexType(), tile_offset,
dynamic_position));
}
}
mlir::Value submask;
if (auto mask = transfer_read_op.getMask()) {
submask = mlir::vector::ExtractOp::create(rewriter, op.getLoc(), mask,
op.getMixedPosition());
}
int64_t rank = transfer_read_op.getBase().getType().getRank();
// Drop major dimensions which reflects the behaviour of vector::ExtractOp.
int64_t num_dropped_dims = rank - vector_type.getRank();
mlir::AffineMap new_permutation_map =
mlir::AffineMap::getFilteredIdentityMap(
rewriter.getContext(), rank, [&](mlir::AffineDimExpr expr) {
return expr.getPosition() >= num_dropped_dims;
});
llvm::SmallVector<mlir::Attribute> in_bounds(
transfer_read_op.getInBounds().begin() + num_dropped_dims,
transfer_read_op.getInBounds().end());
rewriter.replaceOpWithNewOp<mlir::vector::TransferReadOp>(
op, vector_type, transfer_read_op.getBase(), new_offsets,
new_permutation_map, transfer_read_op.getPadding(), submask,
rewriter.getArrayAttr(in_bounds));
return mlir::success();
}
};
// FoldExtractIntoTransferRead creates its own dynamic extracts if a mask is
// present, so we need to fold these.
// We do this by shifting the offset and then extracting with static indices.
struct FoldExtractIntoCreateMask
: mlir::OpRewritePattern<mlir::vector::ExtractOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult matchAndRewrite(
mlir::vector::ExtractOp op,
mlir::PatternRewriter& rewriter) const override {
if (!op.hasDynamicPosition()) {
return rewriter.notifyMatchFailure(
op, "extract does not have dynamic position");
}
auto mask_op = op.getSource().getDefiningOp<mlir::vector::CreateMaskOp>();
if (!mask_op) {
return rewriter.notifyMatchFailure(op, "source is not a create_mask op");
}
mlir::ValueRange mask_operands = mask_op.getOperands();
llvm::SmallVector<mlir::OpFoldResult> extended_positions(
op.getMixedPosition());
for (int64_t idx = extended_positions.size(); idx < mask_operands.size();
++idx) {
extended_positions.push_back(rewriter.getIndexAttr(0));
}
llvm::SmallVector<mlir::Value> new_bounds;
new_bounds.reserve(mask_operands.size());
for (auto [mask_bound, extract_offset] :
llvm::zip(mask_operands, extended_positions)) {
if (auto static_position =
mlir::dyn_cast<mlir::Attribute>(extract_offset)) {
new_bounds.push_back(mlir::arith::SubIOp::create(
rewriter, op.getLoc(), rewriter.getIndexType(), mask_bound,
mlir::arith::ConstantIndexOp::create(
rewriter, op.getLoc(),
mlir::cast<mlir::IntegerAttr>(static_position).getInt())));
} else {
auto dynamic_position = mlir::dyn_cast<mlir::Value>(extract_offset);
new_bounds.push_back(mlir::arith::SubIOp::create(
rewriter, op.getLoc(), rewriter.getIndexType(), mask_bound,
dynamic_position));
}
}
auto shifted_mask = mlir::vector::CreateMaskOp::create(
rewriter, op.getLoc(), mask_op.getType(), new_bounds);
llvm::SmallVector<int64_t> zero_index(op.getMixedPosition().size(), 0);
rewriter.replaceOpWithNewOp<mlir::vector::ExtractOp>(op, shifted_mask,
zero_index);
return mlir::success();
}
};
// Unroll loops that have a vector.extract that depend on the loop induction
// variable.
struct UnrollExtractLoops : mlir::OpRewritePattern<mlir::scf::ForOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult matchAndRewrite(
mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const override {
if (op.getRegion().getOps<mlir::vector::ExtractOp>().empty()) {
return rewriter.notifyMatchFailure(op,
"loop does not contain an extract");
}
llvm::SetVector<mlir::Operation*> slices;
mlir::getForwardSlice(op.getInductionVar(), &slices);
for (auto slice : slices) {
if (mlir::isa<mlir::vector::ExtractOp>(slice)) {
return mlir::loopUnrollFull(op);
}
}
return rewriter.notifyMatchFailure(
op, "loop does not contain a dependent extract");
}
};
class RewriteDynamicVectorExtractPass
: public impl::RewriteDynamicVectorExtractPassBase<
RewriteDynamicVectorExtractPass> {
public:
using RewriteDynamicVectorExtractPassBase::
RewriteDynamicVectorExtractPassBase;
void runOnOperation() override {
mlir::ModuleOp module = getOperation();
mlir::MLIRContext* context = &getContext();
{
mlir::RewritePatternSet patterns(context);
patterns.add<FoldExtractIntoTransferRead, FoldExtractIntoCreateMask>(
context);
if (mlir::failed(
mlir::applyPatternsGreedily(module, std::move(patterns)))) {
signalPassFailure();
return;
}
}
// As a final sledge hammer, we can unroll the loops if we have any
// dependent extracts.
{
mlir::RewritePatternSet patterns(context);
patterns.add<UnrollExtractLoops>(context);
if (mlir::failed(
mlir::applyPatternsGreedily(module, std::move(patterns)))) {
signalPassFailure();
return;
}
}
}
};
} // namespace
std::unique_ptr<mlir::Pass> CreateRewriteDynamicVectorExtractPass() {
return std::make_unique<RewriteDynamicVectorExtractPass>();
}
} // namespace xla::cpu

View File

@ -0,0 +1,104 @@
// RUN: emitters_opt %s \
// RUN: -xtile-cpu-rewrite-dynamic-vector-extract -canonicalize \
// RUN: -split-input-file | FileCheck %s
func.func @fold_vector_extract_into_transfer_read(
%buffer: memref<8x4x2xf32>,
%idx0: index,
%idx1: index) -> vector<2xf32> {
%c0 = arith.constant 0 : index
%c0_f32 = arith.constant 0.0 : f32
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%c7 = arith.constant 7 : index
%mask = vector.create_mask %c7, %c3, %c1 : vector<8x4x2xi1>
%original_vector = vector.transfer_read %buffer[%c0, %c0, %c0],
%c0_f32, %mask : memref<8x4x2xf32>, vector<8x4x2xf32>
%subvector = vector.extract %original_vector[%idx0, %idx1]
: vector<2xf32> from vector<8x4x2xf32>
return %subvector : vector<2xf32>
}
// CHECK: func.func @fold_vector_extract_into_transfer_read(
// CHECK-SAME: %[[BUFFER:.*]]: memref<8x4x2xf32>,
// CHECK-SAME: %[[IDX0:.*]]: index,
// CHECK-SAME: %[[IDX1:.*]]: index) -> vector<2xf32> {
// CHECK-DAG: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
// CHECK: %[[SHIFT_IDX0:.*]] = arith.subi %[[C7]], %[[IDX0]] : index
// CHECK: %[[SHIFT_SUBIDX1:.*]] = arith.subi %[[C3]], %[[IDX1]] : index
// CHECK: %[[SHIFT_MASK:.*]] = vector.create_mask
// CHECK-SAME: %[[SHIFT_IDX0]], %[[SHIFT_SUBIDX1]], %[[C1]] : vector<8x4x2xi1>
// CHECK: %[[SUBMASK:.*]] = vector.extract %[[SHIFT_MASK]][0, 0]
// CHECK-SAME: : vector<2xi1> from vector<8x4x2xi1>
// CHECK: %[[SUBVECTOR:.*]] = vector.transfer_read
// CHECK-SAME: %[[BUFFER]][%[[IDX0]], %[[IDX1]], %[[C0]]], %[[PAD]], %[[SUBMASK]]
// CHECK-SAME: {in_bounds = [true]} : memref<8x4x2xf32>, vector<2xf32>
// CHECK: return %[[SUBVECTOR]] : vector<2xf32>
// CHECK: }
// -----
func.func @unroll_dependent_vector_extract(%input: vector<8x2xf32>) -> vector<2xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c8 = arith.constant 8 : index
%c0_f32 = arith.constant 0. : f32
%init = vector.broadcast %c0_f32 : f32 to vector<2xf32>
%result = scf.for %index = %c0 to %c8 step %c1 iter_args(%carry = %init) -> vector<2xf32> {
%extract = vector.extract %input[%index] : vector<2xf32> from vector<8x2xf32>
%add = arith.addf %carry, %extract : vector<2xf32>
scf.yield %add : vector<2xf32>
}
return %result : vector<2xf32>
}
// CHECK-LABEL: func.func @unroll_dependent_vector_extract(
// CHECK-NOT: scf.for
// CHECK-COUNT-8: vector.extract
// -----
func.func @unroll_indirect_dependent_vector_extract(%input: vector<8x2xf32>) -> vector<2xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%c0_f32 = arith.constant 0. : f32
%init = vector.broadcast %c0_f32 : f32 to vector<2xf32>
%result = scf.for %index = %c0 to %c4 step %c1 iter_args(%carry = %init) -> vector<2xf32> {
%strided_index = arith.muli %index, %c2 : index
%extract = vector.extract %input[%strided_index] : vector<2xf32> from vector<8x2xf32>
%add = arith.addf %carry, %extract : vector<2xf32>
scf.yield %add : vector<2xf32>
}
return %result : vector<2xf32>
}
// CHECK-LABEL: func.func @unroll_indirect_dependent_vector_extract(
// CHECK-NOT: scf.for
// CHECK-COUNT-4: vector.extract
// -----
func.func @does_not_unroll_independent_vector_extract(%input: vector<8x2xf32>, %arg_index: index) -> vector<2xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c8 = arith.constant 8 : index
%c0_f32 = arith.constant 0. : f32
%init = vector.broadcast %c0_f32 : f32 to vector<2xf32>
%result = scf.for %index = %c0 to %c8 step %c1 iter_args(%carry = %init) -> vector<2xf32> {
%extract = vector.extract %input[%arg_index] : vector<2xf32> from vector<8x2xf32>
%add = arith.addf %carry, %extract : vector<2xf32>
scf.yield %add : vector<2xf32>
}
return %result : vector<2xf32>
}
// CHECK-LABEL: func.func @does_not_unroll_independent_vector_extract(
// CHECK: scf.for
// CHECK-COUNT-1: vector.extract