[xla:codegen] Remove LlvmKernelDefinition alias

Rely on KernelEmitter<T>::KernelDefinition alias and CTAD to avoid spelling full type.

PiperOrigin-RevId: 825336093
This commit is contained in:
Eugene Zhulenev 2025-10-28 20:39:07 -07:00 committed by TensorFlower Gardener
parent 8e76d82f01
commit 1444679887
23 changed files with 63 additions and 58 deletions

View File

@ -132,7 +132,7 @@ ComputationKernelEmitter::ComputationKernelEmitter(
buffer_assignment_(buffer_assignment),
target_machine_(target_machine) {}
absl::StatusOr<LlvmKernelDefinition>
absl::StatusOr<ComputationKernelEmitter::KernelDefinition>
ComputationKernelEmitter::EmitKernelDefinition() {
VLOG(2) << "Emit Computation host kernel: " << instr_->name();
@ -221,7 +221,7 @@ ComputationKernelEmitter::EmitKernelDefinition() {
std::move(kernel_prototype.result_buffers),
std::move(kernel_prototype.invariant_arguments));
return LlvmKernelDefinition(std::move(spec), std::move(source));
return KernelDefinition(std::move(spec), std::move(source));
}
absl::StatusOr<llvm::Function*> ComputationKernelEmitter::EmitNestedComputation(

View File

@ -51,7 +51,7 @@ class ComputationKernelEmitter final : public KernelEmitter<LlvmKernelSource> {
const TargetMachineFeatures* target_machine);
absl::string_view name() const final { return "computation_kernel_emitter"; }
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() final;
absl::StatusOr<KernelDefinition> EmitKernelDefinition() final;
private:
absl::StatusOr<llvm::Function*> EmitNestedComputation(

View File

@ -56,7 +56,8 @@ DotKernelEmitter::DotKernelEmitter(const HloInstruction* instr,
buffer_assignment_(buffer_assignment),
target_machine_(target_machine) {}
absl::StatusOr<LlvmKernelDefinition> DotKernelEmitter::EmitKernelDefinition() {
absl::StatusOr<DotKernelEmitter::KernelDefinition>
DotKernelEmitter::EmitKernelDefinition() {
const HloModuleConfig& config = instr_->GetModule()->config();
DotImplementationStrategy strategy = GetDotImplementationStrategy(
@ -111,7 +112,7 @@ absl::StatusOr<LlvmKernelDefinition> DotKernelEmitter::EmitKernelDefinition() {
std::move(kernel_prototype.result_buffers),
std::move(kernel_prototype.invariant_arguments));
return LlvmKernelDefinition(std::move(spec), std::move(source));
return KernelDefinition(std::move(spec), std::move(source));
}
} // namespace xla::cpu

View File

@ -33,7 +33,7 @@ class DotKernelEmitter final : public KernelEmitter<LlvmKernelSource> {
const TargetMachineFeatures* target_machine);
absl::string_view name() const final { return "dot_kernel_emitter"; }
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() override;
absl::StatusOr<KernelDefinition> EmitKernelDefinition() override;
private:
const HloInstruction* instr_;

View File

@ -43,7 +43,6 @@ limitations under the License.
#include "xla/service/cpu/ir_emitter.h"
#include "xla/service/llvm_ir/ir_array.h"
#include "xla/shape.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
@ -67,7 +66,7 @@ ConcatenateKernelEmitter::ConcatenateKernelEmitter(
buffer_assignment_(buffer_assignment),
target_machine_(target_machine) {}
absl::StatusOr<LlvmKernelDefinition>
absl::StatusOr<ConcatenateKernelEmitter::KernelDefinition>
ConcatenateKernelEmitter::EmitKernelDefinition() {
if (absl::Status status = CanDoFastConcatenate(instr_); !status.ok()) {
VLOG(1) << "Could not emit fast concatenate for " << instr_->ToString()
@ -122,7 +121,7 @@ ConcatenateKernelEmitter::EmitKernelDefinition() {
std::move(kernel_prototype.result_buffers),
std::move(kernel_prototype.invariant_arguments));
return LlvmKernelDefinition(std::move(spec), std::move(source));
return KernelDefinition(std::move(spec), std::move(source));
}
} // namespace xla::cpu

View File

@ -33,7 +33,7 @@ class ConcatenateKernelEmitter final : public KernelEmitter<LlvmKernelSource> {
const TargetMachineFeatures* target_machine);
absl::string_view name() const final { return "concatenate_kernel_emitter"; }
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() override;
absl::StatusOr<KernelDefinition> EmitKernelDefinition() override;
private:
const HloInstruction* instr_;

View File

@ -168,7 +168,7 @@ ElementalKernelEmitter::ElementalKernelEmitter(
buffer_assignment_(buffer_assignment),
target_machine_(target_machine) {}
absl::StatusOr<LlvmKernelDefinition>
absl::StatusOr<ElementalKernelEmitter::KernelDefinition>
ElementalKernelEmitter::EmitKernelDefinition() {
VLOG(2) << "Emit elemental host kernel: " << instr_->name();
@ -234,7 +234,7 @@ ElementalKernelEmitter::EmitKernelDefinition() {
std::move(kernel_prototype.result_buffers),
std::move(kernel_prototype.invariant_arguments));
return LlvmKernelDefinition(std::move(spec), std::move(source));
return KernelDefinition(std::move(spec), std::move(source));
}
absl::StatusOr<NumWorkGroups> ElementalKernelEmitter::EmitElementalLoops(

View File

@ -39,7 +39,7 @@ class ElementalKernelEmitter final : public KernelEmitter<LlvmKernelSource> {
const TargetMachineFeatures* target_machine);
absl::string_view name() const final { return "elemental_kernel_emitter"; }
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() override;
absl::StatusOr<KernelDefinition> EmitKernelDefinition() override;
private:
// Emits LLVM IR using elemental loop emitter and the given element generator.

View File

@ -46,7 +46,7 @@ class ElementalKernelEmitterTest : public HloHardwareIndependentTestBase {
ElementalKernelEmitterTest()
: target_machine_features_([](int64_t size) { return 1; }) {}
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition(
absl::StatusOr<KernelDefinition<LlvmKernelSource>> EmitKernelDefinition(
const HloInstruction* instr, const BufferAssignment* buffer_assignment) {
ElementalKernelEmitter emitter(instr, buffer_assignment,
&target_machine_features_);

View File

@ -48,6 +48,7 @@ cc_library(
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Target",
"@llvm-project//llvm:ir_headers",

View File

@ -16,12 +16,12 @@ limitations under the License.
#include "xla/backends/cpu/testlib/kernel_runner.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Module.h"
@ -51,10 +51,9 @@ limitations under the License.
namespace xla::cpu {
absl::StatusOr<KernelRunner> KernelRunner::Create(
LlvmKernelDefinition kernel_definition, JitCompiler compiler) {
auto spec = kernel_definition.spec();
auto thread_safe_module =
std::move(kernel_definition).TakeSource().thread_safe_module();
KernelDefinition<LlvmKernelSource> kernel, JitCompiler compiler) {
auto spec = kernel.spec();
auto thread_safe_module = std::move(kernel).TakeSource().thread_safe_module();
SetModuleMemoryRegionName(*thread_safe_module.getModuleUnlocked(),
"kernel_runner_test");
@ -73,12 +72,12 @@ absl::StatusOr<KernelRunner> KernelRunner::Create(
}
absl::StatusOr<KernelRunner> KernelRunner::Create(
MlirKernelDefinition kernel_definition, JitCompiler compiler) {
auto spec = kernel_definition.spec();
auto source = std::move(kernel_definition).TakeSource();
KernelDefinition<MlirKernelSource> kernel, JitCompiler compiler) {
auto spec = kernel.spec();
auto source = std::move(kernel).TakeSource();
TF_ASSIGN_OR_RETURN(LlvmKernelSource llvm_kernel_source, LowerToLlvm(source));
return Create(LlvmKernelDefinition(spec, std::move(llvm_kernel_source)),
return Create(KernelDefinition(spec, std::move(llvm_kernel_source)),
std::move(compiler));
}

View File

@ -37,9 +37,9 @@ namespace xla::cpu {
class KernelRunner final : public xla::KernelRunner {
public:
static absl::StatusOr<KernelRunner> Create(
LlvmKernelDefinition kernel_definition, JitCompiler compiler);
KernelDefinition<LlvmKernelSource> kernel, JitCompiler compiler);
static absl::StatusOr<KernelRunner> Create(
MlirKernelDefinition kernel_definition, JitCompiler compiler);
KernelDefinition<MlirKernelSource> kernel, JitCompiler compiler);
KernelRunner(KernelRunner&&) = default;
KernelRunner& operator=(KernelRunner&&) = default;

View File

@ -258,8 +258,9 @@ NB_MODULE(_extension, kernel_runner_module) {
return *std::move(runner);
})
.def_static(
"create", [](std::unique_ptr<LlvmKernelDefinition,
nb::deleter<LlvmKernelDefinition>>
"create",
[](std::unique_ptr<KernelDefinition<LlvmKernelSource>,
nb::deleter<KernelDefinition<LlvmKernelSource>>>
kernel_definition,
std::unique_ptr<JitCompiler, nb::deleter<JitCompiler>>
jit_compiler) {

View File

@ -52,7 +52,7 @@ LlvmTestKernelEmitter::LlvmTestKernelEmitter(absl::string_view llvm_ir,
}
}
absl::StatusOr<LlvmKernelDefinition>
absl::StatusOr<LlvmTestKernelEmitter::KernelDefinition>
LlvmTestKernelEmitter::EmitKernelDefinition() {
auto context = std::make_unique<llvm::LLVMContext>();
@ -84,7 +84,7 @@ LlvmTestKernelEmitter::EmitKernelDefinition() {
KernelSpec kernel_spec(kernel_name_, num_workgroups_,
std::move(argument_buffers), std::move(result_buffers),
/*invariant_arguments=*/{});
return LlvmKernelDefinition(std::move(kernel_spec), std::move(source));
return KernelDefinition(std::move(kernel_spec), std::move(source));
}
} // namespace xla::cpu

View File

@ -51,7 +51,7 @@ class LlvmTestKernelEmitter : public KernelEmitter<LlvmKernelSource> {
absl::Span<const KernelArg> args);
absl::string_view name() const override { return "llvm_test_kernel_emitter"; }
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() final;
absl::StatusOr<KernelDefinition> EmitKernelDefinition() final;
private:
std::string llvm_ir_;

View File

@ -58,14 +58,14 @@ class KernelEmitter : public KernelEmitterBase {
static_assert(std::is_base_of_v<KernelSource, Source>,
"Source must be a subclass of KernelSource");
virtual absl::StatusOr<KernelDefinition<Source>> EmitKernelDefinition() = 0;
using KernelDefinition = ::xla::KernelDefinition<Source>;
virtual absl::StatusOr<KernelDefinition> EmitKernelDefinition() = 0;
private:
absl::StatusOr<std::unique_ptr<KernelDefinitionBase>>
EmitKernelDefinitionBase() final {
TF_ASSIGN_OR_RETURN(auto kernel_definition, EmitKernelDefinition());
return std::make_unique<KernelDefinition<Source>>(
std::move(kernel_definition));
return std::make_unique<KernelDefinition>(std::move(kernel_definition));
}
};

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "xla/codegen/kernel_definition.h"
#include "xla/codegen/kernel_source.h"
namespace xla {
@ -51,8 +50,6 @@ class LlvmKernelSource final : public KernelSource {
llvm::orc::ThreadSafeModule module_;
};
using LlvmKernelDefinition = KernelDefinition<LlvmKernelSource>; // NOLINT
} // namespace xla
#endif // XLA_CODEGEN_LLVM_KERNEL_SOURCE_H_

View File

@ -206,9 +206,9 @@ NB_MODULE(_extension, kernel_runner_module) {
},
nb::rv_policy::reference_internal);
nb::class_<MlirKernelDefinition, KernelDefinitionBase>(
nb::class_<KernelDefinition<MlirKernelSource>, KernelDefinitionBase>(
kernel_runner_module, "MlirKernelDefinition");
nb::class_<LlvmKernelDefinition, KernelDefinitionBase>(
nb::class_<KernelDefinition<LlvmKernelSource>, KernelDefinitionBase>(
kernel_runner_module, "LlvmKernelDefinition");
nb::class_<KernelEmitterBase>(kernel_runner_module, "KernelEmitterBase")

View File

@ -2047,6 +2047,7 @@ cc_library(
deps = [
"//xla/backends/cpu/codegen:fusion_compiler",
"//xla/backends/cpu/codegen:fusion_emitter",
"//xla/codegen:kernel_definition",
"//xla/codegen:kernel_spec",
"//xla/codegen:llvm_kernel_source",
"//xla/codegen:mlir_kernel_source",

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h"
#include "xla/backends/cpu/codegen/fusion_compiler.h"
#include "xla/backends/cpu/codegen/fusion_emitter.h"
#include "xla/codegen/kernel_definition.h"
#include "xla/codegen/kernel_spec.h"
#include "xla/codegen/llvm_kernel_source.h"
#include "xla/codegen/mlir_kernel_source.h"
@ -190,7 +191,7 @@ absl::StatusOr<KernelSpec> ParallelFusionEmitter::AddFusion(
return spec;
}
absl::StatusOr<std::vector<LlvmKernelDefinition>>
absl::StatusOr<std::vector<KernelDefinition<LlvmKernelSource>>>
ParallelFusionEmitter::ConsumeKernels() {
absl::MutexLock lock(kernels_mutex_);
@ -203,8 +204,8 @@ ParallelFusionEmitter::ConsumeKernels() {
}
// Sort the kernels by name to ensure a deterministic order.
absl::c_sort(kernels_, [](const LlvmKernelDefinition& lhs,
const LlvmKernelDefinition& rhs) {
absl::c_sort(kernels_, [](const KernelDefinition<LlvmKernelSource>& lhs,
const KernelDefinition<LlvmKernelSource>& rhs) {
return lhs.spec().name() < rhs.spec().name();
});
@ -212,12 +213,12 @@ ParallelFusionEmitter::ConsumeKernels() {
}
void ParallelFusionEmitter::CompileFusion(
std::shared_ptr<MlirKernelDefinition> mlir_kernel_definition,
std::shared_ptr<KernelDefinition<MlirKernelSource>> mlir_kernel,
std::shared_ptr<CompilerInstance> compiler_instance) {
KernelSpec spec = mlir_kernel_definition->spec();
KernelSpec spec = mlir_kernel->spec();
absl::StatusOr<LlvmKernelSource> llvm_kernel_source =
compiler_instance->compiler->Compile(
std::move(*mlir_kernel_definition).TakeSource());
std::move(*mlir_kernel).TakeSource());
absl::MutexLock lock(kernels_mutex_);
outstanding_kernels_--;

View File

@ -25,8 +25,10 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "xla/backends/cpu/codegen/fusion_compiler.h"
#include "xla/codegen/kernel_definition.h"
#include "xla/codegen/kernel_spec.h"
#include "xla/codegen/llvm_kernel_source.h"
#include "xla/codegen/mlir_kernel_source.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/service/buffer_assignment.h"
#include "xla/tsl/platform/threadpool.h"
@ -52,14 +54,15 @@ class ParallelFusionEmitter {
// Returns the kernels for all the added fusions, blocks until all kernels
// have been compiled.
absl::StatusOr<std::vector<LlvmKernelDefinition>> ConsumeKernels();
absl::StatusOr<std::vector<KernelDefinition<LlvmKernelSource>>>
ConsumeKernels();
private:
struct CompilerInstance;
class FusionCompilerPool;
void CompileFusion(
std::shared_ptr<MlirKernelDefinition> mlir_kernel_definition,
std::shared_ptr<KernelDefinition<MlirKernelSource>> mlir_kernel,
std::shared_ptr<CompilerInstance> compiler_instance);
tsl::thread::ThreadPool& thread_pool_;
@ -70,7 +73,8 @@ class ParallelFusionEmitter {
absl::Mutex kernels_mutex_;
int64_t outstanding_kernels_ ABSL_GUARDED_BY(kernels_mutex_) = 0;
absl::Status kernels_status_ ABSL_GUARDED_BY(kernels_mutex_);
std::vector<LlvmKernelDefinition> kernels_ ABSL_GUARDED_BY(kernels_mutex_);
std::vector<KernelDefinition<LlvmKernelSource>> kernels_
ABSL_GUARDED_BY(kernels_mutex_);
};
} // namespace xla::cpu

View File

@ -123,7 +123,7 @@ TEST_F(ParallelFusionEmitterTest, HappyPathSingleFusion) {
TF_ASSERT_OK_AND_ASSIGN(auto kernels, fussion_emitter.ConsumeKernels());
ASSERT_EQ(kernels.size(), 1);
LlvmKernelDefinition& lowered_kernel = kernels[0];
KernelDefinition<LlvmKernelSource>& lowered_kernel = kernels[0];
EXPECT_EQ(lowered_kernel.spec().name(), expected_name);
auto source = std::move(lowered_kernel).TakeSource();

View File

@ -234,11 +234,12 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitEntryComputation(
absl::StatusOr<std::vector<ThunkEmitter::EmittedKernel>>
ThunkEmitter::ConsumeKernels() {
tsl::profiler::TraceMe trace("ThunkEmitter::ConsumeKernels");
TF_ASSIGN_OR_RETURN(std::vector<LlvmKernelDefinition> fusion_kernels,
TF_ASSIGN_OR_RETURN(
std::vector<KernelDefinition<LlvmKernelSource>> fusion_kernels,
parallel_fusion_emitter_.ConsumeKernels());
kernels_.reserve(kernels_.size() + fusion_kernels.size());
for (LlvmKernelDefinition& kernel : fusion_kernels) {
for (KernelDefinition<LlvmKernelSource>& kernel : fusion_kernels) {
std::string name(kernel.spec().name());
auto source = std::move(kernel).TakeSource();
kernels_.push_back({name, std::move(source).thread_safe_module()});
@ -688,7 +689,7 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitCallThunk(
maybe_small_call.has_value() && *maybe_small_call == "true") {
ComputationKernelEmitter emitter(instruction, &buffer_assignment_,
&target_machine_features_);
TF_ASSIGN_OR_RETURN(LlvmKernelDefinition kernel_definition,
TF_ASSIGN_OR_RETURN(KernelDefinition kernel_definition,
emitter.EmitKernelDefinition());
auto kernel_spec = kernel_definition.spec();
@ -712,7 +713,7 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitConcatenateKernelThunk(
const HloInstruction* instruction) {
ConcatenateKernelEmitter emitter(instruction, &buffer_assignment_,
&target_machine_features_);
TF_ASSIGN_OR_RETURN(LlvmKernelDefinition kernel_definition,
TF_ASSIGN_OR_RETURN(KernelDefinition kernel_definition,
emitter.EmitKernelDefinition());
auto kernel_spec = kernel_definition.spec();
@ -818,7 +819,7 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitElementalKernelThunk(
const HloInstruction* instruction) {
ElementalKernelEmitter emitter(instruction, &buffer_assignment_,
&target_machine_features_);
TF_ASSIGN_OR_RETURN(LlvmKernelDefinition kernel_definition,
TF_ASSIGN_OR_RETURN(KernelDefinition kernel_definition,
emitter.EmitKernelDefinition());
auto kernel_spec = kernel_definition.spec();
@ -1062,7 +1063,7 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitDotThunk(
case DotImplementationStrategy::kTiledLlvmIrGemv: {
DotKernelEmitter emitter(instruction, &buffer_assignment_,
&target_machine_features_);
TF_ASSIGN_OR_RETURN(LlvmKernelDefinition kernel_definition,
TF_ASSIGN_OR_RETURN(KernelDefinition kernel_definition,
emitter.EmitKernelDefinition());
auto kernel_spec = kernel_definition.spec();