[XLA:CPU][XTile] Implement vectorized reduce.

PiperOrigin-RevId: 825027697
This commit is contained in:
Will Froom 2025-10-28 07:23:05 -07:00 committed by TensorFlower Gardener
parent 11c00ca2db
commit de7a63363c
12 changed files with 766 additions and 13 deletions

View File

@ -198,6 +198,7 @@ cc_library(
"@llvm-project//mlir:VectorDialect",
"@llvm-project//mlir:VectorToLLVM",
"@llvm-project//mlir:VectorToSCF",
"@llvm-project//mlir:VectorTransforms",
"@local_tsl//tsl/profiler/lib:traceme",
"@local_tsl//tsl/profiler/lib:traceme_encode",
"@stablehlo//:stablehlo_passes",

View File

@ -61,6 +61,8 @@ limitations under the License.
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
@ -244,15 +246,23 @@ static void AddTiledOptimizationPasses(mlir::OpPassManager& pm) {
// The input IR is from the xtile dialect which uses tensors that are converted
// first to the vector dialect and then to LLVM.
static void AddTiledLoweringPasses(mlir::OpPassManager& pm) {
pm.addPass(CreateXTileToVectorPass());
pm.addPass(CreateElementalTensorToVectorPass());
pm.addPass(CreateShloToVectorPass());
pm.addPass(CreateXTileToVectorPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(CreateRewriteDynamicVectorExtractPass());
pm.addPass(CreateElementalTensorToVectorPass());
pm.addPass(CreateLowerXTileEntryPass());
pm.addNestedPass<mlir::func::FuncOp>(
mlir::vector::createLowerVectorMultiReductionPass(
mlir::vector::VectorMultiReductionLowering::InnerParallel));
pm.addPass(CreateTensorOpsToVectorPass());
pm.addPass(cpu::createLowerToLLVMPass());
pm.addPass(mlir::createConvertVectorToSCFPass(
mlir::VectorTransferToSCFOptions().enableFullUnroll(false)));
pm.addPass(mlir::createConvertVectorToLLVMPass());
mlir::ConvertVectorToLLVMPassOptions options;
options.vectorTransposeLowering =
mlir::vector::VectorTransposeLowering::Shuffle1D;
pm.addPass(mlir::createConvertVectorToLLVMPass(options));
pm.addPass(mlir::createConvertComplexToStandardPass());
pm.addPass(mlir::memref::createExpandStridedMetadataPass());

View File

@ -269,6 +269,144 @@ class XtileLoweringTest(absltest.TestCase):
maxulp=5,
)
def test_reduction_add_inner(self):
ir = """
module @reduction_add_inner {
xtile.entry_func @reduction_add_inner(
%input: memref<1024x32xf32>,
%init: memref<f32>,
%output: memref<1024xf32>,
%tile_id: index) attributes {xtile.tiling_info = #xtile.tiling_info<tile_count:128, tiles_per_workgroup:32>} {
%c_0 = arith.constant 0 : index
%c_8 = arith.constant 8 : index
%init_tile = xtile.extract %init[][][] : memref<f32> -> tensor<f32>
%index = arith.muli %tile_id, %c_8 : index
%input_tile = xtile.extract %input[%index, %c_0][8, 32][1, 1] : memref<1024x32xf32> -> tensor<8x32xf32>
%result = stablehlo.reduce(%input_tile init: %init_tile)
across dimensions = [1]
: (tensor<8x32xf32>, tensor<f32>) -> tensor<8xf32>
reducer(%arg0: tensor<f32>, %arg1: tensor<f32>) {
%add = arith.addf %arg0, %arg1 : tensor<f32>
stablehlo.return %add : tensor<f32>
}
xtile.insert %result into %output[%index][8][1] : tensor<8xf32> -> memref<1024xf32>
xtile.return
}
}
"""
compare_kernel(
ir,
"reduction_add_inner",
4,
[(1024, 32), (1,)],
(1024,),
np.int32,
lambda input, init: np.sum(input, axis=1) + init,
)
def test_reduction_add_outer(self):
ir = """
module @reduction_add_outer {
xtile.entry_func @reduction_add_outer(
%input: memref<1024x32xf32>,
%init: memref<f32>,
%output: memref<32xf32>,
%tile_id: index) attributes {xtile.tiling_info = #xtile.tiling_info<tile_count:4, tiles_per_workgroup:1>} {
%c_0 = arith.constant 0 : index
%c_8 = arith.constant 8 : index
%init_tile = xtile.extract %init[][][] : memref<f32> -> tensor<f32>
%index = arith.muli %tile_id, %c_8 : index
%input_tile = xtile.extract %input[%c_0, %index][1024, 8][1, 1] : memref<1024x32xf32> -> tensor<1024x8xf32>
%result = stablehlo.reduce(%input_tile init: %init_tile)
across dimensions = [0]
: (tensor<1024x8xf32>, tensor<f32>) -> tensor<8xf32>
reducer(%arg0: tensor<f32>, %arg1: tensor<f32>) {
%add = arith.addf %arg0, %arg1 : tensor<f32>
stablehlo.return %add : tensor<f32>
}
xtile.insert %result into %output[%index][8][1] : tensor<8xf32> -> memref<32xf32>
xtile.return
}
}
"""
compare_kernel(
ir,
"reduction_add_outer",
4,
[(1024, 32), (1,)],
(32,),
np.float32,
lambda input, init: np.sum(input, axis=0) + init,
)
def test_reduction_middle(self):
ir = """
module @reduction_add_middle {
xtile.entry_func @reduction_add_middle(
%input: memref<8x4x2xf32>,
%init: memref<f32>,
%output: memref<8x2xf32>,
%tile_id: index) attributes {xtile.tiling_info = #xtile.tiling_info<tile_count:1, tiles_per_workgroup:1>} {
%init_val = xtile.extract %init[][][] : memref<f32> -> tensor<f32>
%input_tile = xtile.extract %input[%tile_id, %tile_id, %tile_id][8, 4, 2][1, 1, 1] : memref<8x4x2xf32> -> tensor<8x4x2xf32>
%result = stablehlo.reduce(%input_tile init: %init_val)
across dimensions = [1]
: (tensor<8x4x2xf32>, tensor<f32>) -> tensor<8x2xf32>
reducer(%arg0: tensor<f32>, %arg1: tensor<f32>) {
%add = arith.addf %arg0, %arg1 : tensor<f32>
stablehlo.return %add : tensor<f32>
}
xtile.insert %result into %output[%tile_id, %tile_id][8, 2][1, 1] : tensor<8x2xf32> -> memref<8x2xf32>
xtile.return
}
}
"""
compare_kernel(
ir,
"reduction_add_middle",
1,
[(8, 4, 2), (1,)],
(8, 2),
np.float32,
lambda input, init: np.sum(input, axis=1) + init,
)
def test_reduction_outer_inner(self):
ir = """
module @reduction_add_outer_inner {
xtile.entry_func @reduction_add_outer_inner(
%input: memref<8x4x2xf32>,
%init: memref<f32>,
%output: memref<4xf32>,
%tile_id: index) attributes {xtile.tiling_info = #xtile.tiling_info<tile_count:1, tiles_per_workgroup:1>} {
%init_val = xtile.extract %init[][][] : memref<f32> -> tensor<f32>
%input_tile = xtile.extract %input[%tile_id, %tile_id, %tile_id][8, 4, 2][1, 1, 1] : memref<8x4x2xf32> -> tensor<8x4x2xf32>
%result = stablehlo.reduce(%input_tile init: %init_val)
across dimensions = [0, 2]
: (tensor<8x4x2xf32>, tensor<f32>) -> tensor<4xf32>
reducer(%arg0: tensor<f32>, %arg1: tensor<f32>) {
%add = arith.addf %arg0, %arg1 : tensor<f32>
stablehlo.return %add : tensor<f32>
}
xtile.insert %result into %output[%tile_id][4][1] : tensor<4xf32> -> memref<4xf32>
xtile.return
}
}
"""
compare_kernel(
ir,
"reduction_add_outer_inner",
1,
[(8, 4, 2), (1,)],
(4,),
np.float32,
lambda input, init: np.sum(input, axis=(0, 2)) + init,
)
if __name__ == "__main__":
absltest.main()

View File

@ -36,7 +36,9 @@ cc_library(
hdrs = ["lowering_utils.h"],
visibility = ["//visibility:private"],
deps = [
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:VectorDialect",
@ -57,10 +59,14 @@ cc_library(
deps = [
":lowering_utils",
":passes_inc_gen",
":vectorized_reduce_emitter",
"//xla/backends/cpu/codegen/emitters/ir:xla_cpu",
"//xla/codegen/emitters/ir:xla",
"//xla/codegen/xtile/ir:xtile",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@llvm-project//llvm:Support",
@ -72,6 +78,7 @@ cc_library(
"@llvm-project//mlir:FuncTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MathOpsIncGen",
"@llvm-project//mlir:MemRefDialect",
@ -86,3 +93,24 @@ cc_library(
"@stablehlo//:stablehlo_ops",
],
)
cc_library(
name = "vectorized_reduce_emitter",
srcs = ["vectorized_reduce_emitter.cc"],
hdrs = ["vectorized_reduce_emitter.h"],
deps = [
":lowering_utils",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:VectorDialect",
],
)

View File

@ -15,19 +15,20 @@ limitations under the License.
#include "xla/backends/cpu/codegen/tiled/transforms/lowering_utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
namespace xla::cpu {
mlir::VectorType GetVectorType(mlir::RankedTensorType tensor_type) {
return mlir::VectorType::get(tensor_type.getShape(),
tensor_type.getElementType());
mlir::VectorType GetVectorType(mlir::ShapedType type) {
return mlir::VectorType::get(type.getShape(), type.getElementType());
}
mlir::TypedValue<mlir::VectorType> CastToVector(mlir::OpBuilder& builder,
@ -45,9 +46,8 @@ mlir::TypedValue<mlir::VectorType> CastToVector(mlir::OpBuilder& builder,
return mlir::cast<mlir::TypedValue<mlir::VectorType>>(cast_op.getResult(0));
}
mlir::RankedTensorType GetTensorType(mlir::VectorType vector_type) {
return mlir::RankedTensorType::get(vector_type.getShape(),
vector_type.getElementType());
mlir::RankedTensorType GetTensorType(mlir::ShapedType type) {
return mlir::RankedTensorType::get(type.getShape(), type.getElementType());
}
mlir::TypedValue<mlir::RankedTensorType> CastToTensor(mlir::OpBuilder& builder,
@ -66,4 +66,12 @@ mlir::TypedValue<mlir::RankedTensorType> CastToTensor(mlir::OpBuilder& builder,
cast_op.getResult(0));
}
mlir::TypedValue<mlir::MemRefType> CreateBufferOfShape(mlir::OpBuilder& builder,
mlir::Location loc,
mlir::ShapedType shape) {
mlir::MemRefType memrefType =
mlir::MemRefType::get(shape.getShape(), shape.getElementType());
return mlir::memref::AllocaOp::create(builder, loc, memrefType);
}
} // namespace xla::cpu

View File

@ -18,13 +18,14 @@ limitations under the License.
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Value.h"
namespace xla::cpu {
// Get the vector type that has the same shape and element type as the tensor
// type.
mlir::VectorType GetVectorType(mlir::RankedTensorType tensor_type);
mlir::VectorType GetVectorType(mlir::ShapedType tensor_type);
// Cast the input to a vector value.
// If the input is a scalar it will be simply constructed as a
@ -36,7 +37,7 @@ mlir::TypedValue<mlir::VectorType> CastToVector(mlir::OpBuilder& builder,
// Get the tensor type that has the same shape and element type as the vector
// type.
mlir::RankedTensorType GetTensorType(mlir::VectorType vector_type);
mlir::RankedTensorType GetTensorType(mlir::ShapedType vector_type);
// Cast the input to a tensor value.
// If the input is a scalar it will be simply constructed as a
@ -46,6 +47,10 @@ mlir::RankedTensorType GetTensorType(mlir::VectorType vector_type);
mlir::TypedValue<mlir::RankedTensorType> CastToTensor(mlir::OpBuilder& builder,
mlir::Value input);
mlir::TypedValue<mlir::MemRefType> CreateBufferOfShape(mlir::OpBuilder& builder,
mlir::Location loc,
mlir::ShapedType shape);
} // namespace xla::cpu
#endif // XLA_BACKENDS_CPU_CODEGEN_TILED_TRANSFORMS_LOWERING_UTILS_H_

View File

@ -49,6 +49,8 @@ def ShloToVectorPass : Pass<"xtile-cpu-shlo-to-vector", "mlir::ModuleOp"> {
"mlir::tensor::TensorDialect",
"mlir::vector::VectorDialect",
"mlir::stablehlo::StablehloDialect",
"mlir::scf::SCFDialect",
"mlir::memref::MemRefDialect",
];
}

View File

@ -18,20 +18,25 @@ limitations under the License.
#include <memory>
#include <utility>
#include "absl/algorithm/container.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // IWYU pragma: keep
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Visitors.h"
@ -41,6 +46,7 @@ limitations under the License.
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/backends/cpu/codegen/tiled/transforms/lowering_utils.h"
#include "xla/backends/cpu/codegen/tiled/transforms/passes.h"
#include "xla/backends/cpu/codegen/tiled/transforms/vectorized_reduce_emitter.h"
namespace xla::cpu {
@ -218,6 +224,48 @@ struct LowerTranspose : mlir::OpRewritePattern<mlir::stablehlo::TransposeOp> {
}
};
// Lower stablehlo.reduce to vector operations.
//
struct LowerReduce : mlir::OpRewritePattern<mlir::stablehlo::ReduceOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult matchAndRewrite(
mlir::stablehlo::ReduceOp op,
mlir::PatternRewriter& rewriter) const override {
if (op.getNumResults() != 1) {
return rewriter.notifyMatchFailure(
op, "reduce op with multiple results is not supported");
}
mlir::TypedValue<mlir::VectorType> source_vector =
CastToVector(rewriter, op.getInputs().front());
mlir::VectorType source_vector_type = source_vector.getType();
mlir::Value init_value = rewriter.create<mlir::tensor::ExtractOp>(
op->getLoc(), source_vector_type.getElementType(),
op.getInitValues().front());
mlir::Value result_tensor = op.getResult(0);
auto result_tensor_type =
mlir::cast<mlir::RankedTensorType>(result_tensor.getType());
auto result_vector_type = GetVectorType(result_tensor_type);
// Ensure the reduction dimensions are sorted so we can easily check if the
// minor dimension is reduced.
llvm::SmallVector<int64_t> reduction_dims(op.getDimensions());
absl::c_sort(reduction_dims);
mlir::Value reduced_vector = EmitVectorizedReduction(
rewriter, op->getLoc(), result_vector_type, source_vector, init_value,
reduction_dims, op.getBody().front());
rewriter.replaceOp(op, CastToTensor(rewriter, reduced_vector));
return mlir::success();
}
};
class ShloToVectorPass : public impl::ShloToVectorPassBase<ShloToVectorPass> {
public:
using ShloToVectorPassBase::ShloToVectorPassBase;
@ -225,7 +273,7 @@ class ShloToVectorPass : public impl::ShloToVectorPassBase<ShloToVectorPass> {
void runOnOperation() override {
mlir::MLIRContext* context = &getContext();
mlir::RewritePatternSet patterns(context);
patterns.add<LowerTranspose, LowerDotGeneral>(context);
patterns.add<LowerTranspose, LowerDotGeneral, LowerReduce>(context);
if (mlir::failed(
mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();

View File

@ -37,3 +37,107 @@ func.func @dot_scalar_output(%lhs : tensor<1024xf32>, %rhs : tensor<1024xf32>) -
// CHECK: return %[[RESULT_TENSOR]] : tensor<f32>
return %result : tensor<f32>
}
// -----
func.func @reduce_outer(%input : tensor<1024x32xf32>, %init : tensor<f32>) -> tensor<32xf32> {
%result = stablehlo.reduce(%input init: %init) across dimensions = [0] : (tensor<1024x32xf32>, tensor<f32>) -> tensor<32xf32>
reducer(%arg0: tensor<f32>, %arg1: tensor<f32>) {
%add = arith.addf %arg0, %arg1 : tensor<f32>
stablehlo.return %add : tensor<f32>
}
return %result : tensor<32xf32>
}
// CHECK: func.func @reduce_outer
// CHECK: memref.alloca() : memref<32xf32>
// CHECK: vector.extract %{{.*}} : vector<32xf32> from vector<1024x32xf32>
// CHECK: scf.for
// CHECK: vector.extract %{{.*}} : vector<32xf32> from vector<1024x32xf32>
// CHECK: arith.addf {{.*}} : vector<32xf32>
// CHECK: scf.yield %{{.*}} : vector<32xf32>
// CHECK: }
// CHECK: vector.transfer_write %{{.*}} : vector<32xf32>, memref<32xf32>
// CHECK: vector.transfer_read %{{.*}} : memref<32xf32>, vector<32xf32>
// CHECK: vector.broadcast %{{.*}} : f32 to vector<32xf32>
// CHECK: arith.addf {{.*}} : vector<32xf32>
// -----
func.func @reduce_inner(%input : tensor<1024x32xf32>, %init : tensor<f32>) -> tensor<1024xf32> {
%result = stablehlo.reduce(%input init: %init) across dimensions = [1] : (tensor<1024x32xf32>, tensor<f32>) -> tensor<1024xf32>
reducer(%arg0: tensor<f32>, %arg1: tensor<f32>) {
%add = arith.addf %arg0, %arg1 : tensor<f32>
stablehlo.return %add : tensor<f32>
}
return %result : tensor<1024xf32>
}
// CHECK: func.func @reduce_inner
// CHECK: memref.alloca() : memref<1024xf32>
// CHECK: scf.for
// CHECK: vector.extract {{.*}} : vector<32xf32> from vector<1024x32xf32>
// CHECK: vector.reduction <add>, {{.*}} : vector<32xf32> into f32
// CHECK: memref.store {{.*}} : memref<1024xf32>
// CHECK: }
// CHECK: vector.transfer_read {{.*}} : memref<1024xf32>, vector<1024xf32>
// -----
func.func @reduce_middle(%input : tensor<1024x32x8xf32>, %init : tensor<f32>) -> tensor<1024x8xf32> {
%result = stablehlo.reduce(%input init: %init) across dimensions = [1] : (tensor<1024x32x8xf32>, tensor<f32>) -> tensor<1024x8xf32>
reducer(%arg0: tensor<f32>, %arg1: tensor<f32>) {
%add = arith.addf %arg0, %arg1 : tensor<f32>
stablehlo.return %add : tensor<f32>
}
return %result : tensor<1024x8xf32>
}
// CHECK: func.func @reduce_middle
// CHECK: memref.alloca() : memref<1024x8xf32>
// CHECK: scf.for
// CHECK: vector.extract {{.*}} : vector<8xf32> from vector<1024x32x8xf32>
// CHECK: scf.for
// CHECK: vector.extract %{{.*}} : vector<8xf32> from vector<1024x32x8xf32>
// CHECK: arith.addf {{.*}} : vector<8xf32>
// CHECK: scf.yield {{.*}} : vector<8xf32>
// CHECK: }
// CHECK: vector.transfer_write {{.*}} : vector<8xf32>, memref<1024x8xf32>
// CHECK: }
// CHECK: vector.transfer_read {{.*}} : memref<1024x8xf32>, vector<1024x8xf32>
// CHECK: vector.broadcast {{.*}} : f32 to vector<1024x8xf32>
// CHECK: arith.addf {{.*}} : vector<1024x8xf32>
// CHECK: }
// -----
func.func @reduce_outer_and_inner(%input : tensor<1024x32x8xf32>, %init : tensor<f32>) -> tensor<32xf32> {
%result = stablehlo.reduce(%input init: %init) across dimensions = [0, 2] : (tensor<1024x32x8xf32>, tensor<f32>) -> tensor<32xf32>
reducer(%arg0: tensor<f32>, %arg1: tensor<f32>) {
%add = arith.addf %arg0, %arg1 : tensor<f32>
stablehlo.return %add : tensor<f32>
}
return %result : tensor<32xf32>
}
// CHECK: func.func @reduce_outer_and_inner
// CHECK: %[[BUFFER0:.*]] = memref.alloca() : memref<32x8xf32>
// CHECK: scf.for
// CHECK: vector.extract %{{.*}} : vector<8xf32> from vector<1024x32x8xf32>
// CHECK: scf.for
// CHECK: vector.extract %{{.*}} : vector<8xf32> from vector<1024x32x8xf32>
// CHECK: arith.addf %{{.*}} : vector<8xf32>
// CHECK: scf.yield {{.*}} : vector<8xf32>
// CHECK: }
// CHECK: vector.transfer_write {{.*}} : vector<8xf32>, memref<32x8xf32>
// CHECK: }
// CHECK: %[[BUFFER1:.*]] = memref.alloca() : memref<32xf32>
// CHECK: scf.for
// CHECK: vector.transfer_read %[[BUFFER0]]{{.*}} : memref<32x8xf32>, vector<8xf32>
// CHECK: vector.reduction <add>, {{.*}} : vector<8xf32> into f32
// CHECK: memref.store %{{.*}}, %[[BUFFER1]]{{.*}} : memref<32xf32>
// CHECK: }
// CHECK: vector.transfer_read %[[BUFFER1]]{{.*}} : memref<32xf32>, vector<32xf32>
// CHECK: }

View File

@ -0,0 +1,361 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/backends/cpu/codegen/tiled/transforms/vectorized_reduce_emitter.h"
#include <algorithm>
#include <array>
#include <cstdint>
#include <optional>
#include "absl/algorithm/container.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "xla/backends/cpu/codegen/tiled/transforms/lowering_utils.h"
namespace xla::cpu {
static absl::StatusOr<mlir::vector::CombiningKind> GetCombiningKind(
mlir::Block& reduction_body) {
mlir::Operation* op =
reduction_body.getTerminator()->getOperand(0).getDefiningOp();
if (!op) {
return absl::InternalError("No reduction combiner");
}
for (mlir::Value operand : op->getOperands()) {
if (operand.getDefiningOp()) {
return absl::InternalError("Non trivial reduction combiner");
}
}
if (auto kind = mlir::linalg::getCombinerOpKind(op)) {
return *kind;
}
return absl::InternalError("Unsupported reduction combiner");
}
static mlir::Value ExtractVector(mlir::OpBuilder& builder, mlir::Location loc,
mlir::Value source, mlir::ValueRange indices) {
return mlir::vector::ExtractOp::create(
builder, loc, source, llvm::map_to_vector(indices, [](mlir::Value idx) {
return mlir::OpFoldResult(idx);
}));
}
static void InsertVectorIntoBuffer(mlir::OpBuilder& builder, mlir::Location loc,
mlir::Value value,
mlir::TypedValue<mlir::MemRefType> buffer,
mlir::ValueRange indices) {
llvm::SmallVector<mlir::Value> padded_indices(indices);
while (padded_indices.size() < buffer.getType().getRank()) {
padded_indices.push_back(
builder.create<mlir::arith::ConstantIndexOp>(loc, 0));
}
if (mlir::isa<mlir::VectorType>(value.getType())) {
mlir::vector::TransferWriteOp::create(builder, loc, value, buffer,
padded_indices);
} else {
mlir::memref::StoreOp::create(builder, loc, value, buffer, padded_indices);
}
}
static mlir::TypedValue<mlir::VectorType> ExtractVectorFromBuffer(
mlir::OpBuilder& builder, mlir::Location loc,
mlir::TypedValue<mlir::MemRefType> buffer, mlir::ValueRange indices = {}) {
llvm::SmallVector<mlir::Value> padded_indices(indices);
while (padded_indices.size() < buffer.getType().getRank()) {
padded_indices.push_back(
builder.create<mlir::arith::ConstantIndexOp>(loc, 0));
}
mlir::VectorType vector_type = mlir::VectorType::get(
buffer.getType().getShape().drop_front(indices.size()),
buffer.getType().getElementType());
return mlir::vector::TransferReadOp::create(builder, loc, vector_type, buffer,
padded_indices,
/*padding=*/std::nullopt);
}
static std::array<llvm::SmallVector<mlir::Value>, 3> GetLoopBounds(
mlir::OpBuilder& builder, mlir::Location loc,
llvm::ArrayRef<int64_t> upper_bounds, int64_t lower_bound = 0) {
llvm::SmallVector<mlir::Value> lbs(
upper_bounds.size(),
builder.create<mlir::arith::ConstantIndexOp>(loc, lower_bound));
llvm::SmallVector<mlir::Value> ubs =
llvm::map_to_vector(upper_bounds, [&](int64_t size) -> mlir::Value {
return builder.create<mlir::arith::ConstantIndexOp>(loc, size);
});
llvm::SmallVector<mlir::Value> step(
upper_bounds.size(),
builder.create<mlir::arith::ConstantIndexOp>(loc, 1));
return {lbs, ubs, step};
}
mlir::Value VectorizeBody(mlir::OpBuilder& builder, mlir::Location loc,
mlir::Block& old_body, mlir::Value lhs_vector,
mlir::Value rhs_vector) {
mlir::IRMapping mapping;
mapping.map(old_body.getArgument(0), lhs_vector);
mapping.map(old_body.getArgument(1), rhs_vector);
for (mlir::Operation& op : old_body.without_terminator()) {
// TODO(willfroom): Check
// mlir::OpTrait::hasElementwiseMappableTraits
auto new_operands = llvm::map_to_vector(
op.getOperands(),
[&](mlir::Value operand) { return mapping.lookup(operand); });
mlir::Operation* new_op = op.create(
loc, op.getName(), {lhs_vector.getType()}, new_operands, op.getAttrs(),
op.getPropertiesStorage(), op.getSuccessors(), op.getNumRegions());
mapping.map(&op, new_op);
for (auto [old_res, new_res] :
llvm::zip(op.getResults(), new_op->getResults())) {
mapping.map(old_res, new_res);
}
builder.insert(new_op);
}
return mapping.lookup(old_body.getTerminator()->getOperand(0));
}
mlir::Value EmitNonMinorReduction(
mlir::OpBuilder& builder, mlir::Location loc, mlir::VectorType result_type,
mlir::TypedValue<mlir::VectorType> source_vector,
llvm::ArrayRef<int64_t> reduction_dims, mlir::Block& body,
bool minor_dim_reduced) {
mlir::VectorType source_vector_type = source_vector.getType();
int64_t rank = source_vector_type.getRank();
int64_t minor_dim = rank - 1;
int64_t minor_dim_size = source_vector_type.getDimSize(minor_dim);
llvm::SmallVector<int64_t> non_reduced_dims(rank);
absl::c_iota(non_reduced_dims, 0);
non_reduced_dims.erase(
std::remove_if(non_reduced_dims.begin(), non_reduced_dims.end(),
[&](int64_t dim) {
return absl::c_find(reduction_dims, dim) !=
reduction_dims.end();
}),
non_reduced_dims.end());
// The set of non-reduced dimensions that are not the minor dimension.
llvm::SmallVector<int64_t> non_reduced_non_minor_dims(non_reduced_dims);
if (auto itr = absl::c_find(non_reduced_non_minor_dims, minor_dim);
itr != non_reduced_non_minor_dims.end()) {
non_reduced_non_minor_dims.erase(itr);
}
// The set of reduced dimensions that are not the minor dimension.
llvm::SmallVector<int64_t> non_minor_reduced_dims(reduction_dims);
if (auto itr = absl::c_find(non_minor_reduced_dims, minor_dim);
itr != non_minor_reduced_dims.end()) {
non_minor_reduced_dims.erase(itr);
}
// The shape of the of the non-minor-reduced output.
llvm::SmallVector<int64_t> output_shape(result_type.getShape());
if (minor_dim_reduced) {
output_shape.push_back(minor_dim_size);
}
auto output_buffer_shape =
mlir::MemRefType::get(output_shape, result_type.getElementType());
auto buffer = CreateBufferOfShape(builder, loc, output_buffer_shape);
auto get_source_vector_dim_size = [&](llvm::ArrayRef<int64_t> dims) {
return llvm::map_to_vector(
dims, [&](int64_t dim) { return source_vector_type.getDimSize(dim); });
};
// Outer loop is non-minor non-reduced dimensions.
auto [lbs, ubs, step] = GetLoopBounds(
builder, loc, get_source_vector_dim_size(non_reduced_non_minor_dims));
mlir::scf::buildLoopNest(
builder, loc, lbs, ubs, step,
[&](mlir::OpBuilder& builder, mlir::Location loc,
mlir::ValueRange outer_induction_vars) {
auto [lbs, ubs, step] = GetLoopBounds(
builder, loc, get_source_vector_dim_size(non_minor_reduced_dims),
1);
llvm::SmallVector<mlir::Value> zeroth_step_indices(
rank - 1, mlir::arith::ConstantIndexOp::create(builder, loc, 0));
for (auto [idx, var] :
llvm::zip(non_reduced_non_minor_dims, outer_induction_vars)) {
zeroth_step_indices[idx] = var;
}
// Get the first iteration
mlir::Value minor_accumilator =
ExtractVector(builder, loc, source_vector, zeroth_step_indices);
// Inner loop is the non-minor reduced dimension.
mlir::scf::LoopNest loop_nest = mlir::scf::buildLoopNest(
builder, loc, lbs, ubs, step, minor_accumilator,
[&](mlir::OpBuilder& builder, mlir::Location loc,
mlir::ValueRange inner_induction_vars,
mlir::ValueRange minor_accumilator)
-> mlir::SmallVector<mlir::Value> {
llvm::SmallVector<mlir::Value> indices(rank - 1);
for (auto [idx, var] : llvm::zip(non_reduced_non_minor_dims,
outer_induction_vars)) {
indices[idx] = var;
}
for (auto [idx, var] :
llvm::zip(non_minor_reduced_dims, inner_induction_vars)) {
indices[idx] = var;
}
mlir::Value vector_slice =
ExtractVector(builder, loc, source_vector, indices);
return {VectorizeBody(builder, loc, body, vector_slice,
minor_accumilator.front())};
});
InsertVectorIntoBuffer(builder, loc, loop_nest.results.front(), buffer,
outer_induction_vars);
return;
});
// If the minor dimension is also reduced then it extracts directly from the
// buffer to avoid the additional vector -> subvector operation.
if (minor_dim_reduced) {
return buffer;
}
return ExtractVectorFromBuffer(builder, loc, buffer);
}
mlir::TypedValue<mlir::VectorType> EmitMinorReduction(
mlir::OpBuilder& builder, mlir::Location loc, mlir::VectorType result_type,
mlir::Value input, mlir::Value init_value, mlir::Block& body) {
absl::StatusOr<mlir::vector::CombiningKind> kind_or = GetCombiningKind(body);
if (!kind_or.ok()) {
body.getParentOp()->emitRemark() << kind_or.status().ToString();
}
// TODO(willfroom): we could reuse the non minor result buffer.
auto minor_result_buffer = CreateBufferOfShape(builder, loc, result_type);
auto maybe_input_buffer =
mlir::dyn_cast<mlir::TypedValue<mlir::MemRefType>>(input);
auto maybe_input_type =
llvm::TypeSwitch<mlir::Type, std::optional<mlir::ShapedType>>(
input.getType())
.Case<mlir::MemRefType>([&](auto op) { return input.getType(); })
.Case<mlir::VectorType>([&](auto op) { return input.getType(); })
.Default([&](auto op) { return std::nullopt; });
if (!maybe_input_type.has_value()) {
return nullptr;
}
int64_t minor_dim_size = maybe_input_type->getShape().back();
auto [lbs, ubs, step] = GetLoopBounds(builder, loc, result_type.getShape());
mlir::scf::buildLoopNest(
builder, loc, lbs, ubs, step,
[&](mlir::OpBuilder& builder, mlir::Location loc,
mlir::ValueRange induction_vars) {
mlir::Value vector_slice =
maybe_input_buffer
? ExtractVectorFromBuffer(builder, loc, maybe_input_buffer,
induction_vars)
: ExtractVector(builder, loc, input, induction_vars);
if (kind_or.ok()) {
// TODO(willfroom): Investigate tree-reduction to split the reduction
// op into natural sizes (2, 4, 8, 16, ...) and then remove the
// reassociation flag.
mlir::Value reduced_scalar =
builder.create<mlir::vector::ReductionOp>(
loc, *kind_or, vector_slice, init_value,
mlir::arith::FastMathFlags::reassoc);
InsertVectorIntoBuffer(builder, loc, reduced_scalar,
minor_result_buffer, induction_vars);
return;
}
auto [lbs, ubs, step] = GetLoopBounds(builder, loc, {minor_dim_size});
mlir::scf::LoopNest minor_reduction_loop = mlir::scf::buildLoopNest(
builder, loc, lbs, ubs, step, {init_value},
[&](mlir::OpBuilder& builder, mlir::Location loc,
mlir::ValueRange index, mlir::ValueRange carry_value)
-> mlir::SmallVector<mlir::Value> {
mlir::Value element =
ExtractVector(builder, loc, vector_slice, index);
return {VectorizeBody(builder, loc, body, element,
carry_value.front())};
});
InsertVectorIntoBuffer(builder, loc,
minor_reduction_loop.results.front(),
minor_result_buffer, induction_vars);
return;
});
return ExtractVectorFromBuffer(builder, loc, minor_result_buffer);
}
mlir::Value EmitVectorizedReduction(
mlir::OpBuilder& builder, mlir::Location loc, mlir::VectorType result_type,
mlir::TypedValue<mlir::VectorType> source, mlir::Value init_value,
llvm::ArrayRef<int64_t> reduction_dims, mlir::Block& body) {
int64_t rank = source.getType().getRank();
int64_t minor_dim = rank - 1;
bool minor_dim_reduced = reduction_dims.back() == minor_dim;
bool non_minor_dim_reduced = reduction_dims.size() > 1 || !minor_dim_reduced;
mlir::Value non_minor_result;
if (non_minor_dim_reduced) {
non_minor_result =
EmitNonMinorReduction(builder, loc, result_type, source, reduction_dims,
body, minor_dim_reduced);
}
if (!minor_dim_reduced) {
// We add the init value during the minor reduction loop, if that wasn't
// done then we must apply it here.
mlir::Value init_value_vector =
builder.create<mlir::vector::BroadcastOp>(loc, result_type, init_value);
return VectorizeBody(builder, loc, body, non_minor_result,
init_value_vector);
}
return EmitMinorReduction(builder, loc, result_type,
non_minor_result ? non_minor_result : source,
init_value, body);
}
} // namespace xla::cpu

View File

@ -0,0 +1,47 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_BACKENDS_CPU_CODEGEN_TILED_TRANSFORMS_VECTORIZED_REDUCE_EMITTER_H_
#define XLA_BACKENDS_CPU_CODEGEN_TILED_TRANSFORMS_VECTORIZED_REDUCE_EMITTER_H_
#include <cstdint>
#include "llvm/ADT/ArrayRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Value.h"
namespace xla::cpu {
// Create a vectorized reduction of the given source vector.
//
// The implementation is as follows:
// 1. If the reduction dimension is only the most minor we convert it into a
// nested scf.loop of horizonal reductions and if the body of the reduce is a
// single binary operation that is supported by ReductionOp we use that,
// otherwise we simply loop over the scalar values.
// 2. If the reduction dimensions does not include the most minor dimension, we
// loop over the reductions dimensions and apply the body with vectorized
// inputs.
// 3. If the dimensions are a combindation of minor & non-minor dimensions we
// simply apply strategy 2 followed by strategy 1.
mlir::Value EmitVectorizedReduction(
mlir::OpBuilder& builder, mlir::Location loc, mlir::VectorType result_type,
mlir::TypedValue<mlir::VectorType> source, mlir::Value init_value,
llvm::ArrayRef<int64_t> reduction_dims, mlir::Block& body);
} // namespace xla::cpu
#endif // XLA_BACKENDS_CPU_CODEGEN_TILED_TRANSFORMS_VECTORIZED_REDUCE_EMITTER_H_

View File

@ -71,6 +71,7 @@ def TiledBufferInterface : OpInterface<"TiledBufferInterface"> {
def EntryFuncOp : XTile_Op<"entry_func", [
Symbol,
IsolatedFromAbove,
AutomaticAllocationScope,
FunctionOpInterface]>
{
let summary = "My custom entry function operation";