diff --git a/third_party/xla/xla/backends/cpu/runtime/BUILD b/third_party/xla/xla/backends/cpu/runtime/BUILD index 2a891e0fe78..e43a0663e26 100644 --- a/third_party/xla/xla/backends/cpu/runtime/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/BUILD @@ -136,9 +136,9 @@ cc_library( deps = [ ":kernel_c_api", "//xla/tsl/lib/gtl:int_type", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/backends/cpu/runtime/function_library.h b/third_party/xla/xla/backends/cpu/runtime/function_library.h index d9a685fb16a..3b269964238 100644 --- a/third_party/xla/xla/backends/cpu/runtime/function_library.h +++ b/third_party/xla/xla/backends/cpu/runtime/function_library.h @@ -19,13 +19,12 @@ limitations under the License. #include #include #include -#include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/backends/cpu/runtime/kernel_c_api.h" #include "xla/tsl/lib/gtl/int_type.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" namespace xla::cpu { @@ -64,14 +63,14 @@ class FunctionLibrary { }; template >* = nullptr> - static Symbol Sym(std::string name) { - return Symbol{GetTypeId(), std::move(name)}; + static Symbol Sym(absl::string_view name) { + return Symbol{GetTypeId(), std::string(name)}; } template >* = nullptr> absl::StatusOr ResolveFunction(absl::string_view name) { TF_ASSIGN_OR_RETURN(void* ptr, ResolveFunction(GetTypeId(), name)); - return reinterpret_cast(ptr); + return reinterpret_cast(ptr); // NOLINT } protected: diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc index efa44a13fde..90fd1b74c3c 100644 --- a/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc @@ -35,6 +35,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/function_library.h" @@ -46,7 +47,6 @@ limitations under the License. #include "xla/runtime/work_group.h" #include "xla/service/buffer_assignment.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" @@ -64,7 +64,9 @@ namespace internal { static absl::Status CheckBufferAlignment( const Thunk::Info& info, uint64_t min_alignment, absl::Span kernel_args) { - if (min_alignment == 0) return absl::OkStatus(); + if (min_alignment == 0) { + return absl::OkStatus(); + } for (int64_t i = 0; i < kernel_args.size(); ++i) { auto ptr = reinterpret_cast(kernel_args[i].data); @@ -114,8 +116,9 @@ template KernelThunk::KernelThunk( Info info, absl::Span arguments_buffers, absl::Span results_buffers, - absl::flat_hash_set invariant_arguments, std::string kernel_name, - NumWorkGroups num_workgroups, std::optional min_alignment) + absl::flat_hash_set invariant_arguments, + absl::string_view kernel_name, NumWorkGroups num_workgroups, + std::optional min_alignment) : KernelThunkBase(Kind::kKernel, std::move(info)), invariant_arguments_(std::move(invariant_arguments)), num_kernel_args_(arguments_buffers.size() + results_buffers.size()), @@ -312,7 +315,7 @@ absl::StatusOr> KernelThunk::Create( Thunk::Info info, absl::Span arguments_buffers, absl::Span results_buffers, - std::string kernel_name, NumWorkGroups num_workgroups, + absl::string_view kernel_name, NumWorkGroups num_workgroups, absl::flat_hash_set invariant_arguments, std::optional min_alignment) { if (min_alignment.has_value() && !absl::has_single_bit(*min_alignment)) { @@ -324,8 +327,8 @@ absl::StatusOr> KernelThunk::Create( return absl::WrapUnique( new SmallKernelThunk( std::move(info), arguments_buffers, results_buffers, - std::move(invariant_arguments), std::move(kernel_name), - num_workgroups, min_alignment)); + std::move(invariant_arguments), kernel_name, num_workgroups, + min_alignment)); }; static constexpr auto _0 = std::integral_constant{}; diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.h b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.h index 4b7519bcf9f..1a4500bbbfa 100644 --- a/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.h @@ -137,7 +137,7 @@ class KernelThunk : public KernelThunkBase { absl::Span arguments_buffers, absl::Span results_buffers, absl::flat_hash_set invariant_arguments, - std::string kernel_name, NumWorkGroups num_workgroups, + absl::string_view kernel_name, NumWorkGroups num_workgroups, std::optional min_alignment); absl::Status CheckInvariantBuffersMemory(const KernelArgs& kernel_args) const; @@ -196,7 +196,7 @@ class KernelThunk final : public internal::KernelThunk<> { Thunk::Info info, absl::Span arguments_buffers, absl::Span results_buffers, - std::string kernel_name, NumWorkGroups num_workgroups, + absl::string_view kernel_name, NumWorkGroups num_workgroups, absl::flat_hash_set invariant_arguments, std::optional min_alignment = std::nullopt); diff --git a/third_party/xla/xla/backends/cpu/testlib/kernel_runner.cc b/third_party/xla/xla/backends/cpu/testlib/kernel_runner.cc index 49b7bc25c9e..4f36d7614c0 100644 --- a/third_party/xla/xla/backends/cpu/testlib/kernel_runner.cc +++ b/third_party/xla/xla/backends/cpu/testlib/kernel_runner.cc @@ -62,7 +62,7 @@ absl::StatusOr KernelRunner::Create( TF_RETURN_IF_ERROR(compiler.AddModule(std::move(thread_safe_module))); - const std::string& kernel_name = spec.name(); + absl::string_view kernel_name = spec.name(); TF_ASSIGN_OR_RETURN(std::unique_ptr library, std::move(compiler).Compile( {FunctionLibrary::Sym(kernel_name)})); diff --git a/third_party/xla/xla/codegen/BUILD b/third_party/xla/xla/codegen/BUILD index e96ef0b88a1..23e72d553ed 100644 --- a/third_party/xla/xla/codegen/BUILD +++ b/third_party/xla/xla/codegen/BUILD @@ -81,6 +81,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/xla/xla/codegen/kernel_spec.h b/third_party/xla/xla/codegen/kernel_spec.h index 01f9ff440dd..5ca578767aa 100644 --- a/third_party/xla/xla/codegen/kernel_spec.h +++ b/third_party/xla/xla/codegen/kernel_spec.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/runtime/work_cluster.h" #include "xla/runtime/work_dimensions.h" #include "xla/runtime/work_group.h" @@ -50,9 +51,9 @@ class KernelSpec { absl::flat_hash_set invariant_arguments, std::optional scratch_bytes = std::nullopt); - // Get the backend specific name of the kernel. - // This may be used to identify the kernel in the backend specific runtime. - const std::string& name() const { return name_; } + // Get the backend specific name of the kernel. This may be used to identify + // the kernel in the backend specific runtime. + absl::string_view name() const { return name_; } // Kernel work dimensions define how the kernel execution must be // parallelized. The meaning of these dimensions is backend specific, i.e. @@ -66,12 +67,15 @@ class KernelSpec { // on the exact meaning of these dimensions and how they are mapped to the // underlying hardware, and how to use them for perfrormance optimization. WorkDimensions work_dimensions() const { return work_dimensions_; } + NumWorkClusters num_workclusters() const { return work_dimensions_.num_work_clusters; } + NumWorkGroups num_workgroups() const { return work_dimensions_.num_work_groups; } + NumWorkItems num_workitems() const { return work_dimensions_.num_work_items; } // Requested amount of scratch bytes for the kernel (backed by backend @@ -80,9 +84,14 @@ class KernelSpec { std::optional scratch_bytes() const { return scratch_bytes_; } // Argument buffers read by the kernel. - const Buffers& argument_buffers() const { return argument_buffers_; } + absl::Span argument_buffers() const { + return argument_buffers_; + } + // Result buffers written to by the kernel. - const Buffers& result_buffers() const { return result_buffers_; } + absl::Span result_buffers() const { + return result_buffers_; + } // Returns a set of invariant arguments (corresponding to the indices in the // argument buffers list). diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.h b/third_party/xla/xla/service/cpu/thunk_emitter.h index 8b052b416d8..bd4e88fae6f 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.h +++ b/third_party/xla/xla/service/cpu/thunk_emitter.h @@ -20,11 +20,13 @@ limitations under the License. #include #include #include +#include #include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" #include "mlir/IR/MLIRContext.h" @@ -69,6 +71,9 @@ class ThunkEmitter { }; struct EmittedKernel { + EmittedKernel(absl::string_view name, llvm::orc::ThreadSafeModule module) + : kernel_name(name), module(std::move(module)) {} + std::string kernel_name; llvm::orc::ThreadSafeModule module; };