[xla:ffi] Add TypeRegistry::TypeInfo to be able to register functions to manipulate user-defined types

PiperOrigin-RevId: 820811829
This commit is contained in:
Eugene Zhulenev 2025-10-17 13:34:40 -07:00 committed by TensorFlower Gardener
parent 46522b8a20
commit d531cdce30
26 changed files with 479 additions and 265 deletions

View File

@ -62,7 +62,7 @@ cc_library(
srcs = ["execution_context.cc"],
hdrs = ["execution_context.h"],
deps = [
":type_id_registry",
":type_registry",
"//xla:util",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:logging",
@ -79,7 +79,7 @@ xla_cc_test(
srcs = ["execution_context_test.cc"],
deps = [
":execution_context",
":type_id_registry",
":type_registry",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
@ -94,13 +94,14 @@ cc_library(
srcs = ["execution_state.cc"],
hdrs = ["execution_state.h"],
deps = [
":type_id_registry",
":type_registry",
"//xla:util",
"//xla/tsl/platform:statusor",
"//xla/tsl/util:safe_reinterpret_cast",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:statusor",
],
)
@ -109,11 +110,12 @@ xla_cc_test(
srcs = ["execution_state_test.cc"],
deps = [
":execution_state",
":type_registry",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
],
)
@ -124,7 +126,7 @@ cc_library(
":api",
":execution_context",
":execution_state",
":type_id_registry",
":type_registry",
"//xla:executable_run_options",
"//xla:shape_util",
"//xla:types",
@ -141,6 +143,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
@ -160,7 +163,7 @@ cc_library(
":call_frame",
":execution_context",
":execution_state",
":type_id_registry",
":type_registry",
"//xla:executable_run_options",
"//xla:util",
"//xla/ffi/api:c_api",
@ -218,7 +221,7 @@ xla_cc_test(
":execution_state",
":ffi",
":ffi_api",
":type_id_registry",
":type_registry",
"//xla:executable_run_options",
"//xla:xla_data_proto_cc",
"//xla/ffi/api:c_api",
@ -243,27 +246,31 @@ xla_cc_test(
)
cc_library(
name = "type_id_registry",
srcs = ["type_id_registry.cc"],
hdrs = ["type_id_registry.h"],
name = "type_registry",
srcs = ["type_registry.cc"],
hdrs = ["type_registry.h"],
deps = [
"//xla:util",
"//xla/tsl/lib/gtl:int_type",
"//xla/tsl/util:safe_reinterpret_cast",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/synchronization",
],
)
xla_cc_test(
name = "type_id_registry_test",
srcs = ["type_id_registry_test.cc"],
name = "type_registry_test",
srcs = ["type_registry_test.cc"],
deps = [
":type_id_registry",
":type_registry",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",

View File

@ -89,7 +89,7 @@ xla_cc_test(
"//xla/ffi:execution_context",
"//xla/ffi:execution_state",
"//xla/ffi:ffi_api",
"//xla/ffi:type_id_registry",
"//xla/ffi:type_registry",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:device_memory_allocator",
"//xla/tsl/concurrency:async_value",

View File

@ -348,11 +348,13 @@ inline XLA_FFI_Error* Ffi::RegisterTypeId(const XLA_FFI_Api* api,
std::string_view name,
XLA_FFI_TypeId* type_id,
XLA_FFI_TypeInfo type_info) {
assert(type_id && "type_id must not be null");
XLA_FFI_TypeId_Register_Args args;
args.struct_size = XLA_FFI_TypeId_Register_Args_STRUCT_SIZE;
args.extension_start = nullptr;
args.name = XLA_FFI_ByteSpan{name.data(), name.size()};
args.type_id = type_id;
args.type_info = &type_info;
return api->XLA_FFI_TypeId_Register(&args);
}

View File

@ -67,7 +67,7 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Extension_Base, next);
// * Deleting a method or argument
// * Changing the type of an argument
// * Rearranging fields in the XLA_FFI_Api or argument structs
#define XLA_FFI_API_MAJOR 0
#define XLA_FFI_API_MAJOR 1
// Incremented when the interface is updated in a way that is potentially
// ABI-compatible with older versions, if supported by the caller and/or
@ -82,7 +82,7 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Extension_Base, next);
// Minor changes include:
// * Adding a new field to the XLA_FFI_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
#define XLA_FFI_API_MINOR 1
#define XLA_FFI_API_MINOR 0
struct XLA_FFI_Api_Version {
size_t struct_size;
@ -491,6 +491,7 @@ struct XLA_FFI_TypeId_Register_Args {
XLA_FFI_ByteSpan name;
XLA_FFI_TypeId* type_id; // in-out
XLA_FFI_TypeInfo* type_info;
};
XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_TypeId_Register_Args, type_id);

View File

@ -437,7 +437,7 @@ class CountDownPromise {
assert(state_->count.load() >= count && "Invalid count down value");
if (XLA_FFI_PREDICT_FALSE(!error.success())) {
const std::lock_guard<std::mutex> lock(state_->mutex);
std::lock_guard<std::mutex> lock(state_->mutex); // NOLINT
state_->is_error.store(true, std::memory_order_release);
state_->error = error;
}
@ -448,7 +448,7 @@ class CountDownPromise {
bool is_error = state_->is_error.load(std::memory_order_acquire);
if (XLA_FFI_PREDICT_FALSE(is_error)) {
auto take_error = [&] {
const std::lock_guard<std::mutex> lock(state_->mutex);
std::lock_guard<std::mutex> lock(state_->mutex); // NOLINT
return state_->error;
};
state_->promise.SetError(take_error());
@ -476,7 +476,7 @@ class CountDownPromise {
std::atomic<int64_t> count;
std::atomic<bool> is_error;
std::mutex mutex;
std::mutex mutex; // NOLINT
Error error;
};

View File

@ -39,7 +39,7 @@ limitations under the License.
#include "xla/ffi/execution_context.h"
#include "xla/ffi/execution_state.h"
#include "xla/ffi/ffi_api.h"
#include "xla/ffi/type_id_registry.h"
#include "xla/ffi/type_registry.h"
#include "xla/primitive_util.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/device_memory_allocator.h"
@ -1220,9 +1220,9 @@ TEST(FfiTest, UserData) {
ExecutionContext execution_context;
TF_ASSERT_OK(execution_context.Insert(
TypeIdRegistry::TypeId(MyDataWithAutoTypeId::id.type_id), &data0));
TypeRegistry::TypeId(MyDataWithAutoTypeId::id.type_id), &data0));
TF_ASSERT_OK(execution_context.Insert(
TypeIdRegistry::TypeId(MyDataWithExplicitTypeId::id.type_id), &data1));
TypeRegistry::TypeId(MyDataWithExplicitTypeId::id.type_id), &data1));
CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
auto call_frame = builder.Build();

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "absl/functional/function_ref.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "xla/ffi/type_id_registry.h"
#include "xla/ffi/type_registry.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/statusor.h"
@ -45,7 +45,7 @@ namespace xla::ffi {
// unique between separate calls to XLA execute.
class ExecutionContext {
public:
using TypeId = TypeIdRegistry::TypeId;
using TypeId = TypeRegistry::TypeId;
template <typename T>
using Deleter = std::function<void(T*)>;
@ -67,7 +67,7 @@ class ExecutionContext {
template <typename T>
absl::StatusOr<T*> Lookup() const {
TF_ASSIGN_OR_RETURN(auto user_data,
LookupUserData(TypeIdRegistry::GetTypeId<T>()));
LookupUserData(TypeRegistry::GetTypeId<T>()));
return static_cast<T*>(user_data->data());
}
@ -110,7 +110,7 @@ class ExecutionContext {
template <typename T>
absl::Status ExecutionContext::Insert(T* data, Deleter<T> deleter) {
return InsertUserData(TypeIdRegistry::GetTypeId<T>(),
return InsertUserData(TypeRegistry::GetTypeId<T>(),
std::make_unique<UserData>(
data, [deleter = std::move(deleter)](void* data) {
if (deleter) deleter(static_cast<T*>(data));
@ -119,7 +119,7 @@ absl::Status ExecutionContext::Insert(T* data, Deleter<T> deleter) {
template <typename T, typename... Args>
absl::Status ExecutionContext::Emplace(Args&&... args) {
return InsertUserData(TypeIdRegistry::GetTypeId<T>(),
return InsertUserData(TypeRegistry::GetTypeId<T>(),
std::make_unique<UserData>(
new T(std::forward<Args>(args)...),
[](void* data) { delete static_cast<T*>(data); }));

View File

@ -20,7 +20,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "xla/ffi/type_id_registry.h"
#include "xla/ffi/type_registry.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"
@ -62,8 +62,9 @@ TEST(ExecutionContextTest, InsertUserOwned) {
}
TEST(ExecutionContextTest, InsertUserOwnedWithTypeId) {
TF_ASSERT_OK_AND_ASSIGN(TypeIdRegistry::TypeId type_id,
TypeIdRegistry::AssignExternalTypeId("I32UserData"));
TF_ASSERT_OK_AND_ASSIGN(TypeRegistry::TypeId type_id,
TypeRegistry::AssignExternalTypeId(
"I32UserData", TypeRegistry::TypeInfo{}));
I32UserData user_data(42);

View File

@ -15,38 +15,54 @@ limitations under the License.
#include "xla/ffi/execution_state.h"
#include <utility>
#include "absl/base/attributes.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "xla/ffi/type_id_registry.h"
#include "xla/ffi/type_registry.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "tsl/platform/logging.h"
namespace xla::ffi {
ExecutionState::ExecutionState()
: type_id_(TypeIdRegistry::kUnknownTypeId),
state_(nullptr),
deleter_(nullptr) {}
: type_id_(TypeRegistry::kUnknownTypeId), state_(nullptr) {}
ExecutionState::~ExecutionState() {
if (deleter_) deleter_(state_);
if (type_info_.deleter) {
type_info_.deleter(state_);
}
}
absl::Status ExecutionState::Set(TypeId type_id, void* state,
Deleter<void> deleter) {
DCHECK(state && deleter) << "State and deleter must not be null";
absl::Status ExecutionState::Set(TypeId type_id, void* state) {
TF_ASSIGN_OR_RETURN(auto type_info,
TypeRegistry::GetExternalTypeInfo(type_id));
if (type_info.deleter == nullptr) {
return InvalidArgument(
"Type id %d does not have a registered type info with a deleter",
type_id.value());
}
return Set(type_id, type_info, state);
}
if (type_id_ != TypeIdRegistry::kUnknownTypeId) {
ABSL_DEPRECATED("FFI users must rely in TypeInfo registration")
absl::Status ExecutionState::Set(TypeId type_id, void* state,
void (*deleter)(void*)) {
return Set(type_id, TypeInfo{deleter}, state);
}
absl::Status ExecutionState::Set(TypeId type_id, TypeInfo type_info,
void* state) {
DCHECK(state && type_info.deleter) << "State and deleter must not be null";
if (type_id_ != TypeRegistry::kUnknownTypeId) {
return FailedPrecondition("State is already set with a type id %d",
type_id_.value());
}
type_id_ = type_id;
type_info_ = type_info;
state_ = state;
deleter_ = std::move(deleter);
return absl::OkStatus();
}
@ -54,7 +70,7 @@ absl::Status ExecutionState::Set(TypeId type_id, void* state,
// Returns opaque state of the given type id. If set state type id does not
// match the requested one, returns an error.
absl::StatusOr<void*> ExecutionState::Get(TypeId type_id) const {
if (type_id_ == TypeIdRegistry::kUnknownTypeId) {
if (type_id_ == TypeRegistry::kUnknownTypeId) {
return NotFound("State is not set");
}
@ -68,7 +84,7 @@ absl::StatusOr<void*> ExecutionState::Get(TypeId type_id) const {
}
bool ExecutionState::IsSet() const {
return type_id_ != TypeIdRegistry::kUnknownTypeId;
return type_id_ != TypeRegistry::kUnknownTypeId;
}
} // namespace xla::ffi

View File

@ -16,13 +16,13 @@ limitations under the License.
#ifndef XLA_FFI_EXECUTION_STATE_H_
#define XLA_FFI_EXECUTION_STATE_H_
#include <functional>
#include <memory>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "xla/ffi/type_id_registry.h"
#include "tsl/platform/statusor.h"
#include "xla/ffi/type_registry.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/util/safe_reinterpret_cast.h"
namespace xla::ffi {
@ -41,10 +41,8 @@ namespace xla::ffi {
//
class ExecutionState {
public:
using TypeId = TypeIdRegistry::TypeId;
template <typename T>
using Deleter = std::function<void(T*)>;
using TypeId = TypeRegistry::TypeId;
using TypeInfo = TypeRegistry::TypeInfo;
ExecutionState();
~ExecutionState();
@ -52,9 +50,13 @@ class ExecutionState {
ExecutionState(const ExecutionState&) = delete;
ExecutionState& operator=(const ExecutionState&) = delete;
// Sets opaque state with a given type id and deleter. Returns an error if
// state is already set.
absl::Status Set(TypeId type_id, void* state, Deleter<void> deleter);
// Sets opaque state with a given type id. Returns an error if state is
// already set, or if type id is not supported as a state.
absl::Status Set(TypeId type_id, void* state);
// Sets opaque state with a given type id and custom deleter. Returns an error
// if state is already set, or if type id is not supported as a state.
absl::Status Set(TypeId type_id, void* state, void (*deleter)(void*));
// Returns opaque state of the given type id. If set state type id does not
// match the requested one, returns an error.
@ -73,21 +75,23 @@ class ExecutionState {
bool IsSet() const;
private:
absl::Status Set(TypeId type_id, TypeInfo type_info, void* state);
TypeId type_id_;
TypeInfo type_info_;
void* state_;
Deleter<void> deleter_;
};
template <typename T>
absl::Status ExecutionState::Set(std::unique_ptr<T> state) {
return Set(TypeIdRegistry::GetTypeId<T>(), state.release(),
[](void* state) { delete reinterpret_cast<T*>(state); });
return Set(TypeRegistry::GetTypeId<T>(), TypeRegistry::GetTypeInfo<T>(),
state.release());
}
template <typename T>
absl::StatusOr<T*> ExecutionState::Get() const {
TF_ASSIGN_OR_RETURN(void* state, Get(TypeIdRegistry::GetTypeId<T>()));
return reinterpret_cast<T*>(state);
TF_ASSIGN_OR_RETURN(void* state, Get(TypeRegistry::GetTypeId<T>()));
return tsl::safe_reinterpret_cast<T*>(state);
}
} // namespace xla::ffi

View File

@ -20,9 +20,10 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "xla/ffi/type_registry.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"
namespace xla::ffi {
@ -30,7 +31,7 @@ using TypeId = ExecutionState::TypeId;
using ::testing::HasSubstr;
TEST(ExecutionStateTest, SetAndGet) {
TEST(ExecutionStateTest, SetAndGetForInternalType) {
ExecutionState state;
EXPECT_FALSE(state.IsSet());
@ -52,4 +53,34 @@ TEST(ExecutionStateTest, SetAndGet) {
EXPECT_EQ(*data, 42);
}
TEST(ExecutionStateTest, SetAndGetForExternalType) {
ExecutionState state;
EXPECT_FALSE(state.IsSet());
{ // Empty state returns an error from Get().
auto data = state.Get(TypeId(1));
EXPECT_THAT(data.status().message(), HasSubstr("State is not set"));
}
{ // Empty state returns an error from Get().
auto data = state.Get<int32_t>();
EXPECT_THAT(data.status().message(), HasSubstr("State is not set"));
}
TypeRegistry::TypeInfo type_info = {
[](void* ptr) { delete static_cast<int32_t*>(ptr); }};
TF_ASSERT_OK_AND_ASSIGN(
TypeRegistry::TypeId type_id,
TypeRegistry::AssignExternalTypeId("int32_t", type_info));
int32_t* value = new int32_t(42);
// Once set, state can be retrieved.
TF_ASSERT_OK(state.Set(type_id, value));
EXPECT_TRUE(state.IsSet());
TF_ASSERT_OK_AND_ASSIGN(void* data, state.Get(type_id));
EXPECT_EQ(data, value);
}
} // namespace xla::ffi

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "absl/base/attributes.h"
#include "absl/base/nullability.h"
#include "absl/base/optimization.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
@ -46,7 +47,7 @@ limitations under the License.
#include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep
#include "xla/ffi/execution_context.h"
#include "xla/ffi/execution_state.h"
#include "xla/ffi/type_id_registry.h"
#include "xla/ffi/type_registry.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/primitive_util.h"
#include "xla/stream_executor/device_memory.h"
@ -724,7 +725,7 @@ template <typename T>
struct ResultEncoding<ExecutionStage::kInstantiate,
absl::StatusOr<std::unique_ptr<T>>> {
static XLA_FFI_TypeId state_type_id() {
return XLA_FFI_TypeId{TypeIdRegistry::GetTypeId<T>().value()};
return XLA_FFI_TypeId{TypeRegistry::GetTypeId<T>().value()};
}
static XLA_FFI_Error* Encode(const XLA_FFI_Api* api,

View File

@ -43,7 +43,7 @@ limitations under the License.
#include "xla/ffi/call_frame.h"
#include "xla/ffi/execution_context.h"
#include "xla/ffi/execution_state.h"
#include "xla/ffi/type_id_registry.h"
#include "xla/ffi/type_registry.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/service/platform_util.h"
#include "xla/stream_executor/device_memory.h"
@ -659,12 +659,14 @@ static XLA_FFI_Error* XLA_FFI_TypeId_Register(
XLA_FFI_ExecutionContext_Get_Args_STRUCT_SIZE, args->struct_size));
absl::string_view type_name(args->name.ptr, args->name.len);
TypeIdRegistry::TypeId type_id(args->type_id->type_id);
TypeRegistry::TypeId type_id(args->type_id->type_id);
TypeRegistry::TypeInfo type_info = {args->type_info->deleter};
// If type_id is unknown, we are registering a new type and XLA will assign a
// unique type id to it.
if (type_id == TypeIdRegistry::kUnknownTypeId) {
auto assigned_type_id = TypeIdRegistry::AssignExternalTypeId(type_name);
if (type_id == TypeRegistry::kUnknownTypeId) {
auto assigned_type_id =
TypeRegistry::AssignExternalTypeId(type_name, type_info);
if (!assigned_type_id.ok()) {
return new XLA_FFI_Error{std::move(assigned_type_id).status()};
}
@ -674,9 +676,10 @@ static XLA_FFI_Error* XLA_FFI_TypeId_Register(
}
// If type_id is set, we are relying on the caller-provided unique type id.
if (auto status = TypeIdRegistry::RegisterExternalTypeId(type_name, type_id);
!status.ok()) {
return new XLA_FFI_Error{std::move(status)};
auto registered_type_id =
TypeRegistry::RegisterExternalTypeId(type_name, type_id, type_info);
if (!registered_type_id.ok()) {
return new XLA_FFI_Error{std::move(registered_type_id)};
}
return nullptr;
@ -690,7 +693,7 @@ static XLA_FFI_Error* XLA_FFI_ExecutionContext_Get(
DCHECK(args->ctx->execution_context) << "ExecutionContext must be set";
auto user_data = args->ctx->execution_context->Lookup(
TypeIdRegistry::TypeId(args->type_id->type_id));
TypeRegistry::TypeId(args->type_id->type_id));
if (!user_data.ok()) {
return new XLA_FFI_Error{std::move(user_data).status()};
}
@ -705,9 +708,16 @@ static XLA_FFI_Error* XLA_FFI_State_Set(XLA_FFI_State_Set_Args* args) {
args->struct_size));
DCHECK(args->ctx->execution_state) << "ExecutionState must be set";
absl::Status status = args->ctx->execution_state->Set(
TypeIdRegistry::TypeId(args->type_id->type_id), args->state,
[deleter = args->deleter](void* state) { deleter(state); });
absl::Status status;
if (args->deleter == nullptr) {
status = args->ctx->execution_state->Set(
TypeRegistry::TypeId(args->type_id->type_id), args->state);
} else {
status = args->ctx->execution_state->Set(
TypeRegistry::TypeId(args->type_id->type_id), args->state,
args->deleter);
}
if (!status.ok()) {
return new XLA_FFI_Error{std::move(status)};
@ -723,7 +733,7 @@ static XLA_FFI_Error* XLA_FFI_State_Get(XLA_FFI_State_Get_Args* args) {
DCHECK(args->ctx->execution_state) << "ExecutionState must be set";
absl::StatusOr<void*> state = args->ctx->execution_state->Get(
TypeIdRegistry::TypeId(args->type_id->type_id));
TypeRegistry::TypeId(args->type_id->type_id));
if (!state.ok()) {
return new XLA_FFI_Error{std::move(state).status()};
}

View File

@ -39,7 +39,7 @@ limitations under the License.
#include "xla/ffi/execution_context.h"
#include "xla/ffi/execution_state.h"
#include "xla/ffi/ffi_api.h"
#include "xla/ffi/type_id_registry.h"
#include "xla/ffi/type_registry.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/stream.h"
#include "xla/tsl/concurrency/async_value_ref.h"
@ -1102,7 +1102,7 @@ TEST(FfiTest, Metadata) {
EXPECT_EQ(metadata.api_version.major_version, XLA_FFI_API_MAJOR);
EXPECT_EQ(metadata.api_version.minor_version, XLA_FFI_API_MINOR);
TypeIdRegistry::TypeId type_id = TypeIdRegistry::GetTypeId<StrState>();
TypeRegistry::TypeId type_id = TypeRegistry::GetTypeId<StrState>();
EXPECT_EQ(metadata.state_type_id.type_id, type_id);
}

View File

@ -1,96 +0,0 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/ffi/type_id_registry.h"
#include <atomic>
#include <cstdint>
#include <string>
#include "absl/algorithm/container.h"
#include "absl/base/attributes.h"
#include "absl/base/const_init.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "xla/util.h"
namespace xla::ffi {
ABSL_CONST_INIT absl::Mutex type_registry_mutex(absl::kConstInit);
using ExternalTypeIdRegistry =
absl::flat_hash_map<std::string, TypeIdRegistry::TypeId>;
static ExternalTypeIdRegistry& StaticExternalTypeIdRegistry() {
static auto* const registry = new ExternalTypeIdRegistry();
return *registry;
}
TypeIdRegistry::TypeId TypeIdRegistry::GetNextInternalTypeId() {
static auto* counter = new std::atomic<int64_t>(1);
return TypeId(counter->fetch_add(1));
}
TypeIdRegistry::TypeId TypeIdRegistry::GetNextExternalTypeId() {
static auto* counter = new std::atomic<int64_t>(1);
return TypeId(counter->fetch_add(1));
}
absl::StatusOr<TypeIdRegistry::TypeId> TypeIdRegistry::AssignExternalTypeId(
absl::string_view name) {
absl::MutexLock lock(type_registry_mutex);
auto& registry = StaticExternalTypeIdRegistry();
// Try to emplace with unknow type id and fill it with real type id only if we
// successfully acquired an entry for a given name.
auto emplaced = registry.emplace(name, kUnknownTypeId);
if (!emplaced.second) {
return Internal("Type name %s already registered with type id %d", name,
emplaced.first->second.value());
}
// Returns true if the registry contains an entry with a given type id.
auto type_id_is_in_use = [&registry](TypeId type_id) {
return absl::c_any_of(registry,
[&](const auto& e) { return e.second == type_id; });
};
// Create a new type id that is not already in use.
TypeId type_id = GetNextExternalTypeId();
while (type_id_is_in_use(type_id)) {
type_id = GetNextExternalTypeId();
}
return emplaced.first->second = type_id;
}
absl::Status TypeIdRegistry::RegisterExternalTypeId(absl::string_view name,
TypeId type_id) {
absl::MutexLock lock(type_registry_mutex);
auto& registry = StaticExternalTypeIdRegistry();
auto emplaced = registry.emplace(name, type_id);
if (!emplaced.second && emplaced.first->second != type_id) {
return Internal("Type name %s already registered with type id %d vs %d)",
name, emplaced.first->second.value(), type_id.value());
}
return absl::OkStatus();
}
} // namespace xla::ffi

View File

@ -1,63 +0,0 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/ffi/type_id_registry.h"
#include <cstdint>
#include <limits>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"
namespace xla::ffi {
namespace {
using ::testing::HasSubstr;
TEST(TypeIdRegistryTest, RegisterExternalTypeId) {
TF_ASSERT_OK_AND_ASSIGN(auto type_id,
TypeIdRegistry::AssignExternalTypeId("foo"));
EXPECT_GE(type_id.value(), 0);
auto duplicate_type_id = TypeIdRegistry::AssignExternalTypeId("foo");
EXPECT_THAT(duplicate_type_id.status().message(),
HasSubstr("Type name foo already registered with type id"));
// It's ok to register the same type with same type id.
TF_ASSERT_OK(TypeIdRegistry::RegisterExternalTypeId("foo", type_id));
// It's an error to register the same type with a different type id.
auto wrong_type_id = TypeIdRegistry::RegisterExternalTypeId(
"foo", TypeIdRegistry::TypeId(std::numeric_limits<int64_t>::max()));
EXPECT_THAT(wrong_type_id.message(),
HasSubstr("Type name foo already registered with type id"));
// It's ok to register a new type with a user-provided type id.
TF_ASSERT_OK(TypeIdRegistry::RegisterExternalTypeId(
"bar", TypeIdRegistry::TypeId(std::numeric_limits<int64_t>::max())));
}
TEST(TypeIdRegistryTest, RegisterInternalTypeId) {
auto int32_type_id = TypeIdRegistry::GetTypeId<int32_t>();
auto int64_type_id = TypeIdRegistry::GetTypeId<int64_t>();
EXPECT_NE(int32_type_id, int64_type_id);
}
} // namespace
} // namespace xla::ffi

134
third_party/xla/xla/ffi/type_registry.cc vendored Normal file
View File

@ -0,0 +1,134 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/ffi/type_registry.h"
#include <atomic>
#include <cstdint>
#include <string>
#include "absl/algorithm/container.h"
#include "absl/base/attributes.h"
#include "absl/base/const_init.h"
#include "absl/base/no_destructor.h"
#include "absl/container/flat_hash_map.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "xla/util.h"
namespace xla::ffi {
namespace {
struct TypeRegistration {
TypeRegistry::TypeId type_id;
TypeRegistry::TypeInfo type_info;
};
using ExternalTypeRegistry = absl::flat_hash_map<std::string, TypeRegistration>;
} // namespace
ABSL_CONST_INIT absl::Mutex type_registry_mutex(absl::kConstInit);
static ExternalTypeRegistry& StaticExternalTypeRegistry() {
static absl::NoDestructor<ExternalTypeRegistry> registry;
return *registry;
}
TypeRegistry::TypeId TypeRegistry::GetNextInternalTypeId() {
static auto* counter = new std::atomic<int64_t>(1);
return TypeId(counter->fetch_add(1));
}
TypeRegistry::TypeId TypeRegistry::GetNextExternalTypeId() {
static auto* counter = new std::atomic<int64_t>(1);
return TypeId(counter->fetch_add(1));
}
absl::StatusOr<TypeRegistry::TypeId> TypeRegistry::AssignExternalTypeId(
absl::string_view name, TypeInfo type_info) {
VLOG(3) << absl::StrFormat("Assign external type id: name=%s", name);
absl::MutexLock lock(type_registry_mutex);
auto& registry = StaticExternalTypeRegistry();
// Try to emplace with unknow type id and fill it with real type id only if we
// successfully acquired an entry for a given name.
auto emplaced =
registry.emplace(name, TypeRegistration{kUnknownTypeId, type_info});
if (!emplaced.second) {
return Internal("Type name %s already registered with type id %d", name,
emplaced.first->second.type_id.value());
}
// Returns true if the registry contains an entry with a given type id.
auto type_id_is_in_use = [&registry](TypeId type_id) {
return absl::c_any_of(
registry, [&](const auto& e) { return e.second.type_id == type_id; });
};
// Create a new type id that is not already in use.
TypeId type_id = GetNextExternalTypeId();
while (type_id_is_in_use(type_id)) {
type_id = GetNextExternalTypeId();
}
VLOG(3) << absl::StrFormat("Assigned external type id: name=%s type_id=%d",
name, type_id.value());
return emplaced.first->second.type_id = type_id;
}
absl::Status TypeRegistry::RegisterExternalTypeId(absl::string_view name,
TypeId type_id,
TypeInfo type_info) {
VLOG(3) << absl::StrFormat("Register external type id: name=%s type_id=%d",
name, type_id.value());
absl::MutexLock lock(type_registry_mutex);
auto& registry = StaticExternalTypeRegistry();
auto emplaced = registry.emplace(name, TypeRegistration{type_id, type_info});
if (!emplaced.second && emplaced.first->second.type_id != type_id) {
return Internal("Type name %s already registered with type id %d vs %d)",
name, emplaced.first->second.type_id.value(),
type_id.value());
}
return absl::OkStatus();
}
absl::StatusOr<TypeRegistry::TypeInfo> TypeRegistry::GetExternalTypeInfo(
TypeId type_id) {
absl::MutexLock lock(type_registry_mutex);
auto& registry = StaticExternalTypeRegistry();
auto it = absl::c_find_if(registry, [&](const auto& kv) {
auto& [name, registration] = kv;
return registration.type_id == type_id;
});
if (it == registry.end()) {
return Internal("Type id %d is not registered with a static registry",
type_id.value());
}
return it->second.type_info;
}
} // namespace xla::ffi

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_FFI_TYPE_ID_REGISTRY_H_
#define XLA_FFI_TYPE_ID_REGISTRY_H_
#ifndef XLA_FFI_TYPE_REGISTRY_H_
#define XLA_FFI_TYPE_REGISTRY_H_
#include <cstdint>
@ -22,6 +22,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/tsl/lib/gtl/int_type.h"
#include "xla/tsl/util/safe_reinterpret_cast.h"
namespace xla::ffi {
@ -43,26 +44,50 @@ namespace xla::ffi {
//
// 2. Internal type id. When FFI handler defined in the same binary we rely
// on a global static registry to automatically assign type ids.
class TypeIdRegistry {
//
// TypeInfo defines a set of functions that allow XLA runtime to manipulate
// external types. For user data, that is forwarded to FFI handlers, they all
// can be `nullptr` as XLA runtime doesn't manage their lifetime. For stateful
// handlers, XLA runtime at least must know how to destroy the state when XLA
// executable is destroyed.
class TypeRegistry {
public:
// Unique (within a process) identifier for a type.
TSL_LIB_GTL_DEFINE_INT_TYPE(TypeId, int64_t);
static constexpr TypeId kUnknownTypeId = TypeId(0);
// Pointers to functions that allow XLA runtime to manipulate external types.
struct TypeInfo {
using Deleter = void (*)(void*);
Deleter deleter = nullptr;
};
// Assigns a unique type id to an external type with a given name. Returns an
// error if a type with a given name is already registered in the process.
static absl::StatusOr<TypeId> AssignExternalTypeId(absl::string_view name);
static absl::StatusOr<TypeId> AssignExternalTypeId(absl::string_view name,
TypeInfo type_info);
// Registers external type with a given name and type id. Type id is provided
// by the caller, and must be unique. Returns an error if a type with a given
// name is already registered with a different type id.
static absl::Status RegisterExternalTypeId(absl::string_view name,
TypeId type_id);
TypeId type_id,
TypeInfo type_info);
// Returns type info for a given external type id. Returns an error if type
// id is not registered.
static absl::StatusOr<TypeInfo> GetExternalTypeInfo(TypeId type_id);
// Returns a type id for a given type. For internal type ids only.
template <typename T>
static TypeId GetTypeId();
// Returns type info for a given type id. For internal type ids only.
template <typename T>
static TypeInfo GetTypeInfo();
private:
// We never mix external and internal type ids, so we can use different type
// id spaces to assign unique ids to each type.
@ -71,11 +96,18 @@ class TypeIdRegistry {
};
template <typename T>
TypeIdRegistry::TypeId TypeIdRegistry::GetTypeId() {
TypeRegistry::TypeId TypeRegistry::GetTypeId() {
static const TypeId id = GetNextInternalTypeId();
return id;
}
template <typename T>
TypeRegistry::TypeInfo TypeRegistry::GetTypeInfo() {
return TypeInfo{
[](void* state) { delete tsl::safe_reinterpret_cast<T*>(state); },
};
}
} // namespace xla::ffi
#endif // XLA_FFI_TYPE_ID_REGISTRY_H_
#endif // XLA_FFI_TYPE_REGISTRY_H_

View File

@ -0,0 +1,85 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/ffi/type_registry.h"
#include <cstdint>
#include <limits>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"
namespace xla::ffi {
namespace {
using ::testing::HasSubstr;
TEST(TypeRegistryTest, RegisterExternalTypeId) {
TypeRegistry::TypeInfo type_info = {+[](void* state) {}};
TF_ASSERT_OK_AND_ASSIGN(auto foo_id,
TypeRegistry::AssignExternalTypeId("foo", type_info));
EXPECT_GE(foo_id.value(), 0);
auto duplicate_foo_id = TypeRegistry::AssignExternalTypeId("foo", type_info);
EXPECT_THAT(duplicate_foo_id.status().message(),
HasSubstr("Type name foo already registered with type id"));
// It's ok to register the same type with same type id.
TF_ASSERT_OK(TypeRegistry::RegisterExternalTypeId("foo", foo_id, type_info));
// It's an error to register the same type with a different type id.
auto wrong_foo_id = TypeRegistry::RegisterExternalTypeId(
"foo", TypeRegistry::TypeId(std::numeric_limits<int64_t>::max()),
type_info);
EXPECT_THAT(wrong_foo_id.message(),
HasSubstr("Type name foo already registered with type id"));
// Registered type has a correct type info.
TF_ASSERT_OK_AND_ASSIGN(TypeRegistry::TypeInfo foo_info,
TypeRegistry::GetExternalTypeInfo(foo_id));
EXPECT_EQ(foo_info.deleter, type_info.deleter);
// It's ok to register a new type with a user-provided type id.
auto bar_id = TypeRegistry::TypeId(std::numeric_limits<int64_t>::max());
TF_ASSERT_OK(TypeRegistry::RegisterExternalTypeId(
"bar", TypeRegistry::TypeId(std::numeric_limits<int64_t>::max()),
type_info));
// And a new type has a correct type info.
TF_ASSERT_OK_AND_ASSIGN(TypeRegistry::TypeInfo bar_info,
TypeRegistry::GetExternalTypeInfo(bar_id));
EXPECT_EQ(bar_info.deleter, type_info.deleter);
}
TEST(TypeRegistryTest, RegisterInternalTypeId) {
auto int32_type_id = TypeRegistry::GetTypeId<int32_t>();
auto int64_type_id = TypeRegistry::GetTypeId<int64_t>();
EXPECT_NE(int32_type_id, int64_type_id);
}
TEST(TypeRegistryTest, InternalTypeInfo) {
int32_t* ptr = new int32_t{42};
TypeRegistry::TypeInfo type_info = TypeRegistry::GetTypeInfo<int32_t>();
type_info.deleter(ptr);
}
} // namespace
} // namespace xla::ffi

View File

@ -82,9 +82,10 @@ cc_library(
":pjrt_c_api_hdrs",
":pjrt_c_api_helpers",
":pjrt_c_api_wrapper_impl",
"//xla/ffi",
"//xla/ffi:execution_context",
"//xla/ffi:ffi_api",
"//xla/ffi:type_id_registry",
"//xla/ffi:type_registry",
"//xla/ffi/api:c_api",
"//xla/ffi/api:ffi",
"@com_google_absl//absl/status",
@ -567,7 +568,7 @@ xla_test(
"//xla/client:client_library",
"//xla/ffi:execution_context",
"//xla/ffi:ffi_api",
"//xla/ffi:type_id_registry",
"//xla/ffi:type_registry",
"//xla/ffi/api:ffi",
"//xla/pjrt:pjrt_common",
"//xla/pjrt:pjrt_compiler",

View File

@ -30,7 +30,24 @@ extern "C" {
// and GPU backends it gives access to the XLA FFI internals.
//
// See: https://en.wikipedia.org/wiki/Foreign_function_interface
#define PJRT_API_FFI_EXTENSION_VERSION 2
#define PJRT_API_FFI_EXTENSION_VERSION 3
struct PJRT_FFI_Type_Info {
void (*deleter)(void* object);
void (*serialize)(); // placeholder for future use
void (*deserialize)(); // placeholder for future use
};
struct PJRT_FFI_Type_Register_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
const char* type_name;
size_t type_name_size;
int64_t type_id; // in-out
PJRT_FFI_Type_Info* type_info;
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Type_Register_Args, type_info);
struct PJRT_FFI_TypeID_Register_Args {
size_t struct_size;
@ -46,6 +63,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_TypeID_Register_Args, type_id);
// XLA will assign a unique type id to it and return via out argument, otherwise
// it will verify that user-provided type id matches previously registered type
// id for the given type name.
typedef PJRT_Error* PJRT_FFI_Type_Register(PJRT_FFI_Type_Register_Args* args);
typedef PJRT_Error* PJRT_FFI_TypeID_Register(
PJRT_FFI_TypeID_Register_Args* args);
@ -94,8 +112,9 @@ typedef struct PJRT_FFI_Extension {
PJRT_FFI_TypeID_Register* type_id_register;
PJRT_FFI_UserData_Add* user_data_add;
PJRT_FFI_Register_Handler* register_handler;
PJRT_FFI_Type_Register* type_register;
} PJRT_FFI;
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, register_handler);
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, type_register);
#ifdef __cplusplus
}

View File

@ -19,8 +19,9 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/execution_context.h"
#include "xla/ffi/ffi.h"
#include "xla/ffi/ffi_api.h"
#include "xla/ffi/type_id_registry.h"
#include "xla/ffi/type_registry.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
@ -35,20 +36,48 @@ static PJRT_Error* PJRT_FFI_TypeID_Register(
PJRT_FFI_TypeID_Register_Args_STRUCT_SIZE, args->struct_size));
absl::string_view type_name(args->type_name, args->type_name_size);
xla::ffi::TypeIdRegistry::TypeId type_id(args->type_id);
xla::ffi::TypeRegistry::TypeId type_id(args->type_id);
if (type_id == xla::ffi::TypeIdRegistry::kUnknownTypeId) {
if (type_id == xla::ffi::TypeRegistry::kUnknownTypeId) {
// If type_id is unknown, we are registering a new type and XLA will assign
// a unique type id to it.
PJRT_ASSIGN_OR_RETURN(
auto assigned_type_id,
xla::ffi::TypeIdRegistry::AssignExternalTypeId(type_name));
xla::ffi::TypeRegistry::AssignExternalTypeId(type_name, {}));
args->type_id = assigned_type_id.value();
} else {
// If type_id is set, we are relying on the caller-provided unique type id.
PJRT_RETURN_IF_ERROR(
xla::ffi::TypeIdRegistry::RegisterExternalTypeId(type_name, type_id));
xla::ffi::TypeRegistry::RegisterExternalTypeId(type_name, type_id, {}));
}
return nullptr;
}
static PJRT_Error* PJRT_FFI_Type_Register(PJRT_FFI_Type_Register_Args* args) {
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
"PJRT_FFI_Type_Register_Args", PJRT_FFI_Type_Register_Args_STRUCT_SIZE,
args->struct_size));
absl::string_view type_name(args->type_name, args->type_name_size);
xla::ffi::TypeRegistry::TypeId type_id(args->type_id);
xla::ffi::TypeRegistry::TypeInfo type_info = {
args->type_info->deleter,
};
if (type_id == xla::ffi::TypeRegistry::kUnknownTypeId) {
// If type_id is unknown, we are registering a new type and XLA will assign
// a unique type id to it.
PJRT_ASSIGN_OR_RETURN(
auto assigned_type_id,
xla::ffi::TypeRegistry::AssignExternalTypeId(type_name, type_info));
args->type_id = assigned_type_id.value();
} else {
// If type_id is set, we are relying on the caller-provided unique type id.
PJRT_RETURN_IF_ERROR(xla::ffi::TypeRegistry::RegisterExternalTypeId(
type_name, type_id, type_info));
}
return nullptr;
@ -64,7 +93,7 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) {
"PJRT FFI extension requires execute context to be not nullptr")};
}
xla::ffi::TypeIdRegistry::TypeId type_id(args->user_data.type_id);
xla::ffi::TypeRegistry::TypeId type_id(args->user_data.type_id);
PJRT_RETURN_IF_ERROR(args->context->execute_context->ffi_context().Insert(
type_id, args->user_data.data, args->user_data.deleter));
return nullptr;
@ -102,6 +131,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
/*type_id_register=*/PJRT_FFI_TypeID_Register,
/*user_data_add=*/PJRT_FFI_UserData_Add,
/*register_handler=*/PJRT_FFI_Register_Handler,
/*type_register=*/PJRT_FFI_Type_Register,
};
}

View File

@ -44,7 +44,7 @@ limitations under the License.
#include "xla/ffi/api/ffi.h"
#include "xla/ffi/execution_context.h"
#include "xla/ffi/ffi_api.h"
#include "xla/ffi/type_id_registry.h"
#include "xla/ffi/type_registry.h"
#include "xla/future.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
@ -366,7 +366,7 @@ TEST_F(PjrtCApiGpuTest, CreateAndDestroyExecuteContext) {
TF_ASSERT_OK_AND_ASSIGN(
auto lookup_user_data,
create_arg.context->execute_context->ffi_context().Lookup(
xla::ffi::TypeIdRegistry::TypeId(42)));
xla::ffi::TypeRegistry::TypeId(42)));
EXPECT_EQ(lookup_user_data, &string_data);
PJRT_ExecuteContext_Destroy_Args destroy_args;

View File

@ -324,7 +324,7 @@ cc_library(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/ffi:execution_context",
"//xla/ffi:type_id_registry",
"//xla/ffi:type_registry",
"//xla/hlo/ir:hlo",
"//xla/hlo/translate/mhlo_to_hlo:type_to_shape",
"//xla/pjrt:host_callback",

View File

@ -35,7 +35,7 @@ limitations under the License.
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "xla/ffi/execution_context.h"
#include "xla/ffi/type_id_registry.h"
#include "xla/ffi/type_registry.h"
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h"
#include "xla/layout.h"
@ -768,8 +768,7 @@ PjRtLoadedExecutable::Execute(absl::Span<ArrayRef> args,
}
ffi_callbacks->callbacks = callbacks->data();
ffi_callbacks->num_callbacks = callbacks->size();
auto type_id = xla::ffi::TypeIdRegistry::TypeId(
xla::FfiLoadedHostCallbacks::id.type_id);
ffi::TypeRegistry::TypeId type_id(FfiLoadedHostCallbacks::id.type_id);
CHECK_OK(context->ffi_context().Insert(type_id, ffi_callbacks.get()));
opts.context = context.get();
}

View File

@ -351,7 +351,7 @@ NB_MODULE(py_hlo_multihost_runner, m) {
m.def("custom_call_targets", GetRegisteredCustomCallTargets,
nb::arg("platform"));
m.def(
"register_custom_type_id",
"register_custom_type",
[](absl::string_view type_name, nb::object type_id) {
xla::ThrowIfError(RegisterCustomTypeId(type_name, type_id));
},