diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD index 2e3c8825e84..09bbeaa657c 100644 --- a/third_party/xla/xla/ffi/BUILD +++ b/third_party/xla/xla/ffi/BUILD @@ -62,7 +62,7 @@ cc_library( srcs = ["execution_context.cc"], hdrs = ["execution_context.h"], deps = [ - ":type_id_registry", + ":type_registry", "//xla:util", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", @@ -79,7 +79,7 @@ xla_cc_test( srcs = ["execution_context_test.cc"], deps = [ ":execution_context", - ":type_id_registry", + ":type_registry", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", @@ -94,13 +94,14 @@ cc_library( srcs = ["execution_state.cc"], hdrs = ["execution_state.h"], deps = [ - ":type_id_registry", + ":type_registry", "//xla:util", + "//xla/tsl/platform:statusor", + "//xla/tsl/util:safe_reinterpret_cast", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", ], ) @@ -109,11 +110,12 @@ xla_cc_test( srcs = ["execution_state_test.cc"], deps = [ ":execution_state", + ":type_registry", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", ], ) @@ -124,7 +126,7 @@ cc_library( ":api", ":execution_context", ":execution_state", - ":type_id_registry", + ":type_registry", "//xla:executable_run_options", "//xla:shape_util", "//xla:types", @@ -141,6 +143,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -160,7 +163,7 @@ cc_library( ":call_frame", ":execution_context", ":execution_state", - ":type_id_registry", + ":type_registry", "//xla:executable_run_options", "//xla:util", "//xla/ffi/api:c_api", @@ -218,7 +221,7 @@ xla_cc_test( ":execution_state", ":ffi", ":ffi_api", - ":type_id_registry", + ":type_registry", "//xla:executable_run_options", "//xla:xla_data_proto_cc", "//xla/ffi/api:c_api", @@ -243,27 +246,31 @@ xla_cc_test( ) cc_library( - name = "type_id_registry", - srcs = ["type_id_registry.cc"], - hdrs = ["type_id_registry.h"], + name = "type_registry", + srcs = ["type_registry.cc"], + hdrs = ["type_registry.h"], deps = [ "//xla:util", "//xla/tsl/lib/gtl:int_type", + "//xla/tsl/util:safe_reinterpret_cast", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", ], ) xla_cc_test( - name = "type_id_registry_test", - srcs = ["type_id_registry_test.cc"], + name = "type_registry_test", + srcs = ["type_registry_test.cc"], deps = [ - ":type_id_registry", + ":type_registry", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", diff --git a/third_party/xla/xla/ffi/api/BUILD b/third_party/xla/xla/ffi/api/BUILD index 5b48e05c8f3..30f8fcddf3b 100644 --- a/third_party/xla/xla/ffi/api/BUILD +++ b/third_party/xla/xla/ffi/api/BUILD @@ -89,7 +89,7 @@ xla_cc_test( "//xla/ffi:execution_context", "//xla/ffi:execution_state", "//xla/ffi:ffi_api", - "//xla/ffi:type_id_registry", + "//xla/ffi:type_registry", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "//xla/tsl/concurrency:async_value", diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h index caa8b94e481..ce4e6e78944 100644 --- a/third_party/xla/xla/ffi/api/api.h +++ b/third_party/xla/xla/ffi/api/api.h @@ -348,11 +348,13 @@ inline XLA_FFI_Error* Ffi::RegisterTypeId(const XLA_FFI_Api* api, std::string_view name, XLA_FFI_TypeId* type_id, XLA_FFI_TypeInfo type_info) { + assert(type_id && "type_id must not be null"); XLA_FFI_TypeId_Register_Args args; args.struct_size = XLA_FFI_TypeId_Register_Args_STRUCT_SIZE; args.extension_start = nullptr; args.name = XLA_FFI_ByteSpan{name.data(), name.size()}; args.type_id = type_id; + args.type_info = &type_info; return api->XLA_FFI_TypeId_Register(&args); } diff --git a/third_party/xla/xla/ffi/api/c_api.h b/third_party/xla/xla/ffi/api/c_api.h index 229f31866b9..58dd228f0b2 100644 --- a/third_party/xla/xla/ffi/api/c_api.h +++ b/third_party/xla/xla/ffi/api/c_api.h @@ -67,7 +67,7 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Extension_Base, next); // * Deleting a method or argument // * Changing the type of an argument // * Rearranging fields in the XLA_FFI_Api or argument structs -#define XLA_FFI_API_MAJOR 0 +#define XLA_FFI_API_MAJOR 1 // Incremented when the interface is updated in a way that is potentially // ABI-compatible with older versions, if supported by the caller and/or @@ -82,7 +82,7 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Extension_Base, next); // Minor changes include: // * Adding a new field to the XLA_FFI_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define XLA_FFI_API_MINOR 1 +#define XLA_FFI_API_MINOR 0 struct XLA_FFI_Api_Version { size_t struct_size; @@ -491,6 +491,7 @@ struct XLA_FFI_TypeId_Register_Args { XLA_FFI_ByteSpan name; XLA_FFI_TypeId* type_id; // in-out + XLA_FFI_TypeInfo* type_info; }; XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_TypeId_Register_Args, type_id); diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index bc98f6b9db5..0a3f0dbf9c1 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -437,7 +437,7 @@ class CountDownPromise { assert(state_->count.load() >= count && "Invalid count down value"); if (XLA_FFI_PREDICT_FALSE(!error.success())) { - const std::lock_guard lock(state_->mutex); + std::lock_guard lock(state_->mutex); // NOLINT state_->is_error.store(true, std::memory_order_release); state_->error = error; } @@ -448,7 +448,7 @@ class CountDownPromise { bool is_error = state_->is_error.load(std::memory_order_acquire); if (XLA_FFI_PREDICT_FALSE(is_error)) { auto take_error = [&] { - const std::lock_guard lock(state_->mutex); + std::lock_guard lock(state_->mutex); // NOLINT return state_->error; }; state_->promise.SetError(take_error()); @@ -476,7 +476,7 @@ class CountDownPromise { std::atomic count; std::atomic is_error; - std::mutex mutex; + std::mutex mutex; // NOLINT Error error; }; diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc index dce80f46aa9..aa436822a43 100644 --- a/third_party/xla/xla/ffi/api/ffi_test.cc +++ b/third_party/xla/xla/ffi/api/ffi_test.cc @@ -39,7 +39,7 @@ limitations under the License. #include "xla/ffi/execution_context.h" #include "xla/ffi/execution_state.h" #include "xla/ffi/ffi_api.h" -#include "xla/ffi/type_id_registry.h" +#include "xla/ffi/type_registry.h" #include "xla/primitive_util.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" @@ -1220,9 +1220,9 @@ TEST(FfiTest, UserData) { ExecutionContext execution_context; TF_ASSERT_OK(execution_context.Insert( - TypeIdRegistry::TypeId(MyDataWithAutoTypeId::id.type_id), &data0)); + TypeRegistry::TypeId(MyDataWithAutoTypeId::id.type_id), &data0)); TF_ASSERT_OK(execution_context.Insert( - TypeIdRegistry::TypeId(MyDataWithExplicitTypeId::id.type_id), &data1)); + TypeRegistry::TypeId(MyDataWithExplicitTypeId::id.type_id), &data1)); CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0); auto call_frame = builder.Build(); diff --git a/third_party/xla/xla/ffi/execution_context.h b/third_party/xla/xla/ffi/execution_context.h index f6c76fec84e..827c85cdc45 100644 --- a/third_party/xla/xla/ffi/execution_context.h +++ b/third_party/xla/xla/ffi/execution_context.h @@ -25,7 +25,7 @@ limitations under the License. #include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "xla/ffi/type_id_registry.h" +#include "xla/ffi/type_registry.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" @@ -45,7 +45,7 @@ namespace xla::ffi { // unique between separate calls to XLA execute. class ExecutionContext { public: - using TypeId = TypeIdRegistry::TypeId; + using TypeId = TypeRegistry::TypeId; template using Deleter = std::function; @@ -67,7 +67,7 @@ class ExecutionContext { template absl::StatusOr Lookup() const { TF_ASSIGN_OR_RETURN(auto user_data, - LookupUserData(TypeIdRegistry::GetTypeId())); + LookupUserData(TypeRegistry::GetTypeId())); return static_cast(user_data->data()); } @@ -110,7 +110,7 @@ class ExecutionContext { template absl::Status ExecutionContext::Insert(T* data, Deleter deleter) { - return InsertUserData(TypeIdRegistry::GetTypeId(), + return InsertUserData(TypeRegistry::GetTypeId(), std::make_unique( data, [deleter = std::move(deleter)](void* data) { if (deleter) deleter(static_cast(data)); @@ -119,7 +119,7 @@ absl::Status ExecutionContext::Insert(T* data, Deleter deleter) { template absl::Status ExecutionContext::Emplace(Args&&... args) { - return InsertUserData(TypeIdRegistry::GetTypeId(), + return InsertUserData(TypeRegistry::GetTypeId(), std::make_unique( new T(std::forward(args)...), [](void* data) { delete static_cast(data); })); diff --git a/third_party/xla/xla/ffi/execution_context_test.cc b/third_party/xla/xla/ffi/execution_context_test.cc index 31439ff9562..11739ff8178 100644 --- a/third_party/xla/xla/ffi/execution_context_test.cc +++ b/third_party/xla/xla/ffi/execution_context_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "absl/status/status.h" -#include "xla/ffi/type_id_registry.h" +#include "xla/ffi/type_registry.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" @@ -62,8 +62,9 @@ TEST(ExecutionContextTest, InsertUserOwned) { } TEST(ExecutionContextTest, InsertUserOwnedWithTypeId) { - TF_ASSERT_OK_AND_ASSIGN(TypeIdRegistry::TypeId type_id, - TypeIdRegistry::AssignExternalTypeId("I32UserData")); + TF_ASSERT_OK_AND_ASSIGN(TypeRegistry::TypeId type_id, + TypeRegistry::AssignExternalTypeId( + "I32UserData", TypeRegistry::TypeInfo{})); I32UserData user_data(42); diff --git a/third_party/xla/xla/ffi/execution_state.cc b/third_party/xla/xla/ffi/execution_state.cc index 5aab4a7a3a5..87916729ef7 100644 --- a/third_party/xla/xla/ffi/execution_state.cc +++ b/third_party/xla/xla/ffi/execution_state.cc @@ -15,38 +15,54 @@ limitations under the License. #include "xla/ffi/execution_state.h" -#include - +#include "absl/base/attributes.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "xla/ffi/type_id_registry.h" +#include "xla/ffi/type_registry.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" -#include "tsl/platform/logging.h" namespace xla::ffi { ExecutionState::ExecutionState() - : type_id_(TypeIdRegistry::kUnknownTypeId), - state_(nullptr), - deleter_(nullptr) {} + : type_id_(TypeRegistry::kUnknownTypeId), state_(nullptr) {} ExecutionState::~ExecutionState() { - if (deleter_) deleter_(state_); + if (type_info_.deleter) { + type_info_.deleter(state_); + } } -absl::Status ExecutionState::Set(TypeId type_id, void* state, - Deleter deleter) { - DCHECK(state && deleter) << "State and deleter must not be null"; +absl::Status ExecutionState::Set(TypeId type_id, void* state) { + TF_ASSIGN_OR_RETURN(auto type_info, + TypeRegistry::GetExternalTypeInfo(type_id)); + if (type_info.deleter == nullptr) { + return InvalidArgument( + "Type id %d does not have a registered type info with a deleter", + type_id.value()); + } + return Set(type_id, type_info, state); +} - if (type_id_ != TypeIdRegistry::kUnknownTypeId) { +ABSL_DEPRECATED("FFI users must rely in TypeInfo registration") +absl::Status ExecutionState::Set(TypeId type_id, void* state, + void (*deleter)(void*)) { + return Set(type_id, TypeInfo{deleter}, state); +} + +absl::Status ExecutionState::Set(TypeId type_id, TypeInfo type_info, + void* state) { + DCHECK(state && type_info.deleter) << "State and deleter must not be null"; + + if (type_id_ != TypeRegistry::kUnknownTypeId) { return FailedPrecondition("State is already set with a type id %d", type_id_.value()); } type_id_ = type_id; + type_info_ = type_info; state_ = state; - deleter_ = std::move(deleter); return absl::OkStatus(); } @@ -54,7 +70,7 @@ absl::Status ExecutionState::Set(TypeId type_id, void* state, // Returns opaque state of the given type id. If set state type id does not // match the requested one, returns an error. absl::StatusOr ExecutionState::Get(TypeId type_id) const { - if (type_id_ == TypeIdRegistry::kUnknownTypeId) { + if (type_id_ == TypeRegistry::kUnknownTypeId) { return NotFound("State is not set"); } @@ -68,7 +84,7 @@ absl::StatusOr ExecutionState::Get(TypeId type_id) const { } bool ExecutionState::IsSet() const { - return type_id_ != TypeIdRegistry::kUnknownTypeId; + return type_id_ != TypeRegistry::kUnknownTypeId; } } // namespace xla::ffi diff --git a/third_party/xla/xla/ffi/execution_state.h b/third_party/xla/xla/ffi/execution_state.h index f2f263344aa..3be4d2339e8 100644 --- a/third_party/xla/xla/ffi/execution_state.h +++ b/third_party/xla/xla/ffi/execution_state.h @@ -16,13 +16,13 @@ limitations under the License. #ifndef XLA_FFI_EXECUTION_STATE_H_ #define XLA_FFI_EXECUTION_STATE_H_ -#include #include #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "xla/ffi/type_id_registry.h" -#include "tsl/platform/statusor.h" +#include "xla/ffi/type_registry.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/util/safe_reinterpret_cast.h" namespace xla::ffi { @@ -41,10 +41,8 @@ namespace xla::ffi { // class ExecutionState { public: - using TypeId = TypeIdRegistry::TypeId; - - template - using Deleter = std::function; + using TypeId = TypeRegistry::TypeId; + using TypeInfo = TypeRegistry::TypeInfo; ExecutionState(); ~ExecutionState(); @@ -52,9 +50,13 @@ class ExecutionState { ExecutionState(const ExecutionState&) = delete; ExecutionState& operator=(const ExecutionState&) = delete; - // Sets opaque state with a given type id and deleter. Returns an error if - // state is already set. - absl::Status Set(TypeId type_id, void* state, Deleter deleter); + // Sets opaque state with a given type id. Returns an error if state is + // already set, or if type id is not supported as a state. + absl::Status Set(TypeId type_id, void* state); + + // Sets opaque state with a given type id and custom deleter. Returns an error + // if state is already set, or if type id is not supported as a state. + absl::Status Set(TypeId type_id, void* state, void (*deleter)(void*)); // Returns opaque state of the given type id. If set state type id does not // match the requested one, returns an error. @@ -73,21 +75,23 @@ class ExecutionState { bool IsSet() const; private: + absl::Status Set(TypeId type_id, TypeInfo type_info, void* state); + TypeId type_id_; + TypeInfo type_info_; void* state_; - Deleter deleter_; }; template absl::Status ExecutionState::Set(std::unique_ptr state) { - return Set(TypeIdRegistry::GetTypeId(), state.release(), - [](void* state) { delete reinterpret_cast(state); }); + return Set(TypeRegistry::GetTypeId(), TypeRegistry::GetTypeInfo(), + state.release()); } template absl::StatusOr ExecutionState::Get() const { - TF_ASSIGN_OR_RETURN(void* state, Get(TypeIdRegistry::GetTypeId())); - return reinterpret_cast(state); + TF_ASSIGN_OR_RETURN(void* state, Get(TypeRegistry::GetTypeId())); + return tsl::safe_reinterpret_cast(state); } } // namespace xla::ffi diff --git a/third_party/xla/xla/ffi/execution_state_test.cc b/third_party/xla/xla/ffi/execution_state_test.cc index d32c80f6d92..f3183b0bdb0 100644 --- a/third_party/xla/xla/ffi/execution_state_test.cc +++ b/third_party/xla/xla/ffi/execution_state_test.cc @@ -20,9 +20,10 @@ limitations under the License. #include #include +#include "xla/ffi/type_registry.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla::ffi { @@ -30,7 +31,7 @@ using TypeId = ExecutionState::TypeId; using ::testing::HasSubstr; -TEST(ExecutionStateTest, SetAndGet) { +TEST(ExecutionStateTest, SetAndGetForInternalType) { ExecutionState state; EXPECT_FALSE(state.IsSet()); @@ -52,4 +53,34 @@ TEST(ExecutionStateTest, SetAndGet) { EXPECT_EQ(*data, 42); } +TEST(ExecutionStateTest, SetAndGetForExternalType) { + ExecutionState state; + EXPECT_FALSE(state.IsSet()); + + { // Empty state returns an error from Get(). + auto data = state.Get(TypeId(1)); + EXPECT_THAT(data.status().message(), HasSubstr("State is not set")); + } + + { // Empty state returns an error from Get(). + auto data = state.Get(); + EXPECT_THAT(data.status().message(), HasSubstr("State is not set")); + } + + TypeRegistry::TypeInfo type_info = { + [](void* ptr) { delete static_cast(ptr); }}; + TF_ASSERT_OK_AND_ASSIGN( + TypeRegistry::TypeId type_id, + TypeRegistry::AssignExternalTypeId("int32_t", type_info)); + + int32_t* value = new int32_t(42); + + // Once set, state can be retrieved. + TF_ASSERT_OK(state.Set(type_id, value)); + EXPECT_TRUE(state.IsSet()); + + TF_ASSERT_OK_AND_ASSIGN(void* data, state.Get(type_id)); + EXPECT_EQ(data, value); +} + } // namespace xla::ffi diff --git a/third_party/xla/xla/ffi/ffi.h b/third_party/xla/xla/ffi/ffi.h index 7a61cd96990..39eab7724e7 100644 --- a/third_party/xla/xla/ffi/ffi.h +++ b/third_party/xla/xla/ffi/ffi.h @@ -37,6 +37,7 @@ limitations under the License. #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -46,7 +47,7 @@ limitations under the License. #include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep #include "xla/ffi/execution_context.h" #include "xla/ffi/execution_state.h" -#include "xla/ffi/type_id_registry.h" +#include "xla/ffi/type_registry.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/primitive_util.h" #include "xla/stream_executor/device_memory.h" @@ -724,7 +725,7 @@ template struct ResultEncoding>> { static XLA_FFI_TypeId state_type_id() { - return XLA_FFI_TypeId{TypeIdRegistry::GetTypeId().value()}; + return XLA_FFI_TypeId{TypeRegistry::GetTypeId().value()}; } static XLA_FFI_Error* Encode(const XLA_FFI_Api* api, diff --git a/third_party/xla/xla/ffi/ffi_api.cc b/third_party/xla/xla/ffi/ffi_api.cc index d064953a106..39244ff3ce3 100644 --- a/third_party/xla/xla/ffi/ffi_api.cc +++ b/third_party/xla/xla/ffi/ffi_api.cc @@ -43,7 +43,7 @@ limitations under the License. #include "xla/ffi/call_frame.h" #include "xla/ffi/execution_context.h" #include "xla/ffi/execution_state.h" -#include "xla/ffi/type_id_registry.h" +#include "xla/ffi/type_registry.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/device_memory.h" @@ -659,12 +659,14 @@ static XLA_FFI_Error* XLA_FFI_TypeId_Register( XLA_FFI_ExecutionContext_Get_Args_STRUCT_SIZE, args->struct_size)); absl::string_view type_name(args->name.ptr, args->name.len); - TypeIdRegistry::TypeId type_id(args->type_id->type_id); + TypeRegistry::TypeId type_id(args->type_id->type_id); + TypeRegistry::TypeInfo type_info = {args->type_info->deleter}; // If type_id is unknown, we are registering a new type and XLA will assign a // unique type id to it. - if (type_id == TypeIdRegistry::kUnknownTypeId) { - auto assigned_type_id = TypeIdRegistry::AssignExternalTypeId(type_name); + if (type_id == TypeRegistry::kUnknownTypeId) { + auto assigned_type_id = + TypeRegistry::AssignExternalTypeId(type_name, type_info); if (!assigned_type_id.ok()) { return new XLA_FFI_Error{std::move(assigned_type_id).status()}; } @@ -674,9 +676,10 @@ static XLA_FFI_Error* XLA_FFI_TypeId_Register( } // If type_id is set, we are relying on the caller-provided unique type id. - if (auto status = TypeIdRegistry::RegisterExternalTypeId(type_name, type_id); - !status.ok()) { - return new XLA_FFI_Error{std::move(status)}; + auto registered_type_id = + TypeRegistry::RegisterExternalTypeId(type_name, type_id, type_info); + if (!registered_type_id.ok()) { + return new XLA_FFI_Error{std::move(registered_type_id)}; } return nullptr; @@ -690,7 +693,7 @@ static XLA_FFI_Error* XLA_FFI_ExecutionContext_Get( DCHECK(args->ctx->execution_context) << "ExecutionContext must be set"; auto user_data = args->ctx->execution_context->Lookup( - TypeIdRegistry::TypeId(args->type_id->type_id)); + TypeRegistry::TypeId(args->type_id->type_id)); if (!user_data.ok()) { return new XLA_FFI_Error{std::move(user_data).status()}; } @@ -705,9 +708,16 @@ static XLA_FFI_Error* XLA_FFI_State_Set(XLA_FFI_State_Set_Args* args) { args->struct_size)); DCHECK(args->ctx->execution_state) << "ExecutionState must be set"; - absl::Status status = args->ctx->execution_state->Set( - TypeIdRegistry::TypeId(args->type_id->type_id), args->state, - [deleter = args->deleter](void* state) { deleter(state); }); + + absl::Status status; + if (args->deleter == nullptr) { + status = args->ctx->execution_state->Set( + TypeRegistry::TypeId(args->type_id->type_id), args->state); + } else { + status = args->ctx->execution_state->Set( + TypeRegistry::TypeId(args->type_id->type_id), args->state, + args->deleter); + } if (!status.ok()) { return new XLA_FFI_Error{std::move(status)}; @@ -723,7 +733,7 @@ static XLA_FFI_Error* XLA_FFI_State_Get(XLA_FFI_State_Get_Args* args) { DCHECK(args->ctx->execution_state) << "ExecutionState must be set"; absl::StatusOr state = args->ctx->execution_state->Get( - TypeIdRegistry::TypeId(args->type_id->type_id)); + TypeRegistry::TypeId(args->type_id->type_id)); if (!state.ok()) { return new XLA_FFI_Error{std::move(state).status()}; } diff --git a/third_party/xla/xla/ffi/ffi_test.cc b/third_party/xla/xla/ffi/ffi_test.cc index de2481dbad7..47b3a85fa1a 100644 --- a/third_party/xla/xla/ffi/ffi_test.cc +++ b/third_party/xla/xla/ffi/ffi_test.cc @@ -39,7 +39,7 @@ limitations under the License. #include "xla/ffi/execution_context.h" #include "xla/ffi/execution_state.h" #include "xla/ffi/ffi_api.h" -#include "xla/ffi/type_id_registry.h" +#include "xla/ffi/type_registry.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -1102,7 +1102,7 @@ TEST(FfiTest, Metadata) { EXPECT_EQ(metadata.api_version.major_version, XLA_FFI_API_MAJOR); EXPECT_EQ(metadata.api_version.minor_version, XLA_FFI_API_MINOR); - TypeIdRegistry::TypeId type_id = TypeIdRegistry::GetTypeId(); + TypeRegistry::TypeId type_id = TypeRegistry::GetTypeId(); EXPECT_EQ(metadata.state_type_id.type_id, type_id); } diff --git a/third_party/xla/xla/ffi/type_id_registry.cc b/third_party/xla/xla/ffi/type_id_registry.cc deleted file mode 100644 index 7fdcfe66e2f..00000000000 --- a/third_party/xla/xla/ffi/type_id_registry.cc +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/ffi/type_id_registry.h" - -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/base/attributes.h" -#include "absl/base/const_init.h" -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "xla/util.h" - -namespace xla::ffi { - -ABSL_CONST_INIT absl::Mutex type_registry_mutex(absl::kConstInit); - -using ExternalTypeIdRegistry = - absl::flat_hash_map; - -static ExternalTypeIdRegistry& StaticExternalTypeIdRegistry() { - static auto* const registry = new ExternalTypeIdRegistry(); - return *registry; -} - -TypeIdRegistry::TypeId TypeIdRegistry::GetNextInternalTypeId() { - static auto* counter = new std::atomic(1); - return TypeId(counter->fetch_add(1)); -} - -TypeIdRegistry::TypeId TypeIdRegistry::GetNextExternalTypeId() { - static auto* counter = new std::atomic(1); - return TypeId(counter->fetch_add(1)); -} - -absl::StatusOr TypeIdRegistry::AssignExternalTypeId( - absl::string_view name) { - absl::MutexLock lock(type_registry_mutex); - auto& registry = StaticExternalTypeIdRegistry(); - - // Try to emplace with unknow type id and fill it with real type id only if we - // successfully acquired an entry for a given name. - auto emplaced = registry.emplace(name, kUnknownTypeId); - if (!emplaced.second) { - return Internal("Type name %s already registered with type id %d", name, - emplaced.first->second.value()); - } - - // Returns true if the registry contains an entry with a given type id. - auto type_id_is_in_use = [®istry](TypeId type_id) { - return absl::c_any_of(registry, - [&](const auto& e) { return e.second == type_id; }); - }; - - // Create a new type id that is not already in use. - TypeId type_id = GetNextExternalTypeId(); - while (type_id_is_in_use(type_id)) { - type_id = GetNextExternalTypeId(); - } - - return emplaced.first->second = type_id; -} - -absl::Status TypeIdRegistry::RegisterExternalTypeId(absl::string_view name, - TypeId type_id) { - absl::MutexLock lock(type_registry_mutex); - auto& registry = StaticExternalTypeIdRegistry(); - - auto emplaced = registry.emplace(name, type_id); - if (!emplaced.second && emplaced.first->second != type_id) { - return Internal("Type name %s already registered with type id %d vs %d)", - name, emplaced.first->second.value(), type_id.value()); - } - - return absl::OkStatus(); -} - -} // namespace xla::ffi diff --git a/third_party/xla/xla/ffi/type_id_registry_test.cc b/third_party/xla/xla/ffi/type_id_registry_test.cc deleted file mode 100644 index 7e555291afb..00000000000 --- a/third_party/xla/xla/ffi/type_id_registry_test.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/ffi/type_id_registry.h" - -#include -#include - -#include -#include -#include "absl/status/status.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/tsl/platform/test.h" - -namespace xla::ffi { -namespace { - -using ::testing::HasSubstr; - -TEST(TypeIdRegistryTest, RegisterExternalTypeId) { - TF_ASSERT_OK_AND_ASSIGN(auto type_id, - TypeIdRegistry::AssignExternalTypeId("foo")); - EXPECT_GE(type_id.value(), 0); - - auto duplicate_type_id = TypeIdRegistry::AssignExternalTypeId("foo"); - EXPECT_THAT(duplicate_type_id.status().message(), - HasSubstr("Type name foo already registered with type id")); - - // It's ok to register the same type with same type id. - TF_ASSERT_OK(TypeIdRegistry::RegisterExternalTypeId("foo", type_id)); - - // It's an error to register the same type with a different type id. - auto wrong_type_id = TypeIdRegistry::RegisterExternalTypeId( - "foo", TypeIdRegistry::TypeId(std::numeric_limits::max())); - EXPECT_THAT(wrong_type_id.message(), - HasSubstr("Type name foo already registered with type id")); - - // It's ok to register a new type with a user-provided type id. - TF_ASSERT_OK(TypeIdRegistry::RegisterExternalTypeId( - "bar", TypeIdRegistry::TypeId(std::numeric_limits::max()))); -} - -TEST(TypeIdRegistryTest, RegisterInternalTypeId) { - auto int32_type_id = TypeIdRegistry::GetTypeId(); - auto int64_type_id = TypeIdRegistry::GetTypeId(); - EXPECT_NE(int32_type_id, int64_type_id); -} - -} // namespace -} // namespace xla::ffi diff --git a/third_party/xla/xla/ffi/type_registry.cc b/third_party/xla/xla/ffi/type_registry.cc new file mode 100644 index 00000000000..e6b8ea02898 --- /dev/null +++ b/third_party/xla/xla/ffi/type_registry.cc @@ -0,0 +1,134 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/ffi/type_registry.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/base/const_init.h" +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "xla/util.h" + +namespace xla::ffi { +namespace { + +struct TypeRegistration { + TypeRegistry::TypeId type_id; + TypeRegistry::TypeInfo type_info; +}; + +using ExternalTypeRegistry = absl::flat_hash_map; + +} // namespace + +ABSL_CONST_INIT absl::Mutex type_registry_mutex(absl::kConstInit); + +static ExternalTypeRegistry& StaticExternalTypeRegistry() { + static absl::NoDestructor registry; + return *registry; +} + +TypeRegistry::TypeId TypeRegistry::GetNextInternalTypeId() { + static auto* counter = new std::atomic(1); + return TypeId(counter->fetch_add(1)); +} + +TypeRegistry::TypeId TypeRegistry::GetNextExternalTypeId() { + static auto* counter = new std::atomic(1); + return TypeId(counter->fetch_add(1)); +} + +absl::StatusOr TypeRegistry::AssignExternalTypeId( + absl::string_view name, TypeInfo type_info) { + VLOG(3) << absl::StrFormat("Assign external type id: name=%s", name); + + absl::MutexLock lock(type_registry_mutex); + auto& registry = StaticExternalTypeRegistry(); + + // Try to emplace with unknow type id and fill it with real type id only if we + // successfully acquired an entry for a given name. + auto emplaced = + registry.emplace(name, TypeRegistration{kUnknownTypeId, type_info}); + if (!emplaced.second) { + return Internal("Type name %s already registered with type id %d", name, + emplaced.first->second.type_id.value()); + } + + // Returns true if the registry contains an entry with a given type id. + auto type_id_is_in_use = [®istry](TypeId type_id) { + return absl::c_any_of( + registry, [&](const auto& e) { return e.second.type_id == type_id; }); + }; + + // Create a new type id that is not already in use. + TypeId type_id = GetNextExternalTypeId(); + while (type_id_is_in_use(type_id)) { + type_id = GetNextExternalTypeId(); + } + + VLOG(3) << absl::StrFormat("Assigned external type id: name=%s type_id=%d", + name, type_id.value()); + return emplaced.first->second.type_id = type_id; +} + +absl::Status TypeRegistry::RegisterExternalTypeId(absl::string_view name, + TypeId type_id, + TypeInfo type_info) { + VLOG(3) << absl::StrFormat("Register external type id: name=%s type_id=%d", + name, type_id.value()); + + absl::MutexLock lock(type_registry_mutex); + auto& registry = StaticExternalTypeRegistry(); + + auto emplaced = registry.emplace(name, TypeRegistration{type_id, type_info}); + if (!emplaced.second && emplaced.first->second.type_id != type_id) { + return Internal("Type name %s already registered with type id %d vs %d)", + name, emplaced.first->second.type_id.value(), + type_id.value()); + } + + return absl::OkStatus(); +} + +absl::StatusOr TypeRegistry::GetExternalTypeInfo( + TypeId type_id) { + absl::MutexLock lock(type_registry_mutex); + auto& registry = StaticExternalTypeRegistry(); + + auto it = absl::c_find_if(registry, [&](const auto& kv) { + auto& [name, registration] = kv; + return registration.type_id == type_id; + }); + + if (it == registry.end()) { + return Internal("Type id %d is not registered with a static registry", + type_id.value()); + } + + return it->second.type_info; +} + +} // namespace xla::ffi diff --git a/third_party/xla/xla/ffi/type_id_registry.h b/third_party/xla/xla/ffi/type_registry.h similarity index 65% rename from third_party/xla/xla/ffi/type_id_registry.h rename to third_party/xla/xla/ffi/type_registry.h index 283ec977779..8c61214ddc0 100644 --- a/third_party/xla/xla/ffi/type_id_registry.h +++ b/third_party/xla/xla/ffi/type_registry.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_FFI_TYPE_ID_REGISTRY_H_ -#define XLA_FFI_TYPE_ID_REGISTRY_H_ +#ifndef XLA_FFI_TYPE_REGISTRY_H_ +#define XLA_FFI_TYPE_REGISTRY_H_ #include @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/tsl/lib/gtl/int_type.h" +#include "xla/tsl/util/safe_reinterpret_cast.h" namespace xla::ffi { @@ -43,26 +44,50 @@ namespace xla::ffi { // // 2. Internal type id. When FFI handler defined in the same binary we rely // on a global static registry to automatically assign type ids. -class TypeIdRegistry { +// +// TypeInfo defines a set of functions that allow XLA runtime to manipulate +// external types. For user data, that is forwarded to FFI handlers, they all +// can be `nullptr` as XLA runtime doesn't manage their lifetime. For stateful +// handlers, XLA runtime at least must know how to destroy the state when XLA +// executable is destroyed. +class TypeRegistry { public: + // Unique (within a process) identifier for a type. TSL_LIB_GTL_DEFINE_INT_TYPE(TypeId, int64_t); static constexpr TypeId kUnknownTypeId = TypeId(0); + // Pointers to functions that allow XLA runtime to manipulate external types. + struct TypeInfo { + using Deleter = void (*)(void*); + + Deleter deleter = nullptr; + }; + // Assigns a unique type id to an external type with a given name. Returns an // error if a type with a given name is already registered in the process. - static absl::StatusOr AssignExternalTypeId(absl::string_view name); + static absl::StatusOr AssignExternalTypeId(absl::string_view name, + TypeInfo type_info); // Registers external type with a given name and type id. Type id is provided // by the caller, and must be unique. Returns an error if a type with a given // name is already registered with a different type id. static absl::Status RegisterExternalTypeId(absl::string_view name, - TypeId type_id); + TypeId type_id, + TypeInfo type_info); + + // Returns type info for a given external type id. Returns an error if type + // id is not registered. + static absl::StatusOr GetExternalTypeInfo(TypeId type_id); // Returns a type id for a given type. For internal type ids only. template static TypeId GetTypeId(); + // Returns type info for a given type id. For internal type ids only. + template + static TypeInfo GetTypeInfo(); + private: // We never mix external and internal type ids, so we can use different type // id spaces to assign unique ids to each type. @@ -71,11 +96,18 @@ class TypeIdRegistry { }; template -TypeIdRegistry::TypeId TypeIdRegistry::GetTypeId() { +TypeRegistry::TypeId TypeRegistry::GetTypeId() { static const TypeId id = GetNextInternalTypeId(); return id; } +template +TypeRegistry::TypeInfo TypeRegistry::GetTypeInfo() { + return TypeInfo{ + [](void* state) { delete tsl::safe_reinterpret_cast(state); }, + }; +} + } // namespace xla::ffi -#endif // XLA_FFI_TYPE_ID_REGISTRY_H_ +#endif // XLA_FFI_TYPE_REGISTRY_H_ diff --git a/third_party/xla/xla/ffi/type_registry_test.cc b/third_party/xla/xla/ffi/type_registry_test.cc new file mode 100644 index 00000000000..c39110b9389 --- /dev/null +++ b/third_party/xla/xla/ffi/type_registry_test.cc @@ -0,0 +1,85 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/ffi/type_registry.h" + +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" + +namespace xla::ffi { +namespace { + +using ::testing::HasSubstr; + +TEST(TypeRegistryTest, RegisterExternalTypeId) { + TypeRegistry::TypeInfo type_info = {+[](void* state) {}}; + + TF_ASSERT_OK_AND_ASSIGN(auto foo_id, + TypeRegistry::AssignExternalTypeId("foo", type_info)); + EXPECT_GE(foo_id.value(), 0); + + auto duplicate_foo_id = TypeRegistry::AssignExternalTypeId("foo", type_info); + EXPECT_THAT(duplicate_foo_id.status().message(), + HasSubstr("Type name foo already registered with type id")); + + // It's ok to register the same type with same type id. + TF_ASSERT_OK(TypeRegistry::RegisterExternalTypeId("foo", foo_id, type_info)); + + // It's an error to register the same type with a different type id. + auto wrong_foo_id = TypeRegistry::RegisterExternalTypeId( + "foo", TypeRegistry::TypeId(std::numeric_limits::max()), + type_info); + EXPECT_THAT(wrong_foo_id.message(), + HasSubstr("Type name foo already registered with type id")); + + // Registered type has a correct type info. + TF_ASSERT_OK_AND_ASSIGN(TypeRegistry::TypeInfo foo_info, + TypeRegistry::GetExternalTypeInfo(foo_id)); + EXPECT_EQ(foo_info.deleter, type_info.deleter); + + // It's ok to register a new type with a user-provided type id. + auto bar_id = TypeRegistry::TypeId(std::numeric_limits::max()); + TF_ASSERT_OK(TypeRegistry::RegisterExternalTypeId( + "bar", TypeRegistry::TypeId(std::numeric_limits::max()), + type_info)); + + // And a new type has a correct type info. + TF_ASSERT_OK_AND_ASSIGN(TypeRegistry::TypeInfo bar_info, + TypeRegistry::GetExternalTypeInfo(bar_id)); + EXPECT_EQ(bar_info.deleter, type_info.deleter); +} + +TEST(TypeRegistryTest, RegisterInternalTypeId) { + auto int32_type_id = TypeRegistry::GetTypeId(); + auto int64_type_id = TypeRegistry::GetTypeId(); + EXPECT_NE(int32_type_id, int64_type_id); +} + +TEST(TypeRegistryTest, InternalTypeInfo) { + int32_t* ptr = new int32_t{42}; + + TypeRegistry::TypeInfo type_info = TypeRegistry::GetTypeInfo(); + type_info.deleter(ptr); +} + +} // namespace +} // namespace xla::ffi diff --git a/third_party/xla/xla/pjrt/c/BUILD b/third_party/xla/xla/pjrt/c/BUILD index 4f997f24b36..2f7528c82cc 100644 --- a/third_party/xla/xla/pjrt/c/BUILD +++ b/third_party/xla/xla/pjrt/c/BUILD @@ -82,9 +82,10 @@ cc_library( ":pjrt_c_api_hdrs", ":pjrt_c_api_helpers", ":pjrt_c_api_wrapper_impl", + "//xla/ffi", "//xla/ffi:execution_context", "//xla/ffi:ffi_api", - "//xla/ffi:type_id_registry", + "//xla/ffi:type_registry", "//xla/ffi/api:c_api", "//xla/ffi/api:ffi", "@com_google_absl//absl/status", @@ -567,7 +568,7 @@ xla_test( "//xla/client:client_library", "//xla/ffi:execution_context", "//xla/ffi:ffi_api", - "//xla/ffi:type_id_registry", + "//xla/ffi:type_registry", "//xla/ffi/api:ffi", "//xla/pjrt:pjrt_common", "//xla/pjrt:pjrt_compiler", diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_extension.h index d3bd5c4fc99..2b3e266fbd3 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_extension.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_extension.h @@ -30,7 +30,24 @@ extern "C" { // and GPU backends it gives access to the XLA FFI internals. // // See: https://en.wikipedia.org/wiki/Foreign_function_interface -#define PJRT_API_FFI_EXTENSION_VERSION 2 +#define PJRT_API_FFI_EXTENSION_VERSION 3 + +struct PJRT_FFI_Type_Info { + void (*deleter)(void* object); + void (*serialize)(); // placeholder for future use + void (*deserialize)(); // placeholder for future use +}; + +struct PJRT_FFI_Type_Register_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + + const char* type_name; + size_t type_name_size; + int64_t type_id; // in-out + PJRT_FFI_Type_Info* type_info; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Type_Register_Args, type_info); struct PJRT_FFI_TypeID_Register_Args { size_t struct_size; @@ -46,6 +63,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_TypeID_Register_Args, type_id); // XLA will assign a unique type id to it and return via out argument, otherwise // it will verify that user-provided type id matches previously registered type // id for the given type name. +typedef PJRT_Error* PJRT_FFI_Type_Register(PJRT_FFI_Type_Register_Args* args); typedef PJRT_Error* PJRT_FFI_TypeID_Register( PJRT_FFI_TypeID_Register_Args* args); @@ -94,8 +112,9 @@ typedef struct PJRT_FFI_Extension { PJRT_FFI_TypeID_Register* type_id_register; PJRT_FFI_UserData_Add* user_data_add; PJRT_FFI_Register_Handler* register_handler; + PJRT_FFI_Type_Register* type_register; } PJRT_FFI; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, register_handler); +PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, type_register); #ifdef __cplusplus } diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_internal.cc index 9a7a4c810ff..e65b6f1d6bc 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_internal.cc @@ -19,8 +19,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/execution_context.h" +#include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" -#include "xla/ffi/type_id_registry.h" +#include "xla/ffi/type_registry.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_ffi_extension.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" @@ -35,20 +36,48 @@ static PJRT_Error* PJRT_FFI_TypeID_Register( PJRT_FFI_TypeID_Register_Args_STRUCT_SIZE, args->struct_size)); absl::string_view type_name(args->type_name, args->type_name_size); - xla::ffi::TypeIdRegistry::TypeId type_id(args->type_id); + xla::ffi::TypeRegistry::TypeId type_id(args->type_id); - if (type_id == xla::ffi::TypeIdRegistry::kUnknownTypeId) { + if (type_id == xla::ffi::TypeRegistry::kUnknownTypeId) { // If type_id is unknown, we are registering a new type and XLA will assign // a unique type id to it. PJRT_ASSIGN_OR_RETURN( auto assigned_type_id, - xla::ffi::TypeIdRegistry::AssignExternalTypeId(type_name)); + xla::ffi::TypeRegistry::AssignExternalTypeId(type_name, {})); args->type_id = assigned_type_id.value(); } else { // If type_id is set, we are relying on the caller-provided unique type id. PJRT_RETURN_IF_ERROR( - xla::ffi::TypeIdRegistry::RegisterExternalTypeId(type_name, type_id)); + xla::ffi::TypeRegistry::RegisterExternalTypeId(type_name, type_id, {})); + } + + return nullptr; +} + +static PJRT_Error* PJRT_FFI_Type_Register(PJRT_FFI_Type_Register_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_FFI_Type_Register_Args", PJRT_FFI_Type_Register_Args_STRUCT_SIZE, + args->struct_size)); + + absl::string_view type_name(args->type_name, args->type_name_size); + xla::ffi::TypeRegistry::TypeId type_id(args->type_id); + xla::ffi::TypeRegistry::TypeInfo type_info = { + args->type_info->deleter, + }; + + if (type_id == xla::ffi::TypeRegistry::kUnknownTypeId) { + // If type_id is unknown, we are registering a new type and XLA will assign + // a unique type id to it. + PJRT_ASSIGN_OR_RETURN( + auto assigned_type_id, + xla::ffi::TypeRegistry::AssignExternalTypeId(type_name, type_info)); + args->type_id = assigned_type_id.value(); + + } else { + // If type_id is set, we are relying on the caller-provided unique type id. + PJRT_RETURN_IF_ERROR(xla::ffi::TypeRegistry::RegisterExternalTypeId( + type_name, type_id, type_info)); } return nullptr; @@ -64,7 +93,7 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) { "PJRT FFI extension requires execute context to be not nullptr")}; } - xla::ffi::TypeIdRegistry::TypeId type_id(args->user_data.type_id); + xla::ffi::TypeRegistry::TypeId type_id(args->user_data.type_id); PJRT_RETURN_IF_ERROR(args->context->execute_context->ffi_context().Insert( type_id, args->user_data.data, args->user_data.deleter)); return nullptr; @@ -102,6 +131,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) { /*type_id_register=*/PJRT_FFI_TypeID_Register, /*user_data_add=*/PJRT_FFI_UserData_Add, /*register_handler=*/PJRT_FFI_Register_Handler, + /*type_register=*/PJRT_FFI_Type_Register, }; } diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc index f032c3eb87c..1df28573a11 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -44,7 +44,7 @@ limitations under the License. #include "xla/ffi/api/ffi.h" #include "xla/ffi/execution_context.h" #include "xla/ffi/ffi_api.h" -#include "xla/ffi/type_id_registry.h" +#include "xla/ffi/type_registry.h" #include "xla/future.h" #include "xla/literal.h" #include "xla/literal_util.h" @@ -366,7 +366,7 @@ TEST_F(PjrtCApiGpuTest, CreateAndDestroyExecuteContext) { TF_ASSERT_OK_AND_ASSIGN( auto lookup_user_data, create_arg.context->execute_context->ffi_context().Lookup( - xla::ffi::TypeIdRegistry::TypeId(42))); + xla::ffi::TypeRegistry::TypeId(42))); EXPECT_EQ(lookup_user_data, &string_data); PJRT_ExecuteContext_Destroy_Args destroy_args; diff --git a/third_party/xla/xla/python/pjrt_ifrt/BUILD b/third_party/xla/xla/python/pjrt_ifrt/BUILD index 91b5ddb2dc8..b8f2712af01 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/BUILD +++ b/third_party/xla/xla/python/pjrt_ifrt/BUILD @@ -324,7 +324,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/ffi:execution_context", - "//xla/ffi:type_id_registry", + "//xla/ffi:type_registry", "//xla/hlo/ir:hlo", "//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "//xla/pjrt:host_callback", diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc index 7502c7693ce..4f4905fb5ba 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc @@ -35,7 +35,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "xla/ffi/execution_context.h" -#include "xla/ffi/type_id_registry.h" +#include "xla/ffi/type_registry.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/layout.h" @@ -768,8 +768,7 @@ PjRtLoadedExecutable::Execute(absl::Span args, } ffi_callbacks->callbacks = callbacks->data(); ffi_callbacks->num_callbacks = callbacks->size(); - auto type_id = xla::ffi::TypeIdRegistry::TypeId( - xla::FfiLoadedHostCallbacks::id.type_id); + ffi::TypeRegistry::TypeId type_id(FfiLoadedHostCallbacks::id.type_id); CHECK_OK(context->ffi_context().Insert(type_id, ffi_callbacks.get())); opts.context = context.get(); } diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/python_hlo_runner.cc b/third_party/xla/xla/tools/multihost_hlo_runner/python_hlo_runner.cc index 72051a4ad4a..da6992f4537 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/python_hlo_runner.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/python_hlo_runner.cc @@ -351,7 +351,7 @@ NB_MODULE(py_hlo_multihost_runner, m) { m.def("custom_call_targets", GetRegisteredCustomCallTargets, nb::arg("platform")); m.def( - "register_custom_type_id", + "register_custom_type", [](absl::string_view type_name, nb::object type_id) { xla::ThrowIfError(RegisterCustomTypeId(type_name, type_id)); },