mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[xla:cpu] Cleanup KernelSpec to use absl::string_view and absl::Span
PiperOrigin-RevId: 825075908
This commit is contained in:
parent
10fd9cfebb
commit
d75ad2c4ff
|
|
@ -136,9 +136,9 @@ cc_library(
|
|||
deps = [
|
||||
":kernel_c_api",
|
||||
"//xla/tsl/lib/gtl:int_type",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@local_tsl//tsl/platform:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,13 +19,12 @@ limitations under the License.
|
|||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "xla/backends/cpu/runtime/kernel_c_api.h"
|
||||
#include "xla/tsl/lib/gtl/int_type.h"
|
||||
#include "tsl/platform/statusor.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
|
|
@ -64,14 +63,14 @@ class FunctionLibrary {
|
|||
};
|
||||
|
||||
template <typename F, std::enable_if_t<std::is_function_v<F>>* = nullptr>
|
||||
static Symbol Sym(std::string name) {
|
||||
return Symbol{GetTypeId<F>(), std::move(name)};
|
||||
static Symbol Sym(absl::string_view name) {
|
||||
return Symbol{GetTypeId<F>(), std::string(name)};
|
||||
}
|
||||
|
||||
template <typename F, std::enable_if_t<std::is_function_v<F>>* = nullptr>
|
||||
absl::StatusOr<F*> ResolveFunction(absl::string_view name) {
|
||||
TF_ASSIGN_OR_RETURN(void* ptr, ResolveFunction(GetTypeId<F>(), name));
|
||||
return reinterpret_cast<F*>(ptr);
|
||||
return reinterpret_cast<F*>(ptr); // NOLINT
|
||||
}
|
||||
|
||||
protected:
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/backends/cpu/runtime/buffer_allocations.h"
|
||||
#include "xla/backends/cpu/runtime/function_library.h"
|
||||
|
|
@ -46,7 +47,6 @@ limitations under the License.
|
|||
#include "xla/runtime/work_group.h"
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
#include "xla/stream_executor/device_memory.h"
|
||||
#include "xla/stream_executor/launch_dim.h"
|
||||
#include "xla/tsl/concurrency/async_value_ref.h"
|
||||
#include "xla/tsl/platform/errors.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
|
|
@ -64,7 +64,9 @@ namespace internal {
|
|||
static absl::Status CheckBufferAlignment(
|
||||
const Thunk::Info& info, uint64_t min_alignment,
|
||||
absl::Span<const XLA_CPU_KernelArg> kernel_args) {
|
||||
if (min_alignment == 0) return absl::OkStatus();
|
||||
if (min_alignment == 0) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < kernel_args.size(); ++i) {
|
||||
auto ptr = reinterpret_cast<uintptr_t>(kernel_args[i].data);
|
||||
|
|
@ -114,8 +116,9 @@ template <int64_t num_arguments, int64_t num_results>
|
|||
KernelThunk<num_arguments, num_results>::KernelThunk(
|
||||
Info info, absl::Span<const BufferAllocation::Slice> arguments_buffers,
|
||||
absl::Span<const BufferAllocation::Slice> results_buffers,
|
||||
absl::flat_hash_set<int64_t> invariant_arguments, std::string kernel_name,
|
||||
NumWorkGroups num_workgroups, std::optional<uint64_t> min_alignment)
|
||||
absl::flat_hash_set<int64_t> invariant_arguments,
|
||||
absl::string_view kernel_name, NumWorkGroups num_workgroups,
|
||||
std::optional<uint64_t> min_alignment)
|
||||
: KernelThunkBase(Kind::kKernel, std::move(info)),
|
||||
invariant_arguments_(std::move(invariant_arguments)),
|
||||
num_kernel_args_(arguments_buffers.size() + results_buffers.size()),
|
||||
|
|
@ -312,7 +315,7 @@ absl::StatusOr<std::unique_ptr<Thunk>> KernelThunk::Create(
|
|||
Thunk::Info info,
|
||||
absl::Span<const BufferAllocation::Slice> arguments_buffers,
|
||||
absl::Span<const BufferAllocation::Slice> results_buffers,
|
||||
std::string kernel_name, NumWorkGroups num_workgroups,
|
||||
absl::string_view kernel_name, NumWorkGroups num_workgroups,
|
||||
absl::flat_hash_set<int64_t> invariant_arguments,
|
||||
std::optional<uint64_t> min_alignment) {
|
||||
if (min_alignment.has_value() && !absl::has_single_bit(*min_alignment)) {
|
||||
|
|
@ -324,8 +327,8 @@ absl::StatusOr<std::unique_ptr<Thunk>> KernelThunk::Create(
|
|||
return absl::WrapUnique(
|
||||
new SmallKernelThunk<num_arguments(), num_results()>(
|
||||
std::move(info), arguments_buffers, results_buffers,
|
||||
std::move(invariant_arguments), std::move(kernel_name),
|
||||
num_workgroups, min_alignment));
|
||||
std::move(invariant_arguments), kernel_name, num_workgroups,
|
||||
min_alignment));
|
||||
};
|
||||
|
||||
static constexpr auto _0 = std::integral_constant<size_t, 0>{};
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ class KernelThunk : public KernelThunkBase {
|
|||
absl::Span<const BufferAllocation::Slice> arguments_buffers,
|
||||
absl::Span<const BufferAllocation::Slice> results_buffers,
|
||||
absl::flat_hash_set<int64_t> invariant_arguments,
|
||||
std::string kernel_name, NumWorkGroups num_workgroups,
|
||||
absl::string_view kernel_name, NumWorkGroups num_workgroups,
|
||||
std::optional<uint64_t> min_alignment);
|
||||
|
||||
absl::Status CheckInvariantBuffersMemory(const KernelArgs& kernel_args) const;
|
||||
|
|
@ -196,7 +196,7 @@ class KernelThunk final : public internal::KernelThunk<> {
|
|||
Thunk::Info info,
|
||||
absl::Span<const BufferAllocation::Slice> arguments_buffers,
|
||||
absl::Span<const BufferAllocation::Slice> results_buffers,
|
||||
std::string kernel_name, NumWorkGroups num_workgroups,
|
||||
absl::string_view kernel_name, NumWorkGroups num_workgroups,
|
||||
absl::flat_hash_set<int64_t> invariant_arguments,
|
||||
std::optional<uint64_t> min_alignment = std::nullopt);
|
||||
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ absl::StatusOr<KernelRunner> KernelRunner::Create(
|
|||
|
||||
TF_RETURN_IF_ERROR(compiler.AddModule(std::move(thread_safe_module)));
|
||||
|
||||
const std::string& kernel_name = spec.name();
|
||||
absl::string_view kernel_name = spec.name();
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<FunctionLibrary> library,
|
||||
std::move(compiler).Compile(
|
||||
{FunctionLibrary::Sym<XLA_CPU_Kernel>(kernel_name)}));
|
||||
|
|
|
|||
1
third_party/xla/xla/codegen/BUILD
vendored
1
third_party/xla/xla/codegen/BUILD
vendored
|
|
@ -81,6 +81,7 @@ cc_library(
|
|||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
19
third_party/xla/xla/codegen/kernel_spec.h
vendored
19
third_party/xla/xla/codegen/kernel_spec.h
vendored
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/runtime/work_cluster.h"
|
||||
#include "xla/runtime/work_dimensions.h"
|
||||
#include "xla/runtime/work_group.h"
|
||||
|
|
@ -50,9 +51,9 @@ class KernelSpec {
|
|||
absl::flat_hash_set<int64_t> invariant_arguments,
|
||||
std::optional<size_t> scratch_bytes = std::nullopt);
|
||||
|
||||
// Get the backend specific name of the kernel.
|
||||
// This may be used to identify the kernel in the backend specific runtime.
|
||||
const std::string& name() const { return name_; }
|
||||
// Get the backend specific name of the kernel. This may be used to identify
|
||||
// the kernel in the backend specific runtime.
|
||||
absl::string_view name() const { return name_; }
|
||||
|
||||
// Kernel work dimensions define how the kernel execution must be
|
||||
// parallelized. The meaning of these dimensions is backend specific, i.e.
|
||||
|
|
@ -66,12 +67,15 @@ class KernelSpec {
|
|||
// on the exact meaning of these dimensions and how they are mapped to the
|
||||
// underlying hardware, and how to use them for perfrormance optimization.
|
||||
WorkDimensions work_dimensions() const { return work_dimensions_; }
|
||||
|
||||
NumWorkClusters num_workclusters() const {
|
||||
return work_dimensions_.num_work_clusters;
|
||||
}
|
||||
|
||||
NumWorkGroups num_workgroups() const {
|
||||
return work_dimensions_.num_work_groups;
|
||||
}
|
||||
|
||||
NumWorkItems num_workitems() const { return work_dimensions_.num_work_items; }
|
||||
|
||||
// Requested amount of scratch bytes for the kernel (backed by backend
|
||||
|
|
@ -80,9 +84,14 @@ class KernelSpec {
|
|||
std::optional<size_t> scratch_bytes() const { return scratch_bytes_; }
|
||||
|
||||
// Argument buffers read by the kernel.
|
||||
const Buffers& argument_buffers() const { return argument_buffers_; }
|
||||
absl::Span<const BufferAllocation::Slice> argument_buffers() const {
|
||||
return argument_buffers_;
|
||||
}
|
||||
|
||||
// Result buffers written to by the kernel.
|
||||
const Buffers& result_buffers() const { return result_buffers_; }
|
||||
absl::Span<const BufferAllocation::Slice> result_buffers() const {
|
||||
return result_buffers_;
|
||||
}
|
||||
|
||||
// Returns a set of invariant arguments (corresponding to the indices in the
|
||||
// argument buffers list).
|
||||
|
|
|
|||
|
|
@ -20,11 +20,13 @@ limitations under the License.
|
|||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
|
@ -69,6 +71,9 @@ class ThunkEmitter {
|
|||
};
|
||||
|
||||
struct EmittedKernel {
|
||||
EmittedKernel(absl::string_view name, llvm::orc::ThreadSafeModule module)
|
||||
: kernel_name(name), module(std::move(module)) {}
|
||||
|
||||
std::string kernel_name;
|
||||
llvm::orc::ThreadSafeModule module;
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user