diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD index dff514d6bb5..1deca4bd865 100644 --- a/third_party/xla/xla/ffi/BUILD +++ b/third_party/xla/xla/ffi/BUILD @@ -278,6 +278,7 @@ xla_cc_test( "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", ], diff --git a/third_party/xla/xla/ffi/execution_state.cc b/third_party/xla/xla/ffi/execution_state.cc index 87916729ef7..8db2bf63241 100644 --- a/third_party/xla/xla/ffi/execution_state.cc +++ b/third_party/xla/xla/ffi/execution_state.cc @@ -35,8 +35,7 @@ ExecutionState::~ExecutionState() { } absl::Status ExecutionState::Set(TypeId type_id, void* state) { - TF_ASSIGN_OR_RETURN(auto type_info, - TypeRegistry::GetExternalTypeInfo(type_id)); + TF_ASSIGN_OR_RETURN(auto type_info, TypeRegistry::GetTypeInfo(type_id)); if (type_info.deleter == nullptr) { return InvalidArgument( "Type id %d does not have a registered type info with a deleter", diff --git a/third_party/xla/xla/ffi/type_registry.cc b/third_party/xla/xla/ffi/type_registry.cc index e6b8ea02898..fa45bb8dff8 100644 --- a/third_party/xla/xla/ffi/type_registry.cc +++ b/third_party/xla/xla/ffi/type_registry.cc @@ -40,24 +40,19 @@ struct TypeRegistration { TypeRegistry::TypeInfo type_info; }; -using ExternalTypeRegistry = absl::flat_hash_map; +using TypeRegistryMap = absl::flat_hash_map; } // namespace ABSL_CONST_INIT absl::Mutex type_registry_mutex(absl::kConstInit); -static ExternalTypeRegistry& StaticExternalTypeRegistry() { - static absl::NoDestructor registry; +static TypeRegistryMap& StaticTypeRegistryMap() { + static absl::NoDestructor registry; return *registry; } -TypeRegistry::TypeId TypeRegistry::GetNextInternalTypeId() { - static auto* counter = new std::atomic(1); - return TypeId(counter->fetch_add(1)); -} - -TypeRegistry::TypeId TypeRegistry::GetNextExternalTypeId() { - static auto* counter = new std::atomic(1); +TypeRegistry::TypeId TypeRegistry::GetNextTypeId() { + static absl::NoDestructor> counter(1); return TypeId(counter->fetch_add(1)); } @@ -66,7 +61,7 @@ absl::StatusOr TypeRegistry::AssignExternalTypeId( VLOG(3) << absl::StrFormat("Assign external type id: name=%s", name); absl::MutexLock lock(type_registry_mutex); - auto& registry = StaticExternalTypeRegistry(); + auto& registry = StaticTypeRegistryMap(); // Try to emplace with unknow type id and fill it with real type id only if we // successfully acquired an entry for a given name. @@ -84,9 +79,9 @@ absl::StatusOr TypeRegistry::AssignExternalTypeId( }; // Create a new type id that is not already in use. - TypeId type_id = GetNextExternalTypeId(); + TypeId type_id = GetNextTypeId(); while (type_id_is_in_use(type_id)) { - type_id = GetNextExternalTypeId(); + type_id = GetNextTypeId(); } VLOG(3) << absl::StrFormat("Assigned external type id: name=%s type_id=%d", @@ -101,7 +96,7 @@ absl::Status TypeRegistry::RegisterExternalTypeId(absl::string_view name, name, type_id.value()); absl::MutexLock lock(type_registry_mutex); - auto& registry = StaticExternalTypeRegistry(); + auto& registry = StaticTypeRegistryMap(); auto emplaced = registry.emplace(name, TypeRegistration{type_id, type_info}); if (!emplaced.second && emplaced.first->second.type_id != type_id) { @@ -113,10 +108,22 @@ absl::Status TypeRegistry::RegisterExternalTypeId(absl::string_view name, return absl::OkStatus(); } -absl::StatusOr TypeRegistry::GetExternalTypeInfo( +absl::StatusOr TypeRegistry::GetTypeId( + absl::string_view name) { + absl::MutexLock lock(type_registry_mutex); + auto& registry = StaticTypeRegistryMap(); + + auto it = registry.find(name); + if (it == registry.end()) { + return Internal("Type name %s is not registered", name); + } + return it->second.type_id; +} + +absl::StatusOr TypeRegistry::GetTypeInfo( TypeId type_id) { absl::MutexLock lock(type_registry_mutex); - auto& registry = StaticExternalTypeRegistry(); + auto& registry = StaticTypeRegistryMap(); auto it = absl::c_find_if(registry, [&](const auto& kv) { auto& [name, registration] = kv; diff --git a/third_party/xla/xla/ffi/type_registry.h b/third_party/xla/xla/ffi/type_registry.h index 8c61214ddc0..8fabc15f4ed 100644 --- a/third_party/xla/xla/ffi/type_registry.h +++ b/third_party/xla/xla/ffi/type_registry.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/base/no_destructor.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -64,6 +65,14 @@ class TypeRegistry { Deleter deleter = nullptr; }; + // Returns type id for a given type name. Returns an error if type is + // not registered. Works for both external and internal type ids. + static absl::StatusOr GetTypeId(absl::string_view name); + + // Returns type info for a given type id. Returns an error if type id is not + // registered. Works for both external and internal type ids. + static absl::StatusOr GetTypeInfo(TypeId type_id); + // Assigns a unique type id to an external type with a given name. Returns an // error if a type with a given name is already registered in the process. static absl::StatusOr AssignExternalTypeId(absl::string_view name, @@ -76,9 +85,9 @@ class TypeRegistry { TypeId type_id, TypeInfo type_info); - // Returns type info for a given external type id. Returns an error if type - // id is not registered. - static absl::StatusOr GetExternalTypeInfo(TypeId type_id); + // Returns a type name for a given type. For internal type ids only. + template + static absl::string_view GetTypeName(); // Returns a type id for a given type. For internal type ids only. template @@ -89,16 +98,21 @@ class TypeRegistry { static TypeInfo GetTypeInfo(); private: - // We never mix external and internal type ids, so we can use different type - // id spaces to assign unique ids to each type. - static TypeId GetNextInternalTypeId(); - static TypeId GetNextExternalTypeId(); + static TypeId GetNextTypeId(); }; +template +absl::string_view TypeRegistry::GetTypeName() { + return typeid(T).name(); +} + template TypeRegistry::TypeId TypeRegistry::GetTypeId() { - static const TypeId id = GetNextInternalTypeId(); - return id; + // We always register internal types in the static type registry, because we + // want to be able to lookup them by name. + static const absl::NoDestructor> id( + AssignExternalTypeId(GetTypeName(), GetTypeInfo())); + return **id; } template diff --git a/third_party/xla/xla/ffi/type_registry_test.cc b/third_party/xla/xla/ffi/type_registry_test.cc index c39110b9389..936b4e40231 100644 --- a/third_party/xla/xla/ffi/type_registry_test.cc +++ b/third_party/xla/xla/ffi/type_registry_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" @@ -53,7 +54,7 @@ TEST(TypeRegistryTest, RegisterExternalTypeId) { // Registered type has a correct type info. TF_ASSERT_OK_AND_ASSIGN(TypeRegistry::TypeInfo foo_info, - TypeRegistry::GetExternalTypeInfo(foo_id)); + TypeRegistry::GetTypeInfo(foo_id)); EXPECT_EQ(foo_info.deleter, type_info.deleter); // It's ok to register a new type with a user-provided type id. @@ -64,14 +65,19 @@ TEST(TypeRegistryTest, RegisterExternalTypeId) { // And a new type has a correct type info. TF_ASSERT_OK_AND_ASSIGN(TypeRegistry::TypeInfo bar_info, - TypeRegistry::GetExternalTypeInfo(bar_id)); + TypeRegistry::GetTypeInfo(bar_id)); EXPECT_EQ(bar_info.deleter, type_info.deleter); } TEST(TypeRegistryTest, RegisterInternalTypeId) { - auto int32_type_id = TypeRegistry::GetTypeId(); - auto int64_type_id = TypeRegistry::GetTypeId(); - EXPECT_NE(int32_type_id, int64_type_id); + auto int32_id = TypeRegistry::GetTypeId(); + auto int64_id = TypeRegistry::GetTypeId(); + EXPECT_NE(int32_id, int64_id); + + absl::string_view int32_name = TypeRegistry::GetTypeName(); + absl::string_view int64_name = TypeRegistry::GetTypeName(); + EXPECT_EQ(*TypeRegistry::GetTypeId(int32_name), int32_id); + EXPECT_EQ(*TypeRegistry::GetTypeId(int64_name), int64_id); } TEST(TypeRegistryTest, InternalTypeInfo) {