mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[xla:ffi] Use same id sequence for internal and external types
Add an API to lookup type id and info by type name. We can't rely on type ids for serialization, as they are not stable and assigned at run time depending on the type registration order. Type names on the other hand must be stable. PiperOrigin-RevId: 824512487
This commit is contained in:
parent
0e809d4bc8
commit
5dfa57fd92
1
third_party/xla/xla/ffi/BUILD
vendored
1
third_party/xla/xla/ffi/BUILD
vendored
|
|
@ -278,6 +278,7 @@ xla_cc_test(
|
|||
"//xla/tsl/platform:statusor",
|
||||
"//xla/tsl/platform:test",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
|
|
|
|||
3
third_party/xla/xla/ffi/execution_state.cc
vendored
3
third_party/xla/xla/ffi/execution_state.cc
vendored
|
|
@ -35,8 +35,7 @@ ExecutionState::~ExecutionState() {
|
|||
}
|
||||
|
||||
absl::Status ExecutionState::Set(TypeId type_id, void* state) {
|
||||
TF_ASSIGN_OR_RETURN(auto type_info,
|
||||
TypeRegistry::GetExternalTypeInfo(type_id));
|
||||
TF_ASSIGN_OR_RETURN(auto type_info, TypeRegistry::GetTypeInfo(type_id));
|
||||
if (type_info.deleter == nullptr) {
|
||||
return InvalidArgument(
|
||||
"Type id %d does not have a registered type info with a deleter",
|
||||
|
|
|
|||
39
third_party/xla/xla/ffi/type_registry.cc
vendored
39
third_party/xla/xla/ffi/type_registry.cc
vendored
|
|
@ -40,24 +40,19 @@ struct TypeRegistration {
|
|||
TypeRegistry::TypeInfo type_info;
|
||||
};
|
||||
|
||||
using ExternalTypeRegistry = absl::flat_hash_map<std::string, TypeRegistration>;
|
||||
using TypeRegistryMap = absl::flat_hash_map<std::string, TypeRegistration>;
|
||||
|
||||
} // namespace
|
||||
|
||||
ABSL_CONST_INIT absl::Mutex type_registry_mutex(absl::kConstInit);
|
||||
|
||||
static ExternalTypeRegistry& StaticExternalTypeRegistry() {
|
||||
static absl::NoDestructor<ExternalTypeRegistry> registry;
|
||||
static TypeRegistryMap& StaticTypeRegistryMap() {
|
||||
static absl::NoDestructor<TypeRegistryMap> registry;
|
||||
return *registry;
|
||||
}
|
||||
|
||||
TypeRegistry::TypeId TypeRegistry::GetNextInternalTypeId() {
|
||||
static auto* counter = new std::atomic<int64_t>(1);
|
||||
return TypeId(counter->fetch_add(1));
|
||||
}
|
||||
|
||||
TypeRegistry::TypeId TypeRegistry::GetNextExternalTypeId() {
|
||||
static auto* counter = new std::atomic<int64_t>(1);
|
||||
TypeRegistry::TypeId TypeRegistry::GetNextTypeId() {
|
||||
static absl::NoDestructor<std::atomic<int64_t>> counter(1);
|
||||
return TypeId(counter->fetch_add(1));
|
||||
}
|
||||
|
||||
|
|
@ -66,7 +61,7 @@ absl::StatusOr<TypeRegistry::TypeId> TypeRegistry::AssignExternalTypeId(
|
|||
VLOG(3) << absl::StrFormat("Assign external type id: name=%s", name);
|
||||
|
||||
absl::MutexLock lock(type_registry_mutex);
|
||||
auto& registry = StaticExternalTypeRegistry();
|
||||
auto& registry = StaticTypeRegistryMap();
|
||||
|
||||
// Try to emplace with unknow type id and fill it with real type id only if we
|
||||
// successfully acquired an entry for a given name.
|
||||
|
|
@ -84,9 +79,9 @@ absl::StatusOr<TypeRegistry::TypeId> TypeRegistry::AssignExternalTypeId(
|
|||
};
|
||||
|
||||
// Create a new type id that is not already in use.
|
||||
TypeId type_id = GetNextExternalTypeId();
|
||||
TypeId type_id = GetNextTypeId();
|
||||
while (type_id_is_in_use(type_id)) {
|
||||
type_id = GetNextExternalTypeId();
|
||||
type_id = GetNextTypeId();
|
||||
}
|
||||
|
||||
VLOG(3) << absl::StrFormat("Assigned external type id: name=%s type_id=%d",
|
||||
|
|
@ -101,7 +96,7 @@ absl::Status TypeRegistry::RegisterExternalTypeId(absl::string_view name,
|
|||
name, type_id.value());
|
||||
|
||||
absl::MutexLock lock(type_registry_mutex);
|
||||
auto& registry = StaticExternalTypeRegistry();
|
||||
auto& registry = StaticTypeRegistryMap();
|
||||
|
||||
auto emplaced = registry.emplace(name, TypeRegistration{type_id, type_info});
|
||||
if (!emplaced.second && emplaced.first->second.type_id != type_id) {
|
||||
|
|
@ -113,10 +108,22 @@ absl::Status TypeRegistry::RegisterExternalTypeId(absl::string_view name,
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<TypeRegistry::TypeInfo> TypeRegistry::GetExternalTypeInfo(
|
||||
absl::StatusOr<TypeRegistry::TypeId> TypeRegistry::GetTypeId(
|
||||
absl::string_view name) {
|
||||
absl::MutexLock lock(type_registry_mutex);
|
||||
auto& registry = StaticTypeRegistryMap();
|
||||
|
||||
auto it = registry.find(name);
|
||||
if (it == registry.end()) {
|
||||
return Internal("Type name %s is not registered", name);
|
||||
}
|
||||
return it->second.type_id;
|
||||
}
|
||||
|
||||
absl::StatusOr<TypeRegistry::TypeInfo> TypeRegistry::GetTypeInfo(
|
||||
TypeId type_id) {
|
||||
absl::MutexLock lock(type_registry_mutex);
|
||||
auto& registry = StaticExternalTypeRegistry();
|
||||
auto& registry = StaticTypeRegistryMap();
|
||||
|
||||
auto it = absl::c_find_if(registry, [&](const auto& kv) {
|
||||
auto& [name, registration] = kv;
|
||||
|
|
|
|||
32
third_party/xla/xla/ffi/type_registry.h
vendored
32
third_party/xla/xla/ffi/type_registry.h
vendored
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/base/no_destructor.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
|
|
@ -64,6 +65,14 @@ class TypeRegistry {
|
|||
Deleter deleter = nullptr;
|
||||
};
|
||||
|
||||
// Returns type id for a given type name. Returns an error if type is
|
||||
// not registered. Works for both external and internal type ids.
|
||||
static absl::StatusOr<TypeId> GetTypeId(absl::string_view name);
|
||||
|
||||
// Returns type info for a given type id. Returns an error if type id is not
|
||||
// registered. Works for both external and internal type ids.
|
||||
static absl::StatusOr<TypeInfo> GetTypeInfo(TypeId type_id);
|
||||
|
||||
// Assigns a unique type id to an external type with a given name. Returns an
|
||||
// error if a type with a given name is already registered in the process.
|
||||
static absl::StatusOr<TypeId> AssignExternalTypeId(absl::string_view name,
|
||||
|
|
@ -76,9 +85,9 @@ class TypeRegistry {
|
|||
TypeId type_id,
|
||||
TypeInfo type_info);
|
||||
|
||||
// Returns type info for a given external type id. Returns an error if type
|
||||
// id is not registered.
|
||||
static absl::StatusOr<TypeInfo> GetExternalTypeInfo(TypeId type_id);
|
||||
// Returns a type name for a given type. For internal type ids only.
|
||||
template <typename T>
|
||||
static absl::string_view GetTypeName();
|
||||
|
||||
// Returns a type id for a given type. For internal type ids only.
|
||||
template <typename T>
|
||||
|
|
@ -89,16 +98,21 @@ class TypeRegistry {
|
|||
static TypeInfo GetTypeInfo();
|
||||
|
||||
private:
|
||||
// We never mix external and internal type ids, so we can use different type
|
||||
// id spaces to assign unique ids to each type.
|
||||
static TypeId GetNextInternalTypeId();
|
||||
static TypeId GetNextExternalTypeId();
|
||||
static TypeId GetNextTypeId();
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
absl::string_view TypeRegistry::GetTypeName() {
|
||||
return typeid(T).name();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TypeRegistry::TypeId TypeRegistry::GetTypeId() {
|
||||
static const TypeId id = GetNextInternalTypeId();
|
||||
return id;
|
||||
// We always register internal types in the static type registry, because we
|
||||
// want to be able to lookup them by name.
|
||||
static const absl::NoDestructor<absl::StatusOr<TypeId>> id(
|
||||
AssignExternalTypeId(GetTypeName<T>(), GetTypeInfo<T>()));
|
||||
return **id;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
16
third_party/xla/xla/ffi/type_registry_test.cc
vendored
16
third_party/xla/xla/ffi/type_registry_test.cc
vendored
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "xla/tsl/lib/core/status_test_util.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/tsl/platform/test.h"
|
||||
|
|
@ -53,7 +54,7 @@ TEST(TypeRegistryTest, RegisterExternalTypeId) {
|
|||
|
||||
// Registered type has a correct type info.
|
||||
TF_ASSERT_OK_AND_ASSIGN(TypeRegistry::TypeInfo foo_info,
|
||||
TypeRegistry::GetExternalTypeInfo(foo_id));
|
||||
TypeRegistry::GetTypeInfo(foo_id));
|
||||
EXPECT_EQ(foo_info.deleter, type_info.deleter);
|
||||
|
||||
// It's ok to register a new type with a user-provided type id.
|
||||
|
|
@ -64,14 +65,19 @@ TEST(TypeRegistryTest, RegisterExternalTypeId) {
|
|||
|
||||
// And a new type has a correct type info.
|
||||
TF_ASSERT_OK_AND_ASSIGN(TypeRegistry::TypeInfo bar_info,
|
||||
TypeRegistry::GetExternalTypeInfo(bar_id));
|
||||
TypeRegistry::GetTypeInfo(bar_id));
|
||||
EXPECT_EQ(bar_info.deleter, type_info.deleter);
|
||||
}
|
||||
|
||||
TEST(TypeRegistryTest, RegisterInternalTypeId) {
|
||||
auto int32_type_id = TypeRegistry::GetTypeId<int32_t>();
|
||||
auto int64_type_id = TypeRegistry::GetTypeId<int64_t>();
|
||||
EXPECT_NE(int32_type_id, int64_type_id);
|
||||
auto int32_id = TypeRegistry::GetTypeId<int32_t>();
|
||||
auto int64_id = TypeRegistry::GetTypeId<int64_t>();
|
||||
EXPECT_NE(int32_id, int64_id);
|
||||
|
||||
absl::string_view int32_name = TypeRegistry::GetTypeName<int32_t>();
|
||||
absl::string_view int64_name = TypeRegistry::GetTypeName<int64_t>();
|
||||
EXPECT_EQ(*TypeRegistry::GetTypeId(int32_name), int32_id);
|
||||
EXPECT_EQ(*TypeRegistry::GetTypeId(int64_name), int64_id);
|
||||
}
|
||||
|
||||
TEST(TypeRegistryTest, InternalTypeInfo) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user