mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:CPU][XTile] Add lowering for broadcast.
PiperOrigin-RevId: 825578568
This commit is contained in:
parent
98a24eb07e
commit
d717d76122
|
|
@ -40,6 +40,7 @@ def compare_kernel(
|
|||
dtype,
|
||||
expected_output: Callable[[np.ndarray, ...], np.ndarray],
|
||||
maxulp: Optional[int] = None,
|
||||
random_inputs: bool = False,
|
||||
) -> None:
|
||||
mlir_emitter = cpu_testlib.MlirTestKernelEmitter(
|
||||
ir, kernel_name, (num_workgroups, 1, 1)
|
||||
|
|
@ -51,8 +52,14 @@ def compare_kernel(
|
|||
cpu_testlib.JitCompiler(base_testlib.HloModuleConfig()),
|
||||
)
|
||||
|
||||
# Simply use a all-ones arrays as inputs to make it easy to debug the kernel.
|
||||
inputs = [np.ones(shape=shape, dtype=dtype) for shape in input_shapes]
|
||||
# Simply use a all-ones arrays as inputs to make it easy to debug the kernel
|
||||
# unless random inputs are requested.
|
||||
def get_input(shape):
|
||||
if random_inputs:
|
||||
return get_random_array(shape, dtype)
|
||||
return np.ones(shape=shape, dtype=dtype)
|
||||
|
||||
inputs = [get_input(shape) for shape in input_shapes]
|
||||
|
||||
input_tensors = [create_literal(input) for input in inputs]
|
||||
# Use a random array as the output to ensure all values are written to.
|
||||
|
|
@ -407,6 +414,58 @@ class XtileLoweringTest(absltest.TestCase):
|
|||
lambda input, init: np.sum(input, axis=(0, 2)) + init,
|
||||
)
|
||||
|
||||
def test_broadcast_in_dim_inner(self):
|
||||
ir = """
|
||||
module @broadcast_in_dim_inner {
|
||||
xtile.entry_func @broadcast_in_dim_inner(
|
||||
%input: memref<4xf32>,
|
||||
%output: memref<32x4xf32>,
|
||||
%tile_id: index) attributes {xtile.tiling_info = #xtile.tiling_info<tile_count:1, tiles_per_workgroup:1>} {
|
||||
%input_tile = xtile.extract %input[%tile_id][4][1] : memref<4xf32> -> tensor<4xf32>
|
||||
%result = stablehlo.broadcast_in_dim %input_tile, dims = [1] : (tensor<4xf32>) -> tensor<32x4xf32>
|
||||
xtile.insert %result into %output[%tile_id, %tile_id][32,4][1,1] : tensor<32x4xf32> -> memref<32x4xf32>
|
||||
xtile.return
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
compare_kernel(
|
||||
ir,
|
||||
"broadcast_in_dim_inner",
|
||||
1,
|
||||
[(4,)],
|
||||
(32, 4),
|
||||
np.float32,
|
||||
lambda input: np.broadcast_to(input, (32, 4)),
|
||||
random_inputs=True,
|
||||
)
|
||||
|
||||
def test_broadcast_in_dim_outer(self):
|
||||
ir = """
|
||||
module @broadcast_in_dim_outer {
|
||||
xtile.entry_func @broadcast_in_dim_outer(
|
||||
%input: memref<4xf32>,
|
||||
%output: memref<4x32xf32>,
|
||||
%tile_id: index) attributes {xtile.tiling_info = #xtile.tiling_info<tile_count:1, tiles_per_workgroup:1>} {
|
||||
%input_tile = xtile.extract %input[%tile_id][4][1] : memref<4xf32> -> tensor<4xf32>
|
||||
%result = stablehlo.broadcast_in_dim %input_tile, dims = [0] : (tensor<4xf32>) -> tensor<4x32xf32>
|
||||
xtile.insert %result into %output[%tile_id, %tile_id][4,32][1,1] : tensor<4x32xf32> -> memref<4x32xf32>
|
||||
xtile.return
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
compare_kernel(
|
||||
ir,
|
||||
"broadcast_in_dim_outer",
|
||||
1,
|
||||
[(4,)],
|
||||
(4, 32),
|
||||
np.float32,
|
||||
lambda input: np.transpose(np.broadcast_to(input, (32, 4))),
|
||||
random_inputs=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
|
|
|||
|
|
@ -64,9 +64,6 @@ cc_library(
|
|||
"//xla/codegen/emitters/ir:xla",
|
||||
"//xla/codegen/xtile/ir:xtile",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@llvm-project//llvm:Support",
|
||||
|
|
@ -78,7 +75,6 @@ cc_library(
|
|||
"@llvm-project//mlir:FuncTransforms",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LLVMDialect",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:MathDialect",
|
||||
"@llvm-project//mlir:MathOpsIncGen",
|
||||
"@llvm-project//mlir:MemRefDialect",
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
|
|
@ -224,9 +225,6 @@ struct LowerTranspose : mlir::OpRewritePattern<mlir::stablehlo::TransposeOp> {
|
|||
}
|
||||
};
|
||||
|
||||
// Lower stablehlo.reduce to vector operations.
|
||||
//
|
||||
|
||||
struct LowerReduce : mlir::OpRewritePattern<mlir::stablehlo::ReduceOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
|
|
@ -266,6 +264,42 @@ struct LowerReduce : mlir::OpRewritePattern<mlir::stablehlo::ReduceOp> {
|
|||
}
|
||||
};
|
||||
|
||||
struct LowerBroadcastInDim
|
||||
: mlir::OpRewritePattern<mlir::stablehlo::BroadcastInDimOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(
|
||||
mlir::stablehlo::BroadcastInDimOp op,
|
||||
mlir::PatternRewriter& rewriter) const override {
|
||||
auto source_vector = CastToVector(rewriter, op.getOperand());
|
||||
auto result_vector_type = GetVectorType(op.getType());
|
||||
|
||||
llvm::ArrayRef<int64_t> source_shape = source_vector.getType().getShape();
|
||||
llvm::ArrayRef<int64_t> broadcast_dims = op.getBroadcastDimensions();
|
||||
|
||||
// First create an intermediate vector with the rank of the result vector
|
||||
// but with the broadcasted dimensions set to the source shape with all
|
||||
// additional dimensions set to 1.
|
||||
llvm::SmallVector<int64_t> intermediate_shape(result_vector_type.getRank(),
|
||||
1);
|
||||
for (auto [input_dim, result_dim] : llvm::enumerate(broadcast_dims)) {
|
||||
intermediate_shape[result_dim] = source_shape[input_dim];
|
||||
}
|
||||
mlir::Value intermediate_vector = mlir::vector::ShapeCastOp::create(
|
||||
rewriter, op->getLoc(),
|
||||
mlir::VectorType::get(intermediate_shape,
|
||||
result_vector_type.getElementType()),
|
||||
source_vector);
|
||||
// Now that all the inserted dimensions are size 1 we can legally call
|
||||
// broadcast even if they are not the most major dimensions.
|
||||
mlir::Value broadcast_op = mlir::vector::BroadcastOp::create(
|
||||
rewriter, op->getLoc(), result_vector_type, intermediate_vector);
|
||||
|
||||
rewriter.replaceOp(op, CastToTensor(rewriter, broadcast_op));
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
class ShloToVectorPass : public impl::ShloToVectorPassBase<ShloToVectorPass> {
|
||||
public:
|
||||
using ShloToVectorPassBase::ShloToVectorPassBase;
|
||||
|
|
@ -273,7 +307,9 @@ class ShloToVectorPass : public impl::ShloToVectorPassBase<ShloToVectorPass> {
|
|||
void runOnOperation() override {
|
||||
mlir::MLIRContext* context = &getContext();
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
patterns.add<LowerTranspose, LowerDotGeneral, LowerReduce>(context);
|
||||
patterns
|
||||
.add<LowerTranspose, LowerDotGeneral, LowerReduce, LowerBroadcastInDim>(
|
||||
context);
|
||||
if (mlir::failed(
|
||||
mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
|
|
|
|||
|
|
@ -141,3 +141,36 @@ func.func @reduce_outer_and_inner(%input : tensor<1024x32x8xf32>, %init : tensor
|
|||
// CHECK: }
|
||||
// CHECK: vector.transfer_read %[[BUFFER1]]{{.*}} : memref<32xf32>, vector<32xf32>
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
func.func @broadcast_0D_tensor(%input : tensor<f32>) -> tensor<32xf32> {
|
||||
%result = stablehlo.broadcast_in_dim %input, dims = [] : (tensor<f32>) -> tensor<32xf32>
|
||||
return %result : tensor<32xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @broadcast_0D_tensor
|
||||
// CHECK-NOT: vector.shape_cast
|
||||
// CHECK: vector.broadcast {{.*}} : vector<f32> to vector<32xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @broadcast_2D_tensor_inner(%input : tensor<4xf32>) -> tensor<32x4xf32> {
|
||||
%result = stablehlo.broadcast_in_dim %input, dims = [1] : (tensor<4xf32>) -> tensor<32x4xf32>
|
||||
return %result : tensor<32x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @broadcast_2D_tensor_inner
|
||||
// CHECK-NOT: vector.shape_cast
|
||||
// CHECK: vector.broadcast {{.*}} : vector<4xf32> to vector<32x4xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @broadcast_2D_tensor_outer(%input : tensor<4xf32>) -> tensor<4x32xf32> {
|
||||
%result = stablehlo.broadcast_in_dim %input, dims = [0] : (tensor<4xf32>) -> tensor<4x32xf32>
|
||||
return %result : tensor<4x32xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @broadcast_2D_tensor_outer
|
||||
// CHECK: vector.shape_cast {{.*}} : vector<4xf32> to vector<4x1xf32>
|
||||
// CHECK: vector.broadcast {{.*}} : vector<4x1xf32> to vector<4x32xf32>
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user