[xla:ffi] Keep FFI handler metadata with handler registration

PiperOrigin-RevId: 822741325
This commit is contained in:
Eugene Zhulenev 2025-10-22 14:27:41 -07:00 committed by TensorFlower Gardener
parent 70111bb38f
commit 2cdd8ff5ce
6 changed files with 123 additions and 40 deletions

View File

@ -206,7 +206,7 @@ bool IsConvertible(const CustomCallThunk& custom_call_thunk,
absl::StatusOr<ffi::HandlerRegistration> registration = absl::StatusOr<ffi::HandlerRegistration> registration =
ffi::FindHandler(target_name, "gpu"); ffi::FindHandler(target_name, "gpu");
return registration.ok() return registration.ok()
? ffi::IsCommandBufferCompatible(registration->traits) ? ffi::IsCommandBufferCompatible(registration->metadata)
: false; : false;
} }

View File

@ -1609,13 +1609,17 @@ class Handler : public Ffi {
return err; return err;
} }
// Set the API version to the version of the FFI headers used by a handler.
extension->metadata->api_version = XLA_FFI_Api_Version{ extension->metadata->api_version = XLA_FFI_Api_Version{
XLA_FFI_Api_Version_STRUCT_SIZE, XLA_FFI_Api_Version_STRUCT_SIZE,
/*extension_start=*/nullptr, XLA_FFI_API_MAJOR, XLA_FFI_API_MINOR}; /*extension_start=*/nullptr,
XLA_FFI_API_MAJOR,
XLA_FFI_API_MINOR,
};
// Collect all traits and store them in the metadata. // Collect all traits and store them in the metadata.
XLA_FFI_Handler_Traits traits = 0; XLA_FFI_Handler_Traits traits = 0;
for (const auto& trait : traits_) { for (const Traits& trait : traits_) {
traits |= static_cast<XLA_FFI_Handler_Traits>(trait); traits |= static_cast<XLA_FFI_Handler_Traits>(trait);
} }
extension->metadata->traits = traits; extension->metadata->traits = traits;
@ -1737,7 +1741,8 @@ class Handler : public Ffi {
// Find index of every attribute in the sorted attributes vector. // Find index of every attribute in the sorted attributes vector.
for (size_t i = 0; i < attrs_.size(); ++i) { for (size_t i = 0; i < attrs_.size(); ++i) {
attrs_idx_.push_back(std::distance( attrs_idx_.push_back(std::distance(
sorted.begin(), std::find(sorted.begin(), sorted.end(), attrs_[i]))); sorted.begin(),
std::find(sorted.begin(), sorted.end(), attrs_[i]))); // NOLINT
} }
} }

View File

@ -116,8 +116,8 @@ static bool IsSupportedApiVersion(const XLA_FFI_Api_Version& api_version) {
version <= kMaxSupportedApiVersion; version <= kMaxSupportedApiVersion;
} }
bool IsCommandBufferCompatible(XLA_FFI_Handler_Traits traits) { bool IsCommandBufferCompatible(const XLA_FFI_Metadata& metadata) {
return traits & XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE; return metadata.traits & XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE;
} }
static XLA_FFI_ExecutionContext CreateExecutionContext( static XLA_FFI_ExecutionContext CreateExecutionContext(
@ -175,16 +175,16 @@ tsl::AsyncValueRef<tsl::Chain> TakeFuture(XLA_FFI_Future* future) {
return chain->AsRef(); return chain->AsRef();
} }
// If the future is already completed, immediately return the underlying async // If the future is already completed, immediately return the underlying
// value and delete the XLA_FFI_Future. // async value and delete the XLA_FFI_Future.
if (ABSL_PREDICT_TRUE(future->async_value.IsAvailable())) { if (ABSL_PREDICT_TRUE(future->async_value.IsAvailable())) {
tsl::AsyncValueRef<tsl::Chain> async_value = std::move(future->async_value); tsl::AsyncValueRef<tsl::Chain> async_value = std::move(future->async_value);
delete future; delete future;
return async_value; return async_value;
} }
// If the future is not completed, return a copy of the underlying async value // If the future is not completed, return a copy of the underlying async
// and keep XLA_FFI_Future alive until it is completed. // value and keep XLA_FFI_Future alive until it is completed.
tsl::AsyncValueRef<tsl::Chain> async_value = future->async_value; tsl::AsyncValueRef<tsl::Chain> async_value = future->async_value;
async_value.AndThen([future] { delete future; }); async_value.AndThen([future] { delete future; });
return async_value; return async_value;
@ -201,8 +201,8 @@ static absl::StatusOr<XLA_FFI_Future*> Call(Handler& handler,
XLA_FFI_Error* error = nullptr; XLA_FFI_Error* error = nullptr;
// FFI handlers might be defined in external libraries and use exceptions, so // FFI handlers might be defined in external libraries and use exceptions,
// take extra care to catch them and convert to a status. // so take extra care to catch them and convert to a status.
try { try {
if constexpr (std::is_same_v<Handler, Ffi>) { if constexpr (std::is_same_v<Handler, Ffi>) {
error = handler.Call(&ffi_call_frame); error = handler.Call(&ffi_call_frame);
@ -386,6 +386,20 @@ static std::vector<std::string> GetHandlerStages(
return stages; return stages;
} }
static bool CheckMetadata(const XLA_FFI_Metadata& a,
const XLA_FFI_Metadata& b) {
return a.api_version.major_version == b.api_version.major_version &&
a.api_version.minor_version == b.api_version.minor_version &&
a.traits == b.traits &&
a.state_type_id.type_id == b.state_type_id.type_id;
}
static bool CheckHandlerBundle(const XLA_FFI_Handler_Bundle& a,
const XLA_FFI_Handler_Bundle& b) {
return a.instantiate == b.instantiate && a.prepare == b.prepare &&
a.initialize == b.initialize && a.execute == b.execute;
}
static absl::Status RegisterHandler(absl::string_view name, static absl::Status RegisterHandler(absl::string_view name,
absl::string_view platform, absl::string_view platform,
XLA_FFI_Handler_Bundle bundle, XLA_FFI_Handler_Bundle bundle,
@ -405,8 +419,10 @@ static absl::Status RegisterHandler(absl::string_view name,
if (!IsSupportedApiVersion(metadata.api_version)) { if (!IsSupportedApiVersion(metadata.api_version)) {
return InvalidArgument( return InvalidArgument(
"XLA FFI handler registration for %s on platform %s (canonical %s) " "XLA FFI handler registration for %s on platform %s (canonical %s) "
"failed because the handler's API version (%d.%d) is incompatible with " "failed because the handler's API version (%d.%d) is incompatible "
"the framework's API version (%d.%d). Minimum supported API version is " "with "
"the framework's API version (%d.%d). Minimum supported API version "
"is "
"(%d.%d).", "(%d.%d).",
name, platform, canonical_platform, metadata.api_version.major_version, name, platform, canonical_platform, metadata.api_version.major_version,
metadata.api_version.minor_version, kMaxSupportedApiVersion.first, metadata.api_version.minor_version, kMaxSupportedApiVersion.first,
@ -414,36 +430,46 @@ static absl::Status RegisterHandler(absl::string_view name,
kMinSupportedApiVersion.second); kMinSupportedApiVersion.second);
} }
// Incorporate handler traits. // Incorporate handler traits passed explicitly via handler registration
traits |= metadata.traits; // API.
metadata.traits |= traits;
// Incorporate state type id from the instantiate implementation if present.
if (bundle.instantiate) {
TF_ASSIGN_OR_RETURN(XLA_FFI_Metadata instantiate_metadata,
GetMetadata(bundle.instantiate));
metadata.state_type_id = instantiate_metadata.state_type_id;
}
VLOG(2) << absl::StreamFormat( VLOG(2) << absl::StreamFormat(
"Register XLA FFI handler for '%s'; platform=%s (canonical=%s), " "Register XLA FFI handler for '%s'; platform=%s (canonical=%s), "
"stages=[%s], command_buffer_compatible=%v", "stages=[%s], metadata=%v",
name, platform, canonical_platform, name, platform, canonical_platform,
absl::StrJoin(GetHandlerStages(bundle), ", "), absl::StrJoin(GetHandlerStages(bundle), ", "), metadata);
IsCommandBufferCompatible(traits));
auto emplaced = HandlerRegistration registration{metadata, bundle};
GetHandlerRegistry().try_emplace(MakeHandlerKey(name, canonical_platform), auto [it, emplaced] = GetHandlerRegistry().try_emplace(
HandlerRegistration{bundle, traits}); MakeHandlerKey(name, canonical_platform), registration);
if (!emplaced.second) {
auto existing = emplaced.first->second; // We might accidentally link the same FFI library multiple times (because
if (existing.traits != traits) { // linking shared libraries is hard), and we choose to ignore this problem as
// long as we register exactly the same handler.
if (!emplaced) {
const HandlerRegistration& existing = it->second;
if (!CheckMetadata(existing.metadata, metadata)) {
return InvalidArgument( return InvalidArgument(
"Duplicate FFI handler registration for %s on platform %s " "Duplicate FFI handler registration for %s on platform %s "
"(canonical %s) with different traits", "(canonical %s) with different metadata: %v vs %v",
name, platform, canonical_platform); name, platform, canonical_platform, existing.metadata, metadata);
} }
if (existing.bundle.prepare != bundle.prepare || if (!CheckHandlerBundle(existing.bundle, bundle)) {
existing.bundle.initialize != bundle.initialize ||
existing.bundle.execute != bundle.execute) {
return InvalidArgument( return InvalidArgument(
"Duplicate FFI handler registration for %s on platform %s " "Duplicate FFI handler registration for %s on platform %s "
"(canonical %s) with different bundle addresses", "(canonical %s) with different bundle addresses",
name, platform, canonical_platform); name, platform, canonical_platform);
} }
} }
return absl::OkStatus(); return absl::OkStatus();
} }
@ -681,8 +707,8 @@ static XLA_FFI_Error* XLA_FFI_Type_Register(XLA_FFI_Type_Register_Args* args) {
TypeRegistry::TypeId type_id(args->type_id->type_id); TypeRegistry::TypeId type_id(args->type_id->type_id);
TypeRegistry::TypeInfo type_info = {args->type_info->deleter}; TypeRegistry::TypeInfo type_info = {args->type_info->deleter};
// If type_id is unknown, we are registering a new type and XLA will assign a // If type_id is unknown, we are registering a new type and XLA will assign
// unique type id to it. // a unique type id to it.
if (type_id == TypeRegistry::kUnknownTypeId) { if (type_id == TypeRegistry::kUnknownTypeId) {
auto assigned_type_id = auto assigned_type_id =
TypeRegistry::AssignExternalTypeId(type_name, type_info); TypeRegistry::AssignExternalTypeId(type_name, type_info);

View File

@ -19,10 +19,12 @@ limitations under the License.
#include <cstdint> #include <cstdint>
#include <string> #include <string>
#include <variant> #include <variant>
#include <vector>
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "xla/executable_run_options.h" #include "xla/executable_run_options.h"
#include "xla/ffi/api/api.h" #include "xla/ffi/api/api.h"
@ -142,11 +144,11 @@ class ScopedExecutionContext {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
struct HandlerRegistration { struct HandlerRegistration {
XLA_FFI_Handler_Bundle bundle = {}; XLA_FFI_Metadata metadata;
XLA_FFI_Handler_Traits traits = {}; XLA_FFI_Handler_Bundle bundle;
}; };
bool IsCommandBufferCompatible(XLA_FFI_Handler_Traits traits); bool IsCommandBufferCompatible(const XLA_FFI_Metadata& metadata);
// Returns registered FFI handler for a given name and platform, or an error if // Returns registered FFI handler for a given name and platform, or an error if
// it's not found in the static registry. // it's not found in the static registry.
@ -163,6 +165,54 @@ StaticRegisteredHandlers(absl::string_view platform);
const XLA_FFI_Api* GetXlaFfiApi(); const XLA_FFI_Api* GetXlaFfiApi();
//===----------------------------------------------------------------------===//
// Helper functions
//===----------------------------------------------------------------------===//
// Decodes XLA FFI traits packed into a 32-bit integer into a vector of traits.
inline std::vector<Traits> DecodeTraits(XLA_FFI_Handler_Traits traits) {
std::vector<Traits> result;
if (traits & XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE) {
result.push_back(Traits::kCmdBufferCompatible);
}
return result;
}
//===----------------------------------------------------------------------===//
// Pretty printinting for FFI C++ types.
//===----------------------------------------------------------------------===//
template <typename Sink>
static void AbslStringify(Sink& sink, Traits traits) {
switch (traits) {
case Traits::kCmdBufferCompatible:
absl::Format(&sink, "cmd_buffer_compatible");
break;
}
}
} // namespace xla::ffi } // namespace xla::ffi
//===----------------------------------------------------------------------===//
// Pretty printinting for FFI C types.
//===----------------------------------------------------------------------===//
template <typename Sink>
static void AbslStringify(Sink& sink, const XLA_FFI_TypeId& type_id) {
if (type_id.type_id == XLA_FFI_UNKNOWN_TYPE_ID.type_id) {
absl::Format(&sink, "unknown");
} else {
absl::Format(&sink, "%d", type_id.type_id);
}
}
template <typename Sink>
static void AbslStringify(Sink& sink, const XLA_FFI_Metadata& metadata) {
absl::Format(&sink, "{api_version: %d.%d, traits: [%s], state: %v}",
metadata.api_version.major_version,
metadata.api_version.minor_version,
absl::StrJoin(xla::ffi::DecodeTraits(metadata.traits), ", "),
metadata.state_type_id);
}
#endif // XLA_FFI_FFI_API_H_ #endif // XLA_FFI_FFI_API_H_

View File

@ -105,8 +105,9 @@ TEST(FfiTest, StaticHandlerRegistration) {
TF_ASSERT_OK(handler0.status()); TF_ASSERT_OK(handler0.status());
TF_ASSERT_OK(handler1.status()); TF_ASSERT_OK(handler1.status());
ASSERT_EQ(handler0->traits, XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE); ASSERT_EQ(handler0->metadata.traits,
ASSERT_EQ(handler1->traits, 0); XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE);
ASSERT_EQ(handler1->metadata.traits, 0);
// Check that platform name was canonicalized an we can find handlers // Check that platform name was canonicalized an we can find handlers
// registered for "Host" platform as "Cpu" handlers. // registered for "Host" platform as "Cpu" handlers.
@ -122,7 +123,8 @@ TEST(FfiTest, RegistrationTraitsBackwardsCompatibility) {
XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE); XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE);
auto handler = FindHandler("traits-bwd-compat", "Host"); auto handler = FindHandler("traits-bwd-compat", "Host");
TF_ASSERT_OK(handler.status()); TF_ASSERT_OK(handler.status());
ASSERT_EQ(handler->traits, XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE); ASSERT_EQ(handler->metadata.traits,
XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE);
} }
// Declare XLA FFI handler as a function (extern "C" declaration). // Declare XLA FFI handler as a function (extern "C" declaration).
@ -139,7 +141,7 @@ TEST(FfiTest, StaticHandlerSymbolRegistration) {
auto handler0 = FindHandler("no-op-sym-0", "Cpu"); auto handler0 = FindHandler("no-op-sym-0", "Cpu");
TF_ASSERT_OK(handler0.status()); TF_ASSERT_OK(handler0.status());
ASSERT_EQ(handler0->traits, 0); ASSERT_EQ(handler0->metadata.traits, 0);
} }
TEST(FfiTest, ForwardError) { TEST(FfiTest, ForwardError) {

View File

@ -258,7 +258,7 @@ static bool IsCommand(const HloCustomCallInstruction* hlo,
// Check if FFI handler is compatible with command buffers. // Check if FFI handler is compatible with command buffers.
auto registration = ffi::FindHandler(hlo->custom_call_target(), "gpu"); auto registration = ffi::FindHandler(hlo->custom_call_target(), "gpu");
return registration.ok() return registration.ok()
? ffi::IsCommandBufferCompatible(registration->traits) ? ffi::IsCommandBufferCompatible(registration->metadata)
: false; : false;
} }