mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:CPU][XTile] Implement vectorized reduce.
PiperOrigin-RevId: 825027697
This commit is contained in:
parent
11c00ca2db
commit
de7a63363c
|
|
@ -198,6 +198,7 @@ cc_library(
|
|||
"@llvm-project//mlir:VectorDialect",
|
||||
"@llvm-project//mlir:VectorToLLVM",
|
||||
"@llvm-project//mlir:VectorToSCF",
|
||||
"@llvm-project//mlir:VectorTransforms",
|
||||
"@local_tsl//tsl/profiler/lib:traceme",
|
||||
"@local_tsl//tsl/profiler/lib:traceme_encode",
|
||||
"@stablehlo//:stablehlo_passes",
|
||||
|
|
|
|||
|
|
@ -61,6 +61,8 @@ limitations under the License.
|
|||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
|
@ -244,15 +246,23 @@ static void AddTiledOptimizationPasses(mlir::OpPassManager& pm) {
|
|||
// The input IR is from the xtile dialect which uses tensors that are converted
|
||||
// first to the vector dialect and then to LLVM.
|
||||
static void AddTiledLoweringPasses(mlir::OpPassManager& pm) {
|
||||
pm.addPass(CreateXTileToVectorPass());
|
||||
pm.addPass(CreateElementalTensorToVectorPass());
|
||||
pm.addPass(CreateShloToVectorPass());
|
||||
pm.addPass(CreateXTileToVectorPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(CreateRewriteDynamicVectorExtractPass());
|
||||
pm.addPass(CreateElementalTensorToVectorPass());
|
||||
pm.addPass(CreateLowerXTileEntryPass());
|
||||
pm.addNestedPass<mlir::func::FuncOp>(
|
||||
mlir::vector::createLowerVectorMultiReductionPass(
|
||||
mlir::vector::VectorMultiReductionLowering::InnerParallel));
|
||||
pm.addPass(CreateTensorOpsToVectorPass());
|
||||
pm.addPass(cpu::createLowerToLLVMPass());
|
||||
pm.addPass(mlir::createConvertVectorToSCFPass(
|
||||
mlir::VectorTransferToSCFOptions().enableFullUnroll(false)));
|
||||
pm.addPass(mlir::createConvertVectorToLLVMPass());
|
||||
mlir::ConvertVectorToLLVMPassOptions options;
|
||||
options.vectorTransposeLowering =
|
||||
mlir::vector::VectorTransposeLowering::Shuffle1D;
|
||||
pm.addPass(mlir::createConvertVectorToLLVMPass(options));
|
||||
|
||||
pm.addPass(mlir::createConvertComplexToStandardPass());
|
||||
pm.addPass(mlir::memref::createExpandStridedMetadataPass());
|
||||
|
|
|
|||
|
|
@ -269,6 +269,144 @@ class XtileLoweringTest(absltest.TestCase):
|
|||
maxulp=5,
|
||||
)
|
||||
|
||||
def test_reduction_add_inner(self):
|
||||
ir = """
|
||||
module @reduction_add_inner {
|
||||
xtile.entry_func @reduction_add_inner(
|
||||
%input: memref<1024x32xf32>,
|
||||
%init: memref<f32>,
|
||||
%output: memref<1024xf32>,
|
||||
%tile_id: index) attributes {xtile.tiling_info = #xtile.tiling_info<tile_count:128, tiles_per_workgroup:32>} {
|
||||
%c_0 = arith.constant 0 : index
|
||||
%c_8 = arith.constant 8 : index
|
||||
%init_tile = xtile.extract %init[][][] : memref<f32> -> tensor<f32>
|
||||
%index = arith.muli %tile_id, %c_8 : index
|
||||
%input_tile = xtile.extract %input[%index, %c_0][8, 32][1, 1] : memref<1024x32xf32> -> tensor<8x32xf32>
|
||||
%result = stablehlo.reduce(%input_tile init: %init_tile)
|
||||
across dimensions = [1]
|
||||
: (tensor<8x32xf32>, tensor<f32>) -> tensor<8xf32>
|
||||
reducer(%arg0: tensor<f32>, %arg1: tensor<f32>) {
|
||||
%add = arith.addf %arg0, %arg1 : tensor<f32>
|
||||
stablehlo.return %add : tensor<f32>
|
||||
}
|
||||
xtile.insert %result into %output[%index][8][1] : tensor<8xf32> -> memref<1024xf32>
|
||||
xtile.return
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
compare_kernel(
|
||||
ir,
|
||||
"reduction_add_inner",
|
||||
4,
|
||||
[(1024, 32), (1,)],
|
||||
(1024,),
|
||||
np.int32,
|
||||
lambda input, init: np.sum(input, axis=1) + init,
|
||||
)
|
||||
|
||||
def test_reduction_add_outer(self):
|
||||
ir = """
|
||||
module @reduction_add_outer {
|
||||
xtile.entry_func @reduction_add_outer(
|
||||
%input: memref<1024x32xf32>,
|
||||
%init: memref<f32>,
|
||||
%output: memref<32xf32>,
|
||||
%tile_id: index) attributes {xtile.tiling_info = #xtile.tiling_info<tile_count:4, tiles_per_workgroup:1>} {
|
||||
%c_0 = arith.constant 0 : index
|
||||
%c_8 = arith.constant 8 : index
|
||||
%init_tile = xtile.extract %init[][][] : memref<f32> -> tensor<f32>
|
||||
%index = arith.muli %tile_id, %c_8 : index
|
||||
%input_tile = xtile.extract %input[%c_0, %index][1024, 8][1, 1] : memref<1024x32xf32> -> tensor<1024x8xf32>
|
||||
%result = stablehlo.reduce(%input_tile init: %init_tile)
|
||||
across dimensions = [0]
|
||||
: (tensor<1024x8xf32>, tensor<f32>) -> tensor<8xf32>
|
||||
reducer(%arg0: tensor<f32>, %arg1: tensor<f32>) {
|
||||
%add = arith.addf %arg0, %arg1 : tensor<f32>
|
||||
stablehlo.return %add : tensor<f32>
|
||||
}
|
||||
xtile.insert %result into %output[%index][8][1] : tensor<8xf32> -> memref<32xf32>
|
||||
xtile.return
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
compare_kernel(
|
||||
ir,
|
||||
"reduction_add_outer",
|
||||
4,
|
||||
[(1024, 32), (1,)],
|
||||
(32,),
|
||||
np.float32,
|
||||
lambda input, init: np.sum(input, axis=0) + init,
|
||||
)
|
||||
|
||||
def test_reduction_middle(self):
|
||||
ir = """
|
||||
module @reduction_add_middle {
|
||||
xtile.entry_func @reduction_add_middle(
|
||||
%input: memref<8x4x2xf32>,
|
||||
%init: memref<f32>,
|
||||
%output: memref<8x2xf32>,
|
||||
%tile_id: index) attributes {xtile.tiling_info = #xtile.tiling_info<tile_count:1, tiles_per_workgroup:1>} {
|
||||
%init_val = xtile.extract %init[][][] : memref<f32> -> tensor<f32>
|
||||
%input_tile = xtile.extract %input[%tile_id, %tile_id, %tile_id][8, 4, 2][1, 1, 1] : memref<8x4x2xf32> -> tensor<8x4x2xf32>
|
||||
%result = stablehlo.reduce(%input_tile init: %init_val)
|
||||
across dimensions = [1]
|
||||
: (tensor<8x4x2xf32>, tensor<f32>) -> tensor<8x2xf32>
|
||||
reducer(%arg0: tensor<f32>, %arg1: tensor<f32>) {
|
||||
%add = arith.addf %arg0, %arg1 : tensor<f32>
|
||||
stablehlo.return %add : tensor<f32>
|
||||
}
|
||||
xtile.insert %result into %output[%tile_id, %tile_id][8, 2][1, 1] : tensor<8x2xf32> -> memref<8x2xf32>
|
||||
xtile.return
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
compare_kernel(
|
||||
ir,
|
||||
"reduction_add_middle",
|
||||
1,
|
||||
[(8, 4, 2), (1,)],
|
||||
(8, 2),
|
||||
np.float32,
|
||||
lambda input, init: np.sum(input, axis=1) + init,
|
||||
)
|
||||
|
||||
def test_reduction_outer_inner(self):
|
||||
ir = """
|
||||
module @reduction_add_outer_inner {
|
||||
xtile.entry_func @reduction_add_outer_inner(
|
||||
%input: memref<8x4x2xf32>,
|
||||
%init: memref<f32>,
|
||||
%output: memref<4xf32>,
|
||||
%tile_id: index) attributes {xtile.tiling_info = #xtile.tiling_info<tile_count:1, tiles_per_workgroup:1>} {
|
||||
%init_val = xtile.extract %init[][][] : memref<f32> -> tensor<f32>
|
||||
%input_tile = xtile.extract %input[%tile_id, %tile_id, %tile_id][8, 4, 2][1, 1, 1] : memref<8x4x2xf32> -> tensor<8x4x2xf32>
|
||||
%result = stablehlo.reduce(%input_tile init: %init_val)
|
||||
across dimensions = [0, 2]
|
||||
: (tensor<8x4x2xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
reducer(%arg0: tensor<f32>, %arg1: tensor<f32>) {
|
||||
%add = arith.addf %arg0, %arg1 : tensor<f32>
|
||||
stablehlo.return %add : tensor<f32>
|
||||
}
|
||||
xtile.insert %result into %output[%tile_id][4][1] : tensor<4xf32> -> memref<4xf32>
|
||||
xtile.return
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
compare_kernel(
|
||||
ir,
|
||||
"reduction_add_outer_inner",
|
||||
1,
|
||||
[(8, 4, 2), (1,)],
|
||||
(4,),
|
||||
np.float32,
|
||||
lambda input, init: np.sum(input, axis=(0, 2)) + init,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
|
|
|||
|
|
@ -36,7 +36,9 @@ cc_library(
|
|||
hdrs = ["lowering_utils.h"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
"@llvm-project//mlir:ArithDialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MemRefDialect",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TensorDialect",
|
||||
"@llvm-project//mlir:VectorDialect",
|
||||
|
|
@ -57,10 +59,14 @@ cc_library(
|
|||
deps = [
|
||||
":lowering_utils",
|
||||
":passes_inc_gen",
|
||||
":vectorized_reduce_emitter",
|
||||
"//xla/backends/cpu/codegen/emitters/ir:xla_cpu",
|
||||
"//xla/codegen/emitters/ir:xla",
|
||||
"//xla/codegen/xtile/ir:xtile",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@llvm-project//llvm:Support",
|
||||
|
|
@ -72,6 +78,7 @@ cc_library(
|
|||
"@llvm-project//mlir:FuncTransforms",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LLVMDialect",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:MathDialect",
|
||||
"@llvm-project//mlir:MathOpsIncGen",
|
||||
"@llvm-project//mlir:MemRefDialect",
|
||||
|
|
@ -86,3 +93,24 @@ cc_library(
|
|||
"@stablehlo//:stablehlo_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "vectorized_reduce_emitter",
|
||||
srcs = ["vectorized_reduce_emitter.cc"],
|
||||
hdrs = ["vectorized_reduce_emitter.h"],
|
||||
deps = [
|
||||
":lowering_utils",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:ArithDialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:MemRefDialect",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:VectorDialect",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,19 +15,20 @@ limitations under the License.
|
|||
|
||||
#include "xla/backends/cpu/codegen/tiled/transforms/lowering_utils.h"
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
mlir::VectorType GetVectorType(mlir::RankedTensorType tensor_type) {
|
||||
return mlir::VectorType::get(tensor_type.getShape(),
|
||||
tensor_type.getElementType());
|
||||
mlir::VectorType GetVectorType(mlir::ShapedType type) {
|
||||
return mlir::VectorType::get(type.getShape(), type.getElementType());
|
||||
}
|
||||
|
||||
mlir::TypedValue<mlir::VectorType> CastToVector(mlir::OpBuilder& builder,
|
||||
|
|
@ -45,9 +46,8 @@ mlir::TypedValue<mlir::VectorType> CastToVector(mlir::OpBuilder& builder,
|
|||
return mlir::cast<mlir::TypedValue<mlir::VectorType>>(cast_op.getResult(0));
|
||||
}
|
||||
|
||||
mlir::RankedTensorType GetTensorType(mlir::VectorType vector_type) {
|
||||
return mlir::RankedTensorType::get(vector_type.getShape(),
|
||||
vector_type.getElementType());
|
||||
mlir::RankedTensorType GetTensorType(mlir::ShapedType type) {
|
||||
return mlir::RankedTensorType::get(type.getShape(), type.getElementType());
|
||||
}
|
||||
|
||||
mlir::TypedValue<mlir::RankedTensorType> CastToTensor(mlir::OpBuilder& builder,
|
||||
|
|
@ -66,4 +66,12 @@ mlir::TypedValue<mlir::RankedTensorType> CastToTensor(mlir::OpBuilder& builder,
|
|||
cast_op.getResult(0));
|
||||
}
|
||||
|
||||
mlir::TypedValue<mlir::MemRefType> CreateBufferOfShape(mlir::OpBuilder& builder,
|
||||
mlir::Location loc,
|
||||
mlir::ShapedType shape) {
|
||||
mlir::MemRefType memrefType =
|
||||
mlir::MemRefType::get(shape.getShape(), shape.getElementType());
|
||||
return mlir::memref::AllocaOp::create(builder, loc, memrefType);
|
||||
}
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
|
|
|||
|
|
@ -18,13 +18,14 @@ limitations under the License.
|
|||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
// Get the vector type that has the same shape and element type as the tensor
|
||||
// type.
|
||||
mlir::VectorType GetVectorType(mlir::RankedTensorType tensor_type);
|
||||
mlir::VectorType GetVectorType(mlir::ShapedType tensor_type);
|
||||
|
||||
// Cast the input to a vector value.
|
||||
// If the input is a scalar it will be simply constructed as a
|
||||
|
|
@ -36,7 +37,7 @@ mlir::TypedValue<mlir::VectorType> CastToVector(mlir::OpBuilder& builder,
|
|||
|
||||
// Get the tensor type that has the same shape and element type as the vector
|
||||
// type.
|
||||
mlir::RankedTensorType GetTensorType(mlir::VectorType vector_type);
|
||||
mlir::RankedTensorType GetTensorType(mlir::ShapedType vector_type);
|
||||
|
||||
// Cast the input to a tensor value.
|
||||
// If the input is a scalar it will be simply constructed as a
|
||||
|
|
@ -46,6 +47,10 @@ mlir::RankedTensorType GetTensorType(mlir::VectorType vector_type);
|
|||
mlir::TypedValue<mlir::RankedTensorType> CastToTensor(mlir::OpBuilder& builder,
|
||||
mlir::Value input);
|
||||
|
||||
mlir::TypedValue<mlir::MemRefType> CreateBufferOfShape(mlir::OpBuilder& builder,
|
||||
mlir::Location loc,
|
||||
mlir::ShapedType shape);
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
||||
#endif // XLA_BACKENDS_CPU_CODEGEN_TILED_TRANSFORMS_LOWERING_UTILS_H_
|
||||
|
|
|
|||
|
|
@ -49,6 +49,8 @@ def ShloToVectorPass : Pass<"xtile-cpu-shlo-to-vector", "mlir::ModuleOp"> {
|
|||
"mlir::tensor::TensorDialect",
|
||||
"mlir::vector::VectorDialect",
|
||||
"mlir::stablehlo::StablehloDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::memref::MemRefDialect",
|
||||
];
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -18,20 +18,25 @@ limitations under the License.
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // IWYU pragma: keep
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/Visitors.h"
|
||||
|
|
@ -41,6 +46,7 @@ limitations under the License.
|
|||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "xla/backends/cpu/codegen/tiled/transforms/lowering_utils.h"
|
||||
#include "xla/backends/cpu/codegen/tiled/transforms/passes.h"
|
||||
#include "xla/backends/cpu/codegen/tiled/transforms/vectorized_reduce_emitter.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
|
|
@ -218,6 +224,48 @@ struct LowerTranspose : mlir::OpRewritePattern<mlir::stablehlo::TransposeOp> {
|
|||
}
|
||||
};
|
||||
|
||||
// Lower stablehlo.reduce to vector operations.
|
||||
//
|
||||
|
||||
struct LowerReduce : mlir::OpRewritePattern<mlir::stablehlo::ReduceOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(
|
||||
mlir::stablehlo::ReduceOp op,
|
||||
mlir::PatternRewriter& rewriter) const override {
|
||||
if (op.getNumResults() != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "reduce op with multiple results is not supported");
|
||||
}
|
||||
|
||||
mlir::TypedValue<mlir::VectorType> source_vector =
|
||||
CastToVector(rewriter, op.getInputs().front());
|
||||
mlir::VectorType source_vector_type = source_vector.getType();
|
||||
|
||||
mlir::Value init_value = rewriter.create<mlir::tensor::ExtractOp>(
|
||||
op->getLoc(), source_vector_type.getElementType(),
|
||||
op.getInitValues().front());
|
||||
|
||||
mlir::Value result_tensor = op.getResult(0);
|
||||
auto result_tensor_type =
|
||||
mlir::cast<mlir::RankedTensorType>(result_tensor.getType());
|
||||
auto result_vector_type = GetVectorType(result_tensor_type);
|
||||
|
||||
// Ensure the reduction dimensions are sorted so we can easily check if the
|
||||
// minor dimension is reduced.
|
||||
llvm::SmallVector<int64_t> reduction_dims(op.getDimensions());
|
||||
absl::c_sort(reduction_dims);
|
||||
|
||||
mlir::Value reduced_vector = EmitVectorizedReduction(
|
||||
rewriter, op->getLoc(), result_vector_type, source_vector, init_value,
|
||||
reduction_dims, op.getBody().front());
|
||||
|
||||
rewriter.replaceOp(op, CastToTensor(rewriter, reduced_vector));
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
class ShloToVectorPass : public impl::ShloToVectorPassBase<ShloToVectorPass> {
|
||||
public:
|
||||
using ShloToVectorPassBase::ShloToVectorPassBase;
|
||||
|
|
@ -225,7 +273,7 @@ class ShloToVectorPass : public impl::ShloToVectorPassBase<ShloToVectorPass> {
|
|||
void runOnOperation() override {
|
||||
mlir::MLIRContext* context = &getContext();
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
patterns.add<LowerTranspose, LowerDotGeneral>(context);
|
||||
patterns.add<LowerTranspose, LowerDotGeneral, LowerReduce>(context);
|
||||
if (mlir::failed(
|
||||
mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
|
|
|
|||
|
|
@ -37,3 +37,107 @@ func.func @dot_scalar_output(%lhs : tensor<1024xf32>, %rhs : tensor<1024xf32>) -
|
|||
// CHECK: return %[[RESULT_TENSOR]] : tensor<f32>
|
||||
return %result : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @reduce_outer(%input : tensor<1024x32xf32>, %init : tensor<f32>) -> tensor<32xf32> {
|
||||
%result = stablehlo.reduce(%input init: %init) across dimensions = [0] : (tensor<1024x32xf32>, tensor<f32>) -> tensor<32xf32>
|
||||
reducer(%arg0: tensor<f32>, %arg1: tensor<f32>) {
|
||||
%add = arith.addf %arg0, %arg1 : tensor<f32>
|
||||
stablehlo.return %add : tensor<f32>
|
||||
}
|
||||
return %result : tensor<32xf32>
|
||||
}
|
||||
|
||||
// CHECK: func.func @reduce_outer
|
||||
// CHECK: memref.alloca() : memref<32xf32>
|
||||
// CHECK: vector.extract %{{.*}} : vector<32xf32> from vector<1024x32xf32>
|
||||
// CHECK: scf.for
|
||||
// CHECK: vector.extract %{{.*}} : vector<32xf32> from vector<1024x32xf32>
|
||||
// CHECK: arith.addf {{.*}} : vector<32xf32>
|
||||
// CHECK: scf.yield %{{.*}} : vector<32xf32>
|
||||
// CHECK: }
|
||||
// CHECK: vector.transfer_write %{{.*}} : vector<32xf32>, memref<32xf32>
|
||||
// CHECK: vector.transfer_read %{{.*}} : memref<32xf32>, vector<32xf32>
|
||||
// CHECK: vector.broadcast %{{.*}} : f32 to vector<32xf32>
|
||||
// CHECK: arith.addf {{.*}} : vector<32xf32>
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @reduce_inner(%input : tensor<1024x32xf32>, %init : tensor<f32>) -> tensor<1024xf32> {
|
||||
%result = stablehlo.reduce(%input init: %init) across dimensions = [1] : (tensor<1024x32xf32>, tensor<f32>) -> tensor<1024xf32>
|
||||
reducer(%arg0: tensor<f32>, %arg1: tensor<f32>) {
|
||||
%add = arith.addf %arg0, %arg1 : tensor<f32>
|
||||
stablehlo.return %add : tensor<f32>
|
||||
}
|
||||
return %result : tensor<1024xf32>
|
||||
}
|
||||
|
||||
// CHECK: func.func @reduce_inner
|
||||
// CHECK: memref.alloca() : memref<1024xf32>
|
||||
// CHECK: scf.for
|
||||
// CHECK: vector.extract {{.*}} : vector<32xf32> from vector<1024x32xf32>
|
||||
// CHECK: vector.reduction <add>, {{.*}} : vector<32xf32> into f32
|
||||
// CHECK: memref.store {{.*}} : memref<1024xf32>
|
||||
// CHECK: }
|
||||
// CHECK: vector.transfer_read {{.*}} : memref<1024xf32>, vector<1024xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @reduce_middle(%input : tensor<1024x32x8xf32>, %init : tensor<f32>) -> tensor<1024x8xf32> {
|
||||
%result = stablehlo.reduce(%input init: %init) across dimensions = [1] : (tensor<1024x32x8xf32>, tensor<f32>) -> tensor<1024x8xf32>
|
||||
reducer(%arg0: tensor<f32>, %arg1: tensor<f32>) {
|
||||
%add = arith.addf %arg0, %arg1 : tensor<f32>
|
||||
stablehlo.return %add : tensor<f32>
|
||||
}
|
||||
return %result : tensor<1024x8xf32>
|
||||
}
|
||||
|
||||
// CHECK: func.func @reduce_middle
|
||||
// CHECK: memref.alloca() : memref<1024x8xf32>
|
||||
// CHECK: scf.for
|
||||
// CHECK: vector.extract {{.*}} : vector<8xf32> from vector<1024x32x8xf32>
|
||||
// CHECK: scf.for
|
||||
// CHECK: vector.extract %{{.*}} : vector<8xf32> from vector<1024x32x8xf32>
|
||||
// CHECK: arith.addf {{.*}} : vector<8xf32>
|
||||
// CHECK: scf.yield {{.*}} : vector<8xf32>
|
||||
// CHECK: }
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<8xf32>, memref<1024x8xf32>
|
||||
// CHECK: }
|
||||
// CHECK: vector.transfer_read {{.*}} : memref<1024x8xf32>, vector<1024x8xf32>
|
||||
// CHECK: vector.broadcast {{.*}} : f32 to vector<1024x8xf32>
|
||||
// CHECK: arith.addf {{.*}} : vector<1024x8xf32>
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
func.func @reduce_outer_and_inner(%input : tensor<1024x32x8xf32>, %init : tensor<f32>) -> tensor<32xf32> {
|
||||
%result = stablehlo.reduce(%input init: %init) across dimensions = [0, 2] : (tensor<1024x32x8xf32>, tensor<f32>) -> tensor<32xf32>
|
||||
reducer(%arg0: tensor<f32>, %arg1: tensor<f32>) {
|
||||
%add = arith.addf %arg0, %arg1 : tensor<f32>
|
||||
stablehlo.return %add : tensor<f32>
|
||||
}
|
||||
return %result : tensor<32xf32>
|
||||
}
|
||||
|
||||
// CHECK: func.func @reduce_outer_and_inner
|
||||
// CHECK: %[[BUFFER0:.*]] = memref.alloca() : memref<32x8xf32>
|
||||
// CHECK: scf.for
|
||||
// CHECK: vector.extract %{{.*}} : vector<8xf32> from vector<1024x32x8xf32>
|
||||
// CHECK: scf.for
|
||||
// CHECK: vector.extract %{{.*}} : vector<8xf32> from vector<1024x32x8xf32>
|
||||
// CHECK: arith.addf %{{.*}} : vector<8xf32>
|
||||
// CHECK: scf.yield {{.*}} : vector<8xf32>
|
||||
// CHECK: }
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<8xf32>, memref<32x8xf32>
|
||||
// CHECK: }
|
||||
// CHECK: %[[BUFFER1:.*]] = memref.alloca() : memref<32xf32>
|
||||
// CHECK: scf.for
|
||||
// CHECK: vector.transfer_read %[[BUFFER0]]{{.*}} : memref<32x8xf32>, vector<8xf32>
|
||||
// CHECK: vector.reduction <add>, {{.*}} : vector<8xf32> into f32
|
||||
// CHECK: memref.store %{{.*}}, %[[BUFFER1]]{{.*}} : memref<32xf32>
|
||||
// CHECK: }
|
||||
// CHECK: vector.transfer_read %[[BUFFER1]]{{.*}} : memref<32xf32>, vector<32xf32>
|
||||
// CHECK: }
|
||||
|
|
|
|||
361
third_party/xla/xla/backends/cpu/codegen/tiled/transforms/vectorized_reduce_emitter.cc
vendored
Normal file
361
third_party/xla/xla/backends/cpu/codegen/tiled/transforms/vectorized_reduce_emitter.cc
vendored
Normal file
|
|
@ -0,0 +1,361 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/backends/cpu/codegen/tiled/transforms/vectorized_reduce_emitter.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/SmallVectorExtras.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "xla/backends/cpu/codegen/tiled/transforms/lowering_utils.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
static absl::StatusOr<mlir::vector::CombiningKind> GetCombiningKind(
|
||||
mlir::Block& reduction_body) {
|
||||
mlir::Operation* op =
|
||||
reduction_body.getTerminator()->getOperand(0).getDefiningOp();
|
||||
if (!op) {
|
||||
return absl::InternalError("No reduction combiner");
|
||||
}
|
||||
|
||||
for (mlir::Value operand : op->getOperands()) {
|
||||
if (operand.getDefiningOp()) {
|
||||
return absl::InternalError("Non trivial reduction combiner");
|
||||
}
|
||||
}
|
||||
|
||||
if (auto kind = mlir::linalg::getCombinerOpKind(op)) {
|
||||
return *kind;
|
||||
}
|
||||
|
||||
return absl::InternalError("Unsupported reduction combiner");
|
||||
}
|
||||
|
||||
static mlir::Value ExtractVector(mlir::OpBuilder& builder, mlir::Location loc,
|
||||
mlir::Value source, mlir::ValueRange indices) {
|
||||
return mlir::vector::ExtractOp::create(
|
||||
builder, loc, source, llvm::map_to_vector(indices, [](mlir::Value idx) {
|
||||
return mlir::OpFoldResult(idx);
|
||||
}));
|
||||
}
|
||||
|
||||
static void InsertVectorIntoBuffer(mlir::OpBuilder& builder, mlir::Location loc,
|
||||
mlir::Value value,
|
||||
mlir::TypedValue<mlir::MemRefType> buffer,
|
||||
mlir::ValueRange indices) {
|
||||
llvm::SmallVector<mlir::Value> padded_indices(indices);
|
||||
while (padded_indices.size() < buffer.getType().getRank()) {
|
||||
padded_indices.push_back(
|
||||
builder.create<mlir::arith::ConstantIndexOp>(loc, 0));
|
||||
}
|
||||
|
||||
if (mlir::isa<mlir::VectorType>(value.getType())) {
|
||||
mlir::vector::TransferWriteOp::create(builder, loc, value, buffer,
|
||||
padded_indices);
|
||||
} else {
|
||||
mlir::memref::StoreOp::create(builder, loc, value, buffer, padded_indices);
|
||||
}
|
||||
}
|
||||
|
||||
static mlir::TypedValue<mlir::VectorType> ExtractVectorFromBuffer(
|
||||
mlir::OpBuilder& builder, mlir::Location loc,
|
||||
mlir::TypedValue<mlir::MemRefType> buffer, mlir::ValueRange indices = {}) {
|
||||
llvm::SmallVector<mlir::Value> padded_indices(indices);
|
||||
while (padded_indices.size() < buffer.getType().getRank()) {
|
||||
padded_indices.push_back(
|
||||
builder.create<mlir::arith::ConstantIndexOp>(loc, 0));
|
||||
}
|
||||
mlir::VectorType vector_type = mlir::VectorType::get(
|
||||
buffer.getType().getShape().drop_front(indices.size()),
|
||||
buffer.getType().getElementType());
|
||||
return mlir::vector::TransferReadOp::create(builder, loc, vector_type, buffer,
|
||||
padded_indices,
|
||||
/*padding=*/std::nullopt);
|
||||
}
|
||||
|
||||
static std::array<llvm::SmallVector<mlir::Value>, 3> GetLoopBounds(
|
||||
mlir::OpBuilder& builder, mlir::Location loc,
|
||||
llvm::ArrayRef<int64_t> upper_bounds, int64_t lower_bound = 0) {
|
||||
llvm::SmallVector<mlir::Value> lbs(
|
||||
upper_bounds.size(),
|
||||
builder.create<mlir::arith::ConstantIndexOp>(loc, lower_bound));
|
||||
llvm::SmallVector<mlir::Value> ubs =
|
||||
llvm::map_to_vector(upper_bounds, [&](int64_t size) -> mlir::Value {
|
||||
return builder.create<mlir::arith::ConstantIndexOp>(loc, size);
|
||||
});
|
||||
llvm::SmallVector<mlir::Value> step(
|
||||
upper_bounds.size(),
|
||||
builder.create<mlir::arith::ConstantIndexOp>(loc, 1));
|
||||
return {lbs, ubs, step};
|
||||
}
|
||||
|
||||
mlir::Value VectorizeBody(mlir::OpBuilder& builder, mlir::Location loc,
|
||||
mlir::Block& old_body, mlir::Value lhs_vector,
|
||||
mlir::Value rhs_vector) {
|
||||
mlir::IRMapping mapping;
|
||||
|
||||
mapping.map(old_body.getArgument(0), lhs_vector);
|
||||
mapping.map(old_body.getArgument(1), rhs_vector);
|
||||
|
||||
for (mlir::Operation& op : old_body.without_terminator()) {
|
||||
// TODO(willfroom): Check
|
||||
// mlir::OpTrait::hasElementwiseMappableTraits
|
||||
auto new_operands = llvm::map_to_vector(
|
||||
op.getOperands(),
|
||||
[&](mlir::Value operand) { return mapping.lookup(operand); });
|
||||
mlir::Operation* new_op = op.create(
|
||||
loc, op.getName(), {lhs_vector.getType()}, new_operands, op.getAttrs(),
|
||||
op.getPropertiesStorage(), op.getSuccessors(), op.getNumRegions());
|
||||
mapping.map(&op, new_op);
|
||||
for (auto [old_res, new_res] :
|
||||
llvm::zip(op.getResults(), new_op->getResults())) {
|
||||
mapping.map(old_res, new_res);
|
||||
}
|
||||
builder.insert(new_op);
|
||||
}
|
||||
return mapping.lookup(old_body.getTerminator()->getOperand(0));
|
||||
}
|
||||
|
||||
mlir::Value EmitNonMinorReduction(
|
||||
mlir::OpBuilder& builder, mlir::Location loc, mlir::VectorType result_type,
|
||||
mlir::TypedValue<mlir::VectorType> source_vector,
|
||||
llvm::ArrayRef<int64_t> reduction_dims, mlir::Block& body,
|
||||
bool minor_dim_reduced) {
|
||||
mlir::VectorType source_vector_type = source_vector.getType();
|
||||
int64_t rank = source_vector_type.getRank();
|
||||
int64_t minor_dim = rank - 1;
|
||||
int64_t minor_dim_size = source_vector_type.getDimSize(minor_dim);
|
||||
llvm::SmallVector<int64_t> non_reduced_dims(rank);
|
||||
absl::c_iota(non_reduced_dims, 0);
|
||||
non_reduced_dims.erase(
|
||||
std::remove_if(non_reduced_dims.begin(), non_reduced_dims.end(),
|
||||
[&](int64_t dim) {
|
||||
return absl::c_find(reduction_dims, dim) !=
|
||||
reduction_dims.end();
|
||||
}),
|
||||
non_reduced_dims.end());
|
||||
|
||||
// The set of non-reduced dimensions that are not the minor dimension.
|
||||
llvm::SmallVector<int64_t> non_reduced_non_minor_dims(non_reduced_dims);
|
||||
if (auto itr = absl::c_find(non_reduced_non_minor_dims, minor_dim);
|
||||
itr != non_reduced_non_minor_dims.end()) {
|
||||
non_reduced_non_minor_dims.erase(itr);
|
||||
}
|
||||
|
||||
// The set of reduced dimensions that are not the minor dimension.
|
||||
llvm::SmallVector<int64_t> non_minor_reduced_dims(reduction_dims);
|
||||
if (auto itr = absl::c_find(non_minor_reduced_dims, minor_dim);
|
||||
itr != non_minor_reduced_dims.end()) {
|
||||
non_minor_reduced_dims.erase(itr);
|
||||
}
|
||||
|
||||
// The shape of the of the non-minor-reduced output.
|
||||
llvm::SmallVector<int64_t> output_shape(result_type.getShape());
|
||||
if (minor_dim_reduced) {
|
||||
output_shape.push_back(minor_dim_size);
|
||||
}
|
||||
auto output_buffer_shape =
|
||||
mlir::MemRefType::get(output_shape, result_type.getElementType());
|
||||
auto buffer = CreateBufferOfShape(builder, loc, output_buffer_shape);
|
||||
|
||||
auto get_source_vector_dim_size = [&](llvm::ArrayRef<int64_t> dims) {
|
||||
return llvm::map_to_vector(
|
||||
dims, [&](int64_t dim) { return source_vector_type.getDimSize(dim); });
|
||||
};
|
||||
|
||||
// Outer loop is non-minor non-reduced dimensions.
|
||||
auto [lbs, ubs, step] = GetLoopBounds(
|
||||
builder, loc, get_source_vector_dim_size(non_reduced_non_minor_dims));
|
||||
|
||||
mlir::scf::buildLoopNest(
|
||||
builder, loc, lbs, ubs, step,
|
||||
[&](mlir::OpBuilder& builder, mlir::Location loc,
|
||||
mlir::ValueRange outer_induction_vars) {
|
||||
auto [lbs, ubs, step] = GetLoopBounds(
|
||||
builder, loc, get_source_vector_dim_size(non_minor_reduced_dims),
|
||||
1);
|
||||
|
||||
llvm::SmallVector<mlir::Value> zeroth_step_indices(
|
||||
rank - 1, mlir::arith::ConstantIndexOp::create(builder, loc, 0));
|
||||
for (auto [idx, var] :
|
||||
llvm::zip(non_reduced_non_minor_dims, outer_induction_vars)) {
|
||||
zeroth_step_indices[idx] = var;
|
||||
}
|
||||
// Get the first iteration
|
||||
mlir::Value minor_accumilator =
|
||||
ExtractVector(builder, loc, source_vector, zeroth_step_indices);
|
||||
// Inner loop is the non-minor reduced dimension.
|
||||
mlir::scf::LoopNest loop_nest = mlir::scf::buildLoopNest(
|
||||
builder, loc, lbs, ubs, step, minor_accumilator,
|
||||
[&](mlir::OpBuilder& builder, mlir::Location loc,
|
||||
mlir::ValueRange inner_induction_vars,
|
||||
mlir::ValueRange minor_accumilator)
|
||||
-> mlir::SmallVector<mlir::Value> {
|
||||
llvm::SmallVector<mlir::Value> indices(rank - 1);
|
||||
for (auto [idx, var] : llvm::zip(non_reduced_non_minor_dims,
|
||||
outer_induction_vars)) {
|
||||
indices[idx] = var;
|
||||
}
|
||||
for (auto [idx, var] :
|
||||
llvm::zip(non_minor_reduced_dims, inner_induction_vars)) {
|
||||
indices[idx] = var;
|
||||
}
|
||||
|
||||
mlir::Value vector_slice =
|
||||
ExtractVector(builder, loc, source_vector, indices);
|
||||
|
||||
return {VectorizeBody(builder, loc, body, vector_slice,
|
||||
minor_accumilator.front())};
|
||||
});
|
||||
|
||||
InsertVectorIntoBuffer(builder, loc, loop_nest.results.front(), buffer,
|
||||
outer_induction_vars);
|
||||
return;
|
||||
});
|
||||
|
||||
// If the minor dimension is also reduced then it extracts directly from the
|
||||
// buffer to avoid the additional vector -> subvector operation.
|
||||
if (minor_dim_reduced) {
|
||||
return buffer;
|
||||
}
|
||||
|
||||
return ExtractVectorFromBuffer(builder, loc, buffer);
|
||||
}
|
||||
|
||||
mlir::TypedValue<mlir::VectorType> EmitMinorReduction(
|
||||
mlir::OpBuilder& builder, mlir::Location loc, mlir::VectorType result_type,
|
||||
mlir::Value input, mlir::Value init_value, mlir::Block& body) {
|
||||
absl::StatusOr<mlir::vector::CombiningKind> kind_or = GetCombiningKind(body);
|
||||
if (!kind_or.ok()) {
|
||||
body.getParentOp()->emitRemark() << kind_or.status().ToString();
|
||||
}
|
||||
|
||||
// TODO(willfroom): we could reuse the non minor result buffer.
|
||||
auto minor_result_buffer = CreateBufferOfShape(builder, loc, result_type);
|
||||
auto maybe_input_buffer =
|
||||
mlir::dyn_cast<mlir::TypedValue<mlir::MemRefType>>(input);
|
||||
|
||||
auto maybe_input_type =
|
||||
llvm::TypeSwitch<mlir::Type, std::optional<mlir::ShapedType>>(
|
||||
input.getType())
|
||||
.Case<mlir::MemRefType>([&](auto op) { return input.getType(); })
|
||||
.Case<mlir::VectorType>([&](auto op) { return input.getType(); })
|
||||
.Default([&](auto op) { return std::nullopt; });
|
||||
|
||||
if (!maybe_input_type.has_value()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int64_t minor_dim_size = maybe_input_type->getShape().back();
|
||||
|
||||
auto [lbs, ubs, step] = GetLoopBounds(builder, loc, result_type.getShape());
|
||||
|
||||
mlir::scf::buildLoopNest(
|
||||
builder, loc, lbs, ubs, step,
|
||||
[&](mlir::OpBuilder& builder, mlir::Location loc,
|
||||
mlir::ValueRange induction_vars) {
|
||||
mlir::Value vector_slice =
|
||||
maybe_input_buffer
|
||||
? ExtractVectorFromBuffer(builder, loc, maybe_input_buffer,
|
||||
induction_vars)
|
||||
: ExtractVector(builder, loc, input, induction_vars);
|
||||
|
||||
if (kind_or.ok()) {
|
||||
// TODO(willfroom): Investigate tree-reduction to split the reduction
|
||||
// op into natural sizes (2, 4, 8, 16, ...) and then remove the
|
||||
// reassociation flag.
|
||||
mlir::Value reduced_scalar =
|
||||
builder.create<mlir::vector::ReductionOp>(
|
||||
loc, *kind_or, vector_slice, init_value,
|
||||
mlir::arith::FastMathFlags::reassoc);
|
||||
InsertVectorIntoBuffer(builder, loc, reduced_scalar,
|
||||
minor_result_buffer, induction_vars);
|
||||
return;
|
||||
}
|
||||
|
||||
auto [lbs, ubs, step] = GetLoopBounds(builder, loc, {minor_dim_size});
|
||||
mlir::scf::LoopNest minor_reduction_loop = mlir::scf::buildLoopNest(
|
||||
builder, loc, lbs, ubs, step, {init_value},
|
||||
[&](mlir::OpBuilder& builder, mlir::Location loc,
|
||||
mlir::ValueRange index, mlir::ValueRange carry_value)
|
||||
-> mlir::SmallVector<mlir::Value> {
|
||||
mlir::Value element =
|
||||
ExtractVector(builder, loc, vector_slice, index);
|
||||
return {VectorizeBody(builder, loc, body, element,
|
||||
carry_value.front())};
|
||||
});
|
||||
|
||||
InsertVectorIntoBuffer(builder, loc,
|
||||
minor_reduction_loop.results.front(),
|
||||
minor_result_buffer, induction_vars);
|
||||
return;
|
||||
});
|
||||
|
||||
return ExtractVectorFromBuffer(builder, loc, minor_result_buffer);
|
||||
}
|
||||
|
||||
mlir::Value EmitVectorizedReduction(
|
||||
mlir::OpBuilder& builder, mlir::Location loc, mlir::VectorType result_type,
|
||||
mlir::TypedValue<mlir::VectorType> source, mlir::Value init_value,
|
||||
llvm::ArrayRef<int64_t> reduction_dims, mlir::Block& body) {
|
||||
int64_t rank = source.getType().getRank();
|
||||
int64_t minor_dim = rank - 1;
|
||||
|
||||
bool minor_dim_reduced = reduction_dims.back() == minor_dim;
|
||||
bool non_minor_dim_reduced = reduction_dims.size() > 1 || !minor_dim_reduced;
|
||||
|
||||
mlir::Value non_minor_result;
|
||||
if (non_minor_dim_reduced) {
|
||||
non_minor_result =
|
||||
EmitNonMinorReduction(builder, loc, result_type, source, reduction_dims,
|
||||
body, minor_dim_reduced);
|
||||
}
|
||||
if (!minor_dim_reduced) {
|
||||
// We add the init value during the minor reduction loop, if that wasn't
|
||||
// done then we must apply it here.
|
||||
mlir::Value init_value_vector =
|
||||
builder.create<mlir::vector::BroadcastOp>(loc, result_type, init_value);
|
||||
|
||||
return VectorizeBody(builder, loc, body, non_minor_result,
|
||||
init_value_vector);
|
||||
}
|
||||
|
||||
return EmitMinorReduction(builder, loc, result_type,
|
||||
non_minor_result ? non_minor_result : source,
|
||||
init_value, body);
|
||||
}
|
||||
|
||||
} // namespace xla::cpu
|
||||
47
third_party/xla/xla/backends/cpu/codegen/tiled/transforms/vectorized_reduce_emitter.h
vendored
Normal file
47
third_party/xla/xla/backends/cpu/codegen/tiled/transforms/vectorized_reduce_emitter.h
vendored
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef XLA_BACKENDS_CPU_CODEGEN_TILED_TRANSFORMS_VECTORIZED_REDUCE_EMITTER_H_
|
||||
#define XLA_BACKENDS_CPU_CODEGEN_TILED_TRANSFORMS_VECTORIZED_REDUCE_EMITTER_H_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
// Create a vectorized reduction of the given source vector.
|
||||
//
|
||||
// The implementation is as follows:
|
||||
// 1. If the reduction dimension is only the most minor we convert it into a
|
||||
// nested scf.loop of horizonal reductions and if the body of the reduce is a
|
||||
// single binary operation that is supported by ReductionOp we use that,
|
||||
// otherwise we simply loop over the scalar values.
|
||||
// 2. If the reduction dimensions does not include the most minor dimension, we
|
||||
// loop over the reductions dimensions and apply the body with vectorized
|
||||
// inputs.
|
||||
// 3. If the dimensions are a combindation of minor & non-minor dimensions we
|
||||
// simply apply strategy 2 followed by strategy 1.
|
||||
mlir::Value EmitVectorizedReduction(
|
||||
mlir::OpBuilder& builder, mlir::Location loc, mlir::VectorType result_type,
|
||||
mlir::TypedValue<mlir::VectorType> source, mlir::Value init_value,
|
||||
llvm::ArrayRef<int64_t> reduction_dims, mlir::Block& body);
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
||||
#endif // XLA_BACKENDS_CPU_CODEGEN_TILED_TRANSFORMS_VECTORIZED_REDUCE_EMITTER_H_
|
||||
|
|
@ -71,6 +71,7 @@ def TiledBufferInterface : OpInterface<"TiledBufferInterface"> {
|
|||
def EntryFuncOp : XTile_Op<"entry_func", [
|
||||
Symbol,
|
||||
IsolatedFromAbove,
|
||||
AutomaticAllocationScope,
|
||||
FunctionOpInterface]>
|
||||
{
|
||||
let summary = "My custom entry function operation";
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user