[xla:ffi] Add type checking isa<T>() APIs to RemainingArgs, RemainingRets and Dictionary

PiperOrigin-RevId: 823157056
This commit is contained in:
Eugene Zhulenev 2025-10-23 13:03:48 -07:00 committed by TensorFlower Gardener
parent 809d5c7895
commit eb36f8770a
3 changed files with 86 additions and 3 deletions

View File

@ -903,6 +903,7 @@ struct AttrsBinding<Dictionary> {
//
// template <>
// struct ArgDecoding<MyType> {
// static bool Isa(XLA_FFI_ArgType type, void* arg);
// static std::optional<MyType> Decode(XLA_FFI_ArgType type, void* arg);
// };
//
@ -920,6 +921,7 @@ struct ArgDecoding;
//
// template <>
// struct RetDecoding<MyType> {
// static bool Isa(XLA_FFI_RetType type, void* ret);
// static std::optional<MyType> Decode(XLA_FFI_RetType type, void* ret);
// };
//
@ -937,9 +939,10 @@ struct RetDecoding;
//
// template <>
// struct AttrDecoding<MyType> {
// using Type = <handler argument type for attribute type MyType>
// static std::optional<MyType> Decode(XLA_FFI_AttrType type, void* attr,
// DiagnosticEngine&);
// using Type = <handler argument type for attribute type MyType>
// static bool Isa(XLA_FFI_AttrType type, void* attr);
// static std::optional<MyType> Decode(XLA_FFI_AttrType type, void* attr,
// DiagnosticEngine&);
// }
//
template <typename T>
@ -1205,6 +1208,12 @@ class RemainingArgsBase {
assert(offset <= args_->size && "illegal remaining args offset");
}
template <typename T>
bool isa(size_t index) const {
size_t idx = offset() + index;
return ArgDecoding<T>::Isa(args_->types[idx], args_->args[idx]);
}
size_t size() const { return args_->size - offset_; }
bool empty() const { return size() == 0; }
@ -1232,6 +1241,12 @@ class RemainingRetsBase {
assert(offset <= rets_->size && "illegal remaining rets offset");
}
template <typename T>
bool isa(size_t index) const {
size_t idx = offset_ + index;
return RetDecoding<T>::Isa(rets_->types[idx], rets_->rets[idx]);
}
size_t size() const { return rets_->size - offset_; }
bool empty() const { return size() == 0; }
@ -1264,6 +1279,18 @@ class DictionaryBase {
bool contains(std::string_view name) const { return Find(name).has_value(); }
template <typename T>
bool contains(std::string_view name) const {
std::optional<size_t> idx = Find(name);
if (XLA_FFI_PREDICT_FALSE(!idx.has_value())) {
return false;
}
XLA_FFI_AttrType attr_type = attrs_->types[*idx];
void* attr = attrs_->attrs[*idx];
return AttrDecoding<T>::Isa(attr_type, attr);
}
protected:
template <typename T, typename... Ts>
friend struct DecodeDictionaryAttr;
@ -1768,6 +1795,12 @@ class Handler : public Ffi {
template <> \
struct AttrDecoding<T> { \
using Type = T; \
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE static bool Isa(XLA_FFI_AttrType type, \
void* attr) { \
return type == XLA_FFI_AttrType_SCALAR && \
reinterpret_cast<XLA_FFI_Scalar*>(attr)->dtype == TYPE; \
} \
\
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE static std::optional<T> Decode( \
XLA_FFI_AttrType type, void* attr, DiagnosticEngine& diagnostic) { \
if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_SCALAR)) { \

View File

@ -747,6 +747,12 @@ using Token = BufferR0<DataType::TOKEN>; // NOLINT
namespace internal {
template <DataType dtype, size_t rank>
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE bool IsaBuffer(XLA_FFI_Buffer* buf) {
return static_cast<DataType>(buf->dtype) == dtype &&
(rank == internal::kDynamicRank || buf->rank == rank);
}
template <DataType dtype, size_t rank>
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE std::optional<Buffer<dtype, rank>> DecodeBuffer(
XLA_FFI_Buffer* buf, DiagnosticEngine& diagnostic) {
@ -820,6 +826,11 @@ inline std::ostream& operator<<(std::ostream& os, const XLA_FFI_ArgType type) {
template <>
struct ArgDecoding<AnyBuffer> {
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
static bool Isa(XLA_FFI_ArgType type, void* arg) {
return type == XLA_FFI_ArgType_BUFFER;
}
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
static std::optional<AnyBuffer> Decode(XLA_FFI_ArgType type, void* arg,
DiagnosticEngine& diagnostic) {
@ -833,6 +844,13 @@ struct ArgDecoding<AnyBuffer> {
template <DataType dtype, size_t rank>
struct ArgDecoding<Buffer<dtype, rank>> {
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
static bool Isa(XLA_FFI_ArgType type, void* arg) {
return type == XLA_FFI_ArgType_BUFFER &&
internal::IsaBuffer<dtype, rank>(
reinterpret_cast<XLA_FFI_Buffer*>(arg));
}
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
static std::optional<Buffer<dtype, rank>> Decode(
XLA_FFI_ArgType type, void* arg, DiagnosticEngine& diagnostic) {
@ -895,6 +913,11 @@ inline std::ostream& operator<<(std::ostream& os, const XLA_FFI_RetType type) {
template <>
struct RetDecoding<AnyBuffer> {
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
static bool Isa(XLA_FFI_RetType type, void* ret) {
return type == XLA_FFI_RetType_BUFFER;
}
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
static std::optional<Result<AnyBuffer>> Decode(XLA_FFI_RetType type,
void* ret,
@ -909,6 +932,13 @@ struct RetDecoding<AnyBuffer> {
template <DataType dtype, size_t rank>
struct RetDecoding<Buffer<dtype, rank>> {
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
static bool Isa(XLA_FFI_RetType type, void* ret) {
return type == XLA_FFI_RetType_BUFFER &&
internal::IsaBuffer<dtype, rank>(
reinterpret_cast<XLA_FFI_Buffer*>(ret));
}
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
static std::optional<Result<Buffer<dtype, rank>>> Decode(
XLA_FFI_RetType type, void* ret, DiagnosticEngine& diagnostic) {
@ -1001,6 +1031,11 @@ template <>
struct AttrDecoding<std::string_view> {
using Type = std::string_view;
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE static bool Isa(XLA_FFI_AttrType type,
void* attr) {
return type == XLA_FFI_AttrType_STRING;
}
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE static std::optional<std::string_view> Decode(
XLA_FFI_AttrType type, void* attr, DiagnosticEngine& diagnostic) {
if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_STRING)) {

View File

@ -639,6 +639,9 @@ TEST(FfiTest, RemainingArgs) {
auto fn = [&](RemainingArgs args) {
EXPECT_EQ(args.size(), 1);
EXPECT_TRUE(args.isa<AnyBuffer>(0));
EXPECT_FALSE(args.isa<BufferR2<F64>>(0));
ErrorOr<AnyBuffer> arg0 = args.get<AnyBuffer>(0);
ErrorOr<AnyBuffer> arg1 = args.get<AnyBuffer>(1);
@ -666,6 +669,9 @@ TEST(FfiTest, RemainingRets) {
auto fn = [&](Result<AnyBuffer> ret, RemainingRets rets) {
EXPECT_EQ(rets.size(), 1);
EXPECT_TRUE(rets.isa<AnyBuffer>(0));
EXPECT_FALSE(rets.isa<BufferR2<F64>>(0));
ErrorOr<Result<AnyBuffer>> ret0 = rets.get<AnyBuffer>(0);
ErrorOr<Result<AnyBuffer>> ret1 = rets.get<AnyBuffer>(1);
@ -860,8 +866,17 @@ TEST(FfiTest, AutoBindingStructs) {
TEST(FfiTest, AutoBindingDictionary) {
auto handler = Ffi::BindTo(+[](Dictionary attrs) {
EXPECT_TRUE(attrs.contains("i32"));
EXPECT_TRUE(attrs.contains("f32"));
EXPECT_TRUE(attrs.contains<int32_t>("i32"));
EXPECT_TRUE(attrs.contains<float>("f32"));
EXPECT_FALSE(attrs.contains<int64_t>("i32"));
EXPECT_FALSE(attrs.contains<int64_t>("f32"));
EXPECT_EQ(*attrs.get<int32_t>("i32"), 42);
EXPECT_EQ(*attrs.get<float>("f32"), 42.0f);
return Error::Success();
});