mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[xla:ffi] Add TypeRegistry::TypeInfo to be able to register functions to manipulate user-defined types
PiperOrigin-RevId: 820811829
This commit is contained in:
parent
46522b8a20
commit
d531cdce30
39
third_party/xla/xla/ffi/BUILD
vendored
39
third_party/xla/xla/ffi/BUILD
vendored
|
|
@ -62,7 +62,7 @@ cc_library(
|
|||
srcs = ["execution_context.cc"],
|
||||
hdrs = ["execution_context.h"],
|
||||
deps = [
|
||||
":type_id_registry",
|
||||
":type_registry",
|
||||
"//xla:util",
|
||||
"//xla/tsl/platform:errors",
|
||||
"//xla/tsl/platform:logging",
|
||||
|
|
@ -79,7 +79,7 @@ xla_cc_test(
|
|||
srcs = ["execution_context_test.cc"],
|
||||
deps = [
|
||||
":execution_context",
|
||||
":type_id_registry",
|
||||
":type_registry",
|
||||
"//xla/tsl/lib/core:status_test_util",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"//xla/tsl/platform:test",
|
||||
|
|
@ -94,13 +94,14 @@ cc_library(
|
|||
srcs = ["execution_state.cc"],
|
||||
hdrs = ["execution_state.h"],
|
||||
deps = [
|
||||
":type_id_registry",
|
||||
":type_registry",
|
||||
"//xla:util",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"//xla/tsl/util:safe_reinterpret_cast",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@local_tsl//tsl/platform:logging",
|
||||
"@local_tsl//tsl/platform:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -109,11 +110,12 @@ xla_cc_test(
|
|||
srcs = ["execution_state_test.cc"],
|
||||
deps = [
|
||||
":execution_state",
|
||||
":type_registry",
|
||||
"//xla/tsl/lib/core:status_test_util",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"//xla/tsl/platform:test",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@local_tsl//tsl/platform:statusor",
|
||||
"@local_tsl//tsl/platform:test",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -124,7 +126,7 @@ cc_library(
|
|||
":api",
|
||||
":execution_context",
|
||||
":execution_state",
|
||||
":type_id_registry",
|
||||
":type_registry",
|
||||
"//xla:executable_run_options",
|
||||
"//xla:shape_util",
|
||||
"//xla:types",
|
||||
|
|
@ -141,6 +143,7 @@ cc_library(
|
|||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/base:nullability",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
|
|
@ -160,7 +163,7 @@ cc_library(
|
|||
":call_frame",
|
||||
":execution_context",
|
||||
":execution_state",
|
||||
":type_id_registry",
|
||||
":type_registry",
|
||||
"//xla:executable_run_options",
|
||||
"//xla:util",
|
||||
"//xla/ffi/api:c_api",
|
||||
|
|
@ -218,7 +221,7 @@ xla_cc_test(
|
|||
":execution_state",
|
||||
":ffi",
|
||||
":ffi_api",
|
||||
":type_id_registry",
|
||||
":type_registry",
|
||||
"//xla:executable_run_options",
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/ffi/api:c_api",
|
||||
|
|
@ -243,27 +246,31 @@ xla_cc_test(
|
|||
)
|
||||
|
||||
cc_library(
|
||||
name = "type_id_registry",
|
||||
srcs = ["type_id_registry.cc"],
|
||||
hdrs = ["type_id_registry.h"],
|
||||
name = "type_registry",
|
||||
srcs = ["type_registry.cc"],
|
||||
hdrs = ["type_registry.h"],
|
||||
deps = [
|
||||
"//xla:util",
|
||||
"//xla/tsl/lib/gtl:int_type",
|
||||
"//xla/tsl/util:safe_reinterpret_cast",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/base:no_destructor",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
xla_cc_test(
|
||||
name = "type_id_registry_test",
|
||||
srcs = ["type_id_registry_test.cc"],
|
||||
name = "type_registry_test",
|
||||
srcs = ["type_registry_test.cc"],
|
||||
deps = [
|
||||
":type_id_registry",
|
||||
":type_registry",
|
||||
"//xla/tsl/lib/core:status_test_util",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"//xla/tsl/platform:test",
|
||||
|
|
|
|||
2
third_party/xla/xla/ffi/api/BUILD
vendored
2
third_party/xla/xla/ffi/api/BUILD
vendored
|
|
@ -89,7 +89,7 @@ xla_cc_test(
|
|||
"//xla/ffi:execution_context",
|
||||
"//xla/ffi:execution_state",
|
||||
"//xla/ffi:ffi_api",
|
||||
"//xla/ffi:type_id_registry",
|
||||
"//xla/ffi:type_registry",
|
||||
"//xla/stream_executor:device_memory",
|
||||
"//xla/stream_executor:device_memory_allocator",
|
||||
"//xla/tsl/concurrency:async_value",
|
||||
|
|
|
|||
2
third_party/xla/xla/ffi/api/api.h
vendored
2
third_party/xla/xla/ffi/api/api.h
vendored
|
|
@ -348,11 +348,13 @@ inline XLA_FFI_Error* Ffi::RegisterTypeId(const XLA_FFI_Api* api,
|
|||
std::string_view name,
|
||||
XLA_FFI_TypeId* type_id,
|
||||
XLA_FFI_TypeInfo type_info) {
|
||||
assert(type_id && "type_id must not be null");
|
||||
XLA_FFI_TypeId_Register_Args args;
|
||||
args.struct_size = XLA_FFI_TypeId_Register_Args_STRUCT_SIZE;
|
||||
args.extension_start = nullptr;
|
||||
args.name = XLA_FFI_ByteSpan{name.data(), name.size()};
|
||||
args.type_id = type_id;
|
||||
args.type_info = &type_info;
|
||||
return api->XLA_FFI_TypeId_Register(&args);
|
||||
}
|
||||
|
||||
|
|
|
|||
5
third_party/xla/xla/ffi/api/c_api.h
vendored
5
third_party/xla/xla/ffi/api/c_api.h
vendored
|
|
@ -67,7 +67,7 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Extension_Base, next);
|
|||
// * Deleting a method or argument
|
||||
// * Changing the type of an argument
|
||||
// * Rearranging fields in the XLA_FFI_Api or argument structs
|
||||
#define XLA_FFI_API_MAJOR 0
|
||||
#define XLA_FFI_API_MAJOR 1
|
||||
|
||||
// Incremented when the interface is updated in a way that is potentially
|
||||
// ABI-compatible with older versions, if supported by the caller and/or
|
||||
|
|
@ -82,7 +82,7 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Extension_Base, next);
|
|||
// Minor changes include:
|
||||
// * Adding a new field to the XLA_FFI_Api or argument structs
|
||||
// * Renaming a method or argument (doesn't affect ABI)
|
||||
#define XLA_FFI_API_MINOR 1
|
||||
#define XLA_FFI_API_MINOR 0
|
||||
|
||||
struct XLA_FFI_Api_Version {
|
||||
size_t struct_size;
|
||||
|
|
@ -491,6 +491,7 @@ struct XLA_FFI_TypeId_Register_Args {
|
|||
|
||||
XLA_FFI_ByteSpan name;
|
||||
XLA_FFI_TypeId* type_id; // in-out
|
||||
XLA_FFI_TypeInfo* type_info;
|
||||
};
|
||||
|
||||
XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_TypeId_Register_Args, type_id);
|
||||
|
|
|
|||
6
third_party/xla/xla/ffi/api/ffi.h
vendored
6
third_party/xla/xla/ffi/api/ffi.h
vendored
|
|
@ -437,7 +437,7 @@ class CountDownPromise {
|
|||
assert(state_->count.load() >= count && "Invalid count down value");
|
||||
|
||||
if (XLA_FFI_PREDICT_FALSE(!error.success())) {
|
||||
const std::lock_guard<std::mutex> lock(state_->mutex);
|
||||
std::lock_guard<std::mutex> lock(state_->mutex); // NOLINT
|
||||
state_->is_error.store(true, std::memory_order_release);
|
||||
state_->error = error;
|
||||
}
|
||||
|
|
@ -448,7 +448,7 @@ class CountDownPromise {
|
|||
bool is_error = state_->is_error.load(std::memory_order_acquire);
|
||||
if (XLA_FFI_PREDICT_FALSE(is_error)) {
|
||||
auto take_error = [&] {
|
||||
const std::lock_guard<std::mutex> lock(state_->mutex);
|
||||
std::lock_guard<std::mutex> lock(state_->mutex); // NOLINT
|
||||
return state_->error;
|
||||
};
|
||||
state_->promise.SetError(take_error());
|
||||
|
|
@ -476,7 +476,7 @@ class CountDownPromise {
|
|||
std::atomic<int64_t> count;
|
||||
std::atomic<bool> is_error;
|
||||
|
||||
std::mutex mutex;
|
||||
std::mutex mutex; // NOLINT
|
||||
Error error;
|
||||
};
|
||||
|
||||
|
|
|
|||
6
third_party/xla/xla/ffi/api/ffi_test.cc
vendored
6
third_party/xla/xla/ffi/api/ffi_test.cc
vendored
|
|
@ -39,7 +39,7 @@ limitations under the License.
|
|||
#include "xla/ffi/execution_context.h"
|
||||
#include "xla/ffi/execution_state.h"
|
||||
#include "xla/ffi/ffi_api.h"
|
||||
#include "xla/ffi/type_id_registry.h"
|
||||
#include "xla/ffi/type_registry.h"
|
||||
#include "xla/primitive_util.h"
|
||||
#include "xla/stream_executor/device_memory.h"
|
||||
#include "xla/stream_executor/device_memory_allocator.h"
|
||||
|
|
@ -1220,9 +1220,9 @@ TEST(FfiTest, UserData) {
|
|||
|
||||
ExecutionContext execution_context;
|
||||
TF_ASSERT_OK(execution_context.Insert(
|
||||
TypeIdRegistry::TypeId(MyDataWithAutoTypeId::id.type_id), &data0));
|
||||
TypeRegistry::TypeId(MyDataWithAutoTypeId::id.type_id), &data0));
|
||||
TF_ASSERT_OK(execution_context.Insert(
|
||||
TypeIdRegistry::TypeId(MyDataWithExplicitTypeId::id.type_id), &data1));
|
||||
TypeRegistry::TypeId(MyDataWithExplicitTypeId::id.type_id), &data1));
|
||||
|
||||
CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
|
||||
auto call_frame = builder.Build();
|
||||
|
|
|
|||
10
third_party/xla/xla/ffi/execution_context.h
vendored
10
third_party/xla/xla/ffi/execution_context.h
vendored
|
|
@ -25,7 +25,7 @@ limitations under the License.
|
|||
#include "absl/functional/function_ref.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "xla/ffi/type_id_registry.h"
|
||||
#include "xla/ffi/type_registry.h"
|
||||
#include "xla/tsl/platform/logging.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
|
||||
|
|
@ -45,7 +45,7 @@ namespace xla::ffi {
|
|||
// unique between separate calls to XLA execute.
|
||||
class ExecutionContext {
|
||||
public:
|
||||
using TypeId = TypeIdRegistry::TypeId;
|
||||
using TypeId = TypeRegistry::TypeId;
|
||||
|
||||
template <typename T>
|
||||
using Deleter = std::function<void(T*)>;
|
||||
|
|
@ -67,7 +67,7 @@ class ExecutionContext {
|
|||
template <typename T>
|
||||
absl::StatusOr<T*> Lookup() const {
|
||||
TF_ASSIGN_OR_RETURN(auto user_data,
|
||||
LookupUserData(TypeIdRegistry::GetTypeId<T>()));
|
||||
LookupUserData(TypeRegistry::GetTypeId<T>()));
|
||||
return static_cast<T*>(user_data->data());
|
||||
}
|
||||
|
||||
|
|
@ -110,7 +110,7 @@ class ExecutionContext {
|
|||
|
||||
template <typename T>
|
||||
absl::Status ExecutionContext::Insert(T* data, Deleter<T> deleter) {
|
||||
return InsertUserData(TypeIdRegistry::GetTypeId<T>(),
|
||||
return InsertUserData(TypeRegistry::GetTypeId<T>(),
|
||||
std::make_unique<UserData>(
|
||||
data, [deleter = std::move(deleter)](void* data) {
|
||||
if (deleter) deleter(static_cast<T*>(data));
|
||||
|
|
@ -119,7 +119,7 @@ absl::Status ExecutionContext::Insert(T* data, Deleter<T> deleter) {
|
|||
|
||||
template <typename T, typename... Args>
|
||||
absl::Status ExecutionContext::Emplace(Args&&... args) {
|
||||
return InsertUserData(TypeIdRegistry::GetTypeId<T>(),
|
||||
return InsertUserData(TypeRegistry::GetTypeId<T>(),
|
||||
std::make_unique<UserData>(
|
||||
new T(std::forward<Args>(args)...),
|
||||
[](void* data) { delete static_cast<T*>(data); }));
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/status/status.h"
|
||||
#include "xla/ffi/type_id_registry.h"
|
||||
#include "xla/ffi/type_registry.h"
|
||||
#include "xla/tsl/lib/core/status_test_util.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/tsl/platform/test.h"
|
||||
|
|
@ -62,8 +62,9 @@ TEST(ExecutionContextTest, InsertUserOwned) {
|
|||
}
|
||||
|
||||
TEST(ExecutionContextTest, InsertUserOwnedWithTypeId) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(TypeIdRegistry::TypeId type_id,
|
||||
TypeIdRegistry::AssignExternalTypeId("I32UserData"));
|
||||
TF_ASSERT_OK_AND_ASSIGN(TypeRegistry::TypeId type_id,
|
||||
TypeRegistry::AssignExternalTypeId(
|
||||
"I32UserData", TypeRegistry::TypeInfo{}));
|
||||
|
||||
I32UserData user_data(42);
|
||||
|
||||
|
|
|
|||
46
third_party/xla/xla/ffi/execution_state.cc
vendored
46
third_party/xla/xla/ffi/execution_state.cc
vendored
|
|
@ -15,38 +15,54 @@ limitations under the License.
|
|||
|
||||
#include "xla/ffi/execution_state.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "absl/base/attributes.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "xla/ffi/type_id_registry.h"
|
||||
#include "xla/ffi/type_registry.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/util.h"
|
||||
#include "tsl/platform/logging.h"
|
||||
|
||||
namespace xla::ffi {
|
||||
|
||||
ExecutionState::ExecutionState()
|
||||
: type_id_(TypeIdRegistry::kUnknownTypeId),
|
||||
state_(nullptr),
|
||||
deleter_(nullptr) {}
|
||||
: type_id_(TypeRegistry::kUnknownTypeId), state_(nullptr) {}
|
||||
|
||||
ExecutionState::~ExecutionState() {
|
||||
if (deleter_) deleter_(state_);
|
||||
if (type_info_.deleter) {
|
||||
type_info_.deleter(state_);
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status ExecutionState::Set(TypeId type_id, void* state,
|
||||
Deleter<void> deleter) {
|
||||
DCHECK(state && deleter) << "State and deleter must not be null";
|
||||
absl::Status ExecutionState::Set(TypeId type_id, void* state) {
|
||||
TF_ASSIGN_OR_RETURN(auto type_info,
|
||||
TypeRegistry::GetExternalTypeInfo(type_id));
|
||||
if (type_info.deleter == nullptr) {
|
||||
return InvalidArgument(
|
||||
"Type id %d does not have a registered type info with a deleter",
|
||||
type_id.value());
|
||||
}
|
||||
return Set(type_id, type_info, state);
|
||||
}
|
||||
|
||||
if (type_id_ != TypeIdRegistry::kUnknownTypeId) {
|
||||
ABSL_DEPRECATED("FFI users must rely in TypeInfo registration")
|
||||
absl::Status ExecutionState::Set(TypeId type_id, void* state,
|
||||
void (*deleter)(void*)) {
|
||||
return Set(type_id, TypeInfo{deleter}, state);
|
||||
}
|
||||
|
||||
absl::Status ExecutionState::Set(TypeId type_id, TypeInfo type_info,
|
||||
void* state) {
|
||||
DCHECK(state && type_info.deleter) << "State and deleter must not be null";
|
||||
|
||||
if (type_id_ != TypeRegistry::kUnknownTypeId) {
|
||||
return FailedPrecondition("State is already set with a type id %d",
|
||||
type_id_.value());
|
||||
}
|
||||
|
||||
type_id_ = type_id;
|
||||
type_info_ = type_info;
|
||||
state_ = state;
|
||||
deleter_ = std::move(deleter);
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
@ -54,7 +70,7 @@ absl::Status ExecutionState::Set(TypeId type_id, void* state,
|
|||
// Returns opaque state of the given type id. If set state type id does not
|
||||
// match the requested one, returns an error.
|
||||
absl::StatusOr<void*> ExecutionState::Get(TypeId type_id) const {
|
||||
if (type_id_ == TypeIdRegistry::kUnknownTypeId) {
|
||||
if (type_id_ == TypeRegistry::kUnknownTypeId) {
|
||||
return NotFound("State is not set");
|
||||
}
|
||||
|
||||
|
|
@ -68,7 +84,7 @@ absl::StatusOr<void*> ExecutionState::Get(TypeId type_id) const {
|
|||
}
|
||||
|
||||
bool ExecutionState::IsSet() const {
|
||||
return type_id_ != TypeIdRegistry::kUnknownTypeId;
|
||||
return type_id_ != TypeRegistry::kUnknownTypeId;
|
||||
}
|
||||
|
||||
} // namespace xla::ffi
|
||||
|
|
|
|||
34
third_party/xla/xla/ffi/execution_state.h
vendored
34
third_party/xla/xla/ffi/execution_state.h
vendored
|
|
@ -16,13 +16,13 @@ limitations under the License.
|
|||
#ifndef XLA_FFI_EXECUTION_STATE_H_
|
||||
#define XLA_FFI_EXECUTION_STATE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "xla/ffi/type_id_registry.h"
|
||||
#include "tsl/platform/statusor.h"
|
||||
#include "xla/ffi/type_registry.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/tsl/util/safe_reinterpret_cast.h"
|
||||
|
||||
namespace xla::ffi {
|
||||
|
||||
|
|
@ -41,10 +41,8 @@ namespace xla::ffi {
|
|||
//
|
||||
class ExecutionState {
|
||||
public:
|
||||
using TypeId = TypeIdRegistry::TypeId;
|
||||
|
||||
template <typename T>
|
||||
using Deleter = std::function<void(T*)>;
|
||||
using TypeId = TypeRegistry::TypeId;
|
||||
using TypeInfo = TypeRegistry::TypeInfo;
|
||||
|
||||
ExecutionState();
|
||||
~ExecutionState();
|
||||
|
|
@ -52,9 +50,13 @@ class ExecutionState {
|
|||
ExecutionState(const ExecutionState&) = delete;
|
||||
ExecutionState& operator=(const ExecutionState&) = delete;
|
||||
|
||||
// Sets opaque state with a given type id and deleter. Returns an error if
|
||||
// state is already set.
|
||||
absl::Status Set(TypeId type_id, void* state, Deleter<void> deleter);
|
||||
// Sets opaque state with a given type id. Returns an error if state is
|
||||
// already set, or if type id is not supported as a state.
|
||||
absl::Status Set(TypeId type_id, void* state);
|
||||
|
||||
// Sets opaque state with a given type id and custom deleter. Returns an error
|
||||
// if state is already set, or if type id is not supported as a state.
|
||||
absl::Status Set(TypeId type_id, void* state, void (*deleter)(void*));
|
||||
|
||||
// Returns opaque state of the given type id. If set state type id does not
|
||||
// match the requested one, returns an error.
|
||||
|
|
@ -73,21 +75,23 @@ class ExecutionState {
|
|||
bool IsSet() const;
|
||||
|
||||
private:
|
||||
absl::Status Set(TypeId type_id, TypeInfo type_info, void* state);
|
||||
|
||||
TypeId type_id_;
|
||||
TypeInfo type_info_;
|
||||
void* state_;
|
||||
Deleter<void> deleter_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
absl::Status ExecutionState::Set(std::unique_ptr<T> state) {
|
||||
return Set(TypeIdRegistry::GetTypeId<T>(), state.release(),
|
||||
[](void* state) { delete reinterpret_cast<T*>(state); });
|
||||
return Set(TypeRegistry::GetTypeId<T>(), TypeRegistry::GetTypeInfo<T>(),
|
||||
state.release());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::StatusOr<T*> ExecutionState::Get() const {
|
||||
TF_ASSIGN_OR_RETURN(void* state, Get(TypeIdRegistry::GetTypeId<T>()));
|
||||
return reinterpret_cast<T*>(state);
|
||||
TF_ASSIGN_OR_RETURN(void* state, Get(TypeRegistry::GetTypeId<T>()));
|
||||
return tsl::safe_reinterpret_cast<T*>(state);
|
||||
}
|
||||
|
||||
} // namespace xla::ffi
|
||||
|
|
|
|||
37
third_party/xla/xla/ffi/execution_state_test.cc
vendored
37
third_party/xla/xla/ffi/execution_state_test.cc
vendored
|
|
@ -20,9 +20,10 @@ limitations under the License.
|
|||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "xla/ffi/type_registry.h"
|
||||
#include "xla/tsl/lib/core/status_test_util.h"
|
||||
#include "tsl/platform/statusor.h"
|
||||
#include "tsl/platform/test.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/tsl/platform/test.h"
|
||||
|
||||
namespace xla::ffi {
|
||||
|
||||
|
|
@ -30,7 +31,7 @@ using TypeId = ExecutionState::TypeId;
|
|||
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
TEST(ExecutionStateTest, SetAndGet) {
|
||||
TEST(ExecutionStateTest, SetAndGetForInternalType) {
|
||||
ExecutionState state;
|
||||
EXPECT_FALSE(state.IsSet());
|
||||
|
||||
|
|
@ -52,4 +53,34 @@ TEST(ExecutionStateTest, SetAndGet) {
|
|||
EXPECT_EQ(*data, 42);
|
||||
}
|
||||
|
||||
TEST(ExecutionStateTest, SetAndGetForExternalType) {
|
||||
ExecutionState state;
|
||||
EXPECT_FALSE(state.IsSet());
|
||||
|
||||
{ // Empty state returns an error from Get().
|
||||
auto data = state.Get(TypeId(1));
|
||||
EXPECT_THAT(data.status().message(), HasSubstr("State is not set"));
|
||||
}
|
||||
|
||||
{ // Empty state returns an error from Get().
|
||||
auto data = state.Get<int32_t>();
|
||||
EXPECT_THAT(data.status().message(), HasSubstr("State is not set"));
|
||||
}
|
||||
|
||||
TypeRegistry::TypeInfo type_info = {
|
||||
[](void* ptr) { delete static_cast<int32_t*>(ptr); }};
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
TypeRegistry::TypeId type_id,
|
||||
TypeRegistry::AssignExternalTypeId("int32_t", type_info));
|
||||
|
||||
int32_t* value = new int32_t(42);
|
||||
|
||||
// Once set, state can be retrieved.
|
||||
TF_ASSERT_OK(state.Set(type_id, value));
|
||||
EXPECT_TRUE(state.IsSet());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(void* data, state.Get(type_id));
|
||||
EXPECT_EQ(data, value);
|
||||
}
|
||||
|
||||
} // namespace xla::ffi
|
||||
|
|
|
|||
5
third_party/xla/xla/ffi/ffi.h
vendored
5
third_party/xla/xla/ffi/ffi.h
vendored
|
|
@ -37,6 +37,7 @@ limitations under the License.
|
|||
#include "absl/base/attributes.h"
|
||||
#include "absl/base/nullability.h"
|
||||
#include "absl/base/optimization.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
|
|
@ -46,7 +47,7 @@ limitations under the License.
|
|||
#include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep
|
||||
#include "xla/ffi/execution_context.h"
|
||||
#include "xla/ffi/execution_state.h"
|
||||
#include "xla/ffi/type_id_registry.h"
|
||||
#include "xla/ffi/type_registry.h"
|
||||
#include "xla/hlo/ir/hlo_computation.h"
|
||||
#include "xla/primitive_util.h"
|
||||
#include "xla/stream_executor/device_memory.h"
|
||||
|
|
@ -724,7 +725,7 @@ template <typename T>
|
|||
struct ResultEncoding<ExecutionStage::kInstantiate,
|
||||
absl::StatusOr<std::unique_ptr<T>>> {
|
||||
static XLA_FFI_TypeId state_type_id() {
|
||||
return XLA_FFI_TypeId{TypeIdRegistry::GetTypeId<T>().value()};
|
||||
return XLA_FFI_TypeId{TypeRegistry::GetTypeId<T>().value()};
|
||||
}
|
||||
|
||||
static XLA_FFI_Error* Encode(const XLA_FFI_Api* api,
|
||||
|
|
|
|||
34
third_party/xla/xla/ffi/ffi_api.cc
vendored
34
third_party/xla/xla/ffi/ffi_api.cc
vendored
|
|
@ -43,7 +43,7 @@ limitations under the License.
|
|||
#include "xla/ffi/call_frame.h"
|
||||
#include "xla/ffi/execution_context.h"
|
||||
#include "xla/ffi/execution_state.h"
|
||||
#include "xla/ffi/type_id_registry.h"
|
||||
#include "xla/ffi/type_registry.h"
|
||||
#include "xla/hlo/ir/hlo_computation.h"
|
||||
#include "xla/service/platform_util.h"
|
||||
#include "xla/stream_executor/device_memory.h"
|
||||
|
|
@ -659,12 +659,14 @@ static XLA_FFI_Error* XLA_FFI_TypeId_Register(
|
|||
XLA_FFI_ExecutionContext_Get_Args_STRUCT_SIZE, args->struct_size));
|
||||
|
||||
absl::string_view type_name(args->name.ptr, args->name.len);
|
||||
TypeIdRegistry::TypeId type_id(args->type_id->type_id);
|
||||
TypeRegistry::TypeId type_id(args->type_id->type_id);
|
||||
TypeRegistry::TypeInfo type_info = {args->type_info->deleter};
|
||||
|
||||
// If type_id is unknown, we are registering a new type and XLA will assign a
|
||||
// unique type id to it.
|
||||
if (type_id == TypeIdRegistry::kUnknownTypeId) {
|
||||
auto assigned_type_id = TypeIdRegistry::AssignExternalTypeId(type_name);
|
||||
if (type_id == TypeRegistry::kUnknownTypeId) {
|
||||
auto assigned_type_id =
|
||||
TypeRegistry::AssignExternalTypeId(type_name, type_info);
|
||||
if (!assigned_type_id.ok()) {
|
||||
return new XLA_FFI_Error{std::move(assigned_type_id).status()};
|
||||
}
|
||||
|
|
@ -674,9 +676,10 @@ static XLA_FFI_Error* XLA_FFI_TypeId_Register(
|
|||
}
|
||||
|
||||
// If type_id is set, we are relying on the caller-provided unique type id.
|
||||
if (auto status = TypeIdRegistry::RegisterExternalTypeId(type_name, type_id);
|
||||
!status.ok()) {
|
||||
return new XLA_FFI_Error{std::move(status)};
|
||||
auto registered_type_id =
|
||||
TypeRegistry::RegisterExternalTypeId(type_name, type_id, type_info);
|
||||
if (!registered_type_id.ok()) {
|
||||
return new XLA_FFI_Error{std::move(registered_type_id)};
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
|
|
@ -690,7 +693,7 @@ static XLA_FFI_Error* XLA_FFI_ExecutionContext_Get(
|
|||
|
||||
DCHECK(args->ctx->execution_context) << "ExecutionContext must be set";
|
||||
auto user_data = args->ctx->execution_context->Lookup(
|
||||
TypeIdRegistry::TypeId(args->type_id->type_id));
|
||||
TypeRegistry::TypeId(args->type_id->type_id));
|
||||
if (!user_data.ok()) {
|
||||
return new XLA_FFI_Error{std::move(user_data).status()};
|
||||
}
|
||||
|
|
@ -705,9 +708,16 @@ static XLA_FFI_Error* XLA_FFI_State_Set(XLA_FFI_State_Set_Args* args) {
|
|||
args->struct_size));
|
||||
|
||||
DCHECK(args->ctx->execution_state) << "ExecutionState must be set";
|
||||
absl::Status status = args->ctx->execution_state->Set(
|
||||
TypeIdRegistry::TypeId(args->type_id->type_id), args->state,
|
||||
[deleter = args->deleter](void* state) { deleter(state); });
|
||||
|
||||
absl::Status status;
|
||||
if (args->deleter == nullptr) {
|
||||
status = args->ctx->execution_state->Set(
|
||||
TypeRegistry::TypeId(args->type_id->type_id), args->state);
|
||||
} else {
|
||||
status = args->ctx->execution_state->Set(
|
||||
TypeRegistry::TypeId(args->type_id->type_id), args->state,
|
||||
args->deleter);
|
||||
}
|
||||
|
||||
if (!status.ok()) {
|
||||
return new XLA_FFI_Error{std::move(status)};
|
||||
|
|
@ -723,7 +733,7 @@ static XLA_FFI_Error* XLA_FFI_State_Get(XLA_FFI_State_Get_Args* args) {
|
|||
|
||||
DCHECK(args->ctx->execution_state) << "ExecutionState must be set";
|
||||
absl::StatusOr<void*> state = args->ctx->execution_state->Get(
|
||||
TypeIdRegistry::TypeId(args->type_id->type_id));
|
||||
TypeRegistry::TypeId(args->type_id->type_id));
|
||||
if (!state.ok()) {
|
||||
return new XLA_FFI_Error{std::move(state).status()};
|
||||
}
|
||||
|
|
|
|||
4
third_party/xla/xla/ffi/ffi_test.cc
vendored
4
third_party/xla/xla/ffi/ffi_test.cc
vendored
|
|
@ -39,7 +39,7 @@ limitations under the License.
|
|||
#include "xla/ffi/execution_context.h"
|
||||
#include "xla/ffi/execution_state.h"
|
||||
#include "xla/ffi/ffi_api.h"
|
||||
#include "xla/ffi/type_id_registry.h"
|
||||
#include "xla/ffi/type_registry.h"
|
||||
#include "xla/stream_executor/device_memory.h"
|
||||
#include "xla/stream_executor/stream.h"
|
||||
#include "xla/tsl/concurrency/async_value_ref.h"
|
||||
|
|
@ -1102,7 +1102,7 @@ TEST(FfiTest, Metadata) {
|
|||
EXPECT_EQ(metadata.api_version.major_version, XLA_FFI_API_MAJOR);
|
||||
EXPECT_EQ(metadata.api_version.minor_version, XLA_FFI_API_MINOR);
|
||||
|
||||
TypeIdRegistry::TypeId type_id = TypeIdRegistry::GetTypeId<StrState>();
|
||||
TypeRegistry::TypeId type_id = TypeRegistry::GetTypeId<StrState>();
|
||||
EXPECT_EQ(metadata.state_type_id.type_id, type_id);
|
||||
}
|
||||
|
||||
|
|
|
|||
96
third_party/xla/xla/ffi/type_id_registry.cc
vendored
96
third_party/xla/xla/ffi/type_id_registry.cc
vendored
|
|
@ -1,96 +0,0 @@
|
|||
/* Copyright 2024 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/ffi/type_id_registry.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/base/attributes.h"
|
||||
#include "absl/base/const_init.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "xla/util.h"
|
||||
|
||||
namespace xla::ffi {
|
||||
|
||||
ABSL_CONST_INIT absl::Mutex type_registry_mutex(absl::kConstInit);
|
||||
|
||||
using ExternalTypeIdRegistry =
|
||||
absl::flat_hash_map<std::string, TypeIdRegistry::TypeId>;
|
||||
|
||||
static ExternalTypeIdRegistry& StaticExternalTypeIdRegistry() {
|
||||
static auto* const registry = new ExternalTypeIdRegistry();
|
||||
return *registry;
|
||||
}
|
||||
|
||||
TypeIdRegistry::TypeId TypeIdRegistry::GetNextInternalTypeId() {
|
||||
static auto* counter = new std::atomic<int64_t>(1);
|
||||
return TypeId(counter->fetch_add(1));
|
||||
}
|
||||
|
||||
TypeIdRegistry::TypeId TypeIdRegistry::GetNextExternalTypeId() {
|
||||
static auto* counter = new std::atomic<int64_t>(1);
|
||||
return TypeId(counter->fetch_add(1));
|
||||
}
|
||||
|
||||
absl::StatusOr<TypeIdRegistry::TypeId> TypeIdRegistry::AssignExternalTypeId(
|
||||
absl::string_view name) {
|
||||
absl::MutexLock lock(type_registry_mutex);
|
||||
auto& registry = StaticExternalTypeIdRegistry();
|
||||
|
||||
// Try to emplace with unknow type id and fill it with real type id only if we
|
||||
// successfully acquired an entry for a given name.
|
||||
auto emplaced = registry.emplace(name, kUnknownTypeId);
|
||||
if (!emplaced.second) {
|
||||
return Internal("Type name %s already registered with type id %d", name,
|
||||
emplaced.first->second.value());
|
||||
}
|
||||
|
||||
// Returns true if the registry contains an entry with a given type id.
|
||||
auto type_id_is_in_use = [®istry](TypeId type_id) {
|
||||
return absl::c_any_of(registry,
|
||||
[&](const auto& e) { return e.second == type_id; });
|
||||
};
|
||||
|
||||
// Create a new type id that is not already in use.
|
||||
TypeId type_id = GetNextExternalTypeId();
|
||||
while (type_id_is_in_use(type_id)) {
|
||||
type_id = GetNextExternalTypeId();
|
||||
}
|
||||
|
||||
return emplaced.first->second = type_id;
|
||||
}
|
||||
|
||||
absl::Status TypeIdRegistry::RegisterExternalTypeId(absl::string_view name,
|
||||
TypeId type_id) {
|
||||
absl::MutexLock lock(type_registry_mutex);
|
||||
auto& registry = StaticExternalTypeIdRegistry();
|
||||
|
||||
auto emplaced = registry.emplace(name, type_id);
|
||||
if (!emplaced.second && emplaced.first->second != type_id) {
|
||||
return Internal("Type name %s already registered with type id %d vs %d)",
|
||||
name, emplaced.first->second.value(), type_id.value());
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace xla::ffi
|
||||
63
third_party/xla/xla/ffi/type_id_registry_test.cc
vendored
63
third_party/xla/xla/ffi/type_id_registry_test.cc
vendored
|
|
@ -1,63 +0,0 @@
|
|||
/* Copyright 2024 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/ffi/type_id_registry.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/status/status.h"
|
||||
#include "xla/tsl/lib/core/status_test_util.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/tsl/platform/test.h"
|
||||
|
||||
namespace xla::ffi {
|
||||
namespace {
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
TEST(TypeIdRegistryTest, RegisterExternalTypeId) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto type_id,
|
||||
TypeIdRegistry::AssignExternalTypeId("foo"));
|
||||
EXPECT_GE(type_id.value(), 0);
|
||||
|
||||
auto duplicate_type_id = TypeIdRegistry::AssignExternalTypeId("foo");
|
||||
EXPECT_THAT(duplicate_type_id.status().message(),
|
||||
HasSubstr("Type name foo already registered with type id"));
|
||||
|
||||
// It's ok to register the same type with same type id.
|
||||
TF_ASSERT_OK(TypeIdRegistry::RegisterExternalTypeId("foo", type_id));
|
||||
|
||||
// It's an error to register the same type with a different type id.
|
||||
auto wrong_type_id = TypeIdRegistry::RegisterExternalTypeId(
|
||||
"foo", TypeIdRegistry::TypeId(std::numeric_limits<int64_t>::max()));
|
||||
EXPECT_THAT(wrong_type_id.message(),
|
||||
HasSubstr("Type name foo already registered with type id"));
|
||||
|
||||
// It's ok to register a new type with a user-provided type id.
|
||||
TF_ASSERT_OK(TypeIdRegistry::RegisterExternalTypeId(
|
||||
"bar", TypeIdRegistry::TypeId(std::numeric_limits<int64_t>::max())));
|
||||
}
|
||||
|
||||
TEST(TypeIdRegistryTest, RegisterInternalTypeId) {
|
||||
auto int32_type_id = TypeIdRegistry::GetTypeId<int32_t>();
|
||||
auto int64_type_id = TypeIdRegistry::GetTypeId<int64_t>();
|
||||
EXPECT_NE(int32_type_id, int64_type_id);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla::ffi
|
||||
134
third_party/xla/xla/ffi/type_registry.cc
vendored
Normal file
134
third_party/xla/xla/ffi/type_registry.cc
vendored
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
/* Copyright 2024 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/ffi/type_registry.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/base/attributes.h"
|
||||
#include "absl/base/const_init.h"
|
||||
#include "absl/base/no_destructor.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "xla/util.h"
|
||||
|
||||
namespace xla::ffi {
|
||||
namespace {
|
||||
|
||||
struct TypeRegistration {
|
||||
TypeRegistry::TypeId type_id;
|
||||
TypeRegistry::TypeInfo type_info;
|
||||
};
|
||||
|
||||
using ExternalTypeRegistry = absl::flat_hash_map<std::string, TypeRegistration>;
|
||||
|
||||
} // namespace
|
||||
|
||||
ABSL_CONST_INIT absl::Mutex type_registry_mutex(absl::kConstInit);
|
||||
|
||||
static ExternalTypeRegistry& StaticExternalTypeRegistry() {
|
||||
static absl::NoDestructor<ExternalTypeRegistry> registry;
|
||||
return *registry;
|
||||
}
|
||||
|
||||
TypeRegistry::TypeId TypeRegistry::GetNextInternalTypeId() {
|
||||
static auto* counter = new std::atomic<int64_t>(1);
|
||||
return TypeId(counter->fetch_add(1));
|
||||
}
|
||||
|
||||
TypeRegistry::TypeId TypeRegistry::GetNextExternalTypeId() {
|
||||
static auto* counter = new std::atomic<int64_t>(1);
|
||||
return TypeId(counter->fetch_add(1));
|
||||
}
|
||||
|
||||
absl::StatusOr<TypeRegistry::TypeId> TypeRegistry::AssignExternalTypeId(
|
||||
absl::string_view name, TypeInfo type_info) {
|
||||
VLOG(3) << absl::StrFormat("Assign external type id: name=%s", name);
|
||||
|
||||
absl::MutexLock lock(type_registry_mutex);
|
||||
auto& registry = StaticExternalTypeRegistry();
|
||||
|
||||
// Try to emplace with unknow type id and fill it with real type id only if we
|
||||
// successfully acquired an entry for a given name.
|
||||
auto emplaced =
|
||||
registry.emplace(name, TypeRegistration{kUnknownTypeId, type_info});
|
||||
if (!emplaced.second) {
|
||||
return Internal("Type name %s already registered with type id %d", name,
|
||||
emplaced.first->second.type_id.value());
|
||||
}
|
||||
|
||||
// Returns true if the registry contains an entry with a given type id.
|
||||
auto type_id_is_in_use = [®istry](TypeId type_id) {
|
||||
return absl::c_any_of(
|
||||
registry, [&](const auto& e) { return e.second.type_id == type_id; });
|
||||
};
|
||||
|
||||
// Create a new type id that is not already in use.
|
||||
TypeId type_id = GetNextExternalTypeId();
|
||||
while (type_id_is_in_use(type_id)) {
|
||||
type_id = GetNextExternalTypeId();
|
||||
}
|
||||
|
||||
VLOG(3) << absl::StrFormat("Assigned external type id: name=%s type_id=%d",
|
||||
name, type_id.value());
|
||||
return emplaced.first->second.type_id = type_id;
|
||||
}
|
||||
|
||||
absl::Status TypeRegistry::RegisterExternalTypeId(absl::string_view name,
|
||||
TypeId type_id,
|
||||
TypeInfo type_info) {
|
||||
VLOG(3) << absl::StrFormat("Register external type id: name=%s type_id=%d",
|
||||
name, type_id.value());
|
||||
|
||||
absl::MutexLock lock(type_registry_mutex);
|
||||
auto& registry = StaticExternalTypeRegistry();
|
||||
|
||||
auto emplaced = registry.emplace(name, TypeRegistration{type_id, type_info});
|
||||
if (!emplaced.second && emplaced.first->second.type_id != type_id) {
|
||||
return Internal("Type name %s already registered with type id %d vs %d)",
|
||||
name, emplaced.first->second.type_id.value(),
|
||||
type_id.value());
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<TypeRegistry::TypeInfo> TypeRegistry::GetExternalTypeInfo(
|
||||
TypeId type_id) {
|
||||
absl::MutexLock lock(type_registry_mutex);
|
||||
auto& registry = StaticExternalTypeRegistry();
|
||||
|
||||
auto it = absl::c_find_if(registry, [&](const auto& kv) {
|
||||
auto& [name, registration] = kv;
|
||||
return registration.type_id == type_id;
|
||||
});
|
||||
|
||||
if (it == registry.end()) {
|
||||
return Internal("Type id %d is not registered with a static registry",
|
||||
type_id.value());
|
||||
}
|
||||
|
||||
return it->second.type_info;
|
||||
}
|
||||
|
||||
} // namespace xla::ffi
|
||||
|
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef XLA_FFI_TYPE_ID_REGISTRY_H_
|
||||
#define XLA_FFI_TYPE_ID_REGISTRY_H_
|
||||
#ifndef XLA_FFI_TYPE_REGISTRY_H_
|
||||
#define XLA_FFI_TYPE_REGISTRY_H_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "xla/tsl/lib/gtl/int_type.h"
|
||||
#include "xla/tsl/util/safe_reinterpret_cast.h"
|
||||
|
||||
namespace xla::ffi {
|
||||
|
||||
|
|
@ -43,26 +44,50 @@ namespace xla::ffi {
|
|||
//
|
||||
// 2. Internal type id. When FFI handler defined in the same binary we rely
|
||||
// on a global static registry to automatically assign type ids.
|
||||
class TypeIdRegistry {
|
||||
//
|
||||
// TypeInfo defines a set of functions that allow XLA runtime to manipulate
|
||||
// external types. For user data, that is forwarded to FFI handlers, they all
|
||||
// can be `nullptr` as XLA runtime doesn't manage their lifetime. For stateful
|
||||
// handlers, XLA runtime at least must know how to destroy the state when XLA
|
||||
// executable is destroyed.
|
||||
class TypeRegistry {
|
||||
public:
|
||||
// Unique (within a process) identifier for a type.
|
||||
TSL_LIB_GTL_DEFINE_INT_TYPE(TypeId, int64_t);
|
||||
|
||||
static constexpr TypeId kUnknownTypeId = TypeId(0);
|
||||
|
||||
// Pointers to functions that allow XLA runtime to manipulate external types.
|
||||
struct TypeInfo {
|
||||
using Deleter = void (*)(void*);
|
||||
|
||||
Deleter deleter = nullptr;
|
||||
};
|
||||
|
||||
// Assigns a unique type id to an external type with a given name. Returns an
|
||||
// error if a type with a given name is already registered in the process.
|
||||
static absl::StatusOr<TypeId> AssignExternalTypeId(absl::string_view name);
|
||||
static absl::StatusOr<TypeId> AssignExternalTypeId(absl::string_view name,
|
||||
TypeInfo type_info);
|
||||
|
||||
// Registers external type with a given name and type id. Type id is provided
|
||||
// by the caller, and must be unique. Returns an error if a type with a given
|
||||
// name is already registered with a different type id.
|
||||
static absl::Status RegisterExternalTypeId(absl::string_view name,
|
||||
TypeId type_id);
|
||||
TypeId type_id,
|
||||
TypeInfo type_info);
|
||||
|
||||
// Returns type info for a given external type id. Returns an error if type
|
||||
// id is not registered.
|
||||
static absl::StatusOr<TypeInfo> GetExternalTypeInfo(TypeId type_id);
|
||||
|
||||
// Returns a type id for a given type. For internal type ids only.
|
||||
template <typename T>
|
||||
static TypeId GetTypeId();
|
||||
|
||||
// Returns type info for a given type id. For internal type ids only.
|
||||
template <typename T>
|
||||
static TypeInfo GetTypeInfo();
|
||||
|
||||
private:
|
||||
// We never mix external and internal type ids, so we can use different type
|
||||
// id spaces to assign unique ids to each type.
|
||||
|
|
@ -71,11 +96,18 @@ class TypeIdRegistry {
|
|||
};
|
||||
|
||||
template <typename T>
|
||||
TypeIdRegistry::TypeId TypeIdRegistry::GetTypeId() {
|
||||
TypeRegistry::TypeId TypeRegistry::GetTypeId() {
|
||||
static const TypeId id = GetNextInternalTypeId();
|
||||
return id;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TypeRegistry::TypeInfo TypeRegistry::GetTypeInfo() {
|
||||
return TypeInfo{
|
||||
[](void* state) { delete tsl::safe_reinterpret_cast<T*>(state); },
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace xla::ffi
|
||||
|
||||
#endif // XLA_FFI_TYPE_ID_REGISTRY_H_
|
||||
#endif // XLA_FFI_TYPE_REGISTRY_H_
|
||||
85
third_party/xla/xla/ffi/type_registry_test.cc
vendored
Normal file
85
third_party/xla/xla/ffi/type_registry_test.cc
vendored
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
/* Copyright 2024 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/ffi/type_registry.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/status/status.h"
|
||||
#include "xla/tsl/lib/core/status_test_util.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/tsl/platform/test.h"
|
||||
|
||||
namespace xla::ffi {
|
||||
namespace {
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
TEST(TypeRegistryTest, RegisterExternalTypeId) {
|
||||
TypeRegistry::TypeInfo type_info = {+[](void* state) {}};
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto foo_id,
|
||||
TypeRegistry::AssignExternalTypeId("foo", type_info));
|
||||
EXPECT_GE(foo_id.value(), 0);
|
||||
|
||||
auto duplicate_foo_id = TypeRegistry::AssignExternalTypeId("foo", type_info);
|
||||
EXPECT_THAT(duplicate_foo_id.status().message(),
|
||||
HasSubstr("Type name foo already registered with type id"));
|
||||
|
||||
// It's ok to register the same type with same type id.
|
||||
TF_ASSERT_OK(TypeRegistry::RegisterExternalTypeId("foo", foo_id, type_info));
|
||||
|
||||
// It's an error to register the same type with a different type id.
|
||||
auto wrong_foo_id = TypeRegistry::RegisterExternalTypeId(
|
||||
"foo", TypeRegistry::TypeId(std::numeric_limits<int64_t>::max()),
|
||||
type_info);
|
||||
EXPECT_THAT(wrong_foo_id.message(),
|
||||
HasSubstr("Type name foo already registered with type id"));
|
||||
|
||||
// Registered type has a correct type info.
|
||||
TF_ASSERT_OK_AND_ASSIGN(TypeRegistry::TypeInfo foo_info,
|
||||
TypeRegistry::GetExternalTypeInfo(foo_id));
|
||||
EXPECT_EQ(foo_info.deleter, type_info.deleter);
|
||||
|
||||
// It's ok to register a new type with a user-provided type id.
|
||||
auto bar_id = TypeRegistry::TypeId(std::numeric_limits<int64_t>::max());
|
||||
TF_ASSERT_OK(TypeRegistry::RegisterExternalTypeId(
|
||||
"bar", TypeRegistry::TypeId(std::numeric_limits<int64_t>::max()),
|
||||
type_info));
|
||||
|
||||
// And a new type has a correct type info.
|
||||
TF_ASSERT_OK_AND_ASSIGN(TypeRegistry::TypeInfo bar_info,
|
||||
TypeRegistry::GetExternalTypeInfo(bar_id));
|
||||
EXPECT_EQ(bar_info.deleter, type_info.deleter);
|
||||
}
|
||||
|
||||
TEST(TypeRegistryTest, RegisterInternalTypeId) {
|
||||
auto int32_type_id = TypeRegistry::GetTypeId<int32_t>();
|
||||
auto int64_type_id = TypeRegistry::GetTypeId<int64_t>();
|
||||
EXPECT_NE(int32_type_id, int64_type_id);
|
||||
}
|
||||
|
||||
TEST(TypeRegistryTest, InternalTypeInfo) {
|
||||
int32_t* ptr = new int32_t{42};
|
||||
|
||||
TypeRegistry::TypeInfo type_info = TypeRegistry::GetTypeInfo<int32_t>();
|
||||
type_info.deleter(ptr);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla::ffi
|
||||
5
third_party/xla/xla/pjrt/c/BUILD
vendored
5
third_party/xla/xla/pjrt/c/BUILD
vendored
|
|
@ -82,9 +82,10 @@ cc_library(
|
|||
":pjrt_c_api_hdrs",
|
||||
":pjrt_c_api_helpers",
|
||||
":pjrt_c_api_wrapper_impl",
|
||||
"//xla/ffi",
|
||||
"//xla/ffi:execution_context",
|
||||
"//xla/ffi:ffi_api",
|
||||
"//xla/ffi:type_id_registry",
|
||||
"//xla/ffi:type_registry",
|
||||
"//xla/ffi/api:c_api",
|
||||
"//xla/ffi/api:ffi",
|
||||
"@com_google_absl//absl/status",
|
||||
|
|
@ -567,7 +568,7 @@ xla_test(
|
|||
"//xla/client:client_library",
|
||||
"//xla/ffi:execution_context",
|
||||
"//xla/ffi:ffi_api",
|
||||
"//xla/ffi:type_id_registry",
|
||||
"//xla/ffi:type_registry",
|
||||
"//xla/ffi/api:ffi",
|
||||
"//xla/pjrt:pjrt_common",
|
||||
"//xla/pjrt:pjrt_compiler",
|
||||
|
|
|
|||
|
|
@ -30,7 +30,24 @@ extern "C" {
|
|||
// and GPU backends it gives access to the XLA FFI internals.
|
||||
//
|
||||
// See: https://en.wikipedia.org/wiki/Foreign_function_interface
|
||||
#define PJRT_API_FFI_EXTENSION_VERSION 2
|
||||
#define PJRT_API_FFI_EXTENSION_VERSION 3
|
||||
|
||||
struct PJRT_FFI_Type_Info {
|
||||
void (*deleter)(void* object);
|
||||
void (*serialize)(); // placeholder for future use
|
||||
void (*deserialize)(); // placeholder for future use
|
||||
};
|
||||
|
||||
struct PJRT_FFI_Type_Register_Args {
|
||||
size_t struct_size;
|
||||
PJRT_Extension_Base* extension_start;
|
||||
|
||||
const char* type_name;
|
||||
size_t type_name_size;
|
||||
int64_t type_id; // in-out
|
||||
PJRT_FFI_Type_Info* type_info;
|
||||
};
|
||||
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Type_Register_Args, type_info);
|
||||
|
||||
struct PJRT_FFI_TypeID_Register_Args {
|
||||
size_t struct_size;
|
||||
|
|
@ -46,6 +63,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_TypeID_Register_Args, type_id);
|
|||
// XLA will assign a unique type id to it and return via out argument, otherwise
|
||||
// it will verify that user-provided type id matches previously registered type
|
||||
// id for the given type name.
|
||||
typedef PJRT_Error* PJRT_FFI_Type_Register(PJRT_FFI_Type_Register_Args* args);
|
||||
typedef PJRT_Error* PJRT_FFI_TypeID_Register(
|
||||
PJRT_FFI_TypeID_Register_Args* args);
|
||||
|
||||
|
|
@ -94,8 +112,9 @@ typedef struct PJRT_FFI_Extension {
|
|||
PJRT_FFI_TypeID_Register* type_id_register;
|
||||
PJRT_FFI_UserData_Add* user_data_add;
|
||||
PJRT_FFI_Register_Handler* register_handler;
|
||||
PJRT_FFI_Type_Register* type_register;
|
||||
} PJRT_FFI;
|
||||
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, register_handler);
|
||||
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, type_register);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,8 +19,9 @@ limitations under the License.
|
|||
#include "absl/strings/string_view.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/execution_context.h"
|
||||
#include "xla/ffi/ffi.h"
|
||||
#include "xla/ffi/ffi_api.h"
|
||||
#include "xla/ffi/type_id_registry.h"
|
||||
#include "xla/ffi/type_registry.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
|
||||
|
|
@ -35,20 +36,48 @@ static PJRT_Error* PJRT_FFI_TypeID_Register(
|
|||
PJRT_FFI_TypeID_Register_Args_STRUCT_SIZE, args->struct_size));
|
||||
|
||||
absl::string_view type_name(args->type_name, args->type_name_size);
|
||||
xla::ffi::TypeIdRegistry::TypeId type_id(args->type_id);
|
||||
xla::ffi::TypeRegistry::TypeId type_id(args->type_id);
|
||||
|
||||
if (type_id == xla::ffi::TypeIdRegistry::kUnknownTypeId) {
|
||||
if (type_id == xla::ffi::TypeRegistry::kUnknownTypeId) {
|
||||
// If type_id is unknown, we are registering a new type and XLA will assign
|
||||
// a unique type id to it.
|
||||
PJRT_ASSIGN_OR_RETURN(
|
||||
auto assigned_type_id,
|
||||
xla::ffi::TypeIdRegistry::AssignExternalTypeId(type_name));
|
||||
xla::ffi::TypeRegistry::AssignExternalTypeId(type_name, {}));
|
||||
args->type_id = assigned_type_id.value();
|
||||
|
||||
} else {
|
||||
// If type_id is set, we are relying on the caller-provided unique type id.
|
||||
PJRT_RETURN_IF_ERROR(
|
||||
xla::ffi::TypeIdRegistry::RegisterExternalTypeId(type_name, type_id));
|
||||
xla::ffi::TypeRegistry::RegisterExternalTypeId(type_name, type_id, {}));
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static PJRT_Error* PJRT_FFI_Type_Register(PJRT_FFI_Type_Register_Args* args) {
|
||||
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
|
||||
"PJRT_FFI_Type_Register_Args", PJRT_FFI_Type_Register_Args_STRUCT_SIZE,
|
||||
args->struct_size));
|
||||
|
||||
absl::string_view type_name(args->type_name, args->type_name_size);
|
||||
xla::ffi::TypeRegistry::TypeId type_id(args->type_id);
|
||||
xla::ffi::TypeRegistry::TypeInfo type_info = {
|
||||
args->type_info->deleter,
|
||||
};
|
||||
|
||||
if (type_id == xla::ffi::TypeRegistry::kUnknownTypeId) {
|
||||
// If type_id is unknown, we are registering a new type and XLA will assign
|
||||
// a unique type id to it.
|
||||
PJRT_ASSIGN_OR_RETURN(
|
||||
auto assigned_type_id,
|
||||
xla::ffi::TypeRegistry::AssignExternalTypeId(type_name, type_info));
|
||||
args->type_id = assigned_type_id.value();
|
||||
|
||||
} else {
|
||||
// If type_id is set, we are relying on the caller-provided unique type id.
|
||||
PJRT_RETURN_IF_ERROR(xla::ffi::TypeRegistry::RegisterExternalTypeId(
|
||||
type_name, type_id, type_info));
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
|
|
@ -64,7 +93,7 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) {
|
|||
"PJRT FFI extension requires execute context to be not nullptr")};
|
||||
}
|
||||
|
||||
xla::ffi::TypeIdRegistry::TypeId type_id(args->user_data.type_id);
|
||||
xla::ffi::TypeRegistry::TypeId type_id(args->user_data.type_id);
|
||||
PJRT_RETURN_IF_ERROR(args->context->execute_context->ffi_context().Insert(
|
||||
type_id, args->user_data.data, args->user_data.deleter));
|
||||
return nullptr;
|
||||
|
|
@ -102,6 +131,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
|
|||
/*type_id_register=*/PJRT_FFI_TypeID_Register,
|
||||
/*user_data_add=*/PJRT_FFI_UserData_Add,
|
||||
/*register_handler=*/PJRT_FFI_Register_Handler,
|
||||
/*type_register=*/PJRT_FFI_Type_Register,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ limitations under the License.
|
|||
#include "xla/ffi/api/ffi.h"
|
||||
#include "xla/ffi/execution_context.h"
|
||||
#include "xla/ffi/ffi_api.h"
|
||||
#include "xla/ffi/type_id_registry.h"
|
||||
#include "xla/ffi/type_registry.h"
|
||||
#include "xla/future.h"
|
||||
#include "xla/literal.h"
|
||||
#include "xla/literal_util.h"
|
||||
|
|
@ -366,7 +366,7 @@ TEST_F(PjrtCApiGpuTest, CreateAndDestroyExecuteContext) {
|
|||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto lookup_user_data,
|
||||
create_arg.context->execute_context->ffi_context().Lookup(
|
||||
xla::ffi::TypeIdRegistry::TypeId(42)));
|
||||
xla::ffi::TypeRegistry::TypeId(42)));
|
||||
EXPECT_EQ(lookup_user_data, &string_data);
|
||||
|
||||
PJRT_ExecuteContext_Destroy_Args destroy_args;
|
||||
|
|
|
|||
2
third_party/xla/xla/python/pjrt_ifrt/BUILD
vendored
2
third_party/xla/xla/python/pjrt_ifrt/BUILD
vendored
|
|
@ -324,7 +324,7 @@ cc_library(
|
|||
"//xla:util",
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/ffi:execution_context",
|
||||
"//xla/ffi:type_id_registry",
|
||||
"//xla/ffi:type_registry",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/hlo/translate/mhlo_to_hlo:type_to_shape",
|
||||
"//xla/pjrt:host_callback",
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ limitations under the License.
|
|||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "xla/ffi/execution_context.h"
|
||||
#include "xla/ffi/type_id_registry.h"
|
||||
#include "xla/ffi/type_registry.h"
|
||||
#include "xla/hlo/ir/hlo_sharding.h"
|
||||
#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h"
|
||||
#include "xla/layout.h"
|
||||
|
|
@ -768,8 +768,7 @@ PjRtLoadedExecutable::Execute(absl::Span<ArrayRef> args,
|
|||
}
|
||||
ffi_callbacks->callbacks = callbacks->data();
|
||||
ffi_callbacks->num_callbacks = callbacks->size();
|
||||
auto type_id = xla::ffi::TypeIdRegistry::TypeId(
|
||||
xla::FfiLoadedHostCallbacks::id.type_id);
|
||||
ffi::TypeRegistry::TypeId type_id(FfiLoadedHostCallbacks::id.type_id);
|
||||
CHECK_OK(context->ffi_context().Insert(type_id, ffi_callbacks.get()));
|
||||
opts.context = context.get();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -351,7 +351,7 @@ NB_MODULE(py_hlo_multihost_runner, m) {
|
|||
m.def("custom_call_targets", GetRegisteredCustomCallTargets,
|
||||
nb::arg("platform"));
|
||||
m.def(
|
||||
"register_custom_type_id",
|
||||
"register_custom_type",
|
||||
[](absl::string_view type_name, nb::object type_id) {
|
||||
xla::ThrowIfError(RegisterCustomTypeId(type_name, type_id));
|
||||
},
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user