[xla:ffi] Add support for binding Context object to the handler

PiperOrigin-RevId: 824278531
This commit is contained in:
Eugene Zhulenev 2025-10-26 16:38:14 -07:00 committed by TensorFlower Gardener
parent 2de4be94aa
commit 3630944d0f
5 changed files with 144 additions and 14 deletions

View File

@ -463,9 +463,14 @@ inline XLA_FFI_Error* Ffi::StructSizeIsGreaterOrEqual(
// Type tags for distinguishing handler argument types
//===----------------------------------------------------------------------===//
// Forward declare.
// Dictionary gives type-safe run time access to all attributes. Concrete
// implementation is provided by the `ffi.h` header.
class Dictionary;
// Context gives run time access to the execution context. Concrete
// implementation is provided by the `ffi.h` header.
class Context;
namespace internal {
// WARNING: A lot of template metaprogramming on top of C++ variadic templates
@ -500,7 +505,7 @@ struct AttrTag {};
// A type tag to forward all attributes as `Dictionary` (and optionally decode
// it into a custom struct).
template <typename T = Dictionary>
template <typename T>
struct AttrsTag {};
// A type tag to distinguish parameter extracted from an execution context.
@ -655,6 +660,10 @@ class Binding {
return {std::move(*this)};
}
Binding<stage, Ts..., internal::CtxTag<Context>> Ctx() && {
return {std::move(*this)};
}
template <typename T>
Binding<stage, Ts..., internal::AttrTag<T>> Attr(std::string attr) && {
static_assert(internal::NumTagged<internal::AttrsTag, Ts...>::value == 0,
@ -1402,6 +1411,34 @@ struct internal::Decode<internal::AttrsTag<T>> {
}
};
//===----------------------------------------------------------------------===//
// Type-safe wrapper for accessing context.
//===----------------------------------------------------------------------===//
namespace internal {
class ContextBase {
public:
ContextBase(const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx)
: api_(api), ctx_(ctx) {}
const XLA_FFI_Api* api() const { return api_; }
XLA_FFI_ExecutionContext* ctx() const { return ctx_; }
protected:
template <typename T>
std::optional<typename CtxDecoding<T>::Type> get(
DiagnosticEngine& diagnostic) const {
return CtxDecoding<T>::Decode(api_, ctx_, diagnostic);
}
private:
const XLA_FFI_Api* api_;
XLA_FFI_ExecutionContext* ctx_;
};
} // namespace internal
//===----------------------------------------------------------------------===//
// Template metaprogramming for decoding handler signature
//===----------------------------------------------------------------------===//

View File

@ -1082,8 +1082,8 @@ class Dictionary : public internal::DictionaryBase {
template <typename T>
ErrorOr<T> get(std::string_view name) const {
DiagnosticEngine diagnostic;
std::optional<T> value = internal::DictionaryBase::get<T>(name, diagnostic);
if (!value.has_value()) {
auto value = internal::DictionaryBase::get<T>(name, diagnostic);
if (XLA_FFI_PREDICT_FALSE(!value.has_value())) {
return Unexpected(Error::Internal(diagnostic.Result()));
}
return *value;
@ -1114,6 +1114,38 @@ struct AttrDecoding<Dictionary> {
}
};
//===----------------------------------------------------------------------===//
// Type-safe wrapper for accessing context.
//===----------------------------------------------------------------------===//
class Context : public internal::ContextBase {
public:
using internal::ContextBase::ContextBase;
template <typename T>
ErrorOr<typename CtxDecoding<T>::Type> get() const {
DiagnosticEngine diagnostic;
auto value = internal::ContextBase::get<T>(diagnostic);
if (XLA_FFI_PREDICT_FALSE(!value.has_value())) {
return Unexpected(Error::Internal(diagnostic.Result()));
}
return *value;
}
};
// Context decoding for catch-all `Context` type.
template <>
struct CtxDecoding<Context> {
using Type = Context;
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
static std::optional<Context> Decode(const XLA_FFI_Api* api,
XLA_FFI_ExecutionContext* ctx,
DiagnosticEngine&) {
return Context(api, ctx);
}
};
//===----------------------------------------------------------------------===//
// Error helpers
//===----------------------------------------------------------------------===//

View File

@ -465,13 +465,39 @@ TEST(FfiTest, RunId) {
TF_ASSERT_OK(status);
}
TEST(FfiTest, RunIdViaContext) {
CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
auto call_frame = builder.Build();
auto handler = Ffi::Bind().Ctx().To([&](Context ctx) {
ErrorOr<RunId> run_id = ctx.get<RunId>();
EXPECT_TRUE(run_id.has_value());
EXPECT_EQ(run_id->run_id, 42);
return Error::Success();
});
CallOptions options;
options.run_id = xla::RunId{42};
auto status = Call(*handler, call_frame, options);
TF_ASSERT_OK(status);
}
TEST(FfiTest, DeviceOrdinal) {
CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
auto call_frame = builder.Build();
auto handler =
Ffi::Bind().Ctx<DeviceOrdinal>().To([&](int32_t device_ordinal) {
auto handler = Ffi::Bind().Ctx<DeviceOrdinal>().Ctx().To(
[&](int32_t device_ordinal, Context ctx) {
// Get device ordinal from the argument.
EXPECT_EQ(device_ordinal, 42);
// Get device ordinal from the context.
ErrorOr<int32_t> device_ordinal_or_error = ctx.get<DeviceOrdinal>();
EXPECT_TRUE(device_ordinal_or_error.has_value());
EXPECT_EQ(*device_ordinal_or_error, 42);
return Error::Success();
});

View File

@ -428,7 +428,7 @@ struct AttrDecoding<absl::string_view> {
static std::optional<absl::string_view> Decode(XLA_FFI_AttrType type,
void* attr,
DiagnosticEngine& diagnostic) {
if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_STRING)) {
if (ABSL_PREDICT_FALSE(type != XLA_FFI_AttrType_STRING)) {
return diagnostic.Emit("Wrong attribute type: expected ")
<< XLA_FFI_AttrType_STRING << " but got " << type;
}
@ -472,8 +472,8 @@ class Dictionary : public internal::DictionaryBase {
template <typename T>
absl::StatusOr<T> get(absl::string_view name) const {
DiagnosticEngine diagnostic;
std::optional<T> value = internal::DictionaryBase::get<T>(name, diagnostic);
if (!value.has_value()) {
auto value = internal::DictionaryBase::get<T>(name, diagnostic);
if (ABSL_PREDICT_FALSE(!value.has_value())) {
return Internal("%s", diagnostic.Result());
}
return *value;
@ -496,7 +496,7 @@ struct AttrDecoding<Dictionary> {
using Type = Dictionary;
static std::optional<Dictionary> Decode(XLA_FFI_AttrType type, void* attr,
DiagnosticEngine& diagnostic) {
if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) {
if (ABSL_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) {
return diagnostic.Emit("Wrong attribute type: expected ")
<< XLA_FFI_AttrType_DICTIONARY << " but got " << type;
}
@ -504,6 +504,38 @@ struct AttrDecoding<Dictionary> {
}
};
//===----------------------------------------------------------------------===//
// Type-safe wrapper for accessing context.
//===----------------------------------------------------------------------===//
class Context : public internal::ContextBase {
public:
using internal::ContextBase::ContextBase;
template <typename T>
absl::StatusOr<typename CtxDecoding<T>::Type> get() const {
DiagnosticEngine diagnostic;
auto value = internal::ContextBase::get<T>(diagnostic);
if (ABSL_PREDICT_FALSE(!value.has_value())) {
return Internal("%s", diagnostic.Result());
}
return *value;
}
};
// Context decoding for catch-all `Context` type.
template <>
struct CtxDecoding<Context> {
using Type = Context;
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
static std::optional<Context> Decode(const XLA_FFI_Api* api,
XLA_FFI_ExecutionContext* ctx,
DiagnosticEngine&) {
return Context(api, ctx);
}
};
//===----------------------------------------------------------------------===//
// Context decoding
//===----------------------------------------------------------------------===//

View File

@ -213,10 +213,13 @@ TEST(FfiTest, RunId) {
CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
auto call_frame = builder.Build();
auto handler = Ffi::Bind().Ctx<RunId>().To([&](RunId run_id) {
EXPECT_EQ(run_id.ToInt(), 42);
return absl::OkStatus();
});
auto handler = Ffi::Bind().Ctx<RunId>().Ctx().To(
[&](RunId run_id, Context context) -> absl::Status {
EXPECT_EQ(run_id.ToInt(), 42);
TF_ASSIGN_OR_RETURN(RunId run_id_from_context, context.get<RunId>());
EXPECT_EQ(run_id_from_context.ToInt(), 42);
return absl::OkStatus();
});
CallOptions options;
options.run_id = RunId{42};