mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[xla:codegen] Remove MlirKernelDefinition alias
PiperOrigin-RevId: 825724819
This commit is contained in:
parent
ff2b8b600d
commit
756a72760a
|
|
@ -632,6 +632,7 @@ cc_library(
|
|||
"//xla/backends/cpu:alignment",
|
||||
"//xla/codegen:hlo_fusion_spec",
|
||||
"//xla/codegen:ir_emission_utils",
|
||||
"//xla/codegen:kernel_definition",
|
||||
"//xla/codegen:kernel_spec",
|
||||
"//xla/codegen:mlir_kernel_source",
|
||||
"//xla/codegen/emitters:concatenate_kernel_emitter",
|
||||
|
|
|
|||
|
|
@ -249,7 +249,8 @@ IndexingMap GetScatterIndexingMap(
|
|||
{}, constraints);
|
||||
}
|
||||
|
||||
absl::StatusOr<MlirKernelDefinition> CpuScatterFusion::EmitKernelDefinition() {
|
||||
absl::StatusOr<CpuScatterFusion::KernelDefinition>
|
||||
CpuScatterFusion::EmitKernelDefinition() {
|
||||
mlir::OpBuilder builder(symbolic_expr_context_->GetMLIRContext());
|
||||
TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> mlir_module,
|
||||
CreateNamedMlirModuleOp(*fusion_, builder));
|
||||
|
|
@ -325,8 +326,8 @@ absl::StatusOr<MlirKernelDefinition> CpuScatterFusion::EmitKernelDefinition() {
|
|||
std::move(argument_buffers), std::move(result_buffers),
|
||||
std::move(invariant_arguments));
|
||||
|
||||
return MlirKernelDefinition(std::move(kernel_spec),
|
||||
MlirKernelSource(std::move(mlir_module)));
|
||||
return KernelDefinition(std::move(kernel_spec),
|
||||
MlirKernelSource(std::move(mlir_module)));
|
||||
}
|
||||
|
||||
absl::Status CpuScatterFusion::EmitEntryFunction(
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ limitations under the License.
|
|||
#include "xla/codegen/emitters/loop_kernel_emitter.h"
|
||||
#include "xla/codegen/hlo_fusion_spec.h"
|
||||
#include "xla/codegen/ir_emission_utils.h"
|
||||
#include "xla/codegen/kernel_spec.h"
|
||||
#include "xla/codegen/kernel_definition.h"
|
||||
#include "xla/codegen/mlir_kernel_source.h"
|
||||
#include "xla/hlo/analysis/symbolic_expr.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
|
|
@ -57,7 +57,6 @@ limitations under the License.
|
|||
#include "xla/runtime/work_tile_size.h"
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
#include "xla/service/cpu/backend_config.pb.h"
|
||||
#include "xla/service/gpu/ir_emission_utils.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
|
|
@ -208,7 +207,7 @@ static HloFusionSpec GetLoopFusionSpec(const HloFusionInstruction& fusion) {
|
|||
std::move(heroes));
|
||||
}
|
||||
|
||||
static absl::StatusOr<MlirKernelDefinition> EmitLoopFusionKernel(
|
||||
static absl::StatusOr<KernelDefinition<MlirKernelSource>> EmitLoopFusionKernel(
|
||||
SymbolicExprContext& context, const HloFusionInstruction& fusion,
|
||||
const BufferAssignment* buffer_assignment, absl::string_view name) {
|
||||
VLOG(2) << "Emitting loop fusion kernel: " << name;
|
||||
|
|
@ -230,9 +229,11 @@ static absl::StatusOr<MlirKernelDefinition> EmitLoopFusionKernel(
|
|||
return mlir_kernel_definition;
|
||||
}
|
||||
|
||||
static absl::StatusOr<MlirKernelDefinition> EmitConcatenateFusionKernel(
|
||||
SymbolicExprContext& context, const HloFusionInstruction& fusion,
|
||||
const BufferAssignment* buffer_assignment, absl::string_view name) {
|
||||
static absl::StatusOr<KernelDefinition<MlirKernelSource>>
|
||||
EmitConcatenateFusionKernel(SymbolicExprContext& context,
|
||||
const HloFusionInstruction& fusion,
|
||||
const BufferAssignment* buffer_assignment,
|
||||
absl::string_view name) {
|
||||
VLOG(2) << "Emitting concatenate fusion kernel: " << name;
|
||||
HloFusionSpec fusion_spec = GetLoopFusionSpec(fusion);
|
||||
auto work_dimensions = GetConcatenateEmitterWorkDims(fusion, fusion_spec);
|
||||
|
|
@ -252,9 +253,11 @@ static absl::StatusOr<MlirKernelDefinition> EmitConcatenateFusionKernel(
|
|||
return mlir_kernel_definition;
|
||||
}
|
||||
|
||||
static absl::StatusOr<MlirKernelDefinition> EmitDynamicUpdateSliceFusionKernel(
|
||||
SymbolicExprContext& context, const HloFusionInstruction& fusion,
|
||||
const BufferAssignment* buffer_assignment, absl::string_view name) {
|
||||
static absl::StatusOr<KernelDefinition<MlirKernelSource>>
|
||||
EmitDynamicUpdateSliceFusionKernel(SymbolicExprContext& context,
|
||||
const HloFusionInstruction& fusion,
|
||||
const BufferAssignment* buffer_assignment,
|
||||
absl::string_view name) {
|
||||
VLOG(2) << "Emitting dynamic update slice fusion kernel: " << name;
|
||||
HloFusionSpec fusion_spec = GetLoopFusionSpec(fusion);
|
||||
auto work_dimensions =
|
||||
|
|
@ -275,7 +278,7 @@ static absl::StatusOr<MlirKernelDefinition> EmitDynamicUpdateSliceFusionKernel(
|
|||
return mlir_kernel_definition;
|
||||
}
|
||||
|
||||
absl::StatusOr<MlirKernelDefinition> EmitFusionKernel(
|
||||
absl::StatusOr<KernelDefinition<MlirKernelSource>> EmitFusionKernel(
|
||||
SymbolicExprContext& context, const HloFusionInstruction& fusion,
|
||||
const BufferAssignment* buffer_assignment, bool use_unique_c_name) {
|
||||
if (fusion.fusion_kind() == HloFusionInstruction::FusionKind::kLoop) {
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "xla/codegen/emitters/kernel_arguments.h"
|
||||
#include "xla/codegen/kernel_definition.h"
|
||||
#include "xla/codegen/mlir_kernel_source.h"
|
||||
#include "xla/hlo/analysis/symbolic_expr.h"
|
||||
#include "xla/hlo/ir/hlo_instructions.h"
|
||||
|
|
@ -27,7 +28,7 @@ namespace xla::cpu {
|
|||
|
||||
emitters::KernelArguments::BufferAlignment GetDefaultBufferAlignment();
|
||||
|
||||
absl::StatusOr<MlirKernelDefinition> EmitFusionKernel(
|
||||
absl::StatusOr<KernelDefinition<MlirKernelSource>> EmitFusionKernel(
|
||||
SymbolicExprContext& context, const HloFusionInstruction& fusion,
|
||||
const BufferAssignment* buffer_assignment, bool use_unique_c_name);
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ absl::Status Run(const std::string& filename) {
|
|||
module->entry_computation()->root_instruction());
|
||||
fusion->SetAndSanitizeName("main");
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
MlirKernelDefinition kernel_definition,
|
||||
KernelDefinition kernel_definition,
|
||||
EmitFusionKernel(*symbolic_expr_context, *fusion, nullptr, false));
|
||||
llvm::outs() << kernel_definition.source().ToString();
|
||||
return absl::OkStatus();
|
||||
|
|
|
|||
|
|
@ -206,9 +206,8 @@ NB_MODULE(_extension, kernel_runner_module) {
|
|||
[](SymbolicExprContext& symbolic_expr_context,
|
||||
const HloFusionInstruction& fusion,
|
||||
const BufferAssignment* buffer_assignment) {
|
||||
absl::StatusOr<MlirKernelDefinition> kernel_definition =
|
||||
EmitFusionKernel(symbolic_expr_context, fusion, buffer_assignment,
|
||||
false);
|
||||
auto kernel_definition = EmitFusionKernel(symbolic_expr_context, fusion,
|
||||
buffer_assignment, false);
|
||||
if (!kernel_definition.ok()) {
|
||||
throw std::runtime_error(kernel_definition.status().ToString());
|
||||
}
|
||||
|
|
@ -242,8 +241,8 @@ NB_MODULE(_extension, kernel_runner_module) {
|
|||
"KernelRunner")
|
||||
.def_static(
|
||||
"create",
|
||||
[](std::unique_ptr<MlirKernelDefinition,
|
||||
nb::deleter<MlirKernelDefinition>>
|
||||
[](std::unique_ptr<KernelDefinition<MlirKernelSource>,
|
||||
nb::deleter<KernelDefinition<MlirKernelSource>>>
|
||||
kernel_definition,
|
||||
std::unique_ptr<JitCompiler, nb::deleter<JitCompiler>>
|
||||
jit_compiler) {
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ MlirTestKernelEmitter::MlirTestKernelEmitter(absl::string_view mlir,
|
|||
}
|
||||
}
|
||||
|
||||
absl::StatusOr<MlirKernelDefinition>
|
||||
absl::StatusOr<MlirTestKernelEmitter::KernelDefinition>
|
||||
MlirTestKernelEmitter::EmitKernelDefinition() {
|
||||
std::unique_ptr<mlir::MLIRContext> context = FusionCompiler::CreateContext();
|
||||
|
||||
|
|
@ -71,6 +71,6 @@ MlirTestKernelEmitter::EmitKernelDefinition() {
|
|||
KernelSpec kernel_spec(kernel_name_, num_workgroups_,
|
||||
std::move(argument_buffers), std::move(result_buffers),
|
||||
/*invariant_arguments=*/{});
|
||||
return MlirKernelDefinition(std::move(kernel_spec), std::move(source));
|
||||
return KernelDefinition(std::move(kernel_spec), std::move(source));
|
||||
}
|
||||
} // namespace xla::cpu
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ ConcatenateFusionKernelEmitter::ConcatenateFusionKernelEmitter(
|
|||
entry_function_name_(entry_function_name),
|
||||
backend_kind_(backend_kind) {}
|
||||
|
||||
absl::StatusOr<MlirKernelDefinition>
|
||||
absl::StatusOr<ConcatenateFusionKernelEmitter::KernelDefinition>
|
||||
ConcatenateFusionKernelEmitter::EmitKernelDefinition() {
|
||||
mlir::OpBuilder builder(symbolic_expr_context_.GetMLIRContext());
|
||||
auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_.name()));
|
||||
|
|
@ -121,8 +121,8 @@ ConcatenateFusionKernelEmitter::EmitKernelDefinition() {
|
|||
GetKernelSpec(entry_function_name_, fusion_,
|
||||
buffer_assignment_, work_dimensions_));
|
||||
|
||||
return MlirKernelDefinition(std::move(kernel_spec),
|
||||
MlirKernelSource(std::move(module)));
|
||||
return KernelDefinition(std::move(kernel_spec),
|
||||
MlirKernelSource(std::move(module)));
|
||||
}
|
||||
|
||||
const Shape& ConcatenateFusionKernelEmitter::GetIndexingShape(
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ DynamicUpdateSliceKernelEmitter::DynamicUpdateSliceKernelEmitter(
|
|||
entry_function_name_(entry_function_name),
|
||||
backend_kind_(backend_kind) {}
|
||||
|
||||
absl::StatusOr<MlirKernelDefinition>
|
||||
absl::StatusOr<DynamicUpdateSliceKernelEmitter::KernelDefinition>
|
||||
DynamicUpdateSliceKernelEmitter::EmitKernelDefinition() {
|
||||
mlir::OpBuilder builder(symbolic_expr_context_.GetMLIRContext());
|
||||
auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_.name()));
|
||||
|
|
@ -120,8 +120,8 @@ DynamicUpdateSliceKernelEmitter::EmitKernelDefinition() {
|
|||
|
||||
TF_ASSIGN_OR_RETURN(auto kernel_spec, GetKernelSpec());
|
||||
|
||||
return MlirKernelDefinition(std::move(kernel_spec),
|
||||
MlirKernelSource(std::move(module)));
|
||||
return KernelDefinition(std::move(kernel_spec),
|
||||
MlirKernelSource(std::move(module)));
|
||||
}
|
||||
|
||||
IndexingMap DynamicUpdateSliceKernelEmitter::ComputeWorkItemIdToInputIndexing(
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ LoopFusionKernelEmitter::LoopFusionKernelEmitter(
|
|||
entry_function_name_(entry_function_name),
|
||||
backend_kind_(backend_kind) {}
|
||||
|
||||
absl::StatusOr<MlirKernelDefinition>
|
||||
absl::StatusOr<LoopFusionKernelEmitter::KernelDefinition>
|
||||
LoopFusionKernelEmitter::EmitKernelDefinition() {
|
||||
mlir::OpBuilder builder(symbolic_expr_context_.GetMLIRContext());
|
||||
auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_.name()));
|
||||
|
|
@ -114,8 +114,8 @@ LoopFusionKernelEmitter::EmitKernelDefinition() {
|
|||
GetKernelSpec(entry_function_name_, fusion_,
|
||||
buffer_assignment_, work_dimensions_));
|
||||
|
||||
return MlirKernelDefinition(std::move(kernel_spec),
|
||||
MlirKernelSource(std::move(module)));
|
||||
return KernelDefinition(std::move(kernel_spec),
|
||||
MlirKernelSource(std::move(module)));
|
||||
}
|
||||
|
||||
IndexingMap LoopFusionKernelEmitter::ComputeWorkItemIdToOutputIndexing(
|
||||
|
|
|
|||
|
|
@ -81,8 +81,6 @@ class MlirKernelSource final : public KernelSource {
|
|||
Storage storage_;
|
||||
};
|
||||
|
||||
using MlirKernelDefinition = KernelDefinition<MlirKernelSource>; // NOLINT
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // XLA_CODEGEN_MLIR_KERNEL_SOURCE_H_
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ absl::StatusOr<KernelSpec> ParallelFusionEmitter::AddFusion(
|
|||
// fixed but will require a rework of the ThunkEmitter.
|
||||
auto compiler_instance = fusion_compiler_pool_->GetInstance();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
MlirKernelDefinition mlir_kernel_definition,
|
||||
KernelDefinition mlir_kernel_definition,
|
||||
EmitFusionKernel(*compiler_instance->symbolic_expr_context, *fusion,
|
||||
buffer_assignment_, use_unique_c_name_));
|
||||
|
||||
|
|
@ -181,8 +181,8 @@ absl::StatusOr<KernelSpec> ParallelFusionEmitter::AddFusion(
|
|||
}
|
||||
|
||||
KernelSpec spec = mlir_kernel_definition.spec();
|
||||
auto shared_source =
|
||||
std::make_shared<MlirKernelDefinition>(std::move(mlir_kernel_definition));
|
||||
auto shared_source = std::make_shared<KernelDefinition<MlirKernelSource>>(
|
||||
std::move(mlir_kernel_definition));
|
||||
|
||||
thread_pool_.Schedule(absl::bind_front(&ParallelFusionEmitter::CompileFusion,
|
||||
this, std::move(shared_source),
|
||||
|
|
|
|||
|
|
@ -851,7 +851,7 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitFusionKernelThunk(
|
|||
auto kernel_emitter = std::make_unique<CpuScatterFusion>(
|
||||
buffer_assignment_, fusion, &symbolic_expr_context_);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(MlirKernelDefinition kernel_definition,
|
||||
TF_ASSIGN_OR_RETURN(KernelDefinition kernel_definition,
|
||||
kernel_emitter->EmitKernelDefinition());
|
||||
|
||||
auto kernel_spec = kernel_definition.spec();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user