mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:CPU][XTile] Implement pass to rewrite dynamic vector extracts to static.
PiperOrigin-RevId: 824427163
This commit is contained in:
parent
9add8b7e61
commit
dfeccf211b
|
|
@ -48,6 +48,7 @@ cc_library(
|
|||
srcs = [
|
||||
"elemental_tensor_to_vector.cc",
|
||||
"lower_xtile_entry.cc",
|
||||
"rewrite_dynamic_vector_extract.cc",
|
||||
"shlo_to_vector.cc",
|
||||
"tensor_ops_to_vector.cc",
|
||||
"xtile_to_vector.cc",
|
||||
|
|
@ -63,6 +64,7 @@ cc_library(
|
|||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:ArithDialect",
|
||||
"@llvm-project//mlir:ArithOpsIncGen",
|
||||
"@llvm-project//mlir:DataLayoutInterfaces",
|
||||
|
|
@ -75,6 +77,7 @@ cc_library(
|
|||
"@llvm-project//mlir:MemRefDialect",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:SCFUtils",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TensorDialect",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ std::unique_ptr<mlir::Pass> CreateLowerXTileEntryPass();
|
|||
std::unique_ptr<mlir::Pass> CreateShloToVectorPass();
|
||||
std::unique_ptr<mlir::Pass> CreateXTileToVectorPass();
|
||||
std::unique_ptr<mlir::Pass> CreateTensorOpsToVectorPass();
|
||||
std::unique_ptr<mlir::Pass> CreateRewriteDynamicVectorExtractPass();
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "xla/backends/cpu/codegen/tiled/transforms/passes.h.inc"
|
||||
|
|
|
|||
|
|
@ -76,3 +76,15 @@ def TensorOpsToVectorPass : Pass<"xtile-cpu-tensor-ops-to-vector",
|
|||
"mlir::vector::VectorDialect",
|
||||
];
|
||||
}
|
||||
|
||||
def RewriteDynamicVectorExtractPass : Pass<"xtile-cpu-rewrite-dynamic-vector-extract",
|
||||
"mlir::ModuleOp"> {
|
||||
let summary = "Rewrite vector.extracts with dynamic indices.";
|
||||
|
||||
let constructor = "CreateRewriteDynamicVectorExtractPass()";
|
||||
|
||||
let dependentDialects = [
|
||||
"::mlir::vector::VectorDialect",
|
||||
"::mlir::memref::MemRefDialect",
|
||||
];
|
||||
}
|
||||
|
|
|
|||
262
third_party/xla/xla/backends/cpu/codegen/tiled/transforms/rewrite_dynamic_vector_extract.cc
vendored
Normal file
262
third_party/xla/xla/backends/cpu/codegen/tiled/transforms/rewrite_dynamic_vector_extract.cc
vendored
Normal file
|
|
@ -0,0 +1,262 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/Visitors.h"
|
||||
#include "mlir/Interfaces/DataLayoutInterfaces.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "xla/backends/cpu/codegen/tiled/transforms/passes.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
#define GEN_PASS_DECL_REWRITEDYNAMICVECTOREXTRACTPASS
|
||||
#define GEN_PASS_DEF_REWRITEDYNAMICVECTOREXTRACTPASS
|
||||
#include "xla/backends/cpu/codegen/tiled/transforms/passes.h.inc"
|
||||
|
||||
namespace {
|
||||
|
||||
struct FoldExtractIntoTransferRead
|
||||
: mlir::OpRewritePattern<mlir::vector::ExtractOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(
|
||||
mlir::vector::ExtractOp op,
|
||||
mlir::PatternRewriter& rewriter) const override {
|
||||
if (!op.hasDynamicPosition()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "extract does not have dynamic position");
|
||||
}
|
||||
|
||||
auto transfer_read_op =
|
||||
op.getSource().getDefiningOp<mlir::vector::TransferReadOp>();
|
||||
|
||||
if (!transfer_read_op) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"source is not a transfer_read op");
|
||||
}
|
||||
auto vector_type = mlir::dyn_cast<mlir::VectorType>(op.getType());
|
||||
if (!vector_type) {
|
||||
// TODO(willfroom): Support scalars types.
|
||||
return rewriter.notifyMatchFailure(op, "Output is not a vector type");
|
||||
}
|
||||
|
||||
mlir::ValueRange transfer_read_indices = transfer_read_op.getIndices();
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> extended_positions(
|
||||
op.getMixedPosition());
|
||||
for (int64_t idx = extended_positions.size();
|
||||
idx < transfer_read_indices.size(); ++idx) {
|
||||
extended_positions.push_back(rewriter.getIndexAttr(0));
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Value> new_offsets;
|
||||
new_offsets.reserve(transfer_read_indices.size());
|
||||
for (auto [tile_offset, extract_offset] :
|
||||
llvm::zip(transfer_read_indices, extended_positions)) {
|
||||
if (auto static_position =
|
||||
mlir::dyn_cast<mlir::Attribute>(extract_offset)) {
|
||||
new_offsets.push_back(mlir::arith::AddIOp::create(
|
||||
rewriter, op.getLoc(), rewriter.getIndexType(), tile_offset,
|
||||
mlir::arith::ConstantIndexOp::create(
|
||||
rewriter, op.getLoc(),
|
||||
mlir::cast<mlir::IntegerAttr>(static_position).getInt())));
|
||||
} else {
|
||||
auto dynamic_position = mlir::dyn_cast<mlir::Value>(extract_offset);
|
||||
new_offsets.push_back(mlir::arith::AddIOp::create(
|
||||
rewriter, op.getLoc(), rewriter.getIndexType(), tile_offset,
|
||||
dynamic_position));
|
||||
}
|
||||
}
|
||||
|
||||
mlir::Value submask;
|
||||
if (auto mask = transfer_read_op.getMask()) {
|
||||
submask = mlir::vector::ExtractOp::create(rewriter, op.getLoc(), mask,
|
||||
op.getMixedPosition());
|
||||
}
|
||||
|
||||
int64_t rank = transfer_read_op.getBase().getType().getRank();
|
||||
|
||||
// Drop major dimensions which reflects the behaviour of vector::ExtractOp.
|
||||
int64_t num_dropped_dims = rank - vector_type.getRank();
|
||||
mlir::AffineMap new_permutation_map =
|
||||
mlir::AffineMap::getFilteredIdentityMap(
|
||||
rewriter.getContext(), rank, [&](mlir::AffineDimExpr expr) {
|
||||
return expr.getPosition() >= num_dropped_dims;
|
||||
});
|
||||
|
||||
llvm::SmallVector<mlir::Attribute> in_bounds(
|
||||
transfer_read_op.getInBounds().begin() + num_dropped_dims,
|
||||
transfer_read_op.getInBounds().end());
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::vector::TransferReadOp>(
|
||||
op, vector_type, transfer_read_op.getBase(), new_offsets,
|
||||
new_permutation_map, transfer_read_op.getPadding(), submask,
|
||||
rewriter.getArrayAttr(in_bounds));
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
// FoldExtractIntoTransferRead creates its own dynamic extracts if a mask is
|
||||
// present, so we need to fold these.
|
||||
// We do this by shifting the offset and then extracting with static indices.
|
||||
struct FoldExtractIntoCreateMask
|
||||
: mlir::OpRewritePattern<mlir::vector::ExtractOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(
|
||||
mlir::vector::ExtractOp op,
|
||||
mlir::PatternRewriter& rewriter) const override {
|
||||
if (!op.hasDynamicPosition()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "extract does not have dynamic position");
|
||||
}
|
||||
auto mask_op = op.getSource().getDefiningOp<mlir::vector::CreateMaskOp>();
|
||||
if (!mask_op) {
|
||||
return rewriter.notifyMatchFailure(op, "source is not a create_mask op");
|
||||
}
|
||||
|
||||
mlir::ValueRange mask_operands = mask_op.getOperands();
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> extended_positions(
|
||||
op.getMixedPosition());
|
||||
for (int64_t idx = extended_positions.size(); idx < mask_operands.size();
|
||||
++idx) {
|
||||
extended_positions.push_back(rewriter.getIndexAttr(0));
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Value> new_bounds;
|
||||
new_bounds.reserve(mask_operands.size());
|
||||
for (auto [mask_bound, extract_offset] :
|
||||
llvm::zip(mask_operands, extended_positions)) {
|
||||
if (auto static_position =
|
||||
mlir::dyn_cast<mlir::Attribute>(extract_offset)) {
|
||||
new_bounds.push_back(mlir::arith::SubIOp::create(
|
||||
rewriter, op.getLoc(), rewriter.getIndexType(), mask_bound,
|
||||
mlir::arith::ConstantIndexOp::create(
|
||||
rewriter, op.getLoc(),
|
||||
mlir::cast<mlir::IntegerAttr>(static_position).getInt())));
|
||||
} else {
|
||||
auto dynamic_position = mlir::dyn_cast<mlir::Value>(extract_offset);
|
||||
new_bounds.push_back(mlir::arith::SubIOp::create(
|
||||
rewriter, op.getLoc(), rewriter.getIndexType(), mask_bound,
|
||||
dynamic_position));
|
||||
}
|
||||
}
|
||||
|
||||
auto shifted_mask = mlir::vector::CreateMaskOp::create(
|
||||
rewriter, op.getLoc(), mask_op.getType(), new_bounds);
|
||||
|
||||
llvm::SmallVector<int64_t> zero_index(op.getMixedPosition().size(), 0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::vector::ExtractOp>(op, shifted_mask,
|
||||
zero_index);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
// Unroll loops that have a vector.extract that depend on the loop induction
|
||||
// variable.
|
||||
struct UnrollExtractLoops : mlir::OpRewritePattern<mlir::scf::ForOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(
|
||||
mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const override {
|
||||
if (op.getRegion().getOps<mlir::vector::ExtractOp>().empty()) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"loop does not contain an extract");
|
||||
}
|
||||
|
||||
llvm::SetVector<mlir::Operation*> slices;
|
||||
mlir::getForwardSlice(op.getInductionVar(), &slices);
|
||||
|
||||
for (auto slice : slices) {
|
||||
if (mlir::isa<mlir::vector::ExtractOp>(slice)) {
|
||||
return mlir::loopUnrollFull(op);
|
||||
}
|
||||
}
|
||||
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "loop does not contain a dependent extract");
|
||||
}
|
||||
};
|
||||
|
||||
class RewriteDynamicVectorExtractPass
|
||||
: public impl::RewriteDynamicVectorExtractPassBase<
|
||||
RewriteDynamicVectorExtractPass> {
|
||||
public:
|
||||
using RewriteDynamicVectorExtractPassBase::
|
||||
RewriteDynamicVectorExtractPassBase;
|
||||
|
||||
void runOnOperation() override {
|
||||
mlir::ModuleOp module = getOperation();
|
||||
mlir::MLIRContext* context = &getContext();
|
||||
|
||||
{
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
patterns.add<FoldExtractIntoTransferRead, FoldExtractIntoCreateMask>(
|
||||
context);
|
||||
if (mlir::failed(
|
||||
mlir::applyPatternsGreedily(module, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// As a final sledge hammer, we can unroll the loops if we have any
|
||||
// dependent extracts.
|
||||
{
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
patterns.add<UnrollExtractLoops>(context);
|
||||
if (mlir::failed(
|
||||
mlir::applyPatternsGreedily(module, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<mlir::Pass> CreateRewriteDynamicVectorExtractPass() {
|
||||
return std::make_unique<RewriteDynamicVectorExtractPass>();
|
||||
}
|
||||
|
||||
} // namespace xla::cpu
|
||||
104
third_party/xla/xla/backends/cpu/codegen/tiled/transforms/tests/rewrite_dynamic_vector_extract.mlir
vendored
Normal file
104
third_party/xla/xla/backends/cpu/codegen/tiled/transforms/tests/rewrite_dynamic_vector_extract.mlir
vendored
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
// RUN: emitters_opt %s \
|
||||
// RUN: -xtile-cpu-rewrite-dynamic-vector-extract -canonicalize \
|
||||
// RUN: -split-input-file | FileCheck %s
|
||||
|
||||
func.func @fold_vector_extract_into_transfer_read(
|
||||
%buffer: memref<8x4x2xf32>,
|
||||
%idx0: index,
|
||||
%idx1: index) -> vector<2xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c0_f32 = arith.constant 0.0 : f32
|
||||
%c1 = arith.constant 1 : index
|
||||
%c3 = arith.constant 3 : index
|
||||
%c7 = arith.constant 7 : index
|
||||
%mask = vector.create_mask %c7, %c3, %c1 : vector<8x4x2xi1>
|
||||
%original_vector = vector.transfer_read %buffer[%c0, %c0, %c0],
|
||||
%c0_f32, %mask : memref<8x4x2xf32>, vector<8x4x2xf32>
|
||||
%subvector = vector.extract %original_vector[%idx0, %idx1]
|
||||
: vector<2xf32> from vector<8x4x2xf32>
|
||||
return %subvector : vector<2xf32>
|
||||
}
|
||||
|
||||
// CHECK: func.func @fold_vector_extract_into_transfer_read(
|
||||
// CHECK-SAME: %[[BUFFER:.*]]: memref<8x4x2xf32>,
|
||||
// CHECK-SAME: %[[IDX0:.*]]: index,
|
||||
// CHECK-SAME: %[[IDX1:.*]]: index) -> vector<2xf32> {
|
||||
// CHECK-DAG: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
|
||||
// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
|
||||
// CHECK: %[[SHIFT_IDX0:.*]] = arith.subi %[[C7]], %[[IDX0]] : index
|
||||
// CHECK: %[[SHIFT_SUBIDX1:.*]] = arith.subi %[[C3]], %[[IDX1]] : index
|
||||
// CHECK: %[[SHIFT_MASK:.*]] = vector.create_mask
|
||||
// CHECK-SAME: %[[SHIFT_IDX0]], %[[SHIFT_SUBIDX1]], %[[C1]] : vector<8x4x2xi1>
|
||||
// CHECK: %[[SUBMASK:.*]] = vector.extract %[[SHIFT_MASK]][0, 0]
|
||||
// CHECK-SAME: : vector<2xi1> from vector<8x4x2xi1>
|
||||
// CHECK: %[[SUBVECTOR:.*]] = vector.transfer_read
|
||||
// CHECK-SAME: %[[BUFFER]][%[[IDX0]], %[[IDX1]], %[[C0]]], %[[PAD]], %[[SUBMASK]]
|
||||
// CHECK-SAME: {in_bounds = [true]} : memref<8x4x2xf32>, vector<2xf32>
|
||||
// CHECK: return %[[SUBVECTOR]] : vector<2xf32>
|
||||
// CHECK: }
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func.func @unroll_dependent_vector_extract(%input: vector<8x2xf32>) -> vector<2xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c8 = arith.constant 8 : index
|
||||
%c0_f32 = arith.constant 0. : f32
|
||||
%init = vector.broadcast %c0_f32 : f32 to vector<2xf32>
|
||||
%result = scf.for %index = %c0 to %c8 step %c1 iter_args(%carry = %init) -> vector<2xf32> {
|
||||
%extract = vector.extract %input[%index] : vector<2xf32> from vector<8x2xf32>
|
||||
%add = arith.addf %carry, %extract : vector<2xf32>
|
||||
scf.yield %add : vector<2xf32>
|
||||
}
|
||||
return %result : vector<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @unroll_dependent_vector_extract(
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK-COUNT-8: vector.extract
|
||||
|
||||
// -----
|
||||
|
||||
func.func @unroll_indirect_dependent_vector_extract(%input: vector<8x2xf32>) -> vector<2xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%c0_f32 = arith.constant 0. : f32
|
||||
%init = vector.broadcast %c0_f32 : f32 to vector<2xf32>
|
||||
%result = scf.for %index = %c0 to %c4 step %c1 iter_args(%carry = %init) -> vector<2xf32> {
|
||||
%strided_index = arith.muli %index, %c2 : index
|
||||
%extract = vector.extract %input[%strided_index] : vector<2xf32> from vector<8x2xf32>
|
||||
%add = arith.addf %carry, %extract : vector<2xf32>
|
||||
scf.yield %add : vector<2xf32>
|
||||
}
|
||||
return %result : vector<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @unroll_indirect_dependent_vector_extract(
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK-COUNT-4: vector.extract
|
||||
|
||||
// -----
|
||||
|
||||
func.func @does_not_unroll_independent_vector_extract(%input: vector<8x2xf32>, %arg_index: index) -> vector<2xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c8 = arith.constant 8 : index
|
||||
%c0_f32 = arith.constant 0. : f32
|
||||
%init = vector.broadcast %c0_f32 : f32 to vector<2xf32>
|
||||
%result = scf.for %index = %c0 to %c8 step %c1 iter_args(%carry = %init) -> vector<2xf32> {
|
||||
%extract = vector.extract %input[%arg_index] : vector<2xf32> from vector<8x2xf32>
|
||||
%add = arith.addf %carry, %extract : vector<2xf32>
|
||||
scf.yield %add : vector<2xf32>
|
||||
}
|
||||
return %result : vector<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @does_not_unroll_independent_vector_extract(
|
||||
// CHECK: scf.for
|
||||
// CHECK-COUNT-1: vector.extract
|
||||
Loading…
Reference in New Issue
Block a user