mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[xla:codegen] Move KernelEmitter::name() to base class and change type to absl::string_view
PiperOrigin-RevId: 825237762
This commit is contained in:
parent
03f4c66dd1
commit
a72cd9ceeb
|
|
@ -17,10 +17,10 @@ limitations under the License.
|
|||
#define XLA_BACKENDS_CPU_CODEGEN_COMPUTATION_KERNEL_EMITTER_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Function.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
|
|
@ -50,10 +50,9 @@ class ComputationKernelEmitter final : public LlvmKernelEmitter {
|
|||
const BufferAssignment* buffer_assignment,
|
||||
const TargetMachineFeatures* target_machine);
|
||||
|
||||
absl::string_view name() const final { return "computation_kernel_emitter"; }
|
||||
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() final;
|
||||
|
||||
std::string name() const final { return "computation_kernel_emitter"; }
|
||||
|
||||
private:
|
||||
absl::StatusOr<llvm::Function*> EmitNestedComputation(
|
||||
llvm::Function* function, llvm::BasicBlock* return_block,
|
||||
|
|
|
|||
|
|
@ -16,12 +16,10 @@ limitations under the License.
|
|||
#ifndef XLA_BACKENDS_CPU_CODEGEN_DOT_DOT_KERNEL_EMITTER_H_
|
||||
#define XLA_BACKENDS_CPU_CODEGEN_DOT_DOT_KERNEL_EMITTER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "xla/backends/cpu/codegen/target_machine_features.h"
|
||||
#include "xla/codegen/kernel_definition.h"
|
||||
#include "xla/codegen/kernel_emitter.h"
|
||||
#include "xla/codegen/llvm_kernel_source.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
|
|
@ -34,10 +32,9 @@ class DotKernelEmitter final : public LlvmKernelEmitter {
|
|||
const BufferAssignment* buffer_assignment,
|
||||
const TargetMachineFeatures* target_machine);
|
||||
|
||||
absl::string_view name() const final { return "dot_kernel_emitter"; }
|
||||
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() override;
|
||||
|
||||
std::string name() const final { return "dot_kernel_emitter"; }
|
||||
|
||||
private:
|
||||
const HloInstruction* instr_;
|
||||
|
||||
|
|
|
|||
|
|
@ -16,9 +16,8 @@ limitations under the License.
|
|||
#ifndef XLA_BACKENDS_CPU_CODEGEN_ELEMENTAL_CONCATENATE_KERNEL_EMITTER_H_
|
||||
#define XLA_BACKENDS_CPU_CODEGEN_ELEMENTAL_CONCATENATE_KERNEL_EMITTER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "xla/backends/cpu/codegen/target_machine_features.h"
|
||||
#include "xla/codegen/kernel_definition.h"
|
||||
#include "xla/codegen/llvm_kernel_source.h"
|
||||
|
|
@ -33,10 +32,9 @@ class ConcatenateKernelEmitter final : public LlvmKernelEmitter {
|
|||
const BufferAssignment* buffer_assignment,
|
||||
const TargetMachineFeatures* target_machine);
|
||||
|
||||
absl::string_view name() const final { return "concatenate_kernel_emitter"; }
|
||||
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() override;
|
||||
|
||||
std::string name() const final { return "concatenate_kernel_emitter"; }
|
||||
|
||||
private:
|
||||
const HloInstruction* instr_;
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
#include <string>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h"
|
||||
|
|
@ -39,10 +40,9 @@ class ElementalKernelEmitter final : public LlvmKernelEmitter {
|
|||
const BufferAssignment* buffer_assignment,
|
||||
const TargetMachineFeatures* target_machine);
|
||||
|
||||
absl::string_view name() const final { return "elemental_kernel_emitter"; }
|
||||
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() override;
|
||||
|
||||
std::string name() const final { return "elemental_kernel_emitter"; }
|
||||
|
||||
private:
|
||||
// Emits LLVM IR using elemental loop emitter and the given element generator.
|
||||
// If the instruction is parallelized, it will emit a parallel loop partition
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
|
@ -43,10 +44,9 @@ class CpuScatterFusion final : public MlirKernelEmitter {
|
|||
const HloFusionInstruction* fusion,
|
||||
gpu::SymbolicExprContext* symbolic_expr_context);
|
||||
|
||||
absl::string_view name() const final { return "cpu_scatter_fusion"; }
|
||||
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() final;
|
||||
|
||||
std::string name() const final { return "cpu_scatter_fusion"; }
|
||||
|
||||
private:
|
||||
absl::Status EmitEntryFunction(
|
||||
const emitters::PartitionedComputations& computations,
|
||||
|
|
|
|||
|
|
@ -329,7 +329,7 @@ KernelApiIrBuilder::KernelApiIrBuilder(llvm::LLVMContext& context,
|
|||
auto KernelApiIrBuilder::EmitKernelPrototype(
|
||||
llvm::Module& module, const HloInstruction* instr,
|
||||
const BufferAssignment* buffer_assignment,
|
||||
const std::string& generating_emitter_name, absl::string_view suffix)
|
||||
absl::string_view generating_emitter_name, absl::string_view suffix)
|
||||
-> absl::StatusOr<KernelPrototype> {
|
||||
TF_ASSIGN_OR_RETURN(std::vector<KernelParameter> arguments,
|
||||
GetKernelArgumentsParameters(instr, buffer_assignment));
|
||||
|
|
@ -347,7 +347,7 @@ auto KernelApiIrBuilder::EmitKernelPrototype(
|
|||
llvm::Module& module, absl::string_view name,
|
||||
absl::Span<const KernelParameter> arguments,
|
||||
absl::Span<const KernelParameter> results,
|
||||
const std::string& module_memory_region_name)
|
||||
absl::string_view module_memory_region_name)
|
||||
-> absl::StatusOr<KernelPrototype> {
|
||||
CHECK(&module.getContext() == &context_) << "Module context mismatch";
|
||||
|
||||
|
|
|
|||
|
|
@ -124,14 +124,13 @@ class KernelApiIrBuilder {
|
|||
absl::StatusOr<KernelPrototype> EmitKernelPrototype(
|
||||
llvm::Module& module, const HloInstruction* instr,
|
||||
const BufferAssignment* buffer_assignment,
|
||||
const std::string& generating_emitter_name,
|
||||
absl::string_view suffix = "");
|
||||
absl::string_view generating_emitter_name, absl::string_view suffix = "");
|
||||
|
||||
absl::StatusOr<KernelPrototype> EmitKernelPrototype(
|
||||
llvm::Module& module, absl::string_view name,
|
||||
absl::Span<const KernelParameter> arguments,
|
||||
absl::Span<const KernelParameter> results,
|
||||
const std::string& module_memory_region_name);
|
||||
absl::string_view module_memory_region_name);
|
||||
|
||||
// Get the kernel name for the given HLO instruction.
|
||||
// If generate_unique_c_style_kernel_entry_points is enabled, the name will
|
||||
|
|
|
|||
|
|
@ -49,10 +49,9 @@ class LlvmTestKernelEmitter : public LlvmKernelEmitter {
|
|||
NumWorkGroups num_workgroups,
|
||||
absl::Span<const KernelArg> args);
|
||||
|
||||
absl::string_view name() const override { return "llvm_test_kernel_emitter"; }
|
||||
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() final;
|
||||
|
||||
std::string name() const override { return "llvm_test_kernel_emitter"; }
|
||||
|
||||
private:
|
||||
std::string llvm_ir_;
|
||||
std::string kernel_name_;
|
||||
|
|
|
|||
|
|
@ -48,10 +48,9 @@ class MlirTestKernelEmitter : public MlirKernelEmitter {
|
|||
NumWorkGroups num_workgroups,
|
||||
absl::Span<const KernelArg> args);
|
||||
|
||||
absl::string_view name() const override { return "mlir_test_kernel_emitter"; }
|
||||
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() final;
|
||||
|
||||
std::string name() const override { return "mlir_test_kernel_emitter"; }
|
||||
|
||||
private:
|
||||
std::string mlir_;
|
||||
std::string kernel_name_;
|
||||
|
|
|
|||
1
third_party/xla/xla/codegen/BUILD
vendored
1
third_party/xla/xla/codegen/BUILD
vendored
|
|
@ -65,6 +65,7 @@ cc_library(
|
|||
":kernel_definition",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ limitations under the License.
|
|||
#include "xla/codegen/emitters/kernel_arguments.h"
|
||||
#include "xla/codegen/hlo_fusion_spec.h"
|
||||
#include "xla/codegen/kernel_definition.h"
|
||||
#include "xla/codegen/kernel_spec.h"
|
||||
#include "xla/codegen/mlir_kernel_source.h"
|
||||
#include "xla/hlo/analysis/indexing_map.h"
|
||||
#include "xla/hlo/ir/hlo_instructions.h"
|
||||
|
|
@ -50,6 +49,10 @@ class ConcatenateFusionKernelEmitter final : public MlirKernelEmitter {
|
|||
WorkDimensions work_dimensions, absl::string_view entry_function_name,
|
||||
BackendKind backend_kind);
|
||||
|
||||
absl::string_view name() const final {
|
||||
return "concatenate_fusion_kernel_emitter";
|
||||
}
|
||||
|
||||
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() override;
|
||||
|
||||
static IndexingMap ComputeWorkItemIdToOutputIndexing(
|
||||
|
|
@ -67,8 +70,6 @@ class ConcatenateFusionKernelEmitter final : public MlirKernelEmitter {
|
|||
static int GetValidUnrollFactor(const HloFusionSpec& fusion_spec,
|
||||
int max_unroll_factor);
|
||||
|
||||
std::string name() const final { return "concatenate_fusion_kernel_emitter"; }
|
||||
|
||||
private:
|
||||
IndexingMap ComputeWorkItemIdToOutputIndexing(
|
||||
gpu::SymbolicExprContext* ctx) const;
|
||||
|
|
|
|||
|
|
@ -56,6 +56,10 @@ class DynamicUpdateSliceKernelEmitter final : public MlirKernelEmitter {
|
|||
WorkDimensions work_dimensions, absl::string_view entry_function_name,
|
||||
BackendKind backend_kind);
|
||||
|
||||
absl::string_view name() const final {
|
||||
return "dynamic_update_slice_kernel_emitter";
|
||||
}
|
||||
|
||||
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() override;
|
||||
|
||||
// Get the shape that will be used for loop indexing for the given fusion
|
||||
|
|
@ -66,10 +70,6 @@ class DynamicUpdateSliceKernelEmitter final : public MlirKernelEmitter {
|
|||
const WorkDimensions& work_dimensions, const Shape& update_shape,
|
||||
gpu::SymbolicExprContext* ctx);
|
||||
|
||||
std::string name() const final {
|
||||
return "dynamic_update_slice_kernel_emitter";
|
||||
}
|
||||
|
||||
private:
|
||||
IndexingMap ComputeWorkItemIdToInputIndexing(
|
||||
gpu::SymbolicExprContext* symbolic_expr_context) const;
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ limitations under the License.
|
|||
#ifndef XLA_CODEGEN_EMITTERS_LOOP_KERNEL_EMITTER_H_
|
||||
#define XLA_CODEGEN_EMITTERS_LOOP_KERNEL_EMITTER_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
|
|
@ -28,7 +27,6 @@ limitations under the License.
|
|||
#include "xla/codegen/emitters/kernel_arguments.h"
|
||||
#include "xla/codegen/hlo_fusion_spec.h"
|
||||
#include "xla/codegen/kernel_definition.h"
|
||||
#include "xla/codegen/kernel_spec.h"
|
||||
#include "xla/codegen/mlir_kernel_source.h"
|
||||
#include "xla/hlo/analysis/indexing_map.h"
|
||||
#include "xla/hlo/ir/hlo_instructions.h"
|
||||
|
|
@ -51,6 +49,7 @@ class LoopFusionKernelEmitter final : public MlirKernelEmitter {
|
|||
absl::string_view entry_function_name,
|
||||
BackendKind backend_kind);
|
||||
|
||||
absl::string_view name() const final { return "loop_fusion_kernel_emitter"; }
|
||||
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() override;
|
||||
|
||||
static IndexingMap ComputeWorkItemIdToOutputIndexing(
|
||||
|
|
@ -61,8 +60,6 @@ class LoopFusionKernelEmitter final : public MlirKernelEmitter {
|
|||
// specification.
|
||||
static Shape GetIndexingShape(const HloFusionSpec& fusion_spec);
|
||||
|
||||
std::string name() const final { return "loop_fusion_kernel_emitter"; }
|
||||
|
||||
private:
|
||||
IndexingMap ComputeWorkItemIdToOutputIndexing(
|
||||
gpu::SymbolicExprContext* ctx) const;
|
||||
|
|
|
|||
6
third_party/xla/xla/codegen/kernel_emitter.h
vendored
6
third_party/xla/xla/codegen/kernel_emitter.h
vendored
|
|
@ -17,9 +17,9 @@ limitations under the License.
|
|||
#define XLA_CODEGEN_KERNEL_EMITTER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "xla/codegen/kernel_definition.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
|
||||
|
|
@ -29,6 +29,8 @@ class KernelEmitterBase {
|
|||
public:
|
||||
virtual ~KernelEmitterBase() = default;
|
||||
|
||||
virtual absl::string_view name() const = 0;
|
||||
|
||||
virtual absl::StatusOr<std::unique_ptr<KernelDefinitionBase>>
|
||||
EmitBaseKernelDefinition() = 0;
|
||||
};
|
||||
|
|
@ -42,8 +44,6 @@ class KernelEmitter : public KernelEmitterBase {
|
|||
|
||||
virtual absl::StatusOr<KernelDefinitionType> EmitKernelDefinition() = 0;
|
||||
|
||||
virtual std::string name() const = 0;
|
||||
|
||||
private:
|
||||
absl::StatusOr<std::unique_ptr<KernelDefinitionBase>>
|
||||
EmitBaseKernelDefinition() final {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user