mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[xla:ffi] Keep FFI handler metadata with handler registration
PiperOrigin-RevId: 822741325
This commit is contained in:
parent
70111bb38f
commit
2cdd8ff5ce
|
|
@ -206,7 +206,7 @@ bool IsConvertible(const CustomCallThunk& custom_call_thunk,
|
||||||
absl::StatusOr<ffi::HandlerRegistration> registration =
|
absl::StatusOr<ffi::HandlerRegistration> registration =
|
||||||
ffi::FindHandler(target_name, "gpu");
|
ffi::FindHandler(target_name, "gpu");
|
||||||
return registration.ok()
|
return registration.ok()
|
||||||
? ffi::IsCommandBufferCompatible(registration->traits)
|
? ffi::IsCommandBufferCompatible(registration->metadata)
|
||||||
: false;
|
: false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
11
third_party/xla/xla/ffi/api/api.h
vendored
11
third_party/xla/xla/ffi/api/api.h
vendored
|
|
@ -1609,13 +1609,17 @@ class Handler : public Ffi {
|
||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set the API version to the version of the FFI headers used by a handler.
|
||||||
extension->metadata->api_version = XLA_FFI_Api_Version{
|
extension->metadata->api_version = XLA_FFI_Api_Version{
|
||||||
XLA_FFI_Api_Version_STRUCT_SIZE,
|
XLA_FFI_Api_Version_STRUCT_SIZE,
|
||||||
/*extension_start=*/nullptr, XLA_FFI_API_MAJOR, XLA_FFI_API_MINOR};
|
/*extension_start=*/nullptr,
|
||||||
|
XLA_FFI_API_MAJOR,
|
||||||
|
XLA_FFI_API_MINOR,
|
||||||
|
};
|
||||||
|
|
||||||
// Collect all traits and store them in the metadata.
|
// Collect all traits and store them in the metadata.
|
||||||
XLA_FFI_Handler_Traits traits = 0;
|
XLA_FFI_Handler_Traits traits = 0;
|
||||||
for (const auto& trait : traits_) {
|
for (const Traits& trait : traits_) {
|
||||||
traits |= static_cast<XLA_FFI_Handler_Traits>(trait);
|
traits |= static_cast<XLA_FFI_Handler_Traits>(trait);
|
||||||
}
|
}
|
||||||
extension->metadata->traits = traits;
|
extension->metadata->traits = traits;
|
||||||
|
|
@ -1737,7 +1741,8 @@ class Handler : public Ffi {
|
||||||
// Find index of every attribute in the sorted attributes vector.
|
// Find index of every attribute in the sorted attributes vector.
|
||||||
for (size_t i = 0; i < attrs_.size(); ++i) {
|
for (size_t i = 0; i < attrs_.size(); ++i) {
|
||||||
attrs_idx_.push_back(std::distance(
|
attrs_idx_.push_back(std::distance(
|
||||||
sorted.begin(), std::find(sorted.begin(), sorted.end(), attrs_[i])));
|
sorted.begin(),
|
||||||
|
std::find(sorted.begin(), sorted.end(), attrs_[i]))); // NOLINT
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
82
third_party/xla/xla/ffi/ffi_api.cc
vendored
82
third_party/xla/xla/ffi/ffi_api.cc
vendored
|
|
@ -116,8 +116,8 @@ static bool IsSupportedApiVersion(const XLA_FFI_Api_Version& api_version) {
|
||||||
version <= kMaxSupportedApiVersion;
|
version <= kMaxSupportedApiVersion;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsCommandBufferCompatible(XLA_FFI_Handler_Traits traits) {
|
bool IsCommandBufferCompatible(const XLA_FFI_Metadata& metadata) {
|
||||||
return traits & XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE;
|
return metadata.traits & XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE;
|
||||||
}
|
}
|
||||||
|
|
||||||
static XLA_FFI_ExecutionContext CreateExecutionContext(
|
static XLA_FFI_ExecutionContext CreateExecutionContext(
|
||||||
|
|
@ -175,16 +175,16 @@ tsl::AsyncValueRef<tsl::Chain> TakeFuture(XLA_FFI_Future* future) {
|
||||||
return chain->AsRef();
|
return chain->AsRef();
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the future is already completed, immediately return the underlying async
|
// If the future is already completed, immediately return the underlying
|
||||||
// value and delete the XLA_FFI_Future.
|
// async value and delete the XLA_FFI_Future.
|
||||||
if (ABSL_PREDICT_TRUE(future->async_value.IsAvailable())) {
|
if (ABSL_PREDICT_TRUE(future->async_value.IsAvailable())) {
|
||||||
tsl::AsyncValueRef<tsl::Chain> async_value = std::move(future->async_value);
|
tsl::AsyncValueRef<tsl::Chain> async_value = std::move(future->async_value);
|
||||||
delete future;
|
delete future;
|
||||||
return async_value;
|
return async_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the future is not completed, return a copy of the underlying async value
|
// If the future is not completed, return a copy of the underlying async
|
||||||
// and keep XLA_FFI_Future alive until it is completed.
|
// value and keep XLA_FFI_Future alive until it is completed.
|
||||||
tsl::AsyncValueRef<tsl::Chain> async_value = future->async_value;
|
tsl::AsyncValueRef<tsl::Chain> async_value = future->async_value;
|
||||||
async_value.AndThen([future] { delete future; });
|
async_value.AndThen([future] { delete future; });
|
||||||
return async_value;
|
return async_value;
|
||||||
|
|
@ -201,8 +201,8 @@ static absl::StatusOr<XLA_FFI_Future*> Call(Handler& handler,
|
||||||
|
|
||||||
XLA_FFI_Error* error = nullptr;
|
XLA_FFI_Error* error = nullptr;
|
||||||
|
|
||||||
// FFI handlers might be defined in external libraries and use exceptions, so
|
// FFI handlers might be defined in external libraries and use exceptions,
|
||||||
// take extra care to catch them and convert to a status.
|
// so take extra care to catch them and convert to a status.
|
||||||
try {
|
try {
|
||||||
if constexpr (std::is_same_v<Handler, Ffi>) {
|
if constexpr (std::is_same_v<Handler, Ffi>) {
|
||||||
error = handler.Call(&ffi_call_frame);
|
error = handler.Call(&ffi_call_frame);
|
||||||
|
|
@ -386,6 +386,20 @@ static std::vector<std::string> GetHandlerStages(
|
||||||
return stages;
|
return stages;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool CheckMetadata(const XLA_FFI_Metadata& a,
|
||||||
|
const XLA_FFI_Metadata& b) {
|
||||||
|
return a.api_version.major_version == b.api_version.major_version &&
|
||||||
|
a.api_version.minor_version == b.api_version.minor_version &&
|
||||||
|
a.traits == b.traits &&
|
||||||
|
a.state_type_id.type_id == b.state_type_id.type_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool CheckHandlerBundle(const XLA_FFI_Handler_Bundle& a,
|
||||||
|
const XLA_FFI_Handler_Bundle& b) {
|
||||||
|
return a.instantiate == b.instantiate && a.prepare == b.prepare &&
|
||||||
|
a.initialize == b.initialize && a.execute == b.execute;
|
||||||
|
}
|
||||||
|
|
||||||
static absl::Status RegisterHandler(absl::string_view name,
|
static absl::Status RegisterHandler(absl::string_view name,
|
||||||
absl::string_view platform,
|
absl::string_view platform,
|
||||||
XLA_FFI_Handler_Bundle bundle,
|
XLA_FFI_Handler_Bundle bundle,
|
||||||
|
|
@ -405,8 +419,10 @@ static absl::Status RegisterHandler(absl::string_view name,
|
||||||
if (!IsSupportedApiVersion(metadata.api_version)) {
|
if (!IsSupportedApiVersion(metadata.api_version)) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"XLA FFI handler registration for %s on platform %s (canonical %s) "
|
"XLA FFI handler registration for %s on platform %s (canonical %s) "
|
||||||
"failed because the handler's API version (%d.%d) is incompatible with "
|
"failed because the handler's API version (%d.%d) is incompatible "
|
||||||
"the framework's API version (%d.%d). Minimum supported API version is "
|
"with "
|
||||||
|
"the framework's API version (%d.%d). Minimum supported API version "
|
||||||
|
"is "
|
||||||
"(%d.%d).",
|
"(%d.%d).",
|
||||||
name, platform, canonical_platform, metadata.api_version.major_version,
|
name, platform, canonical_platform, metadata.api_version.major_version,
|
||||||
metadata.api_version.minor_version, kMaxSupportedApiVersion.first,
|
metadata.api_version.minor_version, kMaxSupportedApiVersion.first,
|
||||||
|
|
@ -414,36 +430,46 @@ static absl::Status RegisterHandler(absl::string_view name,
|
||||||
kMinSupportedApiVersion.second);
|
kMinSupportedApiVersion.second);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Incorporate handler traits.
|
// Incorporate handler traits passed explicitly via handler registration
|
||||||
traits |= metadata.traits;
|
// API.
|
||||||
|
metadata.traits |= traits;
|
||||||
|
|
||||||
|
// Incorporate state type id from the instantiate implementation if present.
|
||||||
|
if (bundle.instantiate) {
|
||||||
|
TF_ASSIGN_OR_RETURN(XLA_FFI_Metadata instantiate_metadata,
|
||||||
|
GetMetadata(bundle.instantiate));
|
||||||
|
metadata.state_type_id = instantiate_metadata.state_type_id;
|
||||||
|
}
|
||||||
|
|
||||||
VLOG(2) << absl::StreamFormat(
|
VLOG(2) << absl::StreamFormat(
|
||||||
"Register XLA FFI handler for '%s'; platform=%s (canonical=%s), "
|
"Register XLA FFI handler for '%s'; platform=%s (canonical=%s), "
|
||||||
"stages=[%s], command_buffer_compatible=%v",
|
"stages=[%s], metadata=%v",
|
||||||
name, platform, canonical_platform,
|
name, platform, canonical_platform,
|
||||||
absl::StrJoin(GetHandlerStages(bundle), ", "),
|
absl::StrJoin(GetHandlerStages(bundle), ", "), metadata);
|
||||||
IsCommandBufferCompatible(traits));
|
|
||||||
|
|
||||||
auto emplaced =
|
HandlerRegistration registration{metadata, bundle};
|
||||||
GetHandlerRegistry().try_emplace(MakeHandlerKey(name, canonical_platform),
|
auto [it, emplaced] = GetHandlerRegistry().try_emplace(
|
||||||
HandlerRegistration{bundle, traits});
|
MakeHandlerKey(name, canonical_platform), registration);
|
||||||
if (!emplaced.second) {
|
|
||||||
auto existing = emplaced.first->second;
|
// We might accidentally link the same FFI library multiple times (because
|
||||||
if (existing.traits != traits) {
|
// linking shared libraries is hard), and we choose to ignore this problem as
|
||||||
|
// long as we register exactly the same handler.
|
||||||
|
if (!emplaced) {
|
||||||
|
const HandlerRegistration& existing = it->second;
|
||||||
|
if (!CheckMetadata(existing.metadata, metadata)) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Duplicate FFI handler registration for %s on platform %s "
|
"Duplicate FFI handler registration for %s on platform %s "
|
||||||
"(canonical %s) with different traits",
|
"(canonical %s) with different metadata: %v vs %v",
|
||||||
name, platform, canonical_platform);
|
name, platform, canonical_platform, existing.metadata, metadata);
|
||||||
}
|
}
|
||||||
if (existing.bundle.prepare != bundle.prepare ||
|
if (!CheckHandlerBundle(existing.bundle, bundle)) {
|
||||||
existing.bundle.initialize != bundle.initialize ||
|
|
||||||
existing.bundle.execute != bundle.execute) {
|
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Duplicate FFI handler registration for %s on platform %s "
|
"Duplicate FFI handler registration for %s on platform %s "
|
||||||
"(canonical %s) with different bundle addresses",
|
"(canonical %s) with different bundle addresses",
|
||||||
name, platform, canonical_platform);
|
name, platform, canonical_platform);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -681,8 +707,8 @@ static XLA_FFI_Error* XLA_FFI_Type_Register(XLA_FFI_Type_Register_Args* args) {
|
||||||
TypeRegistry::TypeId type_id(args->type_id->type_id);
|
TypeRegistry::TypeId type_id(args->type_id->type_id);
|
||||||
TypeRegistry::TypeInfo type_info = {args->type_info->deleter};
|
TypeRegistry::TypeInfo type_info = {args->type_info->deleter};
|
||||||
|
|
||||||
// If type_id is unknown, we are registering a new type and XLA will assign a
|
// If type_id is unknown, we are registering a new type and XLA will assign
|
||||||
// unique type id to it.
|
// a unique type id to it.
|
||||||
if (type_id == TypeRegistry::kUnknownTypeId) {
|
if (type_id == TypeRegistry::kUnknownTypeId) {
|
||||||
auto assigned_type_id =
|
auto assigned_type_id =
|
||||||
TypeRegistry::AssignExternalTypeId(type_name, type_info);
|
TypeRegistry::AssignExternalTypeId(type_name, type_info);
|
||||||
|
|
|
||||||
56
third_party/xla/xla/ffi/ffi_api.h
vendored
56
third_party/xla/xla/ffi/ffi_api.h
vendored
|
|
@ -19,10 +19,12 @@ limitations under the License.
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <variant>
|
#include <variant>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/str_join.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "xla/executable_run_options.h"
|
#include "xla/executable_run_options.h"
|
||||||
#include "xla/ffi/api/api.h"
|
#include "xla/ffi/api/api.h"
|
||||||
|
|
@ -142,11 +144,11 @@ class ScopedExecutionContext {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
struct HandlerRegistration {
|
struct HandlerRegistration {
|
||||||
XLA_FFI_Handler_Bundle bundle = {};
|
XLA_FFI_Metadata metadata;
|
||||||
XLA_FFI_Handler_Traits traits = {};
|
XLA_FFI_Handler_Bundle bundle;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool IsCommandBufferCompatible(XLA_FFI_Handler_Traits traits);
|
bool IsCommandBufferCompatible(const XLA_FFI_Metadata& metadata);
|
||||||
|
|
||||||
// Returns registered FFI handler for a given name and platform, or an error if
|
// Returns registered FFI handler for a given name and platform, or an error if
|
||||||
// it's not found in the static registry.
|
// it's not found in the static registry.
|
||||||
|
|
@ -163,6 +165,54 @@ StaticRegisteredHandlers(absl::string_view platform);
|
||||||
|
|
||||||
const XLA_FFI_Api* GetXlaFfiApi();
|
const XLA_FFI_Api* GetXlaFfiApi();
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Helper functions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Decodes XLA FFI traits packed into a 32-bit integer into a vector of traits.
|
||||||
|
inline std::vector<Traits> DecodeTraits(XLA_FFI_Handler_Traits traits) {
|
||||||
|
std::vector<Traits> result;
|
||||||
|
if (traits & XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE) {
|
||||||
|
result.push_back(Traits::kCmdBufferCompatible);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pretty printinting for FFI C++ types.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
template <typename Sink>
|
||||||
|
static void AbslStringify(Sink& sink, Traits traits) {
|
||||||
|
switch (traits) {
|
||||||
|
case Traits::kCmdBufferCompatible:
|
||||||
|
absl::Format(&sink, "cmd_buffer_compatible");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla::ffi
|
} // namespace xla::ffi
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pretty printinting for FFI C types.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
template <typename Sink>
|
||||||
|
static void AbslStringify(Sink& sink, const XLA_FFI_TypeId& type_id) {
|
||||||
|
if (type_id.type_id == XLA_FFI_UNKNOWN_TYPE_ID.type_id) {
|
||||||
|
absl::Format(&sink, "unknown");
|
||||||
|
} else {
|
||||||
|
absl::Format(&sink, "%d", type_id.type_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Sink>
|
||||||
|
static void AbslStringify(Sink& sink, const XLA_FFI_Metadata& metadata) {
|
||||||
|
absl::Format(&sink, "{api_version: %d.%d, traits: [%s], state: %v}",
|
||||||
|
metadata.api_version.major_version,
|
||||||
|
metadata.api_version.minor_version,
|
||||||
|
absl::StrJoin(xla::ffi::DecodeTraits(metadata.traits), ", "),
|
||||||
|
metadata.state_type_id);
|
||||||
|
}
|
||||||
|
|
||||||
#endif // XLA_FFI_FFI_API_H_
|
#endif // XLA_FFI_FFI_API_H_
|
||||||
|
|
|
||||||
10
third_party/xla/xla/ffi/ffi_test.cc
vendored
10
third_party/xla/xla/ffi/ffi_test.cc
vendored
|
|
@ -105,8 +105,9 @@ TEST(FfiTest, StaticHandlerRegistration) {
|
||||||
TF_ASSERT_OK(handler0.status());
|
TF_ASSERT_OK(handler0.status());
|
||||||
TF_ASSERT_OK(handler1.status());
|
TF_ASSERT_OK(handler1.status());
|
||||||
|
|
||||||
ASSERT_EQ(handler0->traits, XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE);
|
ASSERT_EQ(handler0->metadata.traits,
|
||||||
ASSERT_EQ(handler1->traits, 0);
|
XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE);
|
||||||
|
ASSERT_EQ(handler1->metadata.traits, 0);
|
||||||
|
|
||||||
// Check that platform name was canonicalized an we can find handlers
|
// Check that platform name was canonicalized an we can find handlers
|
||||||
// registered for "Host" platform as "Cpu" handlers.
|
// registered for "Host" platform as "Cpu" handlers.
|
||||||
|
|
@ -122,7 +123,8 @@ TEST(FfiTest, RegistrationTraitsBackwardsCompatibility) {
|
||||||
XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE);
|
XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE);
|
||||||
auto handler = FindHandler("traits-bwd-compat", "Host");
|
auto handler = FindHandler("traits-bwd-compat", "Host");
|
||||||
TF_ASSERT_OK(handler.status());
|
TF_ASSERT_OK(handler.status());
|
||||||
ASSERT_EQ(handler->traits, XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE);
|
ASSERT_EQ(handler->metadata.traits,
|
||||||
|
XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Declare XLA FFI handler as a function (extern "C" declaration).
|
// Declare XLA FFI handler as a function (extern "C" declaration).
|
||||||
|
|
@ -139,7 +141,7 @@ TEST(FfiTest, StaticHandlerSymbolRegistration) {
|
||||||
auto handler0 = FindHandler("no-op-sym-0", "Cpu");
|
auto handler0 = FindHandler("no-op-sym-0", "Cpu");
|
||||||
|
|
||||||
TF_ASSERT_OK(handler0.status());
|
TF_ASSERT_OK(handler0.status());
|
||||||
ASSERT_EQ(handler0->traits, 0);
|
ASSERT_EQ(handler0->metadata.traits, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(FfiTest, ForwardError) {
|
TEST(FfiTest, ForwardError) {
|
||||||
|
|
|
||||||
|
|
@ -258,7 +258,7 @@ static bool IsCommand(const HloCustomCallInstruction* hlo,
|
||||||
// Check if FFI handler is compatible with command buffers.
|
// Check if FFI handler is compatible with command buffers.
|
||||||
auto registration = ffi::FindHandler(hlo->custom_call_target(), "gpu");
|
auto registration = ffi::FindHandler(hlo->custom_call_target(), "gpu");
|
||||||
return registration.ok()
|
return registration.ok()
|
||||||
? ffi::IsCommandBufferCompatible(registration->traits)
|
? ffi::IsCommandBufferCompatible(registration->metadata)
|
||||||
: false;
|
: false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user