mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:CPU][XTile] Add support for strided extract/insert tile.
PiperOrigin-RevId: 822035319
This commit is contained in:
parent
e756c21611
commit
373abf8de1
|
|
@ -255,6 +255,7 @@ static void AddTiledLoweringPasses(mlir::OpPassManager& pm) {
|
|||
pm.addPass(mlir::createConvertVectorToLLVMPass());
|
||||
|
||||
pm.addPass(mlir::createConvertComplexToStandardPass());
|
||||
pm.addPass(mlir::memref::createExpandStridedMetadataPass());
|
||||
|
||||
AddGenericLoweringPasses(pm);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -98,6 +98,30 @@ class XtileLoweringTest(absltest.TestCase):
|
|||
lambda arg: arg.transpose(),
|
||||
)
|
||||
|
||||
def test_strided(self):
|
||||
ir = """
|
||||
module @tiled_slice {
|
||||
xtile.entry_func @tiled_slice(
|
||||
%input: memref<64x64xf32>,
|
||||
%output: memref<4x32xf32>,
|
||||
%tile_id: index) attributes {xtile.tiling_info = #xtile.tiling_info<tile_count:1, tiles_per_workgroup:1>} {
|
||||
%input_tile = xtile.extract %input[%tile_id, %tile_id][4, 32][21, 2] : memref<64x64xf32> -> tensor<4x32xf32>
|
||||
xtile.insert %input_tile into %output[%tile_id, %tile_id][4, 32][1, 1] : tensor<4x32xf32> -> memref<4x32xf32>
|
||||
xtile.return
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
compare_kernel(
|
||||
ir,
|
||||
"tiled_slice",
|
||||
1,
|
||||
[(64, 64)],
|
||||
(4, 32),
|
||||
np.float32,
|
||||
lambda arg: arg[::21, ::2],
|
||||
)
|
||||
|
||||
def test_transpose(self):
|
||||
ir = """
|
||||
module @tiled_transpose {
|
||||
|
|
|
|||
|
|
@ -72,11 +72,13 @@ cc_library(
|
|||
"@llvm-project//mlir:LLVMDialect",
|
||||
"@llvm-project//mlir:MathDialect",
|
||||
"@llvm-project//mlir:MathOpsIncGen",
|
||||
"@llvm-project//mlir:MemRefDialect",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TensorDialect",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
"@llvm-project//mlir:UBDialect",
|
||||
"@llvm-project//mlir:VectorDialect",
|
||||
"@stablehlo//:stablehlo_ops",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -1,30 +1,58 @@
|
|||
// RUN: emitters_opt %s --xtile-cpu-xtile-to-vector -split-input-file | FileCheck %s
|
||||
// RUN: emitters_opt %s --xtile-cpu-xtile-to-vector -cse -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @simple_insert_extract
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: memref<1024xf32>, %[[OUTPUT:.*]]: memref<1024xf32>, %[[TILE_ID:.*]]: index)
|
||||
xtile.entry_func @simple_insert_extract(%input: memref<1024xf32>, %output: memref<1024xf32>, %tile_id: index) {
|
||||
// CHECK-DAG: %[[POISON:.*]] = ub.poison : f32
|
||||
// CHECK: %[[EXTRACT:.*]] = vector.transfer_read %[[INPUT]][%[[TILE_ID]]], %[[POISON]] : memref<1024xf32>, vector<1xf32>
|
||||
// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[IN_SUBVIEW:.*]] = memref.subview %[[INPUT]][%[[TILE_ID]]] [1] [1]
|
||||
// CHECK-SAME: memref<1024xf32> to memref<1xf32, strided<[1], offset: ?>>
|
||||
// CHECK: %[[MASK:.*]] = vector.create_mask
|
||||
// CHECK: %[[EXTRACT:.*]] = vector.transfer_read %[[IN_SUBVIEW]][%[[C_0]]], %[[POISON]], %[[MASK]]
|
||||
%tile = xtile.extract %input[%tile_id][1][1] : memref<1024xf32> -> tensor<1xf32>
|
||||
// CHECK: vector.transfer_write %[[EXTRACT]], %[[OUTPUT]][%[[TILE_ID]]] : vector<1xf32>, memref<1024xf32>
|
||||
// CHECK: %[[OUT_SUBVIEW:.*]] = memref.subview %[[OUTPUT]][%[[TILE_ID]]] [1] [1]
|
||||
// CHECK-SAME: memref<1024xf32> to memref<1xf32, strided<[1], offset: ?>>
|
||||
// CHECK: vector.transfer_write %[[EXTRACT]], %[[OUT_SUBVIEW]][%[[C_0]]], %[[MASK]]
|
||||
xtile.insert %tile into %output[%tile_id][1][1] : tensor<1xf32> -> memref<1024xf32>
|
||||
xtile.return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||
// CHECK: @reduce_dimension(%[[INPUT:.*]]: memref<16x1024xf32>, %[[OUTPUT:.*]]: memref<16x1024xf32>, %[[TILE_ID:.*]]: index)
|
||||
xtile.entry_func @reduce_dimension(%input: memref<16x1024xf32>, %output: memref<16x1024xf32>, %tile_id: index) {
|
||||
// CHECK: %[[OFFSET:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[C_0:.*]] = arith.constant 0 : index
|
||||
%offset = arith.constant 0 : index
|
||||
// CHECK: %[[POISON:.*]] = ub.poison : f32
|
||||
// CHECK: %[[EXTRACT:.*]] = vector.transfer_read %[[INPUT]][%[[OFFSET]], %[[TILE_ID]]], %[[POISON]] {in_bounds = [true], permutation_map = #[[MAP]]} : memref<16x1024xf32>, vector<10xf32>
|
||||
// CHECK: vector.transfer_write %[[EXTRACT]], %[[OUTPUT]][%[[OFFSET]], %[[TILE_ID]]] {in_bounds = [true], permutation_map = #[[MAP]]} : vector<10xf32>, memref<16x1024xf32>
|
||||
// CHECK: memref.subview %[[INPUT]][%[[C_0]], %[[TILE_ID]]] [10, 1] [1, 1]
|
||||
// CHECK-SAME: memref<16x1024xf32> to memref<10xf32, strided<[1024], offset: ?>>
|
||||
%tile = xtile.extract %input[%offset, %tile_id][10, 1][1, 1] : memref<16x1024xf32> -> tensor<10xf32>
|
||||
// CHECK: memref.subview %[[OUTPUT]][%[[C_0]], %[[TILE_ID]]] [10, 1] [1, 1]
|
||||
// CHECK-SAME: memref<16x1024xf32> to memref<10xf32, strided<[1024], offset: ?>>
|
||||
xtile.insert %tile into %output[%offset, %tile_id][10, 1][1, 1] : tensor<10xf32> -> memref<16x1024xf32>
|
||||
xtile.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: @extract_strided(%[[SOURCE:.*]]: memref<16xf32>, %[[TILE_ID:.*]]: index)
|
||||
func.func @extract_strided(%source: memref<16xf32>, %tile_id: index) -> tensor<8xf32> {
|
||||
// CHECK: memref.subview %[[SOURCE]][%[[TILE_ID]]] [8] [2] :
|
||||
// CHECK-SAME: memref<16xf32> to memref<8xf32, strided<[2], offset: ?>>
|
||||
%tile = xtile.extract %source[%tile_id][8][2] : memref<16xf32> -> tensor<8xf32>
|
||||
return %tile : tensor<8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: @insert_strided(
|
||||
// CHECK-SAME: %[[SOURCE:.*]]: tensor<8xf32>,
|
||||
// CHECK-SAME: %[[DESTINATION:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[TILE_ID:.*]]: index)
|
||||
func.func @insert_strided(%source: tensor<8xf32>, %destination: memref<16xf32>, %tile_id: index) {
|
||||
// CHECK: memref.subview %[[DESTINATION]][%[[TILE_ID]]] [8] [2] :
|
||||
// CHECK-SAME: memref<16xf32> to memref<8xf32, strided<[2], offset: ?>>
|
||||
xtile.insert %source into %destination[%tile_id][8][2] : tensor<8xf32> -> memref<16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,25 +14,38 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/SmallVectorExtras.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // IWYU pragma: keep
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/UB/IR/UBOps.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "mlir/IR/Visitors.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "xla/backends/cpu/codegen/tiled/transforms/lowering_utils.h"
|
||||
#include "xla/backends/cpu/codegen/tiled/transforms/passes.h"
|
||||
#include "xla/codegen/xtile/ir/xtile_ops.h"
|
||||
|
||||
|
|
@ -44,12 +57,89 @@ namespace xla::cpu {
|
|||
|
||||
namespace {
|
||||
|
||||
mlir::AffineMap GetFilteredDims(mlir::MLIRContext* context, unsigned rank,
|
||||
llvm::SmallDenseSet<unsigned> reduced_dims) {
|
||||
return mlir::AffineMap::getFilteredIdentityMap(
|
||||
context, rank, [&reduced_dims](mlir::AffineDimExpr dim) {
|
||||
return !reduced_dims.contains(dim.getPosition());
|
||||
// Dims are dropped in the subview so we use the identity map.
|
||||
mlir::AffineMapAttr GetIdentityMap(xtile::TiledBufferInterface op) {
|
||||
int64_t rank = op.getTile().getType().getRank();
|
||||
return mlir::AffineMapAttr::get(
|
||||
mlir::AffineMap::getMultiDimIdentityMap(rank, op.getContext()));
|
||||
}
|
||||
|
||||
mlir::TypedValue<mlir::MemRefType> GetSubView(
|
||||
mlir::ImplicitLocOpBuilder& builder, xtile::TiledBufferInterface op) {
|
||||
auto get_static_fold_result = [&](llvm::ArrayRef<int64_t> input) {
|
||||
return llvm::map_to_vector(input, [&builder](int64_t value) {
|
||||
return mlir::OpFoldResult(builder.getIndexAttr(value));
|
||||
});
|
||||
};
|
||||
|
||||
auto offsets = llvm::SmallVector<mlir::OpFoldResult>(op.getOffsets());
|
||||
auto full_tile_shape = get_static_fold_result(op.getFullTileShape());
|
||||
auto strides = get_static_fold_result(op.getStrides());
|
||||
|
||||
mlir::MemRefType subview_type =
|
||||
mlir::memref::SubViewOp::inferRankReducedResultType(
|
||||
op.getTile().getType().getShape(), op.getBuffer().getType(), offsets,
|
||||
full_tile_shape, get_static_fold_result(op.getStrides()));
|
||||
|
||||
return builder.create<mlir::memref::SubViewOp>(
|
||||
subview_type, op.getBuffer(), offsets, full_tile_shape, strides);
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Value> GetZeroIndexVector(
|
||||
mlir::ImplicitLocOpBuilder& builder, int64_t rank) {
|
||||
return llvm::SmallVector<mlir::Value>(
|
||||
rank, builder.create<mlir::arith::ConstantIndexOp>(0));
|
||||
}
|
||||
|
||||
mlir::ArrayAttr GetInBoundsAttr(mlir::ImplicitLocOpBuilder& builder,
|
||||
int64_t rank) {
|
||||
// TODO(willfroom): Add proper support for inBounds attr.
|
||||
llvm::SmallVector<mlir::Attribute> in_bounds(rank,
|
||||
builder.getBoolAttr(false));
|
||||
return builder.getArrayAttr(in_bounds);
|
||||
}
|
||||
|
||||
// Get the mask for the given transfer_<read/write> op on a subview of the
|
||||
// original memeref.
|
||||
// The inequality we need to satisfy in 1D is:
|
||||
// 1. offset + subview_idx * stride <= size - 1
|
||||
// 2. subview_idx * stride <= size - 1 - offset
|
||||
// 3. subview_idx <= (size - 1 - offset) / stride
|
||||
// 4. subview_idx < ((size - 1 - offset) / stride) + 1
|
||||
// 5. subview_idx < (size + stride - 1 - offset) / stride
|
||||
mlir::Value GetMask(mlir::ImplicitLocOpBuilder& builder,
|
||||
xtile::TiledBufferInterface op) {
|
||||
mlir::RankedTensorType tile_tensor_type = op.getTile().getType();
|
||||
|
||||
auto get_const_index_op = [&](int64_t value) {
|
||||
return builder.create<mlir::arith::ConstantIndexOp>(value);
|
||||
};
|
||||
|
||||
if (tile_tensor_type.getRank() == 0) {
|
||||
// Vector transfer read/write currently don't support 0D masks.
|
||||
auto mask_0D_type = mlir::VectorType::get({1}, builder.getI1Type());
|
||||
return builder.create<mlir::vector::CreateMaskOp>(
|
||||
mask_0D_type, mlir::OpFoldResult(builder.getIndexAttr(1)));
|
||||
}
|
||||
|
||||
llvm::SmallDenseSet<unsigned> reduced_dims = op.getReducedDimensions();
|
||||
llvm::SmallVector<mlir::Value> upper_bounds;
|
||||
int64_t idx = 0;
|
||||
for (auto [offset, size, stride] :
|
||||
llvm::zip(op.getOffsets(), op.getBuffer().getType().getShape(),
|
||||
op.getStrides())) {
|
||||
if (reduced_dims.contains(idx++)) {
|
||||
continue;
|
||||
}
|
||||
upper_bounds.push_back(builder.create<mlir::arith::DivSIOp>(
|
||||
builder.create<mlir::arith::SubIOp>(
|
||||
get_const_index_op(size + stride - 1), offset),
|
||||
get_const_index_op(stride)));
|
||||
}
|
||||
|
||||
auto mask_type = mlir::VectorType::get(op.getTile().getType().getShape(),
|
||||
builder.getI1Type());
|
||||
return builder.create<mlir::vector::CreateMaskOp>(mask_type, upper_bounds);
|
||||
}
|
||||
|
||||
struct LowerExtractTile : mlir::OpRewritePattern<xtile::ExtractTileOp> {
|
||||
|
|
@ -57,21 +147,25 @@ struct LowerExtractTile : mlir::OpRewritePattern<xtile::ExtractTileOp> {
|
|||
|
||||
mlir::LogicalResult matchAndRewrite(
|
||||
xtile::ExtractTileOp op, mlir::PatternRewriter& rewriter) const override {
|
||||
mlir::RankedTensorType dest_tensor_type = op.getResult().getType();
|
||||
auto vector_type = mlir::VectorType::get(dest_tensor_type.getShape(),
|
||||
dest_tensor_type.getElementType());
|
||||
mlir::ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
|
||||
auto vector_type = GetVectorType(op.getResult().getType());
|
||||
|
||||
mlir::TypedValue<mlir::MemRefType> buffer_subview = GetSubView(builder, op);
|
||||
|
||||
int64_t reduced_rank = vector_type.getRank();
|
||||
|
||||
// The subview is already offset so the read has zero offsets.
|
||||
auto zero_index = GetZeroIndexVector(builder, reduced_rank);
|
||||
mlir::Value padding =
|
||||
builder.create<mlir::ub::PoisonOp>(vector_type.getElementType());
|
||||
mlir::Value mask = GetMask(builder, op);
|
||||
auto in_bounds = GetInBoundsAttr(builder, reduced_rank);
|
||||
|
||||
// TODO(willfroom): Add support for inBounds attr.
|
||||
mlir::Value vector_value = rewriter.create<mlir::vector::TransferReadOp>(
|
||||
op->getLoc(), vector_type, op.getSource(), op.getOffsets(),
|
||||
/*padding=*/std::nullopt,
|
||||
GetFilteredDims(rewriter.getContext(),
|
||||
op.getSource().getType().getRank(),
|
||||
op.getReducedDimensions()));
|
||||
mlir::UnrealizedConversionCastOp cast =
|
||||
rewriter.create<mlir::UnrealizedConversionCastOp>(
|
||||
op->getLoc(), op.getResult().getType(), vector_value);
|
||||
rewriter.replaceOp(op, cast);
|
||||
op->getLoc(), vector_type, buffer_subview, zero_index,
|
||||
GetIdentityMap(op), padding, mask, in_bounds);
|
||||
|
||||
rewriter.replaceOp(op, CastToTensor(builder, vector_value));
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
|
@ -81,20 +175,23 @@ struct LowerInsertTile : mlir::OpRewritePattern<xtile::InsertTileOp> {
|
|||
|
||||
mlir::LogicalResult matchAndRewrite(
|
||||
xtile::InsertTileOp op, mlir::PatternRewriter& rewriter) const override {
|
||||
mlir::RankedTensorType source_tensor_type = op.getSource().getType();
|
||||
auto vector_type = mlir::VectorType::get(
|
||||
source_tensor_type.getShape(), source_tensor_type.getElementType());
|
||||
mlir::Value cast = rewriter
|
||||
.create<mlir::UnrealizedConversionCastOp>(
|
||||
op->getLoc(), vector_type, op.getSource())
|
||||
.getResult(0);
|
||||
// TODO(willfroom): Add support for inBounds attr.
|
||||
mlir::ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
|
||||
mlir::TypedValue<mlir::VectorType> vector_tile =
|
||||
CastToVector(builder, op.getSource());
|
||||
|
||||
mlir::TypedValue<mlir::MemRefType> buffer_subview = GetSubView(builder, op);
|
||||
|
||||
int64_t reduced_rank = vector_tile.getType().getRank();
|
||||
|
||||
// The subview is already offset so the write has zero offsets.
|
||||
auto zero_index = GetZeroIndexVector(builder, reduced_rank);
|
||||
mlir::Value mask = GetMask(builder, op);
|
||||
auto in_bounds = GetInBoundsAttr(builder, reduced_rank);
|
||||
|
||||
mlir::vector::TransferWriteOp transfer_write =
|
||||
rewriter.create<mlir::vector::TransferWriteOp>(
|
||||
op->getLoc(), cast, op.getDestination(), op.getOffsets(),
|
||||
GetFilteredDims(rewriter.getContext(),
|
||||
op.getDestination().getType().getRank(),
|
||||
op.getReducedDimensions()));
|
||||
builder.create<mlir::vector::TransferWriteOp>(
|
||||
vector_tile, buffer_subview, zero_index, GetIdentityMap(op), mask,
|
||||
in_bounds);
|
||||
|
||||
rewriter.replaceOp(op, transfer_write);
|
||||
return mlir::success();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user