diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_conversion_pass.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_conversion_pass.cc index f3b558de34e..f956644fb32 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_conversion_pass.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_conversion_pass.cc @@ -206,7 +206,7 @@ bool IsConvertible(const CustomCallThunk& custom_call_thunk, absl::StatusOr registration = ffi::FindHandler(target_name, "gpu"); return registration.ok() - ? ffi::IsCommandBufferCompatible(registration->traits) + ? ffi::IsCommandBufferCompatible(registration->metadata) : false; } diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h index 964e0b83117..b7ff8a5b29d 100644 --- a/third_party/xla/xla/ffi/api/api.h +++ b/third_party/xla/xla/ffi/api/api.h @@ -1609,13 +1609,17 @@ class Handler : public Ffi { return err; } + // Set the API version to the version of the FFI headers used by a handler. extension->metadata->api_version = XLA_FFI_Api_Version{ XLA_FFI_Api_Version_STRUCT_SIZE, - /*extension_start=*/nullptr, XLA_FFI_API_MAJOR, XLA_FFI_API_MINOR}; + /*extension_start=*/nullptr, + XLA_FFI_API_MAJOR, + XLA_FFI_API_MINOR, + }; // Collect all traits and store them in the metadata. XLA_FFI_Handler_Traits traits = 0; - for (const auto& trait : traits_) { + for (const Traits& trait : traits_) { traits |= static_cast(trait); } extension->metadata->traits = traits; @@ -1737,7 +1741,8 @@ class Handler : public Ffi { // Find index of every attribute in the sorted attributes vector. for (size_t i = 0; i < attrs_.size(); ++i) { attrs_idx_.push_back(std::distance( - sorted.begin(), std::find(sorted.begin(), sorted.end(), attrs_[i]))); + sorted.begin(), + std::find(sorted.begin(), sorted.end(), attrs_[i]))); // NOLINT } } diff --git a/third_party/xla/xla/ffi/ffi_api.cc b/third_party/xla/xla/ffi/ffi_api.cc index 729a7c51521..543a059db95 100644 --- a/third_party/xla/xla/ffi/ffi_api.cc +++ b/third_party/xla/xla/ffi/ffi_api.cc @@ -116,8 +116,8 @@ static bool IsSupportedApiVersion(const XLA_FFI_Api_Version& api_version) { version <= kMaxSupportedApiVersion; } -bool IsCommandBufferCompatible(XLA_FFI_Handler_Traits traits) { - return traits & XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE; +bool IsCommandBufferCompatible(const XLA_FFI_Metadata& metadata) { + return metadata.traits & XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE; } static XLA_FFI_ExecutionContext CreateExecutionContext( @@ -175,16 +175,16 @@ tsl::AsyncValueRef TakeFuture(XLA_FFI_Future* future) { return chain->AsRef(); } - // If the future is already completed, immediately return the underlying async - // value and delete the XLA_FFI_Future. + // If the future is already completed, immediately return the underlying + // async value and delete the XLA_FFI_Future. if (ABSL_PREDICT_TRUE(future->async_value.IsAvailable())) { tsl::AsyncValueRef async_value = std::move(future->async_value); delete future; return async_value; } - // If the future is not completed, return a copy of the underlying async value - // and keep XLA_FFI_Future alive until it is completed. + // If the future is not completed, return a copy of the underlying async + // value and keep XLA_FFI_Future alive until it is completed. tsl::AsyncValueRef async_value = future->async_value; async_value.AndThen([future] { delete future; }); return async_value; @@ -201,8 +201,8 @@ static absl::StatusOr Call(Handler& handler, XLA_FFI_Error* error = nullptr; - // FFI handlers might be defined in external libraries and use exceptions, so - // take extra care to catch them and convert to a status. + // FFI handlers might be defined in external libraries and use exceptions, + // so take extra care to catch them and convert to a status. try { if constexpr (std::is_same_v) { error = handler.Call(&ffi_call_frame); @@ -386,6 +386,20 @@ static std::vector GetHandlerStages( return stages; } +static bool CheckMetadata(const XLA_FFI_Metadata& a, + const XLA_FFI_Metadata& b) { + return a.api_version.major_version == b.api_version.major_version && + a.api_version.minor_version == b.api_version.minor_version && + a.traits == b.traits && + a.state_type_id.type_id == b.state_type_id.type_id; +} + +static bool CheckHandlerBundle(const XLA_FFI_Handler_Bundle& a, + const XLA_FFI_Handler_Bundle& b) { + return a.instantiate == b.instantiate && a.prepare == b.prepare && + a.initialize == b.initialize && a.execute == b.execute; +} + static absl::Status RegisterHandler(absl::string_view name, absl::string_view platform, XLA_FFI_Handler_Bundle bundle, @@ -405,8 +419,10 @@ static absl::Status RegisterHandler(absl::string_view name, if (!IsSupportedApiVersion(metadata.api_version)) { return InvalidArgument( "XLA FFI handler registration for %s on platform %s (canonical %s) " - "failed because the handler's API version (%d.%d) is incompatible with " - "the framework's API version (%d.%d). Minimum supported API version is " + "failed because the handler's API version (%d.%d) is incompatible " + "with " + "the framework's API version (%d.%d). Minimum supported API version " + "is " "(%d.%d).", name, platform, canonical_platform, metadata.api_version.major_version, metadata.api_version.minor_version, kMaxSupportedApiVersion.first, @@ -414,36 +430,46 @@ static absl::Status RegisterHandler(absl::string_view name, kMinSupportedApiVersion.second); } - // Incorporate handler traits. - traits |= metadata.traits; + // Incorporate handler traits passed explicitly via handler registration + // API. + metadata.traits |= traits; + + // Incorporate state type id from the instantiate implementation if present. + if (bundle.instantiate) { + TF_ASSIGN_OR_RETURN(XLA_FFI_Metadata instantiate_metadata, + GetMetadata(bundle.instantiate)); + metadata.state_type_id = instantiate_metadata.state_type_id; + } VLOG(2) << absl::StreamFormat( "Register XLA FFI handler for '%s'; platform=%s (canonical=%s), " - "stages=[%s], command_buffer_compatible=%v", + "stages=[%s], metadata=%v", name, platform, canonical_platform, - absl::StrJoin(GetHandlerStages(bundle), ", "), - IsCommandBufferCompatible(traits)); + absl::StrJoin(GetHandlerStages(bundle), ", "), metadata); - auto emplaced = - GetHandlerRegistry().try_emplace(MakeHandlerKey(name, canonical_platform), - HandlerRegistration{bundle, traits}); - if (!emplaced.second) { - auto existing = emplaced.first->second; - if (existing.traits != traits) { + HandlerRegistration registration{metadata, bundle}; + auto [it, emplaced] = GetHandlerRegistry().try_emplace( + MakeHandlerKey(name, canonical_platform), registration); + + // We might accidentally link the same FFI library multiple times (because + // linking shared libraries is hard), and we choose to ignore this problem as + // long as we register exactly the same handler. + if (!emplaced) { + const HandlerRegistration& existing = it->second; + if (!CheckMetadata(existing.metadata, metadata)) { return InvalidArgument( "Duplicate FFI handler registration for %s on platform %s " - "(canonical %s) with different traits", - name, platform, canonical_platform); + "(canonical %s) with different metadata: %v vs %v", + name, platform, canonical_platform, existing.metadata, metadata); } - if (existing.bundle.prepare != bundle.prepare || - existing.bundle.initialize != bundle.initialize || - existing.bundle.execute != bundle.execute) { + if (!CheckHandlerBundle(existing.bundle, bundle)) { return InvalidArgument( "Duplicate FFI handler registration for %s on platform %s " "(canonical %s) with different bundle addresses", name, platform, canonical_platform); } } + return absl::OkStatus(); } @@ -681,8 +707,8 @@ static XLA_FFI_Error* XLA_FFI_Type_Register(XLA_FFI_Type_Register_Args* args) { TypeRegistry::TypeId type_id(args->type_id->type_id); TypeRegistry::TypeInfo type_info = {args->type_info->deleter}; - // If type_id is unknown, we are registering a new type and XLA will assign a - // unique type id to it. + // If type_id is unknown, we are registering a new type and XLA will assign + // a unique type id to it. if (type_id == TypeRegistry::kUnknownTypeId) { auto assigned_type_id = TypeRegistry::AssignExternalTypeId(type_name, type_info); diff --git a/third_party/xla/xla/ffi/ffi_api.h b/third_party/xla/xla/ffi/ffi_api.h index ef9f65f4bd0..86c403c10f5 100644 --- a/third_party/xla/xla/ffi/ffi_api.h +++ b/third_party/xla/xla/ffi/ffi_api.h @@ -19,10 +19,12 @@ limitations under the License. #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "xla/executable_run_options.h" #include "xla/ffi/api/api.h" @@ -142,11 +144,11 @@ class ScopedExecutionContext { //===----------------------------------------------------------------------===// struct HandlerRegistration { - XLA_FFI_Handler_Bundle bundle = {}; - XLA_FFI_Handler_Traits traits = {}; + XLA_FFI_Metadata metadata; + XLA_FFI_Handler_Bundle bundle; }; -bool IsCommandBufferCompatible(XLA_FFI_Handler_Traits traits); +bool IsCommandBufferCompatible(const XLA_FFI_Metadata& metadata); // Returns registered FFI handler for a given name and platform, or an error if // it's not found in the static registry. @@ -163,6 +165,54 @@ StaticRegisteredHandlers(absl::string_view platform); const XLA_FFI_Api* GetXlaFfiApi(); +//===----------------------------------------------------------------------===// +// Helper functions +//===----------------------------------------------------------------------===// + +// Decodes XLA FFI traits packed into a 32-bit integer into a vector of traits. +inline std::vector DecodeTraits(XLA_FFI_Handler_Traits traits) { + std::vector result; + if (traits & XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE) { + result.push_back(Traits::kCmdBufferCompatible); + } + return result; +} + +//===----------------------------------------------------------------------===// +// Pretty printinting for FFI C++ types. +//===----------------------------------------------------------------------===// + +template +static void AbslStringify(Sink& sink, Traits traits) { + switch (traits) { + case Traits::kCmdBufferCompatible: + absl::Format(&sink, "cmd_buffer_compatible"); + break; + } +} + } // namespace xla::ffi +//===----------------------------------------------------------------------===// +// Pretty printinting for FFI C types. +//===----------------------------------------------------------------------===// + +template +static void AbslStringify(Sink& sink, const XLA_FFI_TypeId& type_id) { + if (type_id.type_id == XLA_FFI_UNKNOWN_TYPE_ID.type_id) { + absl::Format(&sink, "unknown"); + } else { + absl::Format(&sink, "%d", type_id.type_id); + } +} + +template +static void AbslStringify(Sink& sink, const XLA_FFI_Metadata& metadata) { + absl::Format(&sink, "{api_version: %d.%d, traits: [%s], state: %v}", + metadata.api_version.major_version, + metadata.api_version.minor_version, + absl::StrJoin(xla::ffi::DecodeTraits(metadata.traits), ", "), + metadata.state_type_id); +} + #endif // XLA_FFI_FFI_API_H_ diff --git a/third_party/xla/xla/ffi/ffi_test.cc b/third_party/xla/xla/ffi/ffi_test.cc index 47b3a85fa1a..7f18d80e795 100644 --- a/third_party/xla/xla/ffi/ffi_test.cc +++ b/third_party/xla/xla/ffi/ffi_test.cc @@ -105,8 +105,9 @@ TEST(FfiTest, StaticHandlerRegistration) { TF_ASSERT_OK(handler0.status()); TF_ASSERT_OK(handler1.status()); - ASSERT_EQ(handler0->traits, XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE); - ASSERT_EQ(handler1->traits, 0); + ASSERT_EQ(handler0->metadata.traits, + XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE); + ASSERT_EQ(handler1->metadata.traits, 0); // Check that platform name was canonicalized an we can find handlers // registered for "Host" platform as "Cpu" handlers. @@ -122,7 +123,8 @@ TEST(FfiTest, RegistrationTraitsBackwardsCompatibility) { XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE); auto handler = FindHandler("traits-bwd-compat", "Host"); TF_ASSERT_OK(handler.status()); - ASSERT_EQ(handler->traits, XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE); + ASSERT_EQ(handler->metadata.traits, + XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE); } // Declare XLA FFI handler as a function (extern "C" declaration). @@ -139,7 +141,7 @@ TEST(FfiTest, StaticHandlerSymbolRegistration) { auto handler0 = FindHandler("no-op-sym-0", "Cpu"); TF_ASSERT_OK(handler0.status()); - ASSERT_EQ(handler0->traits, 0); + ASSERT_EQ(handler0->metadata.traits, 0); } TEST(FfiTest, ForwardError) { diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc index 96fa801f5b3..606f345abd0 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc @@ -258,7 +258,7 @@ static bool IsCommand(const HloCustomCallInstruction* hlo, // Check if FFI handler is compatible with command buffers. auto registration = ffi::FindHandler(hlo->custom_call_target(), "gpu"); return registration.ok() - ? ffi::IsCommandBufferCompatible(registration->traits) + ? ffi::IsCommandBufferCompatible(registration->metadata) : false; }