[xla:codegen] Move KernelEmitter::name() to base class and change type to absl::string_view

PiperOrigin-RevId: 825237762
This commit is contained in:
Eugene Zhulenev 2025-10-28 15:45:59 -07:00 committed by TensorFlower Gardener
parent 03f4c66dd1
commit a72cd9ceeb
14 changed files with 29 additions and 39 deletions

View File

@ -17,10 +17,10 @@ limitations under the License.
#define XLA_BACKENDS_CPU_CODEGEN_COMPUTATION_KERNEL_EMITTER_H_
#include <cstdint>
#include <string>
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
@ -50,10 +50,9 @@ class ComputationKernelEmitter final : public LlvmKernelEmitter {
const BufferAssignment* buffer_assignment,
const TargetMachineFeatures* target_machine);
absl::string_view name() const final { return "computation_kernel_emitter"; }
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() final;
std::string name() const final { return "computation_kernel_emitter"; }
private:
absl::StatusOr<llvm::Function*> EmitNestedComputation(
llvm::Function* function, llvm::BasicBlock* return_block,

View File

@ -16,12 +16,10 @@ limitations under the License.
#ifndef XLA_BACKENDS_CPU_CODEGEN_DOT_DOT_KERNEL_EMITTER_H_
#define XLA_BACKENDS_CPU_CODEGEN_DOT_DOT_KERNEL_EMITTER_H_
#include <string>
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/backends/cpu/codegen/target_machine_features.h"
#include "xla/codegen/kernel_definition.h"
#include "xla/codegen/kernel_emitter.h"
#include "xla/codegen/llvm_kernel_source.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/buffer_assignment.h"
@ -34,10 +32,9 @@ class DotKernelEmitter final : public LlvmKernelEmitter {
const BufferAssignment* buffer_assignment,
const TargetMachineFeatures* target_machine);
absl::string_view name() const final { return "dot_kernel_emitter"; }
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() override;
std::string name() const final { return "dot_kernel_emitter"; }
private:
const HloInstruction* instr_;

View File

@ -16,9 +16,8 @@ limitations under the License.
#ifndef XLA_BACKENDS_CPU_CODEGEN_ELEMENTAL_CONCATENATE_KERNEL_EMITTER_H_
#define XLA_BACKENDS_CPU_CODEGEN_ELEMENTAL_CONCATENATE_KERNEL_EMITTER_H_
#include <string>
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/backends/cpu/codegen/target_machine_features.h"
#include "xla/codegen/kernel_definition.h"
#include "xla/codegen/llvm_kernel_source.h"
@ -33,10 +32,9 @@ class ConcatenateKernelEmitter final : public LlvmKernelEmitter {
const BufferAssignment* buffer_assignment,
const TargetMachineFeatures* target_machine);
absl::string_view name() const final { return "concatenate_kernel_emitter"; }
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() override;
std::string name() const final { return "concatenate_kernel_emitter"; }
private:
const HloInstruction* instr_;

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h"
@ -39,10 +40,9 @@ class ElementalKernelEmitter final : public LlvmKernelEmitter {
const BufferAssignment* buffer_assignment,
const TargetMachineFeatures* target_machine);
absl::string_view name() const final { return "elemental_kernel_emitter"; }
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() override;
std::string name() const final { return "elemental_kernel_emitter"; }
private:
// Emits LLVM IR using elemental loop emitter and the given element generator.
// If the instruction is parallelized, it will emit a parallel loop partition

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Value.h"
@ -43,10 +44,9 @@ class CpuScatterFusion final : public MlirKernelEmitter {
const HloFusionInstruction* fusion,
gpu::SymbolicExprContext* symbolic_expr_context);
absl::string_view name() const final { return "cpu_scatter_fusion"; }
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() final;
std::string name() const final { return "cpu_scatter_fusion"; }
private:
absl::Status EmitEntryFunction(
const emitters::PartitionedComputations& computations,

View File

@ -329,7 +329,7 @@ KernelApiIrBuilder::KernelApiIrBuilder(llvm::LLVMContext& context,
auto KernelApiIrBuilder::EmitKernelPrototype(
llvm::Module& module, const HloInstruction* instr,
const BufferAssignment* buffer_assignment,
const std::string& generating_emitter_name, absl::string_view suffix)
absl::string_view generating_emitter_name, absl::string_view suffix)
-> absl::StatusOr<KernelPrototype> {
TF_ASSIGN_OR_RETURN(std::vector<KernelParameter> arguments,
GetKernelArgumentsParameters(instr, buffer_assignment));
@ -347,7 +347,7 @@ auto KernelApiIrBuilder::EmitKernelPrototype(
llvm::Module& module, absl::string_view name,
absl::Span<const KernelParameter> arguments,
absl::Span<const KernelParameter> results,
const std::string& module_memory_region_name)
absl::string_view module_memory_region_name)
-> absl::StatusOr<KernelPrototype> {
CHECK(&module.getContext() == &context_) << "Module context mismatch";

View File

@ -124,14 +124,13 @@ class KernelApiIrBuilder {
absl::StatusOr<KernelPrototype> EmitKernelPrototype(
llvm::Module& module, const HloInstruction* instr,
const BufferAssignment* buffer_assignment,
const std::string& generating_emitter_name,
absl::string_view suffix = "");
absl::string_view generating_emitter_name, absl::string_view suffix = "");
absl::StatusOr<KernelPrototype> EmitKernelPrototype(
llvm::Module& module, absl::string_view name,
absl::Span<const KernelParameter> arguments,
absl::Span<const KernelParameter> results,
const std::string& module_memory_region_name);
absl::string_view module_memory_region_name);
// Get the kernel name for the given HLO instruction.
// If generate_unique_c_style_kernel_entry_points is enabled, the name will

View File

@ -49,10 +49,9 @@ class LlvmTestKernelEmitter : public LlvmKernelEmitter {
NumWorkGroups num_workgroups,
absl::Span<const KernelArg> args);
absl::string_view name() const override { return "llvm_test_kernel_emitter"; }
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() final;
std::string name() const override { return "llvm_test_kernel_emitter"; }
private:
std::string llvm_ir_;
std::string kernel_name_;

View File

@ -48,10 +48,9 @@ class MlirTestKernelEmitter : public MlirKernelEmitter {
NumWorkGroups num_workgroups,
absl::Span<const KernelArg> args);
absl::string_view name() const override { return "mlir_test_kernel_emitter"; }
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() final;
std::string name() const override { return "mlir_test_kernel_emitter"; }
private:
std::string mlir_;
std::string kernel_name_;

View File

@ -65,6 +65,7 @@ cc_library(
":kernel_definition",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
],
)

View File

@ -29,7 +29,6 @@ limitations under the License.
#include "xla/codegen/emitters/kernel_arguments.h"
#include "xla/codegen/hlo_fusion_spec.h"
#include "xla/codegen/kernel_definition.h"
#include "xla/codegen/kernel_spec.h"
#include "xla/codegen/mlir_kernel_source.h"
#include "xla/hlo/analysis/indexing_map.h"
#include "xla/hlo/ir/hlo_instructions.h"
@ -50,6 +49,10 @@ class ConcatenateFusionKernelEmitter final : public MlirKernelEmitter {
WorkDimensions work_dimensions, absl::string_view entry_function_name,
BackendKind backend_kind);
absl::string_view name() const final {
return "concatenate_fusion_kernel_emitter";
}
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() override;
static IndexingMap ComputeWorkItemIdToOutputIndexing(
@ -67,8 +70,6 @@ class ConcatenateFusionKernelEmitter final : public MlirKernelEmitter {
static int GetValidUnrollFactor(const HloFusionSpec& fusion_spec,
int max_unroll_factor);
std::string name() const final { return "concatenate_fusion_kernel_emitter"; }
private:
IndexingMap ComputeWorkItemIdToOutputIndexing(
gpu::SymbolicExprContext* ctx) const;

View File

@ -56,6 +56,10 @@ class DynamicUpdateSliceKernelEmitter final : public MlirKernelEmitter {
WorkDimensions work_dimensions, absl::string_view entry_function_name,
BackendKind backend_kind);
absl::string_view name() const final {
return "dynamic_update_slice_kernel_emitter";
}
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() override;
// Get the shape that will be used for loop indexing for the given fusion
@ -66,10 +70,6 @@ class DynamicUpdateSliceKernelEmitter final : public MlirKernelEmitter {
const WorkDimensions& work_dimensions, const Shape& update_shape,
gpu::SymbolicExprContext* ctx);
std::string name() const final {
return "dynamic_update_slice_kernel_emitter";
}
private:
IndexingMap ComputeWorkItemIdToInputIndexing(
gpu::SymbolicExprContext* symbolic_expr_context) const;

View File

@ -16,7 +16,6 @@ limitations under the License.
#ifndef XLA_CODEGEN_EMITTERS_LOOP_KERNEL_EMITTER_H_
#define XLA_CODEGEN_EMITTERS_LOOP_KERNEL_EMITTER_H_
#include <cstdint>
#include <string>
#include "absl/status/status.h"
@ -28,7 +27,6 @@ limitations under the License.
#include "xla/codegen/emitters/kernel_arguments.h"
#include "xla/codegen/hlo_fusion_spec.h"
#include "xla/codegen/kernel_definition.h"
#include "xla/codegen/kernel_spec.h"
#include "xla/codegen/mlir_kernel_source.h"
#include "xla/hlo/analysis/indexing_map.h"
#include "xla/hlo/ir/hlo_instructions.h"
@ -51,6 +49,7 @@ class LoopFusionKernelEmitter final : public MlirKernelEmitter {
absl::string_view entry_function_name,
BackendKind backend_kind);
absl::string_view name() const final { return "loop_fusion_kernel_emitter"; }
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() override;
static IndexingMap ComputeWorkItemIdToOutputIndexing(
@ -61,8 +60,6 @@ class LoopFusionKernelEmitter final : public MlirKernelEmitter {
// specification.
static Shape GetIndexingShape(const HloFusionSpec& fusion_spec);
std::string name() const final { return "loop_fusion_kernel_emitter"; }
private:
IndexingMap ComputeWorkItemIdToOutputIndexing(
gpu::SymbolicExprContext* ctx) const;

View File

@ -17,9 +17,9 @@ limitations under the License.
#define XLA_CODEGEN_KERNEL_EMITTER_H_
#include <memory>
#include <string>
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/codegen/kernel_definition.h"
#include "xla/tsl/platform/statusor.h"
@ -29,6 +29,8 @@ class KernelEmitterBase {
public:
virtual ~KernelEmitterBase() = default;
virtual absl::string_view name() const = 0;
virtual absl::StatusOr<std::unique_ptr<KernelDefinitionBase>>
EmitBaseKernelDefinition() = 0;
};
@ -42,8 +44,6 @@ class KernelEmitter : public KernelEmitterBase {
virtual absl::StatusOr<KernelDefinitionType> EmitKernelDefinition() = 0;
virtual std::string name() const = 0;
private:
absl::StatusOr<std::unique_ptr<KernelDefinitionBase>>
EmitBaseKernelDefinition() final {