[xla:codegen] Remove MlirKernelDefinition alias

PiperOrigin-RevId: 825724819
This commit is contained in:
Eugene Zhulenev 2025-10-29 15:07:01 -07:00 committed by TensorFlower Gardener
parent ff2b8b600d
commit 756a72760a
13 changed files with 40 additions and 37 deletions

View File

@ -632,6 +632,7 @@ cc_library(
"//xla/backends/cpu:alignment",
"//xla/codegen:hlo_fusion_spec",
"//xla/codegen:ir_emission_utils",
"//xla/codegen:kernel_definition",
"//xla/codegen:kernel_spec",
"//xla/codegen:mlir_kernel_source",
"//xla/codegen/emitters:concatenate_kernel_emitter",

View File

@ -249,7 +249,8 @@ IndexingMap GetScatterIndexingMap(
{}, constraints);
}
absl::StatusOr<MlirKernelDefinition> CpuScatterFusion::EmitKernelDefinition() {
absl::StatusOr<CpuScatterFusion::KernelDefinition>
CpuScatterFusion::EmitKernelDefinition() {
mlir::OpBuilder builder(symbolic_expr_context_->GetMLIRContext());
TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> mlir_module,
CreateNamedMlirModuleOp(*fusion_, builder));
@ -325,8 +326,8 @@ absl::StatusOr<MlirKernelDefinition> CpuScatterFusion::EmitKernelDefinition() {
std::move(argument_buffers), std::move(result_buffers),
std::move(invariant_arguments));
return MlirKernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(mlir_module)));
return KernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(mlir_module)));
}
absl::Status CpuScatterFusion::EmitEntryFunction(

View File

@ -42,7 +42,7 @@ limitations under the License.
#include "xla/codegen/emitters/loop_kernel_emitter.h"
#include "xla/codegen/hlo_fusion_spec.h"
#include "xla/codegen/ir_emission_utils.h"
#include "xla/codegen/kernel_spec.h"
#include "xla/codegen/kernel_definition.h"
#include "xla/codegen/mlir_kernel_source.h"
#include "xla/hlo/analysis/symbolic_expr.h"
#include "xla/hlo/ir/hlo_instruction.h"
@ -57,7 +57,6 @@ limitations under the License.
#include "xla/runtime/work_tile_size.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/cpu/backend_config.pb.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/platform/statusor.h"
@ -208,7 +207,7 @@ static HloFusionSpec GetLoopFusionSpec(const HloFusionInstruction& fusion) {
std::move(heroes));
}
static absl::StatusOr<MlirKernelDefinition> EmitLoopFusionKernel(
static absl::StatusOr<KernelDefinition<MlirKernelSource>> EmitLoopFusionKernel(
SymbolicExprContext& context, const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment, absl::string_view name) {
VLOG(2) << "Emitting loop fusion kernel: " << name;
@ -230,9 +229,11 @@ static absl::StatusOr<MlirKernelDefinition> EmitLoopFusionKernel(
return mlir_kernel_definition;
}
static absl::StatusOr<MlirKernelDefinition> EmitConcatenateFusionKernel(
SymbolicExprContext& context, const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment, absl::string_view name) {
static absl::StatusOr<KernelDefinition<MlirKernelSource>>
EmitConcatenateFusionKernel(SymbolicExprContext& context,
const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment,
absl::string_view name) {
VLOG(2) << "Emitting concatenate fusion kernel: " << name;
HloFusionSpec fusion_spec = GetLoopFusionSpec(fusion);
auto work_dimensions = GetConcatenateEmitterWorkDims(fusion, fusion_spec);
@ -252,9 +253,11 @@ static absl::StatusOr<MlirKernelDefinition> EmitConcatenateFusionKernel(
return mlir_kernel_definition;
}
static absl::StatusOr<MlirKernelDefinition> EmitDynamicUpdateSliceFusionKernel(
SymbolicExprContext& context, const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment, absl::string_view name) {
static absl::StatusOr<KernelDefinition<MlirKernelSource>>
EmitDynamicUpdateSliceFusionKernel(SymbolicExprContext& context,
const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment,
absl::string_view name) {
VLOG(2) << "Emitting dynamic update slice fusion kernel: " << name;
HloFusionSpec fusion_spec = GetLoopFusionSpec(fusion);
auto work_dimensions =
@ -275,7 +278,7 @@ static absl::StatusOr<MlirKernelDefinition> EmitDynamicUpdateSliceFusionKernel(
return mlir_kernel_definition;
}
absl::StatusOr<MlirKernelDefinition> EmitFusionKernel(
absl::StatusOr<KernelDefinition<MlirKernelSource>> EmitFusionKernel(
SymbolicExprContext& context, const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment, bool use_unique_c_name) {
if (fusion.fusion_kind() == HloFusionInstruction::FusionKind::kLoop) {

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "xla/codegen/emitters/kernel_arguments.h"
#include "xla/codegen/kernel_definition.h"
#include "xla/codegen/mlir_kernel_source.h"
#include "xla/hlo/analysis/symbolic_expr.h"
#include "xla/hlo/ir/hlo_instructions.h"
@ -27,7 +28,7 @@ namespace xla::cpu {
emitters::KernelArguments::BufferAlignment GetDefaultBufferAlignment();
absl::StatusOr<MlirKernelDefinition> EmitFusionKernel(
absl::StatusOr<KernelDefinition<MlirKernelSource>> EmitFusionKernel(
SymbolicExprContext& context, const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment, bool use_unique_c_name);

View File

@ -39,7 +39,7 @@ absl::Status Run(const std::string& filename) {
module->entry_computation()->root_instruction());
fusion->SetAndSanitizeName("main");
TF_ASSIGN_OR_RETURN(
MlirKernelDefinition kernel_definition,
KernelDefinition kernel_definition,
EmitFusionKernel(*symbolic_expr_context, *fusion, nullptr, false));
llvm::outs() << kernel_definition.source().ToString();
return absl::OkStatus();

View File

@ -206,9 +206,8 @@ NB_MODULE(_extension, kernel_runner_module) {
[](SymbolicExprContext& symbolic_expr_context,
const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment) {
absl::StatusOr<MlirKernelDefinition> kernel_definition =
EmitFusionKernel(symbolic_expr_context, fusion, buffer_assignment,
false);
auto kernel_definition = EmitFusionKernel(symbolic_expr_context, fusion,
buffer_assignment, false);
if (!kernel_definition.ok()) {
throw std::runtime_error(kernel_definition.status().ToString());
}
@ -242,8 +241,8 @@ NB_MODULE(_extension, kernel_runner_module) {
"KernelRunner")
.def_static(
"create",
[](std::unique_ptr<MlirKernelDefinition,
nb::deleter<MlirKernelDefinition>>
[](std::unique_ptr<KernelDefinition<MlirKernelSource>,
nb::deleter<KernelDefinition<MlirKernelSource>>>
kernel_definition,
std::unique_ptr<JitCompiler, nb::deleter<JitCompiler>>
jit_compiler) {

View File

@ -47,7 +47,7 @@ MlirTestKernelEmitter::MlirTestKernelEmitter(absl::string_view mlir,
}
}
absl::StatusOr<MlirKernelDefinition>
absl::StatusOr<MlirTestKernelEmitter::KernelDefinition>
MlirTestKernelEmitter::EmitKernelDefinition() {
std::unique_ptr<mlir::MLIRContext> context = FusionCompiler::CreateContext();
@ -71,6 +71,6 @@ MlirTestKernelEmitter::EmitKernelDefinition() {
KernelSpec kernel_spec(kernel_name_, num_workgroups_,
std::move(argument_buffers), std::move(result_buffers),
/*invariant_arguments=*/{});
return MlirKernelDefinition(std::move(kernel_spec), std::move(source));
return KernelDefinition(std::move(kernel_spec), std::move(source));
}
} // namespace xla::cpu

View File

@ -89,7 +89,7 @@ ConcatenateFusionKernelEmitter::ConcatenateFusionKernelEmitter(
entry_function_name_(entry_function_name),
backend_kind_(backend_kind) {}
absl::StatusOr<MlirKernelDefinition>
absl::StatusOr<ConcatenateFusionKernelEmitter::KernelDefinition>
ConcatenateFusionKernelEmitter::EmitKernelDefinition() {
mlir::OpBuilder builder(symbolic_expr_context_.GetMLIRContext());
auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_.name()));
@ -121,8 +121,8 @@ ConcatenateFusionKernelEmitter::EmitKernelDefinition() {
GetKernelSpec(entry_function_name_, fusion_,
buffer_assignment_, work_dimensions_));
return MlirKernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(module)));
return KernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(module)));
}
const Shape& ConcatenateFusionKernelEmitter::GetIndexingShape(

View File

@ -92,7 +92,7 @@ DynamicUpdateSliceKernelEmitter::DynamicUpdateSliceKernelEmitter(
entry_function_name_(entry_function_name),
backend_kind_(backend_kind) {}
absl::StatusOr<MlirKernelDefinition>
absl::StatusOr<DynamicUpdateSliceKernelEmitter::KernelDefinition>
DynamicUpdateSliceKernelEmitter::EmitKernelDefinition() {
mlir::OpBuilder builder(symbolic_expr_context_.GetMLIRContext());
auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_.name()));
@ -120,8 +120,8 @@ DynamicUpdateSliceKernelEmitter::EmitKernelDefinition() {
TF_ASSIGN_OR_RETURN(auto kernel_spec, GetKernelSpec());
return MlirKernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(module)));
return KernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(module)));
}
IndexingMap DynamicUpdateSliceKernelEmitter::ComputeWorkItemIdToInputIndexing(

View File

@ -84,7 +84,7 @@ LoopFusionKernelEmitter::LoopFusionKernelEmitter(
entry_function_name_(entry_function_name),
backend_kind_(backend_kind) {}
absl::StatusOr<MlirKernelDefinition>
absl::StatusOr<LoopFusionKernelEmitter::KernelDefinition>
LoopFusionKernelEmitter::EmitKernelDefinition() {
mlir::OpBuilder builder(symbolic_expr_context_.GetMLIRContext());
auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_.name()));
@ -114,8 +114,8 @@ LoopFusionKernelEmitter::EmitKernelDefinition() {
GetKernelSpec(entry_function_name_, fusion_,
buffer_assignment_, work_dimensions_));
return MlirKernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(module)));
return KernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(module)));
}
IndexingMap LoopFusionKernelEmitter::ComputeWorkItemIdToOutputIndexing(

View File

@ -81,8 +81,6 @@ class MlirKernelSource final : public KernelSource {
Storage storage_;
};
using MlirKernelDefinition = KernelDefinition<MlirKernelSource>; // NOLINT
} // namespace xla
#endif // XLA_CODEGEN_MLIR_KERNEL_SOURCE_H_

View File

@ -171,7 +171,7 @@ absl::StatusOr<KernelSpec> ParallelFusionEmitter::AddFusion(
// fixed but will require a rework of the ThunkEmitter.
auto compiler_instance = fusion_compiler_pool_->GetInstance();
TF_ASSIGN_OR_RETURN(
MlirKernelDefinition mlir_kernel_definition,
KernelDefinition mlir_kernel_definition,
EmitFusionKernel(*compiler_instance->symbolic_expr_context, *fusion,
buffer_assignment_, use_unique_c_name_));
@ -181,8 +181,8 @@ absl::StatusOr<KernelSpec> ParallelFusionEmitter::AddFusion(
}
KernelSpec spec = mlir_kernel_definition.spec();
auto shared_source =
std::make_shared<MlirKernelDefinition>(std::move(mlir_kernel_definition));
auto shared_source = std::make_shared<KernelDefinition<MlirKernelSource>>(
std::move(mlir_kernel_definition));
thread_pool_.Schedule(absl::bind_front(&ParallelFusionEmitter::CompileFusion,
this, std::move(shared_source),

View File

@ -851,7 +851,7 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitFusionKernelThunk(
auto kernel_emitter = std::make_unique<CpuScatterFusion>(
buffer_assignment_, fusion, &symbolic_expr_context_);
TF_ASSIGN_OR_RETURN(MlirKernelDefinition kernel_definition,
TF_ASSIGN_OR_RETURN(KernelDefinition kernel_definition,
kernel_emitter->EmitKernelDefinition());
auto kernel_spec = kernel_definition.spec();