diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h index af4132b7eec..0be6b39fb84 100644 --- a/third_party/xla/xla/ffi/api/api.h +++ b/third_party/xla/xla/ffi/api/api.h @@ -463,9 +463,14 @@ inline XLA_FFI_Error* Ffi::StructSizeIsGreaterOrEqual( // Type tags for distinguishing handler argument types //===----------------------------------------------------------------------===// -// Forward declare. +// Dictionary gives type-safe run time access to all attributes. Concrete +// implementation is provided by the `ffi.h` header. class Dictionary; +// Context gives run time access to the execution context. Concrete +// implementation is provided by the `ffi.h` header. +class Context; + namespace internal { // WARNING: A lot of template metaprogramming on top of C++ variadic templates @@ -500,7 +505,7 @@ struct AttrTag {}; // A type tag to forward all attributes as `Dictionary` (and optionally decode // it into a custom struct). -template +template struct AttrsTag {}; // A type tag to distinguish parameter extracted from an execution context. @@ -655,6 +660,10 @@ class Binding { return {std::move(*this)}; } + Binding> Ctx() && { + return {std::move(*this)}; + } + template Binding> Attr(std::string attr) && { static_assert(internal::NumTagged::value == 0, @@ -1402,6 +1411,34 @@ struct internal::Decode> { } }; +//===----------------------------------------------------------------------===// +// Type-safe wrapper for accessing context. +//===----------------------------------------------------------------------===// + +namespace internal { + +class ContextBase { + public: + ContextBase(const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx) + : api_(api), ctx_(ctx) {} + + const XLA_FFI_Api* api() const { return api_; } + XLA_FFI_ExecutionContext* ctx() const { return ctx_; } + + protected: + template + std::optional::Type> get( + DiagnosticEngine& diagnostic) const { + return CtxDecoding::Decode(api_, ctx_, diagnostic); + } + + private: + const XLA_FFI_Api* api_; + XLA_FFI_ExecutionContext* ctx_; +}; + +} // namespace internal + //===----------------------------------------------------------------------===// // Template metaprogramming for decoding handler signature //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index 4687b2354fa..183645788d7 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -1082,8 +1082,8 @@ class Dictionary : public internal::DictionaryBase { template ErrorOr get(std::string_view name) const { DiagnosticEngine diagnostic; - std::optional value = internal::DictionaryBase::get(name, diagnostic); - if (!value.has_value()) { + auto value = internal::DictionaryBase::get(name, diagnostic); + if (XLA_FFI_PREDICT_FALSE(!value.has_value())) { return Unexpected(Error::Internal(diagnostic.Result())); } return *value; @@ -1114,6 +1114,38 @@ struct AttrDecoding { } }; +//===----------------------------------------------------------------------===// +// Type-safe wrapper for accessing context. +//===----------------------------------------------------------------------===// + +class Context : public internal::ContextBase { + public: + using internal::ContextBase::ContextBase; + + template + ErrorOr::Type> get() const { + DiagnosticEngine diagnostic; + auto value = internal::ContextBase::get(diagnostic); + if (XLA_FFI_PREDICT_FALSE(!value.has_value())) { + return Unexpected(Error::Internal(diagnostic.Result())); + } + return *value; + } +}; + +// Context decoding for catch-all `Context` type. +template <> +struct CtxDecoding { + using Type = Context; + + XLA_FFI_ATTRIBUTE_ALWAYS_INLINE + static std::optional Decode(const XLA_FFI_Api* api, + XLA_FFI_ExecutionContext* ctx, + DiagnosticEngine&) { + return Context(api, ctx); + } +}; + //===----------------------------------------------------------------------===// // Error helpers //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc index 94f03ee763c..3f7077c2158 100644 --- a/third_party/xla/xla/ffi/api/ffi_test.cc +++ b/third_party/xla/xla/ffi/api/ffi_test.cc @@ -465,13 +465,39 @@ TEST(FfiTest, RunId) { TF_ASSERT_OK(status); } +TEST(FfiTest, RunIdViaContext) { + CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0); + auto call_frame = builder.Build(); + + auto handler = Ffi::Bind().Ctx().To([&](Context ctx) { + ErrorOr run_id = ctx.get(); + EXPECT_TRUE(run_id.has_value()); + EXPECT_EQ(run_id->run_id, 42); + return Error::Success(); + }); + + CallOptions options; + options.run_id = xla::RunId{42}; + + auto status = Call(*handler, call_frame, options); + + TF_ASSERT_OK(status); +} + TEST(FfiTest, DeviceOrdinal) { CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0); auto call_frame = builder.Build(); - auto handler = - Ffi::Bind().Ctx().To([&](int32_t device_ordinal) { + auto handler = Ffi::Bind().Ctx().Ctx().To( + [&](int32_t device_ordinal, Context ctx) { + // Get device ordinal from the argument. EXPECT_EQ(device_ordinal, 42); + + // Get device ordinal from the context. + ErrorOr device_ordinal_or_error = ctx.get(); + EXPECT_TRUE(device_ordinal_or_error.has_value()); + EXPECT_EQ(*device_ordinal_or_error, 42); + return Error::Success(); }); diff --git a/third_party/xla/xla/ffi/ffi.h b/third_party/xla/xla/ffi/ffi.h index 39eab7724e7..874fdfb8762 100644 --- a/third_party/xla/xla/ffi/ffi.h +++ b/third_party/xla/xla/ffi/ffi.h @@ -428,7 +428,7 @@ struct AttrDecoding { static std::optional Decode(XLA_FFI_AttrType type, void* attr, DiagnosticEngine& diagnostic) { - if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_STRING)) { + if (ABSL_PREDICT_FALSE(type != XLA_FFI_AttrType_STRING)) { return diagnostic.Emit("Wrong attribute type: expected ") << XLA_FFI_AttrType_STRING << " but got " << type; } @@ -472,8 +472,8 @@ class Dictionary : public internal::DictionaryBase { template absl::StatusOr get(absl::string_view name) const { DiagnosticEngine diagnostic; - std::optional value = internal::DictionaryBase::get(name, diagnostic); - if (!value.has_value()) { + auto value = internal::DictionaryBase::get(name, diagnostic); + if (ABSL_PREDICT_FALSE(!value.has_value())) { return Internal("%s", diagnostic.Result()); } return *value; @@ -496,7 +496,7 @@ struct AttrDecoding { using Type = Dictionary; static std::optional Decode(XLA_FFI_AttrType type, void* attr, DiagnosticEngine& diagnostic) { - if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { + if (ABSL_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { return diagnostic.Emit("Wrong attribute type: expected ") << XLA_FFI_AttrType_DICTIONARY << " but got " << type; } @@ -504,6 +504,38 @@ struct AttrDecoding { } }; +//===----------------------------------------------------------------------===// +// Type-safe wrapper for accessing context. +//===----------------------------------------------------------------------===// + +class Context : public internal::ContextBase { + public: + using internal::ContextBase::ContextBase; + + template + absl::StatusOr::Type> get() const { + DiagnosticEngine diagnostic; + auto value = internal::ContextBase::get(diagnostic); + if (ABSL_PREDICT_FALSE(!value.has_value())) { + return Internal("%s", diagnostic.Result()); + } + return *value; + } +}; + +// Context decoding for catch-all `Context` type. +template <> +struct CtxDecoding { + using Type = Context; + + XLA_FFI_ATTRIBUTE_ALWAYS_INLINE + static std::optional Decode(const XLA_FFI_Api* api, + XLA_FFI_ExecutionContext* ctx, + DiagnosticEngine&) { + return Context(api, ctx); + } +}; + //===----------------------------------------------------------------------===// // Context decoding //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/ffi/ffi_test.cc b/third_party/xla/xla/ffi/ffi_test.cc index 7f18d80e795..7558b1db0a7 100644 --- a/third_party/xla/xla/ffi/ffi_test.cc +++ b/third_party/xla/xla/ffi/ffi_test.cc @@ -213,10 +213,13 @@ TEST(FfiTest, RunId) { CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0); auto call_frame = builder.Build(); - auto handler = Ffi::Bind().Ctx().To([&](RunId run_id) { - EXPECT_EQ(run_id.ToInt(), 42); - return absl::OkStatus(); - }); + auto handler = Ffi::Bind().Ctx().Ctx().To( + [&](RunId run_id, Context context) -> absl::Status { + EXPECT_EQ(run_id.ToInt(), 42); + TF_ASSIGN_OR_RETURN(RunId run_id_from_context, context.get()); + EXPECT_EQ(run_id_from_context.ToInt(), 42); + return absl::OkStatus(); + }); CallOptions options; options.run_id = RunId{42};