[xla:cpu] Cleanup KernelSpec to use absl::string_view and absl::Span

PiperOrigin-RevId: 825075908
This commit is contained in:
Eugene Zhulenev 2025-10-28 09:30:30 -07:00 committed by TensorFlower Gardener
parent 10fd9cfebb
commit d75ad2c4ff
8 changed files with 38 additions and 21 deletions

View File

@ -136,9 +136,9 @@ cc_library(
deps = [ deps = [
":kernel_c_api", ":kernel_c_api",
"//xla/tsl/lib/gtl:int_type", "//xla/tsl/lib/gtl:int_type",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:statusor",
], ],
) )

View File

@ -19,13 +19,12 @@ limitations under the License.
#include <cstdint> #include <cstdint>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <utility>
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "xla/backends/cpu/runtime/kernel_c_api.h" #include "xla/backends/cpu/runtime/kernel_c_api.h"
#include "xla/tsl/lib/gtl/int_type.h" #include "xla/tsl/lib/gtl/int_type.h"
#include "tsl/platform/statusor.h" #include "xla/tsl/platform/statusor.h"
namespace xla::cpu { namespace xla::cpu {
@ -64,14 +63,14 @@ class FunctionLibrary {
}; };
template <typename F, std::enable_if_t<std::is_function_v<F>>* = nullptr> template <typename F, std::enable_if_t<std::is_function_v<F>>* = nullptr>
static Symbol Sym(std::string name) { static Symbol Sym(absl::string_view name) {
return Symbol{GetTypeId<F>(), std::move(name)}; return Symbol{GetTypeId<F>(), std::string(name)};
} }
template <typename F, std::enable_if_t<std::is_function_v<F>>* = nullptr> template <typename F, std::enable_if_t<std::is_function_v<F>>* = nullptr>
absl::StatusOr<F*> ResolveFunction(absl::string_view name) { absl::StatusOr<F*> ResolveFunction(absl::string_view name) {
TF_ASSIGN_OR_RETURN(void* ptr, ResolveFunction(GetTypeId<F>(), name)); TF_ASSIGN_OR_RETURN(void* ptr, ResolveFunction(GetTypeId<F>(), name));
return reinterpret_cast<F*>(ptr); return reinterpret_cast<F*>(ptr); // NOLINT
} }
protected: protected:

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/buffer_allocations.h"
#include "xla/backends/cpu/runtime/function_library.h" #include "xla/backends/cpu/runtime/function_library.h"
@ -46,7 +47,6 @@ limitations under the License.
#include "xla/runtime/work_group.h" #include "xla/runtime/work_group.h"
#include "xla/service/buffer_assignment.h" #include "xla/service/buffer_assignment.h"
#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/statusor.h"
@ -64,7 +64,9 @@ namespace internal {
static absl::Status CheckBufferAlignment( static absl::Status CheckBufferAlignment(
const Thunk::Info& info, uint64_t min_alignment, const Thunk::Info& info, uint64_t min_alignment,
absl::Span<const XLA_CPU_KernelArg> kernel_args) { absl::Span<const XLA_CPU_KernelArg> kernel_args) {
if (min_alignment == 0) return absl::OkStatus(); if (min_alignment == 0) {
return absl::OkStatus();
}
for (int64_t i = 0; i < kernel_args.size(); ++i) { for (int64_t i = 0; i < kernel_args.size(); ++i) {
auto ptr = reinterpret_cast<uintptr_t>(kernel_args[i].data); auto ptr = reinterpret_cast<uintptr_t>(kernel_args[i].data);
@ -114,8 +116,9 @@ template <int64_t num_arguments, int64_t num_results>
KernelThunk<num_arguments, num_results>::KernelThunk( KernelThunk<num_arguments, num_results>::KernelThunk(
Info info, absl::Span<const BufferAllocation::Slice> arguments_buffers, Info info, absl::Span<const BufferAllocation::Slice> arguments_buffers,
absl::Span<const BufferAllocation::Slice> results_buffers, absl::Span<const BufferAllocation::Slice> results_buffers,
absl::flat_hash_set<int64_t> invariant_arguments, std::string kernel_name, absl::flat_hash_set<int64_t> invariant_arguments,
NumWorkGroups num_workgroups, std::optional<uint64_t> min_alignment) absl::string_view kernel_name, NumWorkGroups num_workgroups,
std::optional<uint64_t> min_alignment)
: KernelThunkBase(Kind::kKernel, std::move(info)), : KernelThunkBase(Kind::kKernel, std::move(info)),
invariant_arguments_(std::move(invariant_arguments)), invariant_arguments_(std::move(invariant_arguments)),
num_kernel_args_(arguments_buffers.size() + results_buffers.size()), num_kernel_args_(arguments_buffers.size() + results_buffers.size()),
@ -312,7 +315,7 @@ absl::StatusOr<std::unique_ptr<Thunk>> KernelThunk::Create(
Thunk::Info info, Thunk::Info info,
absl::Span<const BufferAllocation::Slice> arguments_buffers, absl::Span<const BufferAllocation::Slice> arguments_buffers,
absl::Span<const BufferAllocation::Slice> results_buffers, absl::Span<const BufferAllocation::Slice> results_buffers,
std::string kernel_name, NumWorkGroups num_workgroups, absl::string_view kernel_name, NumWorkGroups num_workgroups,
absl::flat_hash_set<int64_t> invariant_arguments, absl::flat_hash_set<int64_t> invariant_arguments,
std::optional<uint64_t> min_alignment) { std::optional<uint64_t> min_alignment) {
if (min_alignment.has_value() && !absl::has_single_bit(*min_alignment)) { if (min_alignment.has_value() && !absl::has_single_bit(*min_alignment)) {
@ -324,8 +327,8 @@ absl::StatusOr<std::unique_ptr<Thunk>> KernelThunk::Create(
return absl::WrapUnique( return absl::WrapUnique(
new SmallKernelThunk<num_arguments(), num_results()>( new SmallKernelThunk<num_arguments(), num_results()>(
std::move(info), arguments_buffers, results_buffers, std::move(info), arguments_buffers, results_buffers,
std::move(invariant_arguments), std::move(kernel_name), std::move(invariant_arguments), kernel_name, num_workgroups,
num_workgroups, min_alignment)); min_alignment));
}; };
static constexpr auto _0 = std::integral_constant<size_t, 0>{}; static constexpr auto _0 = std::integral_constant<size_t, 0>{};

View File

@ -137,7 +137,7 @@ class KernelThunk : public KernelThunkBase {
absl::Span<const BufferAllocation::Slice> arguments_buffers, absl::Span<const BufferAllocation::Slice> arguments_buffers,
absl::Span<const BufferAllocation::Slice> results_buffers, absl::Span<const BufferAllocation::Slice> results_buffers,
absl::flat_hash_set<int64_t> invariant_arguments, absl::flat_hash_set<int64_t> invariant_arguments,
std::string kernel_name, NumWorkGroups num_workgroups, absl::string_view kernel_name, NumWorkGroups num_workgroups,
std::optional<uint64_t> min_alignment); std::optional<uint64_t> min_alignment);
absl::Status CheckInvariantBuffersMemory(const KernelArgs& kernel_args) const; absl::Status CheckInvariantBuffersMemory(const KernelArgs& kernel_args) const;
@ -196,7 +196,7 @@ class KernelThunk final : public internal::KernelThunk<> {
Thunk::Info info, Thunk::Info info,
absl::Span<const BufferAllocation::Slice> arguments_buffers, absl::Span<const BufferAllocation::Slice> arguments_buffers,
absl::Span<const BufferAllocation::Slice> results_buffers, absl::Span<const BufferAllocation::Slice> results_buffers,
std::string kernel_name, NumWorkGroups num_workgroups, absl::string_view kernel_name, NumWorkGroups num_workgroups,
absl::flat_hash_set<int64_t> invariant_arguments, absl::flat_hash_set<int64_t> invariant_arguments,
std::optional<uint64_t> min_alignment = std::nullopt); std::optional<uint64_t> min_alignment = std::nullopt);

View File

@ -62,7 +62,7 @@ absl::StatusOr<KernelRunner> KernelRunner::Create(
TF_RETURN_IF_ERROR(compiler.AddModule(std::move(thread_safe_module))); TF_RETURN_IF_ERROR(compiler.AddModule(std::move(thread_safe_module)));
const std::string& kernel_name = spec.name(); absl::string_view kernel_name = spec.name();
TF_ASSIGN_OR_RETURN(std::unique_ptr<FunctionLibrary> library, TF_ASSIGN_OR_RETURN(std::unique_ptr<FunctionLibrary> library,
std::move(compiler).Compile( std::move(compiler).Compile(
{FunctionLibrary::Sym<XLA_CPU_Kernel>(kernel_name)})); {FunctionLibrary::Sym<XLA_CPU_Kernel>(kernel_name)}));

View File

@ -81,6 +81,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
], ],
) )

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h" #include "absl/container/inlined_vector.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/runtime/work_cluster.h" #include "xla/runtime/work_cluster.h"
#include "xla/runtime/work_dimensions.h" #include "xla/runtime/work_dimensions.h"
#include "xla/runtime/work_group.h" #include "xla/runtime/work_group.h"
@ -50,9 +51,9 @@ class KernelSpec {
absl::flat_hash_set<int64_t> invariant_arguments, absl::flat_hash_set<int64_t> invariant_arguments,
std::optional<size_t> scratch_bytes = std::nullopt); std::optional<size_t> scratch_bytes = std::nullopt);
// Get the backend specific name of the kernel. // Get the backend specific name of the kernel. This may be used to identify
// This may be used to identify the kernel in the backend specific runtime. // the kernel in the backend specific runtime.
const std::string& name() const { return name_; } absl::string_view name() const { return name_; }
// Kernel work dimensions define how the kernel execution must be // Kernel work dimensions define how the kernel execution must be
// parallelized. The meaning of these dimensions is backend specific, i.e. // parallelized. The meaning of these dimensions is backend specific, i.e.
@ -66,12 +67,15 @@ class KernelSpec {
// on the exact meaning of these dimensions and how they are mapped to the // on the exact meaning of these dimensions and how they are mapped to the
// underlying hardware, and how to use them for perfrormance optimization. // underlying hardware, and how to use them for perfrormance optimization.
WorkDimensions work_dimensions() const { return work_dimensions_; } WorkDimensions work_dimensions() const { return work_dimensions_; }
NumWorkClusters num_workclusters() const { NumWorkClusters num_workclusters() const {
return work_dimensions_.num_work_clusters; return work_dimensions_.num_work_clusters;
} }
NumWorkGroups num_workgroups() const { NumWorkGroups num_workgroups() const {
return work_dimensions_.num_work_groups; return work_dimensions_.num_work_groups;
} }
NumWorkItems num_workitems() const { return work_dimensions_.num_work_items; } NumWorkItems num_workitems() const { return work_dimensions_.num_work_items; }
// Requested amount of scratch bytes for the kernel (backed by backend // Requested amount of scratch bytes for the kernel (backed by backend
@ -80,9 +84,14 @@ class KernelSpec {
std::optional<size_t> scratch_bytes() const { return scratch_bytes_; } std::optional<size_t> scratch_bytes() const { return scratch_bytes_; }
// Argument buffers read by the kernel. // Argument buffers read by the kernel.
const Buffers& argument_buffers() const { return argument_buffers_; } absl::Span<const BufferAllocation::Slice> argument_buffers() const {
return argument_buffers_;
}
// Result buffers written to by the kernel. // Result buffers written to by the kernel.
const Buffers& result_buffers() const { return result_buffers_; } absl::Span<const BufferAllocation::Slice> result_buffers() const {
return result_buffers_;
}
// Returns a set of invariant arguments (corresponding to the indices in the // Returns a set of invariant arguments (corresponding to the indices in the
// argument buffers list). // argument buffers list).

View File

@ -20,11 +20,13 @@ limitations under the License.
#include <memory> #include <memory>
#include <optional> #include <optional>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" #include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
@ -69,6 +71,9 @@ class ThunkEmitter {
}; };
struct EmittedKernel { struct EmittedKernel {
EmittedKernel(absl::string_view name, llvm::orc::ThreadSafeModule module)
: kernel_name(name), module(std::move(module)) {}
std::string kernel_name; std::string kernel_name;
llvm::orc::ThreadSafeModule module; llvm::orc::ThreadSafeModule module;
}; };