mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[xla:ffi] Add support for binding Context object to the handler
PiperOrigin-RevId: 824278531
This commit is contained in:
parent
2de4be94aa
commit
3630944d0f
41
third_party/xla/xla/ffi/api/api.h
vendored
41
third_party/xla/xla/ffi/api/api.h
vendored
|
|
@ -463,9 +463,14 @@ inline XLA_FFI_Error* Ffi::StructSizeIsGreaterOrEqual(
|
|||
// Type tags for distinguishing handler argument types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Forward declare.
|
||||
// Dictionary gives type-safe run time access to all attributes. Concrete
|
||||
// implementation is provided by the `ffi.h` header.
|
||||
class Dictionary;
|
||||
|
||||
// Context gives run time access to the execution context. Concrete
|
||||
// implementation is provided by the `ffi.h` header.
|
||||
class Context;
|
||||
|
||||
namespace internal {
|
||||
|
||||
// WARNING: A lot of template metaprogramming on top of C++ variadic templates
|
||||
|
|
@ -500,7 +505,7 @@ struct AttrTag {};
|
|||
|
||||
// A type tag to forward all attributes as `Dictionary` (and optionally decode
|
||||
// it into a custom struct).
|
||||
template <typename T = Dictionary>
|
||||
template <typename T>
|
||||
struct AttrsTag {};
|
||||
|
||||
// A type tag to distinguish parameter extracted from an execution context.
|
||||
|
|
@ -655,6 +660,10 @@ class Binding {
|
|||
return {std::move(*this)};
|
||||
}
|
||||
|
||||
Binding<stage, Ts..., internal::CtxTag<Context>> Ctx() && {
|
||||
return {std::move(*this)};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Binding<stage, Ts..., internal::AttrTag<T>> Attr(std::string attr) && {
|
||||
static_assert(internal::NumTagged<internal::AttrsTag, Ts...>::value == 0,
|
||||
|
|
@ -1402,6 +1411,34 @@ struct internal::Decode<internal::AttrsTag<T>> {
|
|||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type-safe wrapper for accessing context.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace internal {
|
||||
|
||||
class ContextBase {
|
||||
public:
|
||||
ContextBase(const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx)
|
||||
: api_(api), ctx_(ctx) {}
|
||||
|
||||
const XLA_FFI_Api* api() const { return api_; }
|
||||
XLA_FFI_ExecutionContext* ctx() const { return ctx_; }
|
||||
|
||||
protected:
|
||||
template <typename T>
|
||||
std::optional<typename CtxDecoding<T>::Type> get(
|
||||
DiagnosticEngine& diagnostic) const {
|
||||
return CtxDecoding<T>::Decode(api_, ctx_, diagnostic);
|
||||
}
|
||||
|
||||
private:
|
||||
const XLA_FFI_Api* api_;
|
||||
XLA_FFI_ExecutionContext* ctx_;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Template metaprogramming for decoding handler signature
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
|||
36
third_party/xla/xla/ffi/api/ffi.h
vendored
36
third_party/xla/xla/ffi/api/ffi.h
vendored
|
|
@ -1082,8 +1082,8 @@ class Dictionary : public internal::DictionaryBase {
|
|||
template <typename T>
|
||||
ErrorOr<T> get(std::string_view name) const {
|
||||
DiagnosticEngine diagnostic;
|
||||
std::optional<T> value = internal::DictionaryBase::get<T>(name, diagnostic);
|
||||
if (!value.has_value()) {
|
||||
auto value = internal::DictionaryBase::get<T>(name, diagnostic);
|
||||
if (XLA_FFI_PREDICT_FALSE(!value.has_value())) {
|
||||
return Unexpected(Error::Internal(diagnostic.Result()));
|
||||
}
|
||||
return *value;
|
||||
|
|
@ -1114,6 +1114,38 @@ struct AttrDecoding<Dictionary> {
|
|||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type-safe wrapper for accessing context.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class Context : public internal::ContextBase {
|
||||
public:
|
||||
using internal::ContextBase::ContextBase;
|
||||
|
||||
template <typename T>
|
||||
ErrorOr<typename CtxDecoding<T>::Type> get() const {
|
||||
DiagnosticEngine diagnostic;
|
||||
auto value = internal::ContextBase::get<T>(diagnostic);
|
||||
if (XLA_FFI_PREDICT_FALSE(!value.has_value())) {
|
||||
return Unexpected(Error::Internal(diagnostic.Result()));
|
||||
}
|
||||
return *value;
|
||||
}
|
||||
};
|
||||
|
||||
// Context decoding for catch-all `Context` type.
|
||||
template <>
|
||||
struct CtxDecoding<Context> {
|
||||
using Type = Context;
|
||||
|
||||
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
|
||||
static std::optional<Context> Decode(const XLA_FFI_Api* api,
|
||||
XLA_FFI_ExecutionContext* ctx,
|
||||
DiagnosticEngine&) {
|
||||
return Context(api, ctx);
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Error helpers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
|||
30
third_party/xla/xla/ffi/api/ffi_test.cc
vendored
30
third_party/xla/xla/ffi/api/ffi_test.cc
vendored
|
|
@ -465,13 +465,39 @@ TEST(FfiTest, RunId) {
|
|||
TF_ASSERT_OK(status);
|
||||
}
|
||||
|
||||
TEST(FfiTest, RunIdViaContext) {
|
||||
CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
|
||||
auto call_frame = builder.Build();
|
||||
|
||||
auto handler = Ffi::Bind().Ctx().To([&](Context ctx) {
|
||||
ErrorOr<RunId> run_id = ctx.get<RunId>();
|
||||
EXPECT_TRUE(run_id.has_value());
|
||||
EXPECT_EQ(run_id->run_id, 42);
|
||||
return Error::Success();
|
||||
});
|
||||
|
||||
CallOptions options;
|
||||
options.run_id = xla::RunId{42};
|
||||
|
||||
auto status = Call(*handler, call_frame, options);
|
||||
|
||||
TF_ASSERT_OK(status);
|
||||
}
|
||||
|
||||
TEST(FfiTest, DeviceOrdinal) {
|
||||
CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
|
||||
auto call_frame = builder.Build();
|
||||
|
||||
auto handler =
|
||||
Ffi::Bind().Ctx<DeviceOrdinal>().To([&](int32_t device_ordinal) {
|
||||
auto handler = Ffi::Bind().Ctx<DeviceOrdinal>().Ctx().To(
|
||||
[&](int32_t device_ordinal, Context ctx) {
|
||||
// Get device ordinal from the argument.
|
||||
EXPECT_EQ(device_ordinal, 42);
|
||||
|
||||
// Get device ordinal from the context.
|
||||
ErrorOr<int32_t> device_ordinal_or_error = ctx.get<DeviceOrdinal>();
|
||||
EXPECT_TRUE(device_ordinal_or_error.has_value());
|
||||
EXPECT_EQ(*device_ordinal_or_error, 42);
|
||||
|
||||
return Error::Success();
|
||||
});
|
||||
|
||||
|
|
|
|||
40
third_party/xla/xla/ffi/ffi.h
vendored
40
third_party/xla/xla/ffi/ffi.h
vendored
|
|
@ -428,7 +428,7 @@ struct AttrDecoding<absl::string_view> {
|
|||
static std::optional<absl::string_view> Decode(XLA_FFI_AttrType type,
|
||||
void* attr,
|
||||
DiagnosticEngine& diagnostic) {
|
||||
if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_STRING)) {
|
||||
if (ABSL_PREDICT_FALSE(type != XLA_FFI_AttrType_STRING)) {
|
||||
return diagnostic.Emit("Wrong attribute type: expected ")
|
||||
<< XLA_FFI_AttrType_STRING << " but got " << type;
|
||||
}
|
||||
|
|
@ -472,8 +472,8 @@ class Dictionary : public internal::DictionaryBase {
|
|||
template <typename T>
|
||||
absl::StatusOr<T> get(absl::string_view name) const {
|
||||
DiagnosticEngine diagnostic;
|
||||
std::optional<T> value = internal::DictionaryBase::get<T>(name, diagnostic);
|
||||
if (!value.has_value()) {
|
||||
auto value = internal::DictionaryBase::get<T>(name, diagnostic);
|
||||
if (ABSL_PREDICT_FALSE(!value.has_value())) {
|
||||
return Internal("%s", diagnostic.Result());
|
||||
}
|
||||
return *value;
|
||||
|
|
@ -496,7 +496,7 @@ struct AttrDecoding<Dictionary> {
|
|||
using Type = Dictionary;
|
||||
static std::optional<Dictionary> Decode(XLA_FFI_AttrType type, void* attr,
|
||||
DiagnosticEngine& diagnostic) {
|
||||
if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) {
|
||||
if (ABSL_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) {
|
||||
return diagnostic.Emit("Wrong attribute type: expected ")
|
||||
<< XLA_FFI_AttrType_DICTIONARY << " but got " << type;
|
||||
}
|
||||
|
|
@ -504,6 +504,38 @@ struct AttrDecoding<Dictionary> {
|
|||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type-safe wrapper for accessing context.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class Context : public internal::ContextBase {
|
||||
public:
|
||||
using internal::ContextBase::ContextBase;
|
||||
|
||||
template <typename T>
|
||||
absl::StatusOr<typename CtxDecoding<T>::Type> get() const {
|
||||
DiagnosticEngine diagnostic;
|
||||
auto value = internal::ContextBase::get<T>(diagnostic);
|
||||
if (ABSL_PREDICT_FALSE(!value.has_value())) {
|
||||
return Internal("%s", diagnostic.Result());
|
||||
}
|
||||
return *value;
|
||||
}
|
||||
};
|
||||
|
||||
// Context decoding for catch-all `Context` type.
|
||||
template <>
|
||||
struct CtxDecoding<Context> {
|
||||
using Type = Context;
|
||||
|
||||
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
|
||||
static std::optional<Context> Decode(const XLA_FFI_Api* api,
|
||||
XLA_FFI_ExecutionContext* ctx,
|
||||
DiagnosticEngine&) {
|
||||
return Context(api, ctx);
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Context decoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
|||
5
third_party/xla/xla/ffi/ffi_test.cc
vendored
5
third_party/xla/xla/ffi/ffi_test.cc
vendored
|
|
@ -213,8 +213,11 @@ TEST(FfiTest, RunId) {
|
|||
CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
|
||||
auto call_frame = builder.Build();
|
||||
|
||||
auto handler = Ffi::Bind().Ctx<RunId>().To([&](RunId run_id) {
|
||||
auto handler = Ffi::Bind().Ctx<RunId>().Ctx().To(
|
||||
[&](RunId run_id, Context context) -> absl::Status {
|
||||
EXPECT_EQ(run_id.ToInt(), 42);
|
||||
TF_ASSIGN_OR_RETURN(RunId run_id_from_context, context.get<RunId>());
|
||||
EXPECT_EQ(run_id_from_context.ToInt(), 42);
|
||||
return absl::OkStatus();
|
||||
});
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user