mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
[xla:codegen] Move KernelEmitter::name() to base class and change type to absl::string_view
PiperOrigin-RevId: 825237762
This commit is contained in:
parent
03f4c66dd1
commit
a72cd9ceeb
|
|
@ -17,10 +17,10 @@ limitations under the License.
|
||||||
#define XLA_BACKENDS_CPU_CODEGEN_COMPUTATION_KERNEL_EMITTER_H_
|
#define XLA_BACKENDS_CPU_CODEGEN_COMPUTATION_KERNEL_EMITTER_H_
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "llvm/IR/BasicBlock.h"
|
#include "llvm/IR/BasicBlock.h"
|
||||||
#include "llvm/IR/Function.h"
|
#include "llvm/IR/Function.h"
|
||||||
#include "llvm/IR/IRBuilder.h"
|
#include "llvm/IR/IRBuilder.h"
|
||||||
|
|
@ -50,10 +50,9 @@ class ComputationKernelEmitter final : public LlvmKernelEmitter {
|
||||||
const BufferAssignment* buffer_assignment,
|
const BufferAssignment* buffer_assignment,
|
||||||
const TargetMachineFeatures* target_machine);
|
const TargetMachineFeatures* target_machine);
|
||||||
|
|
||||||
|
absl::string_view name() const final { return "computation_kernel_emitter"; }
|
||||||
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() final;
|
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() final;
|
||||||
|
|
||||||
std::string name() const final { return "computation_kernel_emitter"; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
absl::StatusOr<llvm::Function*> EmitNestedComputation(
|
absl::StatusOr<llvm::Function*> EmitNestedComputation(
|
||||||
llvm::Function* function, llvm::BasicBlock* return_block,
|
llvm::Function* function, llvm::BasicBlock* return_block,
|
||||||
|
|
|
||||||
|
|
@ -16,12 +16,10 @@ limitations under the License.
|
||||||
#ifndef XLA_BACKENDS_CPU_CODEGEN_DOT_DOT_KERNEL_EMITTER_H_
|
#ifndef XLA_BACKENDS_CPU_CODEGEN_DOT_DOT_KERNEL_EMITTER_H_
|
||||||
#define XLA_BACKENDS_CPU_CODEGEN_DOT_DOT_KERNEL_EMITTER_H_
|
#define XLA_BACKENDS_CPU_CODEGEN_DOT_DOT_KERNEL_EMITTER_H_
|
||||||
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "xla/backends/cpu/codegen/target_machine_features.h"
|
#include "xla/backends/cpu/codegen/target_machine_features.h"
|
||||||
#include "xla/codegen/kernel_definition.h"
|
#include "xla/codegen/kernel_definition.h"
|
||||||
#include "xla/codegen/kernel_emitter.h"
|
|
||||||
#include "xla/codegen/llvm_kernel_source.h"
|
#include "xla/codegen/llvm_kernel_source.h"
|
||||||
#include "xla/hlo/ir/hlo_instruction.h"
|
#include "xla/hlo/ir/hlo_instruction.h"
|
||||||
#include "xla/service/buffer_assignment.h"
|
#include "xla/service/buffer_assignment.h"
|
||||||
|
|
@ -34,10 +32,9 @@ class DotKernelEmitter final : public LlvmKernelEmitter {
|
||||||
const BufferAssignment* buffer_assignment,
|
const BufferAssignment* buffer_assignment,
|
||||||
const TargetMachineFeatures* target_machine);
|
const TargetMachineFeatures* target_machine);
|
||||||
|
|
||||||
|
absl::string_view name() const final { return "dot_kernel_emitter"; }
|
||||||
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() override;
|
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() override;
|
||||||
|
|
||||||
std::string name() const final { return "dot_kernel_emitter"; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const HloInstruction* instr_;
|
const HloInstruction* instr_;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,8 @@ limitations under the License.
|
||||||
#ifndef XLA_BACKENDS_CPU_CODEGEN_ELEMENTAL_CONCATENATE_KERNEL_EMITTER_H_
|
#ifndef XLA_BACKENDS_CPU_CODEGEN_ELEMENTAL_CONCATENATE_KERNEL_EMITTER_H_
|
||||||
#define XLA_BACKENDS_CPU_CODEGEN_ELEMENTAL_CONCATENATE_KERNEL_EMITTER_H_
|
#define XLA_BACKENDS_CPU_CODEGEN_ELEMENTAL_CONCATENATE_KERNEL_EMITTER_H_
|
||||||
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "xla/backends/cpu/codegen/target_machine_features.h"
|
#include "xla/backends/cpu/codegen/target_machine_features.h"
|
||||||
#include "xla/codegen/kernel_definition.h"
|
#include "xla/codegen/kernel_definition.h"
|
||||||
#include "xla/codegen/llvm_kernel_source.h"
|
#include "xla/codegen/llvm_kernel_source.h"
|
||||||
|
|
@ -33,10 +32,9 @@ class ConcatenateKernelEmitter final : public LlvmKernelEmitter {
|
||||||
const BufferAssignment* buffer_assignment,
|
const BufferAssignment* buffer_assignment,
|
||||||
const TargetMachineFeatures* target_machine);
|
const TargetMachineFeatures* target_machine);
|
||||||
|
|
||||||
|
absl::string_view name() const final { return "concatenate_kernel_emitter"; }
|
||||||
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() override;
|
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() override;
|
||||||
|
|
||||||
std::string name() const final { return "concatenate_kernel_emitter"; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const HloInstruction* instr_;
|
const HloInstruction* instr_;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "llvm/IR/IRBuilder.h"
|
#include "llvm/IR/IRBuilder.h"
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h"
|
#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h"
|
||||||
|
|
@ -39,10 +40,9 @@ class ElementalKernelEmitter final : public LlvmKernelEmitter {
|
||||||
const BufferAssignment* buffer_assignment,
|
const BufferAssignment* buffer_assignment,
|
||||||
const TargetMachineFeatures* target_machine);
|
const TargetMachineFeatures* target_machine);
|
||||||
|
|
||||||
|
absl::string_view name() const final { return "elemental_kernel_emitter"; }
|
||||||
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() override;
|
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() override;
|
||||||
|
|
||||||
std::string name() const final { return "elemental_kernel_emitter"; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Emits LLVM IR using elemental loop emitter and the given element generator.
|
// Emits LLVM IR using elemental loop emitter and the given element generator.
|
||||||
// If the instruction is parallelized, it will emit a parallel loop partition
|
// If the instruction is parallelized, it will emit a parallel loop partition
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
|
@ -43,10 +44,9 @@ class CpuScatterFusion final : public MlirKernelEmitter {
|
||||||
const HloFusionInstruction* fusion,
|
const HloFusionInstruction* fusion,
|
||||||
gpu::SymbolicExprContext* symbolic_expr_context);
|
gpu::SymbolicExprContext* symbolic_expr_context);
|
||||||
|
|
||||||
|
absl::string_view name() const final { return "cpu_scatter_fusion"; }
|
||||||
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() final;
|
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() final;
|
||||||
|
|
||||||
std::string name() const final { return "cpu_scatter_fusion"; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
absl::Status EmitEntryFunction(
|
absl::Status EmitEntryFunction(
|
||||||
const emitters::PartitionedComputations& computations,
|
const emitters::PartitionedComputations& computations,
|
||||||
|
|
|
||||||
|
|
@ -329,7 +329,7 @@ KernelApiIrBuilder::KernelApiIrBuilder(llvm::LLVMContext& context,
|
||||||
auto KernelApiIrBuilder::EmitKernelPrototype(
|
auto KernelApiIrBuilder::EmitKernelPrototype(
|
||||||
llvm::Module& module, const HloInstruction* instr,
|
llvm::Module& module, const HloInstruction* instr,
|
||||||
const BufferAssignment* buffer_assignment,
|
const BufferAssignment* buffer_assignment,
|
||||||
const std::string& generating_emitter_name, absl::string_view suffix)
|
absl::string_view generating_emitter_name, absl::string_view suffix)
|
||||||
-> absl::StatusOr<KernelPrototype> {
|
-> absl::StatusOr<KernelPrototype> {
|
||||||
TF_ASSIGN_OR_RETURN(std::vector<KernelParameter> arguments,
|
TF_ASSIGN_OR_RETURN(std::vector<KernelParameter> arguments,
|
||||||
GetKernelArgumentsParameters(instr, buffer_assignment));
|
GetKernelArgumentsParameters(instr, buffer_assignment));
|
||||||
|
|
@ -347,7 +347,7 @@ auto KernelApiIrBuilder::EmitKernelPrototype(
|
||||||
llvm::Module& module, absl::string_view name,
|
llvm::Module& module, absl::string_view name,
|
||||||
absl::Span<const KernelParameter> arguments,
|
absl::Span<const KernelParameter> arguments,
|
||||||
absl::Span<const KernelParameter> results,
|
absl::Span<const KernelParameter> results,
|
||||||
const std::string& module_memory_region_name)
|
absl::string_view module_memory_region_name)
|
||||||
-> absl::StatusOr<KernelPrototype> {
|
-> absl::StatusOr<KernelPrototype> {
|
||||||
CHECK(&module.getContext() == &context_) << "Module context mismatch";
|
CHECK(&module.getContext() == &context_) << "Module context mismatch";
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -124,14 +124,13 @@ class KernelApiIrBuilder {
|
||||||
absl::StatusOr<KernelPrototype> EmitKernelPrototype(
|
absl::StatusOr<KernelPrototype> EmitKernelPrototype(
|
||||||
llvm::Module& module, const HloInstruction* instr,
|
llvm::Module& module, const HloInstruction* instr,
|
||||||
const BufferAssignment* buffer_assignment,
|
const BufferAssignment* buffer_assignment,
|
||||||
const std::string& generating_emitter_name,
|
absl::string_view generating_emitter_name, absl::string_view suffix = "");
|
||||||
absl::string_view suffix = "");
|
|
||||||
|
|
||||||
absl::StatusOr<KernelPrototype> EmitKernelPrototype(
|
absl::StatusOr<KernelPrototype> EmitKernelPrototype(
|
||||||
llvm::Module& module, absl::string_view name,
|
llvm::Module& module, absl::string_view name,
|
||||||
absl::Span<const KernelParameter> arguments,
|
absl::Span<const KernelParameter> arguments,
|
||||||
absl::Span<const KernelParameter> results,
|
absl::Span<const KernelParameter> results,
|
||||||
const std::string& module_memory_region_name);
|
absl::string_view module_memory_region_name);
|
||||||
|
|
||||||
// Get the kernel name for the given HLO instruction.
|
// Get the kernel name for the given HLO instruction.
|
||||||
// If generate_unique_c_style_kernel_entry_points is enabled, the name will
|
// If generate_unique_c_style_kernel_entry_points is enabled, the name will
|
||||||
|
|
|
||||||
|
|
@ -49,10 +49,9 @@ class LlvmTestKernelEmitter : public LlvmKernelEmitter {
|
||||||
NumWorkGroups num_workgroups,
|
NumWorkGroups num_workgroups,
|
||||||
absl::Span<const KernelArg> args);
|
absl::Span<const KernelArg> args);
|
||||||
|
|
||||||
|
absl::string_view name() const override { return "llvm_test_kernel_emitter"; }
|
||||||
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() final;
|
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() final;
|
||||||
|
|
||||||
std::string name() const override { return "llvm_test_kernel_emitter"; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string llvm_ir_;
|
std::string llvm_ir_;
|
||||||
std::string kernel_name_;
|
std::string kernel_name_;
|
||||||
|
|
|
||||||
|
|
@ -48,10 +48,9 @@ class MlirTestKernelEmitter : public MlirKernelEmitter {
|
||||||
NumWorkGroups num_workgroups,
|
NumWorkGroups num_workgroups,
|
||||||
absl::Span<const KernelArg> args);
|
absl::Span<const KernelArg> args);
|
||||||
|
|
||||||
|
absl::string_view name() const override { return "mlir_test_kernel_emitter"; }
|
||||||
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() final;
|
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() final;
|
||||||
|
|
||||||
std::string name() const override { return "mlir_test_kernel_emitter"; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string mlir_;
|
std::string mlir_;
|
||||||
std::string kernel_name_;
|
std::string kernel_name_;
|
||||||
|
|
|
||||||
1
third_party/xla/xla/codegen/BUILD
vendored
1
third_party/xla/xla/codegen/BUILD
vendored
|
|
@ -65,6 +65,7 @@ cc_library(
|
||||||
":kernel_definition",
|
":kernel_definition",
|
||||||
"//xla/tsl/platform:statusor",
|
"//xla/tsl/platform:statusor",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings:string_view",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,6 @@ limitations under the License.
|
||||||
#include "xla/codegen/emitters/kernel_arguments.h"
|
#include "xla/codegen/emitters/kernel_arguments.h"
|
||||||
#include "xla/codegen/hlo_fusion_spec.h"
|
#include "xla/codegen/hlo_fusion_spec.h"
|
||||||
#include "xla/codegen/kernel_definition.h"
|
#include "xla/codegen/kernel_definition.h"
|
||||||
#include "xla/codegen/kernel_spec.h"
|
|
||||||
#include "xla/codegen/mlir_kernel_source.h"
|
#include "xla/codegen/mlir_kernel_source.h"
|
||||||
#include "xla/hlo/analysis/indexing_map.h"
|
#include "xla/hlo/analysis/indexing_map.h"
|
||||||
#include "xla/hlo/ir/hlo_instructions.h"
|
#include "xla/hlo/ir/hlo_instructions.h"
|
||||||
|
|
@ -50,6 +49,10 @@ class ConcatenateFusionKernelEmitter final : public MlirKernelEmitter {
|
||||||
WorkDimensions work_dimensions, absl::string_view entry_function_name,
|
WorkDimensions work_dimensions, absl::string_view entry_function_name,
|
||||||
BackendKind backend_kind);
|
BackendKind backend_kind);
|
||||||
|
|
||||||
|
absl::string_view name() const final {
|
||||||
|
return "concatenate_fusion_kernel_emitter";
|
||||||
|
}
|
||||||
|
|
||||||
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() override;
|
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() override;
|
||||||
|
|
||||||
static IndexingMap ComputeWorkItemIdToOutputIndexing(
|
static IndexingMap ComputeWorkItemIdToOutputIndexing(
|
||||||
|
|
@ -67,8 +70,6 @@ class ConcatenateFusionKernelEmitter final : public MlirKernelEmitter {
|
||||||
static int GetValidUnrollFactor(const HloFusionSpec& fusion_spec,
|
static int GetValidUnrollFactor(const HloFusionSpec& fusion_spec,
|
||||||
int max_unroll_factor);
|
int max_unroll_factor);
|
||||||
|
|
||||||
std::string name() const final { return "concatenate_fusion_kernel_emitter"; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
IndexingMap ComputeWorkItemIdToOutputIndexing(
|
IndexingMap ComputeWorkItemIdToOutputIndexing(
|
||||||
gpu::SymbolicExprContext* ctx) const;
|
gpu::SymbolicExprContext* ctx) const;
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,10 @@ class DynamicUpdateSliceKernelEmitter final : public MlirKernelEmitter {
|
||||||
WorkDimensions work_dimensions, absl::string_view entry_function_name,
|
WorkDimensions work_dimensions, absl::string_view entry_function_name,
|
||||||
BackendKind backend_kind);
|
BackendKind backend_kind);
|
||||||
|
|
||||||
|
absl::string_view name() const final {
|
||||||
|
return "dynamic_update_slice_kernel_emitter";
|
||||||
|
}
|
||||||
|
|
||||||
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() override;
|
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() override;
|
||||||
|
|
||||||
// Get the shape that will be used for loop indexing for the given fusion
|
// Get the shape that will be used for loop indexing for the given fusion
|
||||||
|
|
@ -66,10 +70,6 @@ class DynamicUpdateSliceKernelEmitter final : public MlirKernelEmitter {
|
||||||
const WorkDimensions& work_dimensions, const Shape& update_shape,
|
const WorkDimensions& work_dimensions, const Shape& update_shape,
|
||||||
gpu::SymbolicExprContext* ctx);
|
gpu::SymbolicExprContext* ctx);
|
||||||
|
|
||||||
std::string name() const final {
|
|
||||||
return "dynamic_update_slice_kernel_emitter";
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
IndexingMap ComputeWorkItemIdToInputIndexing(
|
IndexingMap ComputeWorkItemIdToInputIndexing(
|
||||||
gpu::SymbolicExprContext* symbolic_expr_context) const;
|
gpu::SymbolicExprContext* symbolic_expr_context) const;
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||||
#ifndef XLA_CODEGEN_EMITTERS_LOOP_KERNEL_EMITTER_H_
|
#ifndef XLA_CODEGEN_EMITTERS_LOOP_KERNEL_EMITTER_H_
|
||||||
#define XLA_CODEGEN_EMITTERS_LOOP_KERNEL_EMITTER_H_
|
#define XLA_CODEGEN_EMITTERS_LOOP_KERNEL_EMITTER_H_
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
|
|
@ -28,7 +27,6 @@ limitations under the License.
|
||||||
#include "xla/codegen/emitters/kernel_arguments.h"
|
#include "xla/codegen/emitters/kernel_arguments.h"
|
||||||
#include "xla/codegen/hlo_fusion_spec.h"
|
#include "xla/codegen/hlo_fusion_spec.h"
|
||||||
#include "xla/codegen/kernel_definition.h"
|
#include "xla/codegen/kernel_definition.h"
|
||||||
#include "xla/codegen/kernel_spec.h"
|
|
||||||
#include "xla/codegen/mlir_kernel_source.h"
|
#include "xla/codegen/mlir_kernel_source.h"
|
||||||
#include "xla/hlo/analysis/indexing_map.h"
|
#include "xla/hlo/analysis/indexing_map.h"
|
||||||
#include "xla/hlo/ir/hlo_instructions.h"
|
#include "xla/hlo/ir/hlo_instructions.h"
|
||||||
|
|
@ -51,6 +49,7 @@ class LoopFusionKernelEmitter final : public MlirKernelEmitter {
|
||||||
absl::string_view entry_function_name,
|
absl::string_view entry_function_name,
|
||||||
BackendKind backend_kind);
|
BackendKind backend_kind);
|
||||||
|
|
||||||
|
absl::string_view name() const final { return "loop_fusion_kernel_emitter"; }
|
||||||
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() override;
|
absl::StatusOr<MlirKernelDefinition> EmitKernelDefinition() override;
|
||||||
|
|
||||||
static IndexingMap ComputeWorkItemIdToOutputIndexing(
|
static IndexingMap ComputeWorkItemIdToOutputIndexing(
|
||||||
|
|
@ -61,8 +60,6 @@ class LoopFusionKernelEmitter final : public MlirKernelEmitter {
|
||||||
// specification.
|
// specification.
|
||||||
static Shape GetIndexingShape(const HloFusionSpec& fusion_spec);
|
static Shape GetIndexingShape(const HloFusionSpec& fusion_spec);
|
||||||
|
|
||||||
std::string name() const final { return "loop_fusion_kernel_emitter"; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
IndexingMap ComputeWorkItemIdToOutputIndexing(
|
IndexingMap ComputeWorkItemIdToOutputIndexing(
|
||||||
gpu::SymbolicExprContext* ctx) const;
|
gpu::SymbolicExprContext* ctx) const;
|
||||||
|
|
|
||||||
6
third_party/xla/xla/codegen/kernel_emitter.h
vendored
6
third_party/xla/xla/codegen/kernel_emitter.h
vendored
|
|
@ -17,9 +17,9 @@ limitations under the License.
|
||||||
#define XLA_CODEGEN_KERNEL_EMITTER_H_
|
#define XLA_CODEGEN_KERNEL_EMITTER_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "xla/codegen/kernel_definition.h"
|
#include "xla/codegen/kernel_definition.h"
|
||||||
#include "xla/tsl/platform/statusor.h"
|
#include "xla/tsl/platform/statusor.h"
|
||||||
|
|
||||||
|
|
@ -29,6 +29,8 @@ class KernelEmitterBase {
|
||||||
public:
|
public:
|
||||||
virtual ~KernelEmitterBase() = default;
|
virtual ~KernelEmitterBase() = default;
|
||||||
|
|
||||||
|
virtual absl::string_view name() const = 0;
|
||||||
|
|
||||||
virtual absl::StatusOr<std::unique_ptr<KernelDefinitionBase>>
|
virtual absl::StatusOr<std::unique_ptr<KernelDefinitionBase>>
|
||||||
EmitBaseKernelDefinition() = 0;
|
EmitBaseKernelDefinition() = 0;
|
||||||
};
|
};
|
||||||
|
|
@ -42,8 +44,6 @@ class KernelEmitter : public KernelEmitterBase {
|
||||||
|
|
||||||
virtual absl::StatusOr<KernelDefinitionType> EmitKernelDefinition() = 0;
|
virtual absl::StatusOr<KernelDefinitionType> EmitKernelDefinition() = 0;
|
||||||
|
|
||||||
virtual std::string name() const = 0;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
absl::StatusOr<std::unique_ptr<KernelDefinitionBase>>
|
absl::StatusOr<std::unique_ptr<KernelDefinitionBase>>
|
||||||
EmitBaseKernelDefinition() final {
|
EmitBaseKernelDefinition() final {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user