[XLA:CPU][XTile] Add support for strided extract/insert tile.

PiperOrigin-RevId: 822035319
This commit is contained in:
Will Froom 2025-10-21 03:22:48 -07:00 committed by TensorFlower Gardener
parent e756c21611
commit 373abf8de1
5 changed files with 194 additions and 42 deletions

View File

@ -255,6 +255,7 @@ static void AddTiledLoweringPasses(mlir::OpPassManager& pm) {
pm.addPass(mlir::createConvertVectorToLLVMPass()); pm.addPass(mlir::createConvertVectorToLLVMPass());
pm.addPass(mlir::createConvertComplexToStandardPass()); pm.addPass(mlir::createConvertComplexToStandardPass());
pm.addPass(mlir::memref::createExpandStridedMetadataPass());
AddGenericLoweringPasses(pm); AddGenericLoweringPasses(pm);
} }

View File

@ -98,6 +98,30 @@ class XtileLoweringTest(absltest.TestCase):
lambda arg: arg.transpose(), lambda arg: arg.transpose(),
) )
def test_strided(self):
ir = """
module @tiled_slice {
xtile.entry_func @tiled_slice(
%input: memref<64x64xf32>,
%output: memref<4x32xf32>,
%tile_id: index) attributes {xtile.tiling_info = #xtile.tiling_info<tile_count:1, tiles_per_workgroup:1>} {
%input_tile = xtile.extract %input[%tile_id, %tile_id][4, 32][21, 2] : memref<64x64xf32> -> tensor<4x32xf32>
xtile.insert %input_tile into %output[%tile_id, %tile_id][4, 32][1, 1] : tensor<4x32xf32> -> memref<4x32xf32>
xtile.return
}
}
"""
compare_kernel(
ir,
"tiled_slice",
1,
[(64, 64)],
(4, 32),
np.float32,
lambda arg: arg[::21, ::2],
)
def test_transpose(self): def test_transpose(self):
ir = """ ir = """
module @tiled_transpose { module @tiled_transpose {

View File

@ -72,11 +72,13 @@ cc_library(
"@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MathOpsIncGen", "@llvm-project//mlir:MathOpsIncGen",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:UBDialect",
"@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorDialect",
"@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_ops",
], ],

View File

@ -1,30 +1,58 @@
// RUN: emitters_opt %s --xtile-cpu-xtile-to-vector -split-input-file | FileCheck %s // RUN: emitters_opt %s --xtile-cpu-xtile-to-vector -cse -split-input-file | FileCheck %s
// CHECK-LABEL: @simple_insert_extract // CHECK-LABEL: @simple_insert_extract
// CHECK-SAME: (%[[INPUT:.*]]: memref<1024xf32>, %[[OUTPUT:.*]]: memref<1024xf32>, %[[TILE_ID:.*]]: index) // CHECK-SAME: (%[[INPUT:.*]]: memref<1024xf32>, %[[OUTPUT:.*]]: memref<1024xf32>, %[[TILE_ID:.*]]: index)
xtile.entry_func @simple_insert_extract(%input: memref<1024xf32>, %output: memref<1024xf32>, %tile_id: index) { xtile.entry_func @simple_insert_extract(%input: memref<1024xf32>, %output: memref<1024xf32>, %tile_id: index) {
// CHECK-DAG: %[[POISON:.*]] = ub.poison : f32 // CHECK-DAG: %[[POISON:.*]] = ub.poison : f32
// CHECK: %[[EXTRACT:.*]] = vector.transfer_read %[[INPUT]][%[[TILE_ID]]], %[[POISON]] : memref<1024xf32>, vector<1xf32> // CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index
// CHECK: %[[IN_SUBVIEW:.*]] = memref.subview %[[INPUT]][%[[TILE_ID]]] [1] [1]
// CHECK-SAME: memref<1024xf32> to memref<1xf32, strided<[1], offset: ?>>
// CHECK: %[[MASK:.*]] = vector.create_mask
// CHECK: %[[EXTRACT:.*]] = vector.transfer_read %[[IN_SUBVIEW]][%[[C_0]]], %[[POISON]], %[[MASK]]
%tile = xtile.extract %input[%tile_id][1][1] : memref<1024xf32> -> tensor<1xf32> %tile = xtile.extract %input[%tile_id][1][1] : memref<1024xf32> -> tensor<1xf32>
// CHECK: vector.transfer_write %[[EXTRACT]], %[[OUTPUT]][%[[TILE_ID]]] : vector<1xf32>, memref<1024xf32> // CHECK: %[[OUT_SUBVIEW:.*]] = memref.subview %[[OUTPUT]][%[[TILE_ID]]] [1] [1]
// CHECK-SAME: memref<1024xf32> to memref<1xf32, strided<[1], offset: ?>>
// CHECK: vector.transfer_write %[[EXTRACT]], %[[OUT_SUBVIEW]][%[[C_0]]], %[[MASK]]
xtile.insert %tile into %output[%tile_id][1][1] : tensor<1xf32> -> memref<1024xf32> xtile.insert %tile into %output[%tile_id][1][1] : tensor<1xf32> -> memref<1024xf32>
xtile.return xtile.return
} }
// ----- // -----
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0)>
// CHECK: @reduce_dimension(%[[INPUT:.*]]: memref<16x1024xf32>, %[[OUTPUT:.*]]: memref<16x1024xf32>, %[[TILE_ID:.*]]: index) // CHECK: @reduce_dimension(%[[INPUT:.*]]: memref<16x1024xf32>, %[[OUTPUT:.*]]: memref<16x1024xf32>, %[[TILE_ID:.*]]: index)
xtile.entry_func @reduce_dimension(%input: memref<16x1024xf32>, %output: memref<16x1024xf32>, %tile_id: index) { xtile.entry_func @reduce_dimension(%input: memref<16x1024xf32>, %output: memref<16x1024xf32>, %tile_id: index) {
// CHECK: %[[OFFSET:.*]] = arith.constant 0 : index // CHECK: %[[C_0:.*]] = arith.constant 0 : index
%offset = arith.constant 0 : index %offset = arith.constant 0 : index
// CHECK: %[[POISON:.*]] = ub.poison : f32 // CHECK: memref.subview %[[INPUT]][%[[C_0]], %[[TILE_ID]]] [10, 1] [1, 1]
// CHECK: %[[EXTRACT:.*]] = vector.transfer_read %[[INPUT]][%[[OFFSET]], %[[TILE_ID]]], %[[POISON]] {in_bounds = [true], permutation_map = #[[MAP]]} : memref<16x1024xf32>, vector<10xf32> // CHECK-SAME: memref<16x1024xf32> to memref<10xf32, strided<[1024], offset: ?>>
// CHECK: vector.transfer_write %[[EXTRACT]], %[[OUTPUT]][%[[OFFSET]], %[[TILE_ID]]] {in_bounds = [true], permutation_map = #[[MAP]]} : vector<10xf32>, memref<16x1024xf32>
%tile = xtile.extract %input[%offset, %tile_id][10, 1][1, 1] : memref<16x1024xf32> -> tensor<10xf32> %tile = xtile.extract %input[%offset, %tile_id][10, 1][1, 1] : memref<16x1024xf32> -> tensor<10xf32>
// CHECK: memref.subview %[[OUTPUT]][%[[C_0]], %[[TILE_ID]]] [10, 1] [1, 1]
// CHECK-SAME: memref<16x1024xf32> to memref<10xf32, strided<[1024], offset: ?>>
xtile.insert %tile into %output[%offset, %tile_id][10, 1][1, 1] : tensor<10xf32> -> memref<16x1024xf32> xtile.insert %tile into %output[%offset, %tile_id][10, 1][1, 1] : tensor<10xf32> -> memref<16x1024xf32>
xtile.return xtile.return
} }
// ----- // -----
// CHECK: @extract_strided(%[[SOURCE:.*]]: memref<16xf32>, %[[TILE_ID:.*]]: index)
func.func @extract_strided(%source: memref<16xf32>, %tile_id: index) -> tensor<8xf32> {
// CHECK: memref.subview %[[SOURCE]][%[[TILE_ID]]] [8] [2] :
// CHECK-SAME: memref<16xf32> to memref<8xf32, strided<[2], offset: ?>>
%tile = xtile.extract %source[%tile_id][8][2] : memref<16xf32> -> tensor<8xf32>
return %tile : tensor<8xf32>
}
// -----
// CHECK: @insert_strided(
// CHECK-SAME: %[[SOURCE:.*]]: tensor<8xf32>,
// CHECK-SAME: %[[DESTINATION:.*]]: memref<16xf32>,
// CHECK-SAME: %[[TILE_ID:.*]]: index)
func.func @insert_strided(%source: tensor<8xf32>, %destination: memref<16xf32>, %tile_id: index) {
// CHECK: memref.subview %[[DESTINATION]][%[[TILE_ID]]] [8] [2] :
// CHECK-SAME: memref<16xf32> to memref<8xf32, strided<[2], offset: ?>>
xtile.insert %source into %destination[%tile_id][8][2] : tensor<8xf32> -> memref<16xf32>
return
}

View File

@ -14,25 +14,38 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <cassert> #include <cassert>
#include <cstdint>
#include <memory> #include <memory>
#include <optional> #include <optional>
#include <utility> #include <utility>
#include "absl/algorithm/container.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // IWYU pragma: keep #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // IWYU pragma: keep
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/IR/Visitors.h" #include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "xla/backends/cpu/codegen/tiled/transforms/lowering_utils.h"
#include "xla/backends/cpu/codegen/tiled/transforms/passes.h" #include "xla/backends/cpu/codegen/tiled/transforms/passes.h"
#include "xla/codegen/xtile/ir/xtile_ops.h" #include "xla/codegen/xtile/ir/xtile_ops.h"
@ -44,12 +57,89 @@ namespace xla::cpu {
namespace { namespace {
mlir::AffineMap GetFilteredDims(mlir::MLIRContext* context, unsigned rank, // Dims are dropped in the subview so we use the identity map.
llvm::SmallDenseSet<unsigned> reduced_dims) { mlir::AffineMapAttr GetIdentityMap(xtile::TiledBufferInterface op) {
return mlir::AffineMap::getFilteredIdentityMap( int64_t rank = op.getTile().getType().getRank();
context, rank, [&reduced_dims](mlir::AffineDimExpr dim) { return mlir::AffineMapAttr::get(
return !reduced_dims.contains(dim.getPosition()); mlir::AffineMap::getMultiDimIdentityMap(rank, op.getContext()));
}
mlir::TypedValue<mlir::MemRefType> GetSubView(
mlir::ImplicitLocOpBuilder& builder, xtile::TiledBufferInterface op) {
auto get_static_fold_result = [&](llvm::ArrayRef<int64_t> input) {
return llvm::map_to_vector(input, [&builder](int64_t value) {
return mlir::OpFoldResult(builder.getIndexAttr(value));
}); });
};
auto offsets = llvm::SmallVector<mlir::OpFoldResult>(op.getOffsets());
auto full_tile_shape = get_static_fold_result(op.getFullTileShape());
auto strides = get_static_fold_result(op.getStrides());
mlir::MemRefType subview_type =
mlir::memref::SubViewOp::inferRankReducedResultType(
op.getTile().getType().getShape(), op.getBuffer().getType(), offsets,
full_tile_shape, get_static_fold_result(op.getStrides()));
return builder.create<mlir::memref::SubViewOp>(
subview_type, op.getBuffer(), offsets, full_tile_shape, strides);
}
llvm::SmallVector<mlir::Value> GetZeroIndexVector(
mlir::ImplicitLocOpBuilder& builder, int64_t rank) {
return llvm::SmallVector<mlir::Value>(
rank, builder.create<mlir::arith::ConstantIndexOp>(0));
}
mlir::ArrayAttr GetInBoundsAttr(mlir::ImplicitLocOpBuilder& builder,
int64_t rank) {
// TODO(willfroom): Add proper support for inBounds attr.
llvm::SmallVector<mlir::Attribute> in_bounds(rank,
builder.getBoolAttr(false));
return builder.getArrayAttr(in_bounds);
}
// Get the mask for the given transfer_<read/write> op on a subview of the
// original memeref.
// The inequality we need to satisfy in 1D is:
// 1. offset + subview_idx * stride <= size - 1
// 2. subview_idx * stride <= size - 1 - offset
// 3. subview_idx <= (size - 1 - offset) / stride
// 4. subview_idx < ((size - 1 - offset) / stride) + 1
// 5. subview_idx < (size + stride - 1 - offset) / stride
mlir::Value GetMask(mlir::ImplicitLocOpBuilder& builder,
xtile::TiledBufferInterface op) {
mlir::RankedTensorType tile_tensor_type = op.getTile().getType();
auto get_const_index_op = [&](int64_t value) {
return builder.create<mlir::arith::ConstantIndexOp>(value);
};
if (tile_tensor_type.getRank() == 0) {
// Vector transfer read/write currently don't support 0D masks.
auto mask_0D_type = mlir::VectorType::get({1}, builder.getI1Type());
return builder.create<mlir::vector::CreateMaskOp>(
mask_0D_type, mlir::OpFoldResult(builder.getIndexAttr(1)));
}
llvm::SmallDenseSet<unsigned> reduced_dims = op.getReducedDimensions();
llvm::SmallVector<mlir::Value> upper_bounds;
int64_t idx = 0;
for (auto [offset, size, stride] :
llvm::zip(op.getOffsets(), op.getBuffer().getType().getShape(),
op.getStrides())) {
if (reduced_dims.contains(idx++)) {
continue;
}
upper_bounds.push_back(builder.create<mlir::arith::DivSIOp>(
builder.create<mlir::arith::SubIOp>(
get_const_index_op(size + stride - 1), offset),
get_const_index_op(stride)));
}
auto mask_type = mlir::VectorType::get(op.getTile().getType().getShape(),
builder.getI1Type());
return builder.create<mlir::vector::CreateMaskOp>(mask_type, upper_bounds);
} }
struct LowerExtractTile : mlir::OpRewritePattern<xtile::ExtractTileOp> { struct LowerExtractTile : mlir::OpRewritePattern<xtile::ExtractTileOp> {
@ -57,21 +147,25 @@ struct LowerExtractTile : mlir::OpRewritePattern<xtile::ExtractTileOp> {
mlir::LogicalResult matchAndRewrite( mlir::LogicalResult matchAndRewrite(
xtile::ExtractTileOp op, mlir::PatternRewriter& rewriter) const override { xtile::ExtractTileOp op, mlir::PatternRewriter& rewriter) const override {
mlir::RankedTensorType dest_tensor_type = op.getResult().getType(); mlir::ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto vector_type = mlir::VectorType::get(dest_tensor_type.getShape(), auto vector_type = GetVectorType(op.getResult().getType());
dest_tensor_type.getElementType());
mlir::TypedValue<mlir::MemRefType> buffer_subview = GetSubView(builder, op);
int64_t reduced_rank = vector_type.getRank();
// The subview is already offset so the read has zero offsets.
auto zero_index = GetZeroIndexVector(builder, reduced_rank);
mlir::Value padding =
builder.create<mlir::ub::PoisonOp>(vector_type.getElementType());
mlir::Value mask = GetMask(builder, op);
auto in_bounds = GetInBoundsAttr(builder, reduced_rank);
// TODO(willfroom): Add support for inBounds attr.
mlir::Value vector_value = rewriter.create<mlir::vector::TransferReadOp>( mlir::Value vector_value = rewriter.create<mlir::vector::TransferReadOp>(
op->getLoc(), vector_type, op.getSource(), op.getOffsets(), op->getLoc(), vector_type, buffer_subview, zero_index,
/*padding=*/std::nullopt, GetIdentityMap(op), padding, mask, in_bounds);
GetFilteredDims(rewriter.getContext(),
op.getSource().getType().getRank(), rewriter.replaceOp(op, CastToTensor(builder, vector_value));
op.getReducedDimensions()));
mlir::UnrealizedConversionCastOp cast =
rewriter.create<mlir::UnrealizedConversionCastOp>(
op->getLoc(), op.getResult().getType(), vector_value);
rewriter.replaceOp(op, cast);
return mlir::success(); return mlir::success();
} }
}; };
@ -81,20 +175,23 @@ struct LowerInsertTile : mlir::OpRewritePattern<xtile::InsertTileOp> {
mlir::LogicalResult matchAndRewrite( mlir::LogicalResult matchAndRewrite(
xtile::InsertTileOp op, mlir::PatternRewriter& rewriter) const override { xtile::InsertTileOp op, mlir::PatternRewriter& rewriter) const override {
mlir::RankedTensorType source_tensor_type = op.getSource().getType(); mlir::ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto vector_type = mlir::VectorType::get( mlir::TypedValue<mlir::VectorType> vector_tile =
source_tensor_type.getShape(), source_tensor_type.getElementType()); CastToVector(builder, op.getSource());
mlir::Value cast = rewriter
.create<mlir::UnrealizedConversionCastOp>( mlir::TypedValue<mlir::MemRefType> buffer_subview = GetSubView(builder, op);
op->getLoc(), vector_type, op.getSource())
.getResult(0); int64_t reduced_rank = vector_tile.getType().getRank();
// TODO(willfroom): Add support for inBounds attr.
// The subview is already offset so the write has zero offsets.
auto zero_index = GetZeroIndexVector(builder, reduced_rank);
mlir::Value mask = GetMask(builder, op);
auto in_bounds = GetInBoundsAttr(builder, reduced_rank);
mlir::vector::TransferWriteOp transfer_write = mlir::vector::TransferWriteOp transfer_write =
rewriter.create<mlir::vector::TransferWriteOp>( builder.create<mlir::vector::TransferWriteOp>(
op->getLoc(), cast, op.getDestination(), op.getOffsets(), vector_tile, buffer_subview, zero_index, GetIdentityMap(op), mask,
GetFilteredDims(rewriter.getContext(), in_bounds);
op.getDestination().getType().getRank(),
op.getReducedDimensions()));
rewriter.replaceOp(op, transfer_write); rewriter.replaceOp(op, transfer_write);
return mlir::success(); return mlir::success();