[XLA:CPU][XTile] Add lowering for broadcast.

PiperOrigin-RevId: 825578568
This commit is contained in:
Will Froom 2025-10-29 09:21:02 -07:00 committed by TensorFlower Gardener
parent 98a24eb07e
commit d717d76122
4 changed files with 134 additions and 10 deletions

View File

@ -40,6 +40,7 @@ def compare_kernel(
dtype,
expected_output: Callable[[np.ndarray, ...], np.ndarray],
maxulp: Optional[int] = None,
random_inputs: bool = False,
) -> None:
mlir_emitter = cpu_testlib.MlirTestKernelEmitter(
ir, kernel_name, (num_workgroups, 1, 1)
@ -51,8 +52,14 @@ def compare_kernel(
cpu_testlib.JitCompiler(base_testlib.HloModuleConfig()),
)
# Simply use a all-ones arrays as inputs to make it easy to debug the kernel.
inputs = [np.ones(shape=shape, dtype=dtype) for shape in input_shapes]
# Simply use a all-ones arrays as inputs to make it easy to debug the kernel
# unless random inputs are requested.
def get_input(shape):
if random_inputs:
return get_random_array(shape, dtype)
return np.ones(shape=shape, dtype=dtype)
inputs = [get_input(shape) for shape in input_shapes]
input_tensors = [create_literal(input) for input in inputs]
# Use a random array as the output to ensure all values are written to.
@ -407,6 +414,58 @@ class XtileLoweringTest(absltest.TestCase):
lambda input, init: np.sum(input, axis=(0, 2)) + init,
)
def test_broadcast_in_dim_inner(self):
ir = """
module @broadcast_in_dim_inner {
xtile.entry_func @broadcast_in_dim_inner(
%input: memref<4xf32>,
%output: memref<32x4xf32>,
%tile_id: index) attributes {xtile.tiling_info = #xtile.tiling_info<tile_count:1, tiles_per_workgroup:1>} {
%input_tile = xtile.extract %input[%tile_id][4][1] : memref<4xf32> -> tensor<4xf32>
%result = stablehlo.broadcast_in_dim %input_tile, dims = [1] : (tensor<4xf32>) -> tensor<32x4xf32>
xtile.insert %result into %output[%tile_id, %tile_id][32,4][1,1] : tensor<32x4xf32> -> memref<32x4xf32>
xtile.return
}
}
"""
compare_kernel(
ir,
"broadcast_in_dim_inner",
1,
[(4,)],
(32, 4),
np.float32,
lambda input: np.broadcast_to(input, (32, 4)),
random_inputs=True,
)
def test_broadcast_in_dim_outer(self):
ir = """
module @broadcast_in_dim_outer {
xtile.entry_func @broadcast_in_dim_outer(
%input: memref<4xf32>,
%output: memref<4x32xf32>,
%tile_id: index) attributes {xtile.tiling_info = #xtile.tiling_info<tile_count:1, tiles_per_workgroup:1>} {
%input_tile = xtile.extract %input[%tile_id][4][1] : memref<4xf32> -> tensor<4xf32>
%result = stablehlo.broadcast_in_dim %input_tile, dims = [0] : (tensor<4xf32>) -> tensor<4x32xf32>
xtile.insert %result into %output[%tile_id, %tile_id][4,32][1,1] : tensor<4x32xf32> -> memref<4x32xf32>
xtile.return
}
}
"""
compare_kernel(
ir,
"broadcast_in_dim_outer",
1,
[(4,)],
(4, 32),
np.float32,
lambda input: np.transpose(np.broadcast_to(input, (32, 4))),
random_inputs=True,
)
if __name__ == "__main__":
absltest.main()

View File

@ -64,9 +64,6 @@ cc_library(
"//xla/codegen/emitters/ir:xla",
"//xla/codegen/xtile/ir:xtile",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@llvm-project//llvm:Support",
@ -78,7 +75,6 @@ cc_library(
"@llvm-project//mlir:FuncTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MathOpsIncGen",
"@llvm-project//mlir:MemRefDialect",

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@ -224,9 +225,6 @@ struct LowerTranspose : mlir::OpRewritePattern<mlir::stablehlo::TransposeOp> {
}
};
// Lower stablehlo.reduce to vector operations.
//
struct LowerReduce : mlir::OpRewritePattern<mlir::stablehlo::ReduceOp> {
using OpRewritePattern::OpRewritePattern;
@ -266,6 +264,42 @@ struct LowerReduce : mlir::OpRewritePattern<mlir::stablehlo::ReduceOp> {
}
};
struct LowerBroadcastInDim
: mlir::OpRewritePattern<mlir::stablehlo::BroadcastInDimOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult matchAndRewrite(
mlir::stablehlo::BroadcastInDimOp op,
mlir::PatternRewriter& rewriter) const override {
auto source_vector = CastToVector(rewriter, op.getOperand());
auto result_vector_type = GetVectorType(op.getType());
llvm::ArrayRef<int64_t> source_shape = source_vector.getType().getShape();
llvm::ArrayRef<int64_t> broadcast_dims = op.getBroadcastDimensions();
// First create an intermediate vector with the rank of the result vector
// but with the broadcasted dimensions set to the source shape with all
// additional dimensions set to 1.
llvm::SmallVector<int64_t> intermediate_shape(result_vector_type.getRank(),
1);
for (auto [input_dim, result_dim] : llvm::enumerate(broadcast_dims)) {
intermediate_shape[result_dim] = source_shape[input_dim];
}
mlir::Value intermediate_vector = mlir::vector::ShapeCastOp::create(
rewriter, op->getLoc(),
mlir::VectorType::get(intermediate_shape,
result_vector_type.getElementType()),
source_vector);
// Now that all the inserted dimensions are size 1 we can legally call
// broadcast even if they are not the most major dimensions.
mlir::Value broadcast_op = mlir::vector::BroadcastOp::create(
rewriter, op->getLoc(), result_vector_type, intermediate_vector);
rewriter.replaceOp(op, CastToTensor(rewriter, broadcast_op));
return mlir::success();
}
};
class ShloToVectorPass : public impl::ShloToVectorPassBase<ShloToVectorPass> {
public:
using ShloToVectorPassBase::ShloToVectorPassBase;
@ -273,7 +307,9 @@ class ShloToVectorPass : public impl::ShloToVectorPassBase<ShloToVectorPass> {
void runOnOperation() override {
mlir::MLIRContext* context = &getContext();
mlir::RewritePatternSet patterns(context);
patterns.add<LowerTranspose, LowerDotGeneral, LowerReduce>(context);
patterns
.add<LowerTranspose, LowerDotGeneral, LowerReduce, LowerBroadcastInDim>(
context);
if (mlir::failed(
mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();

View File

@ -141,3 +141,36 @@ func.func @reduce_outer_and_inner(%input : tensor<1024x32x8xf32>, %init : tensor
// CHECK: }
// CHECK: vector.transfer_read %[[BUFFER1]]{{.*}} : memref<32xf32>, vector<32xf32>
// CHECK: }
// -----
func.func @broadcast_0D_tensor(%input : tensor<f32>) -> tensor<32xf32> {
%result = stablehlo.broadcast_in_dim %input, dims = [] : (tensor<f32>) -> tensor<32xf32>
return %result : tensor<32xf32>
}
// CHECK-LABEL: @broadcast_0D_tensor
// CHECK-NOT: vector.shape_cast
// CHECK: vector.broadcast {{.*}} : vector<f32> to vector<32xf32>
// -----
func.func @broadcast_2D_tensor_inner(%input : tensor<4xf32>) -> tensor<32x4xf32> {
%result = stablehlo.broadcast_in_dim %input, dims = [1] : (tensor<4xf32>) -> tensor<32x4xf32>
return %result : tensor<32x4xf32>
}
// CHECK-LABEL: @broadcast_2D_tensor_inner
// CHECK-NOT: vector.shape_cast
// CHECK: vector.broadcast {{.*}} : vector<4xf32> to vector<32x4xf32>
// -----
func.func @broadcast_2D_tensor_outer(%input : tensor<4xf32>) -> tensor<4x32xf32> {
%result = stablehlo.broadcast_in_dim %input, dims = [0] : (tensor<4xf32>) -> tensor<4x32xf32>
return %result : tensor<4x32xf32>
}
// CHECK-LABEL: @broadcast_2D_tensor_outer
// CHECK: vector.shape_cast {{.*}} : vector<4xf32> to vector<4x1xf32>
// CHECK: vector.broadcast {{.*}} : vector<4x1xf32> to vector<4x32xf32>