diff --git a/third_party/xla/xla/backends/cpu/codegen/fusion_compiler.cc b/third_party/xla/xla/backends/cpu/codegen/fusion_compiler.cc index 766d9169dc4..fd79949e60e 100644 --- a/third_party/xla/xla/backends/cpu/codegen/fusion_compiler.cc +++ b/third_party/xla/xla/backends/cpu/codegen/fusion_compiler.cc @@ -255,6 +255,7 @@ static void AddTiledLoweringPasses(mlir::OpPassManager& pm) { pm.addPass(mlir::createConvertVectorToLLVMPass()); pm.addPass(mlir::createConvertComplexToStandardPass()); + pm.addPass(mlir::memref::createExpandStridedMetadataPass()); AddGenericLoweringPasses(pm); } diff --git a/third_party/xla/xla/backends/cpu/codegen/tiled/tiled_kernel_test.py b/third_party/xla/xla/backends/cpu/codegen/tiled/tiled_kernel_test.py index ac5995353fa..4fbb7a0eee3 100644 --- a/third_party/xla/xla/backends/cpu/codegen/tiled/tiled_kernel_test.py +++ b/third_party/xla/xla/backends/cpu/codegen/tiled/tiled_kernel_test.py @@ -98,6 +98,30 @@ class XtileLoweringTest(absltest.TestCase): lambda arg: arg.transpose(), ) + def test_strided(self): + ir = """ + module @tiled_slice { + xtile.entry_func @tiled_slice( + %input: memref<64x64xf32>, + %output: memref<4x32xf32>, + %tile_id: index) attributes {xtile.tiling_info = #xtile.tiling_info} { + %input_tile = xtile.extract %input[%tile_id, %tile_id][4, 32][21, 2] : memref<64x64xf32> -> tensor<4x32xf32> + xtile.insert %input_tile into %output[%tile_id, %tile_id][4, 32][1, 1] : tensor<4x32xf32> -> memref<4x32xf32> + xtile.return + } + } + """ + + compare_kernel( + ir, + "tiled_slice", + 1, + [(64, 64)], + (4, 32), + np.float32, + lambda arg: arg[::21, ::2], + ) + def test_transpose(self): ir = """ module @tiled_transpose { diff --git a/third_party/xla/xla/backends/cpu/codegen/tiled/transforms/BUILD b/third_party/xla/xla/backends/cpu/codegen/tiled/transforms/BUILD index 81491115e24..aa1f510ba7a 100644 --- a/third_party/xla/xla/backends/cpu/codegen/tiled/transforms/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/tiled/transforms/BUILD @@ -72,11 +72,13 @@ cc_library( "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MathOpsIncGen", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:UBDialect", "@llvm-project//mlir:VectorDialect", "@stablehlo//:stablehlo_ops", ], diff --git a/third_party/xla/xla/backends/cpu/codegen/tiled/transforms/tests/xtile_to_vector.mlir b/third_party/xla/xla/backends/cpu/codegen/tiled/transforms/tests/xtile_to_vector.mlir index 6c16cdee563..d4ba7a6a991 100644 --- a/third_party/xla/xla/backends/cpu/codegen/tiled/transforms/tests/xtile_to_vector.mlir +++ b/third_party/xla/xla/backends/cpu/codegen/tiled/transforms/tests/xtile_to_vector.mlir @@ -1,30 +1,58 @@ -// RUN: emitters_opt %s --xtile-cpu-xtile-to-vector -split-input-file | FileCheck %s +// RUN: emitters_opt %s --xtile-cpu-xtile-to-vector -cse -split-input-file | FileCheck %s // CHECK-LABEL: @simple_insert_extract // CHECK-SAME: (%[[INPUT:.*]]: memref<1024xf32>, %[[OUTPUT:.*]]: memref<1024xf32>, %[[TILE_ID:.*]]: index) xtile.entry_func @simple_insert_extract(%input: memref<1024xf32>, %output: memref<1024xf32>, %tile_id: index) { // CHECK-DAG: %[[POISON:.*]] = ub.poison : f32 - // CHECK: %[[EXTRACT:.*]] = vector.transfer_read %[[INPUT]][%[[TILE_ID]]], %[[POISON]] : memref<1024xf32>, vector<1xf32> + // CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index + // CHECK: %[[IN_SUBVIEW:.*]] = memref.subview %[[INPUT]][%[[TILE_ID]]] [1] [1] + // CHECK-SAME: memref<1024xf32> to memref<1xf32, strided<[1], offset: ?>> + // CHECK: %[[MASK:.*]] = vector.create_mask + // CHECK: %[[EXTRACT:.*]] = vector.transfer_read %[[IN_SUBVIEW]][%[[C_0]]], %[[POISON]], %[[MASK]] %tile = xtile.extract %input[%tile_id][1][1] : memref<1024xf32> -> tensor<1xf32> - // CHECK: vector.transfer_write %[[EXTRACT]], %[[OUTPUT]][%[[TILE_ID]]] : vector<1xf32>, memref<1024xf32> + // CHECK: %[[OUT_SUBVIEW:.*]] = memref.subview %[[OUTPUT]][%[[TILE_ID]]] [1] [1] + // CHECK-SAME: memref<1024xf32> to memref<1xf32, strided<[1], offset: ?>> + // CHECK: vector.transfer_write %[[EXTRACT]], %[[OUT_SUBVIEW]][%[[C_0]]], %[[MASK]] xtile.insert %tile into %output[%tile_id][1][1] : tensor<1xf32> -> memref<1024xf32> xtile.return } - // ----- -// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK: @reduce_dimension(%[[INPUT:.*]]: memref<16x1024xf32>, %[[OUTPUT:.*]]: memref<16x1024xf32>, %[[TILE_ID:.*]]: index) xtile.entry_func @reduce_dimension(%input: memref<16x1024xf32>, %output: memref<16x1024xf32>, %tile_id: index) { - // CHECK: %[[OFFSET:.*]] = arith.constant 0 : index + // CHECK: %[[C_0:.*]] = arith.constant 0 : index %offset = arith.constant 0 : index - // CHECK: %[[POISON:.*]] = ub.poison : f32 - // CHECK: %[[EXTRACT:.*]] = vector.transfer_read %[[INPUT]][%[[OFFSET]], %[[TILE_ID]]], %[[POISON]] {in_bounds = [true], permutation_map = #[[MAP]]} : memref<16x1024xf32>, vector<10xf32> - // CHECK: vector.transfer_write %[[EXTRACT]], %[[OUTPUT]][%[[OFFSET]], %[[TILE_ID]]] {in_bounds = [true], permutation_map = #[[MAP]]} : vector<10xf32>, memref<16x1024xf32> + // CHECK: memref.subview %[[INPUT]][%[[C_0]], %[[TILE_ID]]] [10, 1] [1, 1] + // CHECK-SAME: memref<16x1024xf32> to memref<10xf32, strided<[1024], offset: ?>> %tile = xtile.extract %input[%offset, %tile_id][10, 1][1, 1] : memref<16x1024xf32> -> tensor<10xf32> + // CHECK: memref.subview %[[OUTPUT]][%[[C_0]], %[[TILE_ID]]] [10, 1] [1, 1] + // CHECK-SAME: memref<16x1024xf32> to memref<10xf32, strided<[1024], offset: ?>> xtile.insert %tile into %output[%offset, %tile_id][10, 1][1, 1] : tensor<10xf32> -> memref<16x1024xf32> xtile.return } // ----- + +// CHECK: @extract_strided(%[[SOURCE:.*]]: memref<16xf32>, %[[TILE_ID:.*]]: index) +func.func @extract_strided(%source: memref<16xf32>, %tile_id: index) -> tensor<8xf32> { + // CHECK: memref.subview %[[SOURCE]][%[[TILE_ID]]] [8] [2] : + // CHECK-SAME: memref<16xf32> to memref<8xf32, strided<[2], offset: ?>> + %tile = xtile.extract %source[%tile_id][8][2] : memref<16xf32> -> tensor<8xf32> + return %tile : tensor<8xf32> +} + +// ----- + +// CHECK: @insert_strided( +// CHECK-SAME: %[[SOURCE:.*]]: tensor<8xf32>, +// CHECK-SAME: %[[DESTINATION:.*]]: memref<16xf32>, +// CHECK-SAME: %[[TILE_ID:.*]]: index) +func.func @insert_strided(%source: tensor<8xf32>, %destination: memref<16xf32>, %tile_id: index) { + // CHECK: memref.subview %[[DESTINATION]][%[[TILE_ID]]] [8] [2] : + // CHECK-SAME: memref<16xf32> to memref<8xf32, strided<[2], offset: ?>> + xtile.insert %source into %destination[%tile_id][8][2] : tensor<8xf32> -> memref<16xf32> + return +} + + diff --git a/third_party/xla/xla/backends/cpu/codegen/tiled/transforms/xtile_to_vector.cc b/third_party/xla/xla/backends/cpu/codegen/tiled/transforms/xtile_to_vector.cc index a177bb72251..f7c6f2fc8f3 100644 --- a/third_party/xla/xla/backends/cpu/codegen/tiled/transforms/xtile_to_vector.cc +++ b/third_party/xla/xla/backends/cpu/codegen/tiled/transforms/xtile_to_vector.cc @@ -14,25 +14,38 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include +#include "absl/algorithm/container.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // IWYU pragma: keep +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xla/backends/cpu/codegen/tiled/transforms/lowering_utils.h" #include "xla/backends/cpu/codegen/tiled/transforms/passes.h" #include "xla/codegen/xtile/ir/xtile_ops.h" @@ -44,12 +57,89 @@ namespace xla::cpu { namespace { -mlir::AffineMap GetFilteredDims(mlir::MLIRContext* context, unsigned rank, - llvm::SmallDenseSet reduced_dims) { - return mlir::AffineMap::getFilteredIdentityMap( - context, rank, [&reduced_dims](mlir::AffineDimExpr dim) { - return !reduced_dims.contains(dim.getPosition()); - }); +// Dims are dropped in the subview so we use the identity map. +mlir::AffineMapAttr GetIdentityMap(xtile::TiledBufferInterface op) { + int64_t rank = op.getTile().getType().getRank(); + return mlir::AffineMapAttr::get( + mlir::AffineMap::getMultiDimIdentityMap(rank, op.getContext())); +} + +mlir::TypedValue GetSubView( + mlir::ImplicitLocOpBuilder& builder, xtile::TiledBufferInterface op) { + auto get_static_fold_result = [&](llvm::ArrayRef input) { + return llvm::map_to_vector(input, [&builder](int64_t value) { + return mlir::OpFoldResult(builder.getIndexAttr(value)); + }); + }; + + auto offsets = llvm::SmallVector(op.getOffsets()); + auto full_tile_shape = get_static_fold_result(op.getFullTileShape()); + auto strides = get_static_fold_result(op.getStrides()); + + mlir::MemRefType subview_type = + mlir::memref::SubViewOp::inferRankReducedResultType( + op.getTile().getType().getShape(), op.getBuffer().getType(), offsets, + full_tile_shape, get_static_fold_result(op.getStrides())); + + return builder.create( + subview_type, op.getBuffer(), offsets, full_tile_shape, strides); +} + +llvm::SmallVector GetZeroIndexVector( + mlir::ImplicitLocOpBuilder& builder, int64_t rank) { + return llvm::SmallVector( + rank, builder.create(0)); +} + +mlir::ArrayAttr GetInBoundsAttr(mlir::ImplicitLocOpBuilder& builder, + int64_t rank) { + // TODO(willfroom): Add proper support for inBounds attr. + llvm::SmallVector in_bounds(rank, + builder.getBoolAttr(false)); + return builder.getArrayAttr(in_bounds); +} + +// Get the mask for the given transfer_ op on a subview of the +// original memeref. +// The inequality we need to satisfy in 1D is: +// 1. offset + subview_idx * stride <= size - 1 +// 2. subview_idx * stride <= size - 1 - offset +// 3. subview_idx <= (size - 1 - offset) / stride +// 4. subview_idx < ((size - 1 - offset) / stride) + 1 +// 5. subview_idx < (size + stride - 1 - offset) / stride +mlir::Value GetMask(mlir::ImplicitLocOpBuilder& builder, + xtile::TiledBufferInterface op) { + mlir::RankedTensorType tile_tensor_type = op.getTile().getType(); + + auto get_const_index_op = [&](int64_t value) { + return builder.create(value); + }; + + if (tile_tensor_type.getRank() == 0) { + // Vector transfer read/write currently don't support 0D masks. + auto mask_0D_type = mlir::VectorType::get({1}, builder.getI1Type()); + return builder.create( + mask_0D_type, mlir::OpFoldResult(builder.getIndexAttr(1))); + } + + llvm::SmallDenseSet reduced_dims = op.getReducedDimensions(); + llvm::SmallVector upper_bounds; + int64_t idx = 0; + for (auto [offset, size, stride] : + llvm::zip(op.getOffsets(), op.getBuffer().getType().getShape(), + op.getStrides())) { + if (reduced_dims.contains(idx++)) { + continue; + } + upper_bounds.push_back(builder.create( + builder.create( + get_const_index_op(size + stride - 1), offset), + get_const_index_op(stride))); + } + + auto mask_type = mlir::VectorType::get(op.getTile().getType().getShape(), + builder.getI1Type()); + return builder.create(mask_type, upper_bounds); } struct LowerExtractTile : mlir::OpRewritePattern { @@ -57,21 +147,25 @@ struct LowerExtractTile : mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite( xtile::ExtractTileOp op, mlir::PatternRewriter& rewriter) const override { - mlir::RankedTensorType dest_tensor_type = op.getResult().getType(); - auto vector_type = mlir::VectorType::get(dest_tensor_type.getShape(), - dest_tensor_type.getElementType()); + mlir::ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + auto vector_type = GetVectorType(op.getResult().getType()); + + mlir::TypedValue buffer_subview = GetSubView(builder, op); + + int64_t reduced_rank = vector_type.getRank(); + + // The subview is already offset so the read has zero offsets. + auto zero_index = GetZeroIndexVector(builder, reduced_rank); + mlir::Value padding = + builder.create(vector_type.getElementType()); + mlir::Value mask = GetMask(builder, op); + auto in_bounds = GetInBoundsAttr(builder, reduced_rank); - // TODO(willfroom): Add support for inBounds attr. mlir::Value vector_value = rewriter.create( - op->getLoc(), vector_type, op.getSource(), op.getOffsets(), - /*padding=*/std::nullopt, - GetFilteredDims(rewriter.getContext(), - op.getSource().getType().getRank(), - op.getReducedDimensions())); - mlir::UnrealizedConversionCastOp cast = - rewriter.create( - op->getLoc(), op.getResult().getType(), vector_value); - rewriter.replaceOp(op, cast); + op->getLoc(), vector_type, buffer_subview, zero_index, + GetIdentityMap(op), padding, mask, in_bounds); + + rewriter.replaceOp(op, CastToTensor(builder, vector_value)); return mlir::success(); } }; @@ -81,20 +175,23 @@ struct LowerInsertTile : mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite( xtile::InsertTileOp op, mlir::PatternRewriter& rewriter) const override { - mlir::RankedTensorType source_tensor_type = op.getSource().getType(); - auto vector_type = mlir::VectorType::get( - source_tensor_type.getShape(), source_tensor_type.getElementType()); - mlir::Value cast = rewriter - .create( - op->getLoc(), vector_type, op.getSource()) - .getResult(0); - // TODO(willfroom): Add support for inBounds attr. + mlir::ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + mlir::TypedValue vector_tile = + CastToVector(builder, op.getSource()); + + mlir::TypedValue buffer_subview = GetSubView(builder, op); + + int64_t reduced_rank = vector_tile.getType().getRank(); + + // The subview is already offset so the write has zero offsets. + auto zero_index = GetZeroIndexVector(builder, reduced_rank); + mlir::Value mask = GetMask(builder, op); + auto in_bounds = GetInBoundsAttr(builder, reduced_rank); + mlir::vector::TransferWriteOp transfer_write = - rewriter.create( - op->getLoc(), cast, op.getDestination(), op.getOffsets(), - GetFilteredDims(rewriter.getContext(), - op.getDestination().getType().getRank(), - op.getReducedDimensions())); + builder.create( + vector_tile, buffer_subview, zero_index, GetIdentityMap(op), mask, + in_bounds); rewriter.replaceOp(op, transfer_write); return mlir::success();