[xla:ffi] Keep FFI handler metadata with handler registration

PiperOrigin-RevId: 822741325
This commit is contained in:
Eugene Zhulenev 2025-10-22 14:27:41 -07:00 committed by TensorFlower Gardener
parent 70111bb38f
commit 2cdd8ff5ce
6 changed files with 123 additions and 40 deletions

View File

@ -206,7 +206,7 @@ bool IsConvertible(const CustomCallThunk& custom_call_thunk,
absl::StatusOr<ffi::HandlerRegistration> registration =
ffi::FindHandler(target_name, "gpu");
return registration.ok()
? ffi::IsCommandBufferCompatible(registration->traits)
? ffi::IsCommandBufferCompatible(registration->metadata)
: false;
}

View File

@ -1609,13 +1609,17 @@ class Handler : public Ffi {
return err;
}
// Set the API version to the version of the FFI headers used by a handler.
extension->metadata->api_version = XLA_FFI_Api_Version{
XLA_FFI_Api_Version_STRUCT_SIZE,
/*extension_start=*/nullptr, XLA_FFI_API_MAJOR, XLA_FFI_API_MINOR};
/*extension_start=*/nullptr,
XLA_FFI_API_MAJOR,
XLA_FFI_API_MINOR,
};
// Collect all traits and store them in the metadata.
XLA_FFI_Handler_Traits traits = 0;
for (const auto& trait : traits_) {
for (const Traits& trait : traits_) {
traits |= static_cast<XLA_FFI_Handler_Traits>(trait);
}
extension->metadata->traits = traits;
@ -1737,7 +1741,8 @@ class Handler : public Ffi {
// Find index of every attribute in the sorted attributes vector.
for (size_t i = 0; i < attrs_.size(); ++i) {
attrs_idx_.push_back(std::distance(
sorted.begin(), std::find(sorted.begin(), sorted.end(), attrs_[i])));
sorted.begin(),
std::find(sorted.begin(), sorted.end(), attrs_[i]))); // NOLINT
}
}

View File

@ -116,8 +116,8 @@ static bool IsSupportedApiVersion(const XLA_FFI_Api_Version& api_version) {
version <= kMaxSupportedApiVersion;
}
bool IsCommandBufferCompatible(XLA_FFI_Handler_Traits traits) {
return traits & XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE;
bool IsCommandBufferCompatible(const XLA_FFI_Metadata& metadata) {
return metadata.traits & XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE;
}
static XLA_FFI_ExecutionContext CreateExecutionContext(
@ -175,16 +175,16 @@ tsl::AsyncValueRef<tsl::Chain> TakeFuture(XLA_FFI_Future* future) {
return chain->AsRef();
}
// If the future is already completed, immediately return the underlying async
// value and delete the XLA_FFI_Future.
// If the future is already completed, immediately return the underlying
// async value and delete the XLA_FFI_Future.
if (ABSL_PREDICT_TRUE(future->async_value.IsAvailable())) {
tsl::AsyncValueRef<tsl::Chain> async_value = std::move(future->async_value);
delete future;
return async_value;
}
// If the future is not completed, return a copy of the underlying async value
// and keep XLA_FFI_Future alive until it is completed.
// If the future is not completed, return a copy of the underlying async
// value and keep XLA_FFI_Future alive until it is completed.
tsl::AsyncValueRef<tsl::Chain> async_value = future->async_value;
async_value.AndThen([future] { delete future; });
return async_value;
@ -201,8 +201,8 @@ static absl::StatusOr<XLA_FFI_Future*> Call(Handler& handler,
XLA_FFI_Error* error = nullptr;
// FFI handlers might be defined in external libraries and use exceptions, so
// take extra care to catch them and convert to a status.
// FFI handlers might be defined in external libraries and use exceptions,
// so take extra care to catch them and convert to a status.
try {
if constexpr (std::is_same_v<Handler, Ffi>) {
error = handler.Call(&ffi_call_frame);
@ -386,6 +386,20 @@ static std::vector<std::string> GetHandlerStages(
return stages;
}
static bool CheckMetadata(const XLA_FFI_Metadata& a,
const XLA_FFI_Metadata& b) {
return a.api_version.major_version == b.api_version.major_version &&
a.api_version.minor_version == b.api_version.minor_version &&
a.traits == b.traits &&
a.state_type_id.type_id == b.state_type_id.type_id;
}
static bool CheckHandlerBundle(const XLA_FFI_Handler_Bundle& a,
const XLA_FFI_Handler_Bundle& b) {
return a.instantiate == b.instantiate && a.prepare == b.prepare &&
a.initialize == b.initialize && a.execute == b.execute;
}
static absl::Status RegisterHandler(absl::string_view name,
absl::string_view platform,
XLA_FFI_Handler_Bundle bundle,
@ -405,8 +419,10 @@ static absl::Status RegisterHandler(absl::string_view name,
if (!IsSupportedApiVersion(metadata.api_version)) {
return InvalidArgument(
"XLA FFI handler registration for %s on platform %s (canonical %s) "
"failed because the handler's API version (%d.%d) is incompatible with "
"the framework's API version (%d.%d). Minimum supported API version is "
"failed because the handler's API version (%d.%d) is incompatible "
"with "
"the framework's API version (%d.%d). Minimum supported API version "
"is "
"(%d.%d).",
name, platform, canonical_platform, metadata.api_version.major_version,
metadata.api_version.minor_version, kMaxSupportedApiVersion.first,
@ -414,36 +430,46 @@ static absl::Status RegisterHandler(absl::string_view name,
kMinSupportedApiVersion.second);
}
// Incorporate handler traits.
traits |= metadata.traits;
// Incorporate handler traits passed explicitly via handler registration
// API.
metadata.traits |= traits;
// Incorporate state type id from the instantiate implementation if present.
if (bundle.instantiate) {
TF_ASSIGN_OR_RETURN(XLA_FFI_Metadata instantiate_metadata,
GetMetadata(bundle.instantiate));
metadata.state_type_id = instantiate_metadata.state_type_id;
}
VLOG(2) << absl::StreamFormat(
"Register XLA FFI handler for '%s'; platform=%s (canonical=%s), "
"stages=[%s], command_buffer_compatible=%v",
"stages=[%s], metadata=%v",
name, platform, canonical_platform,
absl::StrJoin(GetHandlerStages(bundle), ", "),
IsCommandBufferCompatible(traits));
absl::StrJoin(GetHandlerStages(bundle), ", "), metadata);
auto emplaced =
GetHandlerRegistry().try_emplace(MakeHandlerKey(name, canonical_platform),
HandlerRegistration{bundle, traits});
if (!emplaced.second) {
auto existing = emplaced.first->second;
if (existing.traits != traits) {
HandlerRegistration registration{metadata, bundle};
auto [it, emplaced] = GetHandlerRegistry().try_emplace(
MakeHandlerKey(name, canonical_platform), registration);
// We might accidentally link the same FFI library multiple times (because
// linking shared libraries is hard), and we choose to ignore this problem as
// long as we register exactly the same handler.
if (!emplaced) {
const HandlerRegistration& existing = it->second;
if (!CheckMetadata(existing.metadata, metadata)) {
return InvalidArgument(
"Duplicate FFI handler registration for %s on platform %s "
"(canonical %s) with different traits",
name, platform, canonical_platform);
"(canonical %s) with different metadata: %v vs %v",
name, platform, canonical_platform, existing.metadata, metadata);
}
if (existing.bundle.prepare != bundle.prepare ||
existing.bundle.initialize != bundle.initialize ||
existing.bundle.execute != bundle.execute) {
if (!CheckHandlerBundle(existing.bundle, bundle)) {
return InvalidArgument(
"Duplicate FFI handler registration for %s on platform %s "
"(canonical %s) with different bundle addresses",
name, platform, canonical_platform);
}
}
return absl::OkStatus();
}
@ -681,8 +707,8 @@ static XLA_FFI_Error* XLA_FFI_Type_Register(XLA_FFI_Type_Register_Args* args) {
TypeRegistry::TypeId type_id(args->type_id->type_id);
TypeRegistry::TypeInfo type_info = {args->type_info->deleter};
// If type_id is unknown, we are registering a new type and XLA will assign a
// unique type id to it.
// If type_id is unknown, we are registering a new type and XLA will assign
// a unique type id to it.
if (type_id == TypeRegistry::kUnknownTypeId) {
auto assigned_type_id =
TypeRegistry::AssignExternalTypeId(type_name, type_info);

View File

@ -19,10 +19,12 @@ limitations under the License.
#include <cstdint>
#include <string>
#include <variant>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "xla/executable_run_options.h"
#include "xla/ffi/api/api.h"
@ -142,11 +144,11 @@ class ScopedExecutionContext {
//===----------------------------------------------------------------------===//
struct HandlerRegistration {
XLA_FFI_Handler_Bundle bundle = {};
XLA_FFI_Handler_Traits traits = {};
XLA_FFI_Metadata metadata;
XLA_FFI_Handler_Bundle bundle;
};
bool IsCommandBufferCompatible(XLA_FFI_Handler_Traits traits);
bool IsCommandBufferCompatible(const XLA_FFI_Metadata& metadata);
// Returns registered FFI handler for a given name and platform, or an error if
// it's not found in the static registry.
@ -163,6 +165,54 @@ StaticRegisteredHandlers(absl::string_view platform);
const XLA_FFI_Api* GetXlaFfiApi();
//===----------------------------------------------------------------------===//
// Helper functions
//===----------------------------------------------------------------------===//
// Decodes XLA FFI traits packed into a 32-bit integer into a vector of traits.
inline std::vector<Traits> DecodeTraits(XLA_FFI_Handler_Traits traits) {
std::vector<Traits> result;
if (traits & XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE) {
result.push_back(Traits::kCmdBufferCompatible);
}
return result;
}
//===----------------------------------------------------------------------===//
// Pretty printinting for FFI C++ types.
//===----------------------------------------------------------------------===//
template <typename Sink>
static void AbslStringify(Sink& sink, Traits traits) {
switch (traits) {
case Traits::kCmdBufferCompatible:
absl::Format(&sink, "cmd_buffer_compatible");
break;
}
}
} // namespace xla::ffi
//===----------------------------------------------------------------------===//
// Pretty printinting for FFI C types.
//===----------------------------------------------------------------------===//
template <typename Sink>
static void AbslStringify(Sink& sink, const XLA_FFI_TypeId& type_id) {
if (type_id.type_id == XLA_FFI_UNKNOWN_TYPE_ID.type_id) {
absl::Format(&sink, "unknown");
} else {
absl::Format(&sink, "%d", type_id.type_id);
}
}
template <typename Sink>
static void AbslStringify(Sink& sink, const XLA_FFI_Metadata& metadata) {
absl::Format(&sink, "{api_version: %d.%d, traits: [%s], state: %v}",
metadata.api_version.major_version,
metadata.api_version.minor_version,
absl::StrJoin(xla::ffi::DecodeTraits(metadata.traits), ", "),
metadata.state_type_id);
}
#endif // XLA_FFI_FFI_API_H_

View File

@ -105,8 +105,9 @@ TEST(FfiTest, StaticHandlerRegistration) {
TF_ASSERT_OK(handler0.status());
TF_ASSERT_OK(handler1.status());
ASSERT_EQ(handler0->traits, XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE);
ASSERT_EQ(handler1->traits, 0);
ASSERT_EQ(handler0->metadata.traits,
XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE);
ASSERT_EQ(handler1->metadata.traits, 0);
// Check that platform name was canonicalized an we can find handlers
// registered for "Host" platform as "Cpu" handlers.
@ -122,7 +123,8 @@ TEST(FfiTest, RegistrationTraitsBackwardsCompatibility) {
XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE);
auto handler = FindHandler("traits-bwd-compat", "Host");
TF_ASSERT_OK(handler.status());
ASSERT_EQ(handler->traits, XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE);
ASSERT_EQ(handler->metadata.traits,
XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE);
}
// Declare XLA FFI handler as a function (extern "C" declaration).
@ -139,7 +141,7 @@ TEST(FfiTest, StaticHandlerSymbolRegistration) {
auto handler0 = FindHandler("no-op-sym-0", "Cpu");
TF_ASSERT_OK(handler0.status());
ASSERT_EQ(handler0->traits, 0);
ASSERT_EQ(handler0->metadata.traits, 0);
}
TEST(FfiTest, ForwardError) {

View File

@ -258,7 +258,7 @@ static bool IsCommand(const HloCustomCallInstruction* hlo,
// Check if FFI handler is compatible with command buffers.
auto registration = ffi::FindHandler(hlo->custom_call_target(), "gpu");
return registration.ok()
? ffi::IsCommandBufferCompatible(registration->traits)
? ffi::IsCommandBufferCompatible(registration->metadata)
: false;
}