[xla:codegen] Remove MlirKernelDefinition alias

PiperOrigin-RevId: 825724819
This commit is contained in:
Eugene Zhulenev 2025-10-29 15:07:01 -07:00 committed by TensorFlower Gardener
parent ff2b8b600d
commit 756a72760a
13 changed files with 40 additions and 37 deletions

View File

@ -632,6 +632,7 @@ cc_library(
"//xla/backends/cpu:alignment", "//xla/backends/cpu:alignment",
"//xla/codegen:hlo_fusion_spec", "//xla/codegen:hlo_fusion_spec",
"//xla/codegen:ir_emission_utils", "//xla/codegen:ir_emission_utils",
"//xla/codegen:kernel_definition",
"//xla/codegen:kernel_spec", "//xla/codegen:kernel_spec",
"//xla/codegen:mlir_kernel_source", "//xla/codegen:mlir_kernel_source",
"//xla/codegen/emitters:concatenate_kernel_emitter", "//xla/codegen/emitters:concatenate_kernel_emitter",

View File

@ -249,7 +249,8 @@ IndexingMap GetScatterIndexingMap(
{}, constraints); {}, constraints);
} }
absl::StatusOr<MlirKernelDefinition> CpuScatterFusion::EmitKernelDefinition() { absl::StatusOr<CpuScatterFusion::KernelDefinition>
CpuScatterFusion::EmitKernelDefinition() {
mlir::OpBuilder builder(symbolic_expr_context_->GetMLIRContext()); mlir::OpBuilder builder(symbolic_expr_context_->GetMLIRContext());
TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> mlir_module, TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> mlir_module,
CreateNamedMlirModuleOp(*fusion_, builder)); CreateNamedMlirModuleOp(*fusion_, builder));
@ -325,8 +326,8 @@ absl::StatusOr<MlirKernelDefinition> CpuScatterFusion::EmitKernelDefinition() {
std::move(argument_buffers), std::move(result_buffers), std::move(argument_buffers), std::move(result_buffers),
std::move(invariant_arguments)); std::move(invariant_arguments));
return MlirKernelDefinition(std::move(kernel_spec), return KernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(mlir_module))); MlirKernelSource(std::move(mlir_module)));
} }
absl::Status CpuScatterFusion::EmitEntryFunction( absl::Status CpuScatterFusion::EmitEntryFunction(

View File

@ -42,7 +42,7 @@ limitations under the License.
#include "xla/codegen/emitters/loop_kernel_emitter.h" #include "xla/codegen/emitters/loop_kernel_emitter.h"
#include "xla/codegen/hlo_fusion_spec.h" #include "xla/codegen/hlo_fusion_spec.h"
#include "xla/codegen/ir_emission_utils.h" #include "xla/codegen/ir_emission_utils.h"
#include "xla/codegen/kernel_spec.h" #include "xla/codegen/kernel_definition.h"
#include "xla/codegen/mlir_kernel_source.h" #include "xla/codegen/mlir_kernel_source.h"
#include "xla/hlo/analysis/symbolic_expr.h" #include "xla/hlo/analysis/symbolic_expr.h"
#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instruction.h"
@ -57,7 +57,6 @@ limitations under the License.
#include "xla/runtime/work_tile_size.h" #include "xla/runtime/work_tile_size.h"
#include "xla/service/buffer_assignment.h" #include "xla/service/buffer_assignment.h"
#include "xla/service/cpu/backend_config.pb.h" #include "xla/service/cpu/backend_config.pb.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/shape.h" #include "xla/shape.h"
#include "xla/shape_util.h" #include "xla/shape_util.h"
#include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/statusor.h"
@ -208,7 +207,7 @@ static HloFusionSpec GetLoopFusionSpec(const HloFusionInstruction& fusion) {
std::move(heroes)); std::move(heroes));
} }
static absl::StatusOr<MlirKernelDefinition> EmitLoopFusionKernel( static absl::StatusOr<KernelDefinition<MlirKernelSource>> EmitLoopFusionKernel(
SymbolicExprContext& context, const HloFusionInstruction& fusion, SymbolicExprContext& context, const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment, absl::string_view name) { const BufferAssignment* buffer_assignment, absl::string_view name) {
VLOG(2) << "Emitting loop fusion kernel: " << name; VLOG(2) << "Emitting loop fusion kernel: " << name;
@ -230,9 +229,11 @@ static absl::StatusOr<MlirKernelDefinition> EmitLoopFusionKernel(
return mlir_kernel_definition; return mlir_kernel_definition;
} }
static absl::StatusOr<MlirKernelDefinition> EmitConcatenateFusionKernel( static absl::StatusOr<KernelDefinition<MlirKernelSource>>
SymbolicExprContext& context, const HloFusionInstruction& fusion, EmitConcatenateFusionKernel(SymbolicExprContext& context,
const BufferAssignment* buffer_assignment, absl::string_view name) { const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment,
absl::string_view name) {
VLOG(2) << "Emitting concatenate fusion kernel: " << name; VLOG(2) << "Emitting concatenate fusion kernel: " << name;
HloFusionSpec fusion_spec = GetLoopFusionSpec(fusion); HloFusionSpec fusion_spec = GetLoopFusionSpec(fusion);
auto work_dimensions = GetConcatenateEmitterWorkDims(fusion, fusion_spec); auto work_dimensions = GetConcatenateEmitterWorkDims(fusion, fusion_spec);
@ -252,9 +253,11 @@ static absl::StatusOr<MlirKernelDefinition> EmitConcatenateFusionKernel(
return mlir_kernel_definition; return mlir_kernel_definition;
} }
static absl::StatusOr<MlirKernelDefinition> EmitDynamicUpdateSliceFusionKernel( static absl::StatusOr<KernelDefinition<MlirKernelSource>>
SymbolicExprContext& context, const HloFusionInstruction& fusion, EmitDynamicUpdateSliceFusionKernel(SymbolicExprContext& context,
const BufferAssignment* buffer_assignment, absl::string_view name) { const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment,
absl::string_view name) {
VLOG(2) << "Emitting dynamic update slice fusion kernel: " << name; VLOG(2) << "Emitting dynamic update slice fusion kernel: " << name;
HloFusionSpec fusion_spec = GetLoopFusionSpec(fusion); HloFusionSpec fusion_spec = GetLoopFusionSpec(fusion);
auto work_dimensions = auto work_dimensions =
@ -275,7 +278,7 @@ static absl::StatusOr<MlirKernelDefinition> EmitDynamicUpdateSliceFusionKernel(
return mlir_kernel_definition; return mlir_kernel_definition;
} }
absl::StatusOr<MlirKernelDefinition> EmitFusionKernel( absl::StatusOr<KernelDefinition<MlirKernelSource>> EmitFusionKernel(
SymbolicExprContext& context, const HloFusionInstruction& fusion, SymbolicExprContext& context, const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment, bool use_unique_c_name) { const BufferAssignment* buffer_assignment, bool use_unique_c_name) {
if (fusion.fusion_kind() == HloFusionInstruction::FusionKind::kLoop) { if (fusion.fusion_kind() == HloFusionInstruction::FusionKind::kLoop) {

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "xla/codegen/emitters/kernel_arguments.h" #include "xla/codegen/emitters/kernel_arguments.h"
#include "xla/codegen/kernel_definition.h"
#include "xla/codegen/mlir_kernel_source.h" #include "xla/codegen/mlir_kernel_source.h"
#include "xla/hlo/analysis/symbolic_expr.h" #include "xla/hlo/analysis/symbolic_expr.h"
#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_instructions.h"
@ -27,7 +28,7 @@ namespace xla::cpu {
emitters::KernelArguments::BufferAlignment GetDefaultBufferAlignment(); emitters::KernelArguments::BufferAlignment GetDefaultBufferAlignment();
absl::StatusOr<MlirKernelDefinition> EmitFusionKernel( absl::StatusOr<KernelDefinition<MlirKernelSource>> EmitFusionKernel(
SymbolicExprContext& context, const HloFusionInstruction& fusion, SymbolicExprContext& context, const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment, bool use_unique_c_name); const BufferAssignment* buffer_assignment, bool use_unique_c_name);

View File

@ -39,7 +39,7 @@ absl::Status Run(const std::string& filename) {
module->entry_computation()->root_instruction()); module->entry_computation()->root_instruction());
fusion->SetAndSanitizeName("main"); fusion->SetAndSanitizeName("main");
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
MlirKernelDefinition kernel_definition, KernelDefinition kernel_definition,
EmitFusionKernel(*symbolic_expr_context, *fusion, nullptr, false)); EmitFusionKernel(*symbolic_expr_context, *fusion, nullptr, false));
llvm::outs() << kernel_definition.source().ToString(); llvm::outs() << kernel_definition.source().ToString();
return absl::OkStatus(); return absl::OkStatus();

View File

@ -206,9 +206,8 @@ NB_MODULE(_extension, kernel_runner_module) {
[](SymbolicExprContext& symbolic_expr_context, [](SymbolicExprContext& symbolic_expr_context,
const HloFusionInstruction& fusion, const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment) { const BufferAssignment* buffer_assignment) {
absl::StatusOr<MlirKernelDefinition> kernel_definition = auto kernel_definition = EmitFusionKernel(symbolic_expr_context, fusion,
EmitFusionKernel(symbolic_expr_context, fusion, buffer_assignment, buffer_assignment, false);
false);
if (!kernel_definition.ok()) { if (!kernel_definition.ok()) {
throw std::runtime_error(kernel_definition.status().ToString()); throw std::runtime_error(kernel_definition.status().ToString());
} }
@ -242,8 +241,8 @@ NB_MODULE(_extension, kernel_runner_module) {
"KernelRunner") "KernelRunner")
.def_static( .def_static(
"create", "create",
[](std::unique_ptr<MlirKernelDefinition, [](std::unique_ptr<KernelDefinition<MlirKernelSource>,
nb::deleter<MlirKernelDefinition>> nb::deleter<KernelDefinition<MlirKernelSource>>>
kernel_definition, kernel_definition,
std::unique_ptr<JitCompiler, nb::deleter<JitCompiler>> std::unique_ptr<JitCompiler, nb::deleter<JitCompiler>>
jit_compiler) { jit_compiler) {

View File

@ -47,7 +47,7 @@ MlirTestKernelEmitter::MlirTestKernelEmitter(absl::string_view mlir,
} }
} }
absl::StatusOr<MlirKernelDefinition> absl::StatusOr<MlirTestKernelEmitter::KernelDefinition>
MlirTestKernelEmitter::EmitKernelDefinition() { MlirTestKernelEmitter::EmitKernelDefinition() {
std::unique_ptr<mlir::MLIRContext> context = FusionCompiler::CreateContext(); std::unique_ptr<mlir::MLIRContext> context = FusionCompiler::CreateContext();
@ -71,6 +71,6 @@ MlirTestKernelEmitter::EmitKernelDefinition() {
KernelSpec kernel_spec(kernel_name_, num_workgroups_, KernelSpec kernel_spec(kernel_name_, num_workgroups_,
std::move(argument_buffers), std::move(result_buffers), std::move(argument_buffers), std::move(result_buffers),
/*invariant_arguments=*/{}); /*invariant_arguments=*/{});
return MlirKernelDefinition(std::move(kernel_spec), std::move(source)); return KernelDefinition(std::move(kernel_spec), std::move(source));
} }
} // namespace xla::cpu } // namespace xla::cpu

View File

@ -89,7 +89,7 @@ ConcatenateFusionKernelEmitter::ConcatenateFusionKernelEmitter(
entry_function_name_(entry_function_name), entry_function_name_(entry_function_name),
backend_kind_(backend_kind) {} backend_kind_(backend_kind) {}
absl::StatusOr<MlirKernelDefinition> absl::StatusOr<ConcatenateFusionKernelEmitter::KernelDefinition>
ConcatenateFusionKernelEmitter::EmitKernelDefinition() { ConcatenateFusionKernelEmitter::EmitKernelDefinition() {
mlir::OpBuilder builder(symbolic_expr_context_.GetMLIRContext()); mlir::OpBuilder builder(symbolic_expr_context_.GetMLIRContext());
auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_.name())); auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_.name()));
@ -121,8 +121,8 @@ ConcatenateFusionKernelEmitter::EmitKernelDefinition() {
GetKernelSpec(entry_function_name_, fusion_, GetKernelSpec(entry_function_name_, fusion_,
buffer_assignment_, work_dimensions_)); buffer_assignment_, work_dimensions_));
return MlirKernelDefinition(std::move(kernel_spec), return KernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(module))); MlirKernelSource(std::move(module)));
} }
const Shape& ConcatenateFusionKernelEmitter::GetIndexingShape( const Shape& ConcatenateFusionKernelEmitter::GetIndexingShape(

View File

@ -92,7 +92,7 @@ DynamicUpdateSliceKernelEmitter::DynamicUpdateSliceKernelEmitter(
entry_function_name_(entry_function_name), entry_function_name_(entry_function_name),
backend_kind_(backend_kind) {} backend_kind_(backend_kind) {}
absl::StatusOr<MlirKernelDefinition> absl::StatusOr<DynamicUpdateSliceKernelEmitter::KernelDefinition>
DynamicUpdateSliceKernelEmitter::EmitKernelDefinition() { DynamicUpdateSliceKernelEmitter::EmitKernelDefinition() {
mlir::OpBuilder builder(symbolic_expr_context_.GetMLIRContext()); mlir::OpBuilder builder(symbolic_expr_context_.GetMLIRContext());
auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_.name())); auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_.name()));
@ -120,8 +120,8 @@ DynamicUpdateSliceKernelEmitter::EmitKernelDefinition() {
TF_ASSIGN_OR_RETURN(auto kernel_spec, GetKernelSpec()); TF_ASSIGN_OR_RETURN(auto kernel_spec, GetKernelSpec());
return MlirKernelDefinition(std::move(kernel_spec), return KernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(module))); MlirKernelSource(std::move(module)));
} }
IndexingMap DynamicUpdateSliceKernelEmitter::ComputeWorkItemIdToInputIndexing( IndexingMap DynamicUpdateSliceKernelEmitter::ComputeWorkItemIdToInputIndexing(

View File

@ -84,7 +84,7 @@ LoopFusionKernelEmitter::LoopFusionKernelEmitter(
entry_function_name_(entry_function_name), entry_function_name_(entry_function_name),
backend_kind_(backend_kind) {} backend_kind_(backend_kind) {}
absl::StatusOr<MlirKernelDefinition> absl::StatusOr<LoopFusionKernelEmitter::KernelDefinition>
LoopFusionKernelEmitter::EmitKernelDefinition() { LoopFusionKernelEmitter::EmitKernelDefinition() {
mlir::OpBuilder builder(symbolic_expr_context_.GetMLIRContext()); mlir::OpBuilder builder(symbolic_expr_context_.GetMLIRContext());
auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_.name())); auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_.name()));
@ -114,8 +114,8 @@ LoopFusionKernelEmitter::EmitKernelDefinition() {
GetKernelSpec(entry_function_name_, fusion_, GetKernelSpec(entry_function_name_, fusion_,
buffer_assignment_, work_dimensions_)); buffer_assignment_, work_dimensions_));
return MlirKernelDefinition(std::move(kernel_spec), return KernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(module))); MlirKernelSource(std::move(module)));
} }
IndexingMap LoopFusionKernelEmitter::ComputeWorkItemIdToOutputIndexing( IndexingMap LoopFusionKernelEmitter::ComputeWorkItemIdToOutputIndexing(

View File

@ -81,8 +81,6 @@ class MlirKernelSource final : public KernelSource {
Storage storage_; Storage storage_;
}; };
using MlirKernelDefinition = KernelDefinition<MlirKernelSource>; // NOLINT
} // namespace xla } // namespace xla
#endif // XLA_CODEGEN_MLIR_KERNEL_SOURCE_H_ #endif // XLA_CODEGEN_MLIR_KERNEL_SOURCE_H_

View File

@ -171,7 +171,7 @@ absl::StatusOr<KernelSpec> ParallelFusionEmitter::AddFusion(
// fixed but will require a rework of the ThunkEmitter. // fixed but will require a rework of the ThunkEmitter.
auto compiler_instance = fusion_compiler_pool_->GetInstance(); auto compiler_instance = fusion_compiler_pool_->GetInstance();
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
MlirKernelDefinition mlir_kernel_definition, KernelDefinition mlir_kernel_definition,
EmitFusionKernel(*compiler_instance->symbolic_expr_context, *fusion, EmitFusionKernel(*compiler_instance->symbolic_expr_context, *fusion,
buffer_assignment_, use_unique_c_name_)); buffer_assignment_, use_unique_c_name_));
@ -181,8 +181,8 @@ absl::StatusOr<KernelSpec> ParallelFusionEmitter::AddFusion(
} }
KernelSpec spec = mlir_kernel_definition.spec(); KernelSpec spec = mlir_kernel_definition.spec();
auto shared_source = auto shared_source = std::make_shared<KernelDefinition<MlirKernelSource>>(
std::make_shared<MlirKernelDefinition>(std::move(mlir_kernel_definition)); std::move(mlir_kernel_definition));
thread_pool_.Schedule(absl::bind_front(&ParallelFusionEmitter::CompileFusion, thread_pool_.Schedule(absl::bind_front(&ParallelFusionEmitter::CompileFusion,
this, std::move(shared_source), this, std::move(shared_source),

View File

@ -851,7 +851,7 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitFusionKernelThunk(
auto kernel_emitter = std::make_unique<CpuScatterFusion>( auto kernel_emitter = std::make_unique<CpuScatterFusion>(
buffer_assignment_, fusion, &symbolic_expr_context_); buffer_assignment_, fusion, &symbolic_expr_context_);
TF_ASSIGN_OR_RETURN(MlirKernelDefinition kernel_definition, TF_ASSIGN_OR_RETURN(KernelDefinition kernel_definition,
kernel_emitter->EmitKernelDefinition()); kernel_emitter->EmitKernelDefinition());
auto kernel_spec = kernel_definition.spec(); auto kernel_spec = kernel_definition.spec();