mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[xla:codegen] Remove LlvmKernelDefinition alias
Rely on KernelEmitter<T>::KernelDefinition alias and CTAD to avoid spelling full type. PiperOrigin-RevId: 825336093
This commit is contained in:
parent
8e76d82f01
commit
1444679887
|
|
@ -132,7 +132,7 @@ ComputationKernelEmitter::ComputationKernelEmitter(
|
|||
buffer_assignment_(buffer_assignment),
|
||||
target_machine_(target_machine) {}
|
||||
|
||||
absl::StatusOr<LlvmKernelDefinition>
|
||||
absl::StatusOr<ComputationKernelEmitter::KernelDefinition>
|
||||
ComputationKernelEmitter::EmitKernelDefinition() {
|
||||
VLOG(2) << "Emit Computation host kernel: " << instr_->name();
|
||||
|
||||
|
|
@ -221,7 +221,7 @@ ComputationKernelEmitter::EmitKernelDefinition() {
|
|||
std::move(kernel_prototype.result_buffers),
|
||||
std::move(kernel_prototype.invariant_arguments));
|
||||
|
||||
return LlvmKernelDefinition(std::move(spec), std::move(source));
|
||||
return KernelDefinition(std::move(spec), std::move(source));
|
||||
}
|
||||
|
||||
absl::StatusOr<llvm::Function*> ComputationKernelEmitter::EmitNestedComputation(
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class ComputationKernelEmitter final : public KernelEmitter<LlvmKernelSource> {
|
|||
const TargetMachineFeatures* target_machine);
|
||||
|
||||
absl::string_view name() const final { return "computation_kernel_emitter"; }
|
||||
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() final;
|
||||
absl::StatusOr<KernelDefinition> EmitKernelDefinition() final;
|
||||
|
||||
private:
|
||||
absl::StatusOr<llvm::Function*> EmitNestedComputation(
|
||||
|
|
|
|||
|
|
@ -56,7 +56,8 @@ DotKernelEmitter::DotKernelEmitter(const HloInstruction* instr,
|
|||
buffer_assignment_(buffer_assignment),
|
||||
target_machine_(target_machine) {}
|
||||
|
||||
absl::StatusOr<LlvmKernelDefinition> DotKernelEmitter::EmitKernelDefinition() {
|
||||
absl::StatusOr<DotKernelEmitter::KernelDefinition>
|
||||
DotKernelEmitter::EmitKernelDefinition() {
|
||||
const HloModuleConfig& config = instr_->GetModule()->config();
|
||||
|
||||
DotImplementationStrategy strategy = GetDotImplementationStrategy(
|
||||
|
|
@ -111,7 +112,7 @@ absl::StatusOr<LlvmKernelDefinition> DotKernelEmitter::EmitKernelDefinition() {
|
|||
std::move(kernel_prototype.result_buffers),
|
||||
std::move(kernel_prototype.invariant_arguments));
|
||||
|
||||
return LlvmKernelDefinition(std::move(spec), std::move(source));
|
||||
return KernelDefinition(std::move(spec), std::move(source));
|
||||
}
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class DotKernelEmitter final : public KernelEmitter<LlvmKernelSource> {
|
|||
const TargetMachineFeatures* target_machine);
|
||||
|
||||
absl::string_view name() const final { return "dot_kernel_emitter"; }
|
||||
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() override;
|
||||
absl::StatusOr<KernelDefinition> EmitKernelDefinition() override;
|
||||
|
||||
private:
|
||||
const HloInstruction* instr_;
|
||||
|
|
|
|||
|
|
@ -43,7 +43,6 @@ limitations under the License.
|
|||
#include "xla/service/cpu/ir_emitter.h"
|
||||
#include "xla/service/llvm_ir/ir_array.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/tsl/platform/errors.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/util.h"
|
||||
|
||||
|
|
@ -67,7 +66,7 @@ ConcatenateKernelEmitter::ConcatenateKernelEmitter(
|
|||
buffer_assignment_(buffer_assignment),
|
||||
target_machine_(target_machine) {}
|
||||
|
||||
absl::StatusOr<LlvmKernelDefinition>
|
||||
absl::StatusOr<ConcatenateKernelEmitter::KernelDefinition>
|
||||
ConcatenateKernelEmitter::EmitKernelDefinition() {
|
||||
if (absl::Status status = CanDoFastConcatenate(instr_); !status.ok()) {
|
||||
VLOG(1) << "Could not emit fast concatenate for " << instr_->ToString()
|
||||
|
|
@ -122,7 +121,7 @@ ConcatenateKernelEmitter::EmitKernelDefinition() {
|
|||
std::move(kernel_prototype.result_buffers),
|
||||
std::move(kernel_prototype.invariant_arguments));
|
||||
|
||||
return LlvmKernelDefinition(std::move(spec), std::move(source));
|
||||
return KernelDefinition(std::move(spec), std::move(source));
|
||||
}
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class ConcatenateKernelEmitter final : public KernelEmitter<LlvmKernelSource> {
|
|||
const TargetMachineFeatures* target_machine);
|
||||
|
||||
absl::string_view name() const final { return "concatenate_kernel_emitter"; }
|
||||
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() override;
|
||||
absl::StatusOr<KernelDefinition> EmitKernelDefinition() override;
|
||||
|
||||
private:
|
||||
const HloInstruction* instr_;
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ ElementalKernelEmitter::ElementalKernelEmitter(
|
|||
buffer_assignment_(buffer_assignment),
|
||||
target_machine_(target_machine) {}
|
||||
|
||||
absl::StatusOr<LlvmKernelDefinition>
|
||||
absl::StatusOr<ElementalKernelEmitter::KernelDefinition>
|
||||
ElementalKernelEmitter::EmitKernelDefinition() {
|
||||
VLOG(2) << "Emit elemental host kernel: " << instr_->name();
|
||||
|
||||
|
|
@ -234,7 +234,7 @@ ElementalKernelEmitter::EmitKernelDefinition() {
|
|||
std::move(kernel_prototype.result_buffers),
|
||||
std::move(kernel_prototype.invariant_arguments));
|
||||
|
||||
return LlvmKernelDefinition(std::move(spec), std::move(source));
|
||||
return KernelDefinition(std::move(spec), std::move(source));
|
||||
}
|
||||
|
||||
absl::StatusOr<NumWorkGroups> ElementalKernelEmitter::EmitElementalLoops(
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class ElementalKernelEmitter final : public KernelEmitter<LlvmKernelSource> {
|
|||
const TargetMachineFeatures* target_machine);
|
||||
|
||||
absl::string_view name() const final { return "elemental_kernel_emitter"; }
|
||||
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() override;
|
||||
absl::StatusOr<KernelDefinition> EmitKernelDefinition() override;
|
||||
|
||||
private:
|
||||
// Emits LLVM IR using elemental loop emitter and the given element generator.
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class ElementalKernelEmitterTest : public HloHardwareIndependentTestBase {
|
|||
ElementalKernelEmitterTest()
|
||||
: target_machine_features_([](int64_t size) { return 1; }) {}
|
||||
|
||||
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition(
|
||||
absl::StatusOr<KernelDefinition<LlvmKernelSource>> EmitKernelDefinition(
|
||||
const HloInstruction* instr, const BufferAssignment* buffer_assignment) {
|
||||
ElementalKernelEmitter emitter(instr, buffer_assignment,
|
||||
&target_machine_features_);
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ cc_library(
|
|||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:Target",
|
||||
"@llvm-project//llvm:ir_headers",
|
||||
|
|
|
|||
|
|
@ -16,12 +16,12 @@ limitations under the License.
|
|||
#include "xla/backends/cpu/testlib/kernel_runner.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/IR/DataLayout.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
|
|
@ -51,10 +51,9 @@ limitations under the License.
|
|||
namespace xla::cpu {
|
||||
|
||||
absl::StatusOr<KernelRunner> KernelRunner::Create(
|
||||
LlvmKernelDefinition kernel_definition, JitCompiler compiler) {
|
||||
auto spec = kernel_definition.spec();
|
||||
auto thread_safe_module =
|
||||
std::move(kernel_definition).TakeSource().thread_safe_module();
|
||||
KernelDefinition<LlvmKernelSource> kernel, JitCompiler compiler) {
|
||||
auto spec = kernel.spec();
|
||||
auto thread_safe_module = std::move(kernel).TakeSource().thread_safe_module();
|
||||
SetModuleMemoryRegionName(*thread_safe_module.getModuleUnlocked(),
|
||||
"kernel_runner_test");
|
||||
|
||||
|
|
@ -73,12 +72,12 @@ absl::StatusOr<KernelRunner> KernelRunner::Create(
|
|||
}
|
||||
|
||||
absl::StatusOr<KernelRunner> KernelRunner::Create(
|
||||
MlirKernelDefinition kernel_definition, JitCompiler compiler) {
|
||||
auto spec = kernel_definition.spec();
|
||||
auto source = std::move(kernel_definition).TakeSource();
|
||||
KernelDefinition<MlirKernelSource> kernel, JitCompiler compiler) {
|
||||
auto spec = kernel.spec();
|
||||
auto source = std::move(kernel).TakeSource();
|
||||
TF_ASSIGN_OR_RETURN(LlvmKernelSource llvm_kernel_source, LowerToLlvm(source));
|
||||
|
||||
return Create(LlvmKernelDefinition(spec, std::move(llvm_kernel_source)),
|
||||
return Create(KernelDefinition(spec, std::move(llvm_kernel_source)),
|
||||
std::move(compiler));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -37,9 +37,9 @@ namespace xla::cpu {
|
|||
class KernelRunner final : public xla::KernelRunner {
|
||||
public:
|
||||
static absl::StatusOr<KernelRunner> Create(
|
||||
LlvmKernelDefinition kernel_definition, JitCompiler compiler);
|
||||
KernelDefinition<LlvmKernelSource> kernel, JitCompiler compiler);
|
||||
static absl::StatusOr<KernelRunner> Create(
|
||||
MlirKernelDefinition kernel_definition, JitCompiler compiler);
|
||||
KernelDefinition<MlirKernelSource> kernel, JitCompiler compiler);
|
||||
|
||||
KernelRunner(KernelRunner&&) = default;
|
||||
KernelRunner& operator=(KernelRunner&&) = default;
|
||||
|
|
|
|||
|
|
@ -258,11 +258,12 @@ NB_MODULE(_extension, kernel_runner_module) {
|
|||
return *std::move(runner);
|
||||
})
|
||||
.def_static(
|
||||
"create", [](std::unique_ptr<LlvmKernelDefinition,
|
||||
nb::deleter<LlvmKernelDefinition>>
|
||||
kernel_definition,
|
||||
std::unique_ptr<JitCompiler, nb::deleter<JitCompiler>>
|
||||
jit_compiler) {
|
||||
"create",
|
||||
[](std::unique_ptr<KernelDefinition<LlvmKernelSource>,
|
||||
nb::deleter<KernelDefinition<LlvmKernelSource>>>
|
||||
kernel_definition,
|
||||
std::unique_ptr<JitCompiler, nb::deleter<JitCompiler>>
|
||||
jit_compiler) {
|
||||
absl::StatusOr<KernelRunner> runner = KernelRunner::Create(
|
||||
std::move(*kernel_definition), std::move(*jit_compiler));
|
||||
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ LlvmTestKernelEmitter::LlvmTestKernelEmitter(absl::string_view llvm_ir,
|
|||
}
|
||||
}
|
||||
|
||||
absl::StatusOr<LlvmKernelDefinition>
|
||||
absl::StatusOr<LlvmTestKernelEmitter::KernelDefinition>
|
||||
LlvmTestKernelEmitter::EmitKernelDefinition() {
|
||||
auto context = std::make_unique<llvm::LLVMContext>();
|
||||
|
||||
|
|
@ -84,7 +84,7 @@ LlvmTestKernelEmitter::EmitKernelDefinition() {
|
|||
KernelSpec kernel_spec(kernel_name_, num_workgroups_,
|
||||
std::move(argument_buffers), std::move(result_buffers),
|
||||
/*invariant_arguments=*/{});
|
||||
return LlvmKernelDefinition(std::move(kernel_spec), std::move(source));
|
||||
return KernelDefinition(std::move(kernel_spec), std::move(source));
|
||||
}
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class LlvmTestKernelEmitter : public KernelEmitter<LlvmKernelSource> {
|
|||
absl::Span<const KernelArg> args);
|
||||
|
||||
absl::string_view name() const override { return "llvm_test_kernel_emitter"; }
|
||||
absl::StatusOr<LlvmKernelDefinition> EmitKernelDefinition() final;
|
||||
absl::StatusOr<KernelDefinition> EmitKernelDefinition() final;
|
||||
|
||||
private:
|
||||
std::string llvm_ir_;
|
||||
|
|
|
|||
6
third_party/xla/xla/codegen/kernel_emitter.h
vendored
6
third_party/xla/xla/codegen/kernel_emitter.h
vendored
|
|
@ -58,14 +58,14 @@ class KernelEmitter : public KernelEmitterBase {
|
|||
static_assert(std::is_base_of_v<KernelSource, Source>,
|
||||
"Source must be a subclass of KernelSource");
|
||||
|
||||
virtual absl::StatusOr<KernelDefinition<Source>> EmitKernelDefinition() = 0;
|
||||
using KernelDefinition = ::xla::KernelDefinition<Source>;
|
||||
virtual absl::StatusOr<KernelDefinition> EmitKernelDefinition() = 0;
|
||||
|
||||
private:
|
||||
absl::StatusOr<std::unique_ptr<KernelDefinitionBase>>
|
||||
EmitKernelDefinitionBase() final {
|
||||
TF_ASSIGN_OR_RETURN(auto kernel_definition, EmitKernelDefinition());
|
||||
return std::make_unique<KernelDefinition<Source>>(
|
||||
std::move(kernel_definition));
|
||||
return std::make_unique<KernelDefinition>(std::move(kernel_definition));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ limitations under the License.
|
|||
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "xla/codegen/kernel_definition.h"
|
||||
#include "xla/codegen/kernel_source.h"
|
||||
|
||||
namespace xla {
|
||||
|
|
@ -51,8 +50,6 @@ class LlvmKernelSource final : public KernelSource {
|
|||
llvm::orc::ThreadSafeModule module_;
|
||||
};
|
||||
|
||||
using LlvmKernelDefinition = KernelDefinition<LlvmKernelSource>; // NOLINT
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // XLA_CODEGEN_LLVM_KERNEL_SOURCE_H_
|
||||
|
|
|
|||
|
|
@ -206,9 +206,9 @@ NB_MODULE(_extension, kernel_runner_module) {
|
|||
},
|
||||
nb::rv_policy::reference_internal);
|
||||
|
||||
nb::class_<MlirKernelDefinition, KernelDefinitionBase>(
|
||||
nb::class_<KernelDefinition<MlirKernelSource>, KernelDefinitionBase>(
|
||||
kernel_runner_module, "MlirKernelDefinition");
|
||||
nb::class_<LlvmKernelDefinition, KernelDefinitionBase>(
|
||||
nb::class_<KernelDefinition<LlvmKernelSource>, KernelDefinitionBase>(
|
||||
kernel_runner_module, "LlvmKernelDefinition");
|
||||
|
||||
nb::class_<KernelEmitterBase>(kernel_runner_module, "KernelEmitterBase")
|
||||
|
|
|
|||
1
third_party/xla/xla/service/cpu/BUILD
vendored
1
third_party/xla/xla/service/cpu/BUILD
vendored
|
|
@ -2047,6 +2047,7 @@ cc_library(
|
|||
deps = [
|
||||
"//xla/backends/cpu/codegen:fusion_compiler",
|
||||
"//xla/backends/cpu/codegen:fusion_emitter",
|
||||
"//xla/codegen:kernel_definition",
|
||||
"//xla/codegen:kernel_spec",
|
||||
"//xla/codegen:llvm_kernel_source",
|
||||
"//xla/codegen:mlir_kernel_source",
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ limitations under the License.
|
|||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "xla/backends/cpu/codegen/fusion_compiler.h"
|
||||
#include "xla/backends/cpu/codegen/fusion_emitter.h"
|
||||
#include "xla/codegen/kernel_definition.h"
|
||||
#include "xla/codegen/kernel_spec.h"
|
||||
#include "xla/codegen/llvm_kernel_source.h"
|
||||
#include "xla/codegen/mlir_kernel_source.h"
|
||||
|
|
@ -190,7 +191,7 @@ absl::StatusOr<KernelSpec> ParallelFusionEmitter::AddFusion(
|
|||
return spec;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<LlvmKernelDefinition>>
|
||||
absl::StatusOr<std::vector<KernelDefinition<LlvmKernelSource>>>
|
||||
ParallelFusionEmitter::ConsumeKernels() {
|
||||
absl::MutexLock lock(kernels_mutex_);
|
||||
|
||||
|
|
@ -203,8 +204,8 @@ ParallelFusionEmitter::ConsumeKernels() {
|
|||
}
|
||||
|
||||
// Sort the kernels by name to ensure a deterministic order.
|
||||
absl::c_sort(kernels_, [](const LlvmKernelDefinition& lhs,
|
||||
const LlvmKernelDefinition& rhs) {
|
||||
absl::c_sort(kernels_, [](const KernelDefinition<LlvmKernelSource>& lhs,
|
||||
const KernelDefinition<LlvmKernelSource>& rhs) {
|
||||
return lhs.spec().name() < rhs.spec().name();
|
||||
});
|
||||
|
||||
|
|
@ -212,12 +213,12 @@ ParallelFusionEmitter::ConsumeKernels() {
|
|||
}
|
||||
|
||||
void ParallelFusionEmitter::CompileFusion(
|
||||
std::shared_ptr<MlirKernelDefinition> mlir_kernel_definition,
|
||||
std::shared_ptr<KernelDefinition<MlirKernelSource>> mlir_kernel,
|
||||
std::shared_ptr<CompilerInstance> compiler_instance) {
|
||||
KernelSpec spec = mlir_kernel_definition->spec();
|
||||
KernelSpec spec = mlir_kernel->spec();
|
||||
absl::StatusOr<LlvmKernelSource> llvm_kernel_source =
|
||||
compiler_instance->compiler->Compile(
|
||||
std::move(*mlir_kernel_definition).TakeSource());
|
||||
std::move(*mlir_kernel).TakeSource());
|
||||
|
||||
absl::MutexLock lock(kernels_mutex_);
|
||||
outstanding_kernels_--;
|
||||
|
|
|
|||
|
|
@ -25,8 +25,10 @@ limitations under the License.
|
|||
#include "absl/status/statusor.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "xla/backends/cpu/codegen/fusion_compiler.h"
|
||||
#include "xla/codegen/kernel_definition.h"
|
||||
#include "xla/codegen/kernel_spec.h"
|
||||
#include "xla/codegen/llvm_kernel_source.h"
|
||||
#include "xla/codegen/mlir_kernel_source.h"
|
||||
#include "xla/hlo/ir/hlo_instructions.h"
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
#include "xla/tsl/platform/threadpool.h"
|
||||
|
|
@ -52,14 +54,15 @@ class ParallelFusionEmitter {
|
|||
|
||||
// Returns the kernels for all the added fusions, blocks until all kernels
|
||||
// have been compiled.
|
||||
absl::StatusOr<std::vector<LlvmKernelDefinition>> ConsumeKernels();
|
||||
absl::StatusOr<std::vector<KernelDefinition<LlvmKernelSource>>>
|
||||
ConsumeKernels();
|
||||
|
||||
private:
|
||||
struct CompilerInstance;
|
||||
class FusionCompilerPool;
|
||||
|
||||
void CompileFusion(
|
||||
std::shared_ptr<MlirKernelDefinition> mlir_kernel_definition,
|
||||
std::shared_ptr<KernelDefinition<MlirKernelSource>> mlir_kernel,
|
||||
std::shared_ptr<CompilerInstance> compiler_instance);
|
||||
|
||||
tsl::thread::ThreadPool& thread_pool_;
|
||||
|
|
@ -70,7 +73,8 @@ class ParallelFusionEmitter {
|
|||
absl::Mutex kernels_mutex_;
|
||||
int64_t outstanding_kernels_ ABSL_GUARDED_BY(kernels_mutex_) = 0;
|
||||
absl::Status kernels_status_ ABSL_GUARDED_BY(kernels_mutex_);
|
||||
std::vector<LlvmKernelDefinition> kernels_ ABSL_GUARDED_BY(kernels_mutex_);
|
||||
std::vector<KernelDefinition<LlvmKernelSource>> kernels_
|
||||
ABSL_GUARDED_BY(kernels_mutex_);
|
||||
};
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ TEST_F(ParallelFusionEmitterTest, HappyPathSingleFusion) {
|
|||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto kernels, fussion_emitter.ConsumeKernels());
|
||||
ASSERT_EQ(kernels.size(), 1);
|
||||
LlvmKernelDefinition& lowered_kernel = kernels[0];
|
||||
KernelDefinition<LlvmKernelSource>& lowered_kernel = kernels[0];
|
||||
EXPECT_EQ(lowered_kernel.spec().name(), expected_name);
|
||||
auto source = std::move(lowered_kernel).TakeSource();
|
||||
|
||||
|
|
|
|||
15
third_party/xla/xla/service/cpu/thunk_emitter.cc
vendored
15
third_party/xla/xla/service/cpu/thunk_emitter.cc
vendored
|
|
@ -234,11 +234,12 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitEntryComputation(
|
|||
absl::StatusOr<std::vector<ThunkEmitter::EmittedKernel>>
|
||||
ThunkEmitter::ConsumeKernels() {
|
||||
tsl::profiler::TraceMe trace("ThunkEmitter::ConsumeKernels");
|
||||
TF_ASSIGN_OR_RETURN(std::vector<LlvmKernelDefinition> fusion_kernels,
|
||||
parallel_fusion_emitter_.ConsumeKernels());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<KernelDefinition<LlvmKernelSource>> fusion_kernels,
|
||||
parallel_fusion_emitter_.ConsumeKernels());
|
||||
|
||||
kernels_.reserve(kernels_.size() + fusion_kernels.size());
|
||||
for (LlvmKernelDefinition& kernel : fusion_kernels) {
|
||||
for (KernelDefinition<LlvmKernelSource>& kernel : fusion_kernels) {
|
||||
std::string name(kernel.spec().name());
|
||||
auto source = std::move(kernel).TakeSource();
|
||||
kernels_.push_back({name, std::move(source).thread_safe_module()});
|
||||
|
|
@ -688,7 +689,7 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitCallThunk(
|
|||
maybe_small_call.has_value() && *maybe_small_call == "true") {
|
||||
ComputationKernelEmitter emitter(instruction, &buffer_assignment_,
|
||||
&target_machine_features_);
|
||||
TF_ASSIGN_OR_RETURN(LlvmKernelDefinition kernel_definition,
|
||||
TF_ASSIGN_OR_RETURN(KernelDefinition kernel_definition,
|
||||
emitter.EmitKernelDefinition());
|
||||
|
||||
auto kernel_spec = kernel_definition.spec();
|
||||
|
|
@ -712,7 +713,7 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitConcatenateKernelThunk(
|
|||
const HloInstruction* instruction) {
|
||||
ConcatenateKernelEmitter emitter(instruction, &buffer_assignment_,
|
||||
&target_machine_features_);
|
||||
TF_ASSIGN_OR_RETURN(LlvmKernelDefinition kernel_definition,
|
||||
TF_ASSIGN_OR_RETURN(KernelDefinition kernel_definition,
|
||||
emitter.EmitKernelDefinition());
|
||||
|
||||
auto kernel_spec = kernel_definition.spec();
|
||||
|
|
@ -818,7 +819,7 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitElementalKernelThunk(
|
|||
const HloInstruction* instruction) {
|
||||
ElementalKernelEmitter emitter(instruction, &buffer_assignment_,
|
||||
&target_machine_features_);
|
||||
TF_ASSIGN_OR_RETURN(LlvmKernelDefinition kernel_definition,
|
||||
TF_ASSIGN_OR_RETURN(KernelDefinition kernel_definition,
|
||||
emitter.EmitKernelDefinition());
|
||||
|
||||
auto kernel_spec = kernel_definition.spec();
|
||||
|
|
@ -1062,7 +1063,7 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitDotThunk(
|
|||
case DotImplementationStrategy::kTiledLlvmIrGemv: {
|
||||
DotKernelEmitter emitter(instruction, &buffer_assignment_,
|
||||
&target_machine_features_);
|
||||
TF_ASSIGN_OR_RETURN(LlvmKernelDefinition kernel_definition,
|
||||
TF_ASSIGN_OR_RETURN(KernelDefinition kernel_definition,
|
||||
emitter.EmitKernelDefinition());
|
||||
|
||||
auto kernel_spec = kernel_definition.spec();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user