[xla:cpu] Cleanup KernelSpec to use absl::string_view and absl::Span

PiperOrigin-RevId: 825075908
This commit is contained in:
Eugene Zhulenev 2025-10-28 09:30:30 -07:00 committed by TensorFlower Gardener
parent 10fd9cfebb
commit d75ad2c4ff
8 changed files with 38 additions and 21 deletions

View File

@ -136,9 +136,9 @@ cc_library(
deps = [
":kernel_c_api",
"//xla/tsl/lib/gtl:int_type",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:statusor",
],
)

View File

@ -19,13 +19,12 @@ limitations under the License.
#include <cstdint>
#include <string>
#include <type_traits>
#include <utility>
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/backends/cpu/runtime/kernel_c_api.h"
#include "xla/tsl/lib/gtl/int_type.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/statusor.h"
namespace xla::cpu {
@ -64,14 +63,14 @@ class FunctionLibrary {
};
template <typename F, std::enable_if_t<std::is_function_v<F>>* = nullptr>
static Symbol Sym(std::string name) {
return Symbol{GetTypeId<F>(), std::move(name)};
static Symbol Sym(absl::string_view name) {
return Symbol{GetTypeId<F>(), std::string(name)};
}
template <typename F, std::enable_if_t<std::is_function_v<F>>* = nullptr>
absl::StatusOr<F*> ResolveFunction(absl::string_view name) {
TF_ASSIGN_OR_RETURN(void* ptr, ResolveFunction(GetTypeId<F>(), name));
return reinterpret_cast<F*>(ptr);
return reinterpret_cast<F*>(ptr); // NOLINT
}
protected:

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/backends/cpu/runtime/buffer_allocations.h"
#include "xla/backends/cpu/runtime/function_library.h"
@ -46,7 +47,6 @@ limitations under the License.
#include "xla/runtime/work_group.h"
#include "xla/service/buffer_assignment.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
@ -64,7 +64,9 @@ namespace internal {
static absl::Status CheckBufferAlignment(
const Thunk::Info& info, uint64_t min_alignment,
absl::Span<const XLA_CPU_KernelArg> kernel_args) {
if (min_alignment == 0) return absl::OkStatus();
if (min_alignment == 0) {
return absl::OkStatus();
}
for (int64_t i = 0; i < kernel_args.size(); ++i) {
auto ptr = reinterpret_cast<uintptr_t>(kernel_args[i].data);
@ -114,8 +116,9 @@ template <int64_t num_arguments, int64_t num_results>
KernelThunk<num_arguments, num_results>::KernelThunk(
Info info, absl::Span<const BufferAllocation::Slice> arguments_buffers,
absl::Span<const BufferAllocation::Slice> results_buffers,
absl::flat_hash_set<int64_t> invariant_arguments, std::string kernel_name,
NumWorkGroups num_workgroups, std::optional<uint64_t> min_alignment)
absl::flat_hash_set<int64_t> invariant_arguments,
absl::string_view kernel_name, NumWorkGroups num_workgroups,
std::optional<uint64_t> min_alignment)
: KernelThunkBase(Kind::kKernel, std::move(info)),
invariant_arguments_(std::move(invariant_arguments)),
num_kernel_args_(arguments_buffers.size() + results_buffers.size()),
@ -312,7 +315,7 @@ absl::StatusOr<std::unique_ptr<Thunk>> KernelThunk::Create(
Thunk::Info info,
absl::Span<const BufferAllocation::Slice> arguments_buffers,
absl::Span<const BufferAllocation::Slice> results_buffers,
std::string kernel_name, NumWorkGroups num_workgroups,
absl::string_view kernel_name, NumWorkGroups num_workgroups,
absl::flat_hash_set<int64_t> invariant_arguments,
std::optional<uint64_t> min_alignment) {
if (min_alignment.has_value() && !absl::has_single_bit(*min_alignment)) {
@ -324,8 +327,8 @@ absl::StatusOr<std::unique_ptr<Thunk>> KernelThunk::Create(
return absl::WrapUnique(
new SmallKernelThunk<num_arguments(), num_results()>(
std::move(info), arguments_buffers, results_buffers,
std::move(invariant_arguments), std::move(kernel_name),
num_workgroups, min_alignment));
std::move(invariant_arguments), kernel_name, num_workgroups,
min_alignment));
};
static constexpr auto _0 = std::integral_constant<size_t, 0>{};

View File

@ -137,7 +137,7 @@ class KernelThunk : public KernelThunkBase {
absl::Span<const BufferAllocation::Slice> arguments_buffers,
absl::Span<const BufferAllocation::Slice> results_buffers,
absl::flat_hash_set<int64_t> invariant_arguments,
std::string kernel_name, NumWorkGroups num_workgroups,
absl::string_view kernel_name, NumWorkGroups num_workgroups,
std::optional<uint64_t> min_alignment);
absl::Status CheckInvariantBuffersMemory(const KernelArgs& kernel_args) const;
@ -196,7 +196,7 @@ class KernelThunk final : public internal::KernelThunk<> {
Thunk::Info info,
absl::Span<const BufferAllocation::Slice> arguments_buffers,
absl::Span<const BufferAllocation::Slice> results_buffers,
std::string kernel_name, NumWorkGroups num_workgroups,
absl::string_view kernel_name, NumWorkGroups num_workgroups,
absl::flat_hash_set<int64_t> invariant_arguments,
std::optional<uint64_t> min_alignment = std::nullopt);

View File

@ -62,7 +62,7 @@ absl::StatusOr<KernelRunner> KernelRunner::Create(
TF_RETURN_IF_ERROR(compiler.AddModule(std::move(thread_safe_module)));
const std::string& kernel_name = spec.name();
absl::string_view kernel_name = spec.name();
TF_ASSIGN_OR_RETURN(std::unique_ptr<FunctionLibrary> library,
std::move(compiler).Compile(
{FunctionLibrary::Sym<XLA_CPU_Kernel>(kernel_name)}));

View File

@ -81,6 +81,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
],
)

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/runtime/work_cluster.h"
#include "xla/runtime/work_dimensions.h"
#include "xla/runtime/work_group.h"
@ -50,9 +51,9 @@ class KernelSpec {
absl::flat_hash_set<int64_t> invariant_arguments,
std::optional<size_t> scratch_bytes = std::nullopt);
// Get the backend specific name of the kernel.
// This may be used to identify the kernel in the backend specific runtime.
const std::string& name() const { return name_; }
// Get the backend specific name of the kernel. This may be used to identify
// the kernel in the backend specific runtime.
absl::string_view name() const { return name_; }
// Kernel work dimensions define how the kernel execution must be
// parallelized. The meaning of these dimensions is backend specific, i.e.
@ -66,12 +67,15 @@ class KernelSpec {
// on the exact meaning of these dimensions and how they are mapped to the
// underlying hardware, and how to use them for perfrormance optimization.
WorkDimensions work_dimensions() const { return work_dimensions_; }
NumWorkClusters num_workclusters() const {
return work_dimensions_.num_work_clusters;
}
NumWorkGroups num_workgroups() const {
return work_dimensions_.num_work_groups;
}
NumWorkItems num_workitems() const { return work_dimensions_.num_work_items; }
// Requested amount of scratch bytes for the kernel (backed by backend
@ -80,9 +84,14 @@ class KernelSpec {
std::optional<size_t> scratch_bytes() const { return scratch_bytes_; }
// Argument buffers read by the kernel.
const Buffers& argument_buffers() const { return argument_buffers_; }
absl::Span<const BufferAllocation::Slice> argument_buffers() const {
return argument_buffers_;
}
// Result buffers written to by the kernel.
const Buffers& result_buffers() const { return result_buffers_; }
absl::Span<const BufferAllocation::Slice> result_buffers() const {
return result_buffers_;
}
// Returns a set of invariant arguments (corresponding to the indices in the
// argument buffers list).

View File

@ -20,11 +20,13 @@ limitations under the License.
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
#include "mlir/IR/MLIRContext.h"
@ -69,6 +71,9 @@ class ThunkEmitter {
};
struct EmittedKernel {
EmittedKernel(absl::string_view name, llvm::orc::ThreadSafeModule module)
: kernel_name(name), module(std::move(module)) {}
std::string kernel_name;
llvm::orc::ThreadSafeModule module;
};