mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[xla:ffi] Use same id sequence for internal and external types
Add an API to lookup type id and info by type name. We can't rely on type ids for serialization, as they are not stable and assigned at run time depending on the type registration order. Type names on the other hand must be stable. PiperOrigin-RevId: 824512487
This commit is contained in:
parent
0e809d4bc8
commit
5dfa57fd92
1
third_party/xla/xla/ffi/BUILD
vendored
1
third_party/xla/xla/ffi/BUILD
vendored
|
|
@ -278,6 +278,7 @@ xla_cc_test(
|
||||||
"//xla/tsl/platform:statusor",
|
"//xla/tsl/platform:statusor",
|
||||||
"//xla/tsl/platform:test",
|
"//xla/tsl/platform:test",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/strings:string_view",
|
||||||
"@com_google_googletest//:gtest",
|
"@com_google_googletest//:gtest",
|
||||||
"@com_google_googletest//:gtest_main",
|
"@com_google_googletest//:gtest_main",
|
||||||
],
|
],
|
||||||
|
|
|
||||||
3
third_party/xla/xla/ffi/execution_state.cc
vendored
3
third_party/xla/xla/ffi/execution_state.cc
vendored
|
|
@ -35,8 +35,7 @@ ExecutionState::~ExecutionState() {
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ExecutionState::Set(TypeId type_id, void* state) {
|
absl::Status ExecutionState::Set(TypeId type_id, void* state) {
|
||||||
TF_ASSIGN_OR_RETURN(auto type_info,
|
TF_ASSIGN_OR_RETURN(auto type_info, TypeRegistry::GetTypeInfo(type_id));
|
||||||
TypeRegistry::GetExternalTypeInfo(type_id));
|
|
||||||
if (type_info.deleter == nullptr) {
|
if (type_info.deleter == nullptr) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Type id %d does not have a registered type info with a deleter",
|
"Type id %d does not have a registered type info with a deleter",
|
||||||
|
|
|
||||||
39
third_party/xla/xla/ffi/type_registry.cc
vendored
39
third_party/xla/xla/ffi/type_registry.cc
vendored
|
|
@ -40,24 +40,19 @@ struct TypeRegistration {
|
||||||
TypeRegistry::TypeInfo type_info;
|
TypeRegistry::TypeInfo type_info;
|
||||||
};
|
};
|
||||||
|
|
||||||
using ExternalTypeRegistry = absl::flat_hash_map<std::string, TypeRegistration>;
|
using TypeRegistryMap = absl::flat_hash_map<std::string, TypeRegistration>;
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
ABSL_CONST_INIT absl::Mutex type_registry_mutex(absl::kConstInit);
|
ABSL_CONST_INIT absl::Mutex type_registry_mutex(absl::kConstInit);
|
||||||
|
|
||||||
static ExternalTypeRegistry& StaticExternalTypeRegistry() {
|
static TypeRegistryMap& StaticTypeRegistryMap() {
|
||||||
static absl::NoDestructor<ExternalTypeRegistry> registry;
|
static absl::NoDestructor<TypeRegistryMap> registry;
|
||||||
return *registry;
|
return *registry;
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeRegistry::TypeId TypeRegistry::GetNextInternalTypeId() {
|
TypeRegistry::TypeId TypeRegistry::GetNextTypeId() {
|
||||||
static auto* counter = new std::atomic<int64_t>(1);
|
static absl::NoDestructor<std::atomic<int64_t>> counter(1);
|
||||||
return TypeId(counter->fetch_add(1));
|
|
||||||
}
|
|
||||||
|
|
||||||
TypeRegistry::TypeId TypeRegistry::GetNextExternalTypeId() {
|
|
||||||
static auto* counter = new std::atomic<int64_t>(1);
|
|
||||||
return TypeId(counter->fetch_add(1));
|
return TypeId(counter->fetch_add(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -66,7 +61,7 @@ absl::StatusOr<TypeRegistry::TypeId> TypeRegistry::AssignExternalTypeId(
|
||||||
VLOG(3) << absl::StrFormat("Assign external type id: name=%s", name);
|
VLOG(3) << absl::StrFormat("Assign external type id: name=%s", name);
|
||||||
|
|
||||||
absl::MutexLock lock(type_registry_mutex);
|
absl::MutexLock lock(type_registry_mutex);
|
||||||
auto& registry = StaticExternalTypeRegistry();
|
auto& registry = StaticTypeRegistryMap();
|
||||||
|
|
||||||
// Try to emplace with unknow type id and fill it with real type id only if we
|
// Try to emplace with unknow type id and fill it with real type id only if we
|
||||||
// successfully acquired an entry for a given name.
|
// successfully acquired an entry for a given name.
|
||||||
|
|
@ -84,9 +79,9 @@ absl::StatusOr<TypeRegistry::TypeId> TypeRegistry::AssignExternalTypeId(
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create a new type id that is not already in use.
|
// Create a new type id that is not already in use.
|
||||||
TypeId type_id = GetNextExternalTypeId();
|
TypeId type_id = GetNextTypeId();
|
||||||
while (type_id_is_in_use(type_id)) {
|
while (type_id_is_in_use(type_id)) {
|
||||||
type_id = GetNextExternalTypeId();
|
type_id = GetNextTypeId();
|
||||||
}
|
}
|
||||||
|
|
||||||
VLOG(3) << absl::StrFormat("Assigned external type id: name=%s type_id=%d",
|
VLOG(3) << absl::StrFormat("Assigned external type id: name=%s type_id=%d",
|
||||||
|
|
@ -101,7 +96,7 @@ absl::Status TypeRegistry::RegisterExternalTypeId(absl::string_view name,
|
||||||
name, type_id.value());
|
name, type_id.value());
|
||||||
|
|
||||||
absl::MutexLock lock(type_registry_mutex);
|
absl::MutexLock lock(type_registry_mutex);
|
||||||
auto& registry = StaticExternalTypeRegistry();
|
auto& registry = StaticTypeRegistryMap();
|
||||||
|
|
||||||
auto emplaced = registry.emplace(name, TypeRegistration{type_id, type_info});
|
auto emplaced = registry.emplace(name, TypeRegistration{type_id, type_info});
|
||||||
if (!emplaced.second && emplaced.first->second.type_id != type_id) {
|
if (!emplaced.second && emplaced.first->second.type_id != type_id) {
|
||||||
|
|
@ -113,10 +108,22 @@ absl::Status TypeRegistry::RegisterExternalTypeId(absl::string_view name,
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<TypeRegistry::TypeInfo> TypeRegistry::GetExternalTypeInfo(
|
absl::StatusOr<TypeRegistry::TypeId> TypeRegistry::GetTypeId(
|
||||||
|
absl::string_view name) {
|
||||||
|
absl::MutexLock lock(type_registry_mutex);
|
||||||
|
auto& registry = StaticTypeRegistryMap();
|
||||||
|
|
||||||
|
auto it = registry.find(name);
|
||||||
|
if (it == registry.end()) {
|
||||||
|
return Internal("Type name %s is not registered", name);
|
||||||
|
}
|
||||||
|
return it->second.type_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<TypeRegistry::TypeInfo> TypeRegistry::GetTypeInfo(
|
||||||
TypeId type_id) {
|
TypeId type_id) {
|
||||||
absl::MutexLock lock(type_registry_mutex);
|
absl::MutexLock lock(type_registry_mutex);
|
||||||
auto& registry = StaticExternalTypeRegistry();
|
auto& registry = StaticTypeRegistryMap();
|
||||||
|
|
||||||
auto it = absl::c_find_if(registry, [&](const auto& kv) {
|
auto it = absl::c_find_if(registry, [&](const auto& kv) {
|
||||||
auto& [name, registration] = kv;
|
auto& [name, registration] = kv;
|
||||||
|
|
|
||||||
32
third_party/xla/xla/ffi/type_registry.h
vendored
32
third_party/xla/xla/ffi/type_registry.h
vendored
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "absl/base/no_destructor.h"
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
|
@ -64,6 +65,14 @@ class TypeRegistry {
|
||||||
Deleter deleter = nullptr;
|
Deleter deleter = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Returns type id for a given type name. Returns an error if type is
|
||||||
|
// not registered. Works for both external and internal type ids.
|
||||||
|
static absl::StatusOr<TypeId> GetTypeId(absl::string_view name);
|
||||||
|
|
||||||
|
// Returns type info for a given type id. Returns an error if type id is not
|
||||||
|
// registered. Works for both external and internal type ids.
|
||||||
|
static absl::StatusOr<TypeInfo> GetTypeInfo(TypeId type_id);
|
||||||
|
|
||||||
// Assigns a unique type id to an external type with a given name. Returns an
|
// Assigns a unique type id to an external type with a given name. Returns an
|
||||||
// error if a type with a given name is already registered in the process.
|
// error if a type with a given name is already registered in the process.
|
||||||
static absl::StatusOr<TypeId> AssignExternalTypeId(absl::string_view name,
|
static absl::StatusOr<TypeId> AssignExternalTypeId(absl::string_view name,
|
||||||
|
|
@ -76,9 +85,9 @@ class TypeRegistry {
|
||||||
TypeId type_id,
|
TypeId type_id,
|
||||||
TypeInfo type_info);
|
TypeInfo type_info);
|
||||||
|
|
||||||
// Returns type info for a given external type id. Returns an error if type
|
// Returns a type name for a given type. For internal type ids only.
|
||||||
// id is not registered.
|
template <typename T>
|
||||||
static absl::StatusOr<TypeInfo> GetExternalTypeInfo(TypeId type_id);
|
static absl::string_view GetTypeName();
|
||||||
|
|
||||||
// Returns a type id for a given type. For internal type ids only.
|
// Returns a type id for a given type. For internal type ids only.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
@ -89,16 +98,21 @@ class TypeRegistry {
|
||||||
static TypeInfo GetTypeInfo();
|
static TypeInfo GetTypeInfo();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// We never mix external and internal type ids, so we can use different type
|
static TypeId GetNextTypeId();
|
||||||
// id spaces to assign unique ids to each type.
|
|
||||||
static TypeId GetNextInternalTypeId();
|
|
||||||
static TypeId GetNextExternalTypeId();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
absl::string_view TypeRegistry::GetTypeName() {
|
||||||
|
return typeid(T).name();
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
TypeRegistry::TypeId TypeRegistry::GetTypeId() {
|
TypeRegistry::TypeId TypeRegistry::GetTypeId() {
|
||||||
static const TypeId id = GetNextInternalTypeId();
|
// We always register internal types in the static type registry, because we
|
||||||
return id;
|
// want to be able to lookup them by name.
|
||||||
|
static const absl::NoDestructor<absl::StatusOr<TypeId>> id(
|
||||||
|
AssignExternalTypeId(GetTypeName<T>(), GetTypeInfo<T>()));
|
||||||
|
return **id;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
|
||||||
16
third_party/xla/xla/ffi/type_registry_test.cc
vendored
16
third_party/xla/xla/ffi/type_registry_test.cc
vendored
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||||
#include <gmock/gmock.h>
|
#include <gmock/gmock.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "xla/tsl/lib/core/status_test_util.h"
|
#include "xla/tsl/lib/core/status_test_util.h"
|
||||||
#include "xla/tsl/platform/statusor.h"
|
#include "xla/tsl/platform/statusor.h"
|
||||||
#include "xla/tsl/platform/test.h"
|
#include "xla/tsl/platform/test.h"
|
||||||
|
|
@ -53,7 +54,7 @@ TEST(TypeRegistryTest, RegisterExternalTypeId) {
|
||||||
|
|
||||||
// Registered type has a correct type info.
|
// Registered type has a correct type info.
|
||||||
TF_ASSERT_OK_AND_ASSIGN(TypeRegistry::TypeInfo foo_info,
|
TF_ASSERT_OK_AND_ASSIGN(TypeRegistry::TypeInfo foo_info,
|
||||||
TypeRegistry::GetExternalTypeInfo(foo_id));
|
TypeRegistry::GetTypeInfo(foo_id));
|
||||||
EXPECT_EQ(foo_info.deleter, type_info.deleter);
|
EXPECT_EQ(foo_info.deleter, type_info.deleter);
|
||||||
|
|
||||||
// It's ok to register a new type with a user-provided type id.
|
// It's ok to register a new type with a user-provided type id.
|
||||||
|
|
@ -64,14 +65,19 @@ TEST(TypeRegistryTest, RegisterExternalTypeId) {
|
||||||
|
|
||||||
// And a new type has a correct type info.
|
// And a new type has a correct type info.
|
||||||
TF_ASSERT_OK_AND_ASSIGN(TypeRegistry::TypeInfo bar_info,
|
TF_ASSERT_OK_AND_ASSIGN(TypeRegistry::TypeInfo bar_info,
|
||||||
TypeRegistry::GetExternalTypeInfo(bar_id));
|
TypeRegistry::GetTypeInfo(bar_id));
|
||||||
EXPECT_EQ(bar_info.deleter, type_info.deleter);
|
EXPECT_EQ(bar_info.deleter, type_info.deleter);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TypeRegistryTest, RegisterInternalTypeId) {
|
TEST(TypeRegistryTest, RegisterInternalTypeId) {
|
||||||
auto int32_type_id = TypeRegistry::GetTypeId<int32_t>();
|
auto int32_id = TypeRegistry::GetTypeId<int32_t>();
|
||||||
auto int64_type_id = TypeRegistry::GetTypeId<int64_t>();
|
auto int64_id = TypeRegistry::GetTypeId<int64_t>();
|
||||||
EXPECT_NE(int32_type_id, int64_type_id);
|
EXPECT_NE(int32_id, int64_id);
|
||||||
|
|
||||||
|
absl::string_view int32_name = TypeRegistry::GetTypeName<int32_t>();
|
||||||
|
absl::string_view int64_name = TypeRegistry::GetTypeName<int64_t>();
|
||||||
|
EXPECT_EQ(*TypeRegistry::GetTypeId(int32_name), int32_id);
|
||||||
|
EXPECT_EQ(*TypeRegistry::GetTypeId(int64_name), int64_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TypeRegistryTest, InternalTypeInfo) {
|
TEST(TypeRegistryTest, InternalTypeInfo) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user