mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[xla:ffi] Add type checking isa<T>() APIs to RemainingArgs, RemainingRets and Dictionary
PiperOrigin-RevId: 823157056
This commit is contained in:
parent
809d5c7895
commit
eb36f8770a
39
third_party/xla/xla/ffi/api/api.h
vendored
39
third_party/xla/xla/ffi/api/api.h
vendored
|
|
@ -903,6 +903,7 @@ struct AttrsBinding<Dictionary> {
|
|||
//
|
||||
// template <>
|
||||
// struct ArgDecoding<MyType> {
|
||||
// static bool Isa(XLA_FFI_ArgType type, void* arg);
|
||||
// static std::optional<MyType> Decode(XLA_FFI_ArgType type, void* arg);
|
||||
// };
|
||||
//
|
||||
|
|
@ -920,6 +921,7 @@ struct ArgDecoding;
|
|||
//
|
||||
// template <>
|
||||
// struct RetDecoding<MyType> {
|
||||
// static bool Isa(XLA_FFI_RetType type, void* ret);
|
||||
// static std::optional<MyType> Decode(XLA_FFI_RetType type, void* ret);
|
||||
// };
|
||||
//
|
||||
|
|
@ -937,9 +939,10 @@ struct RetDecoding;
|
|||
//
|
||||
// template <>
|
||||
// struct AttrDecoding<MyType> {
|
||||
// using Type = <handler argument type for attribute type MyType>
|
||||
// static std::optional<MyType> Decode(XLA_FFI_AttrType type, void* attr,
|
||||
// DiagnosticEngine&);
|
||||
// using Type = <handler argument type for attribute type MyType>
|
||||
// static bool Isa(XLA_FFI_AttrType type, void* attr);
|
||||
// static std::optional<MyType> Decode(XLA_FFI_AttrType type, void* attr,
|
||||
// DiagnosticEngine&);
|
||||
// }
|
||||
//
|
||||
template <typename T>
|
||||
|
|
@ -1205,6 +1208,12 @@ class RemainingArgsBase {
|
|||
assert(offset <= args_->size && "illegal remaining args offset");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool isa(size_t index) const {
|
||||
size_t idx = offset() + index;
|
||||
return ArgDecoding<T>::Isa(args_->types[idx], args_->args[idx]);
|
||||
}
|
||||
|
||||
size_t size() const { return args_->size - offset_; }
|
||||
bool empty() const { return size() == 0; }
|
||||
|
||||
|
|
@ -1232,6 +1241,12 @@ class RemainingRetsBase {
|
|||
assert(offset <= rets_->size && "illegal remaining rets offset");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool isa(size_t index) const {
|
||||
size_t idx = offset_ + index;
|
||||
return RetDecoding<T>::Isa(rets_->types[idx], rets_->rets[idx]);
|
||||
}
|
||||
|
||||
size_t size() const { return rets_->size - offset_; }
|
||||
bool empty() const { return size() == 0; }
|
||||
|
||||
|
|
@ -1264,6 +1279,18 @@ class DictionaryBase {
|
|||
|
||||
bool contains(std::string_view name) const { return Find(name).has_value(); }
|
||||
|
||||
template <typename T>
|
||||
bool contains(std::string_view name) const {
|
||||
std::optional<size_t> idx = Find(name);
|
||||
if (XLA_FFI_PREDICT_FALSE(!idx.has_value())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
XLA_FFI_AttrType attr_type = attrs_->types[*idx];
|
||||
void* attr = attrs_->attrs[*idx];
|
||||
return AttrDecoding<T>::Isa(attr_type, attr);
|
||||
}
|
||||
|
||||
protected:
|
||||
template <typename T, typename... Ts>
|
||||
friend struct DecodeDictionaryAttr;
|
||||
|
|
@ -1768,6 +1795,12 @@ class Handler : public Ffi {
|
|||
template <> \
|
||||
struct AttrDecoding<T> { \
|
||||
using Type = T; \
|
||||
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE static bool Isa(XLA_FFI_AttrType type, \
|
||||
void* attr) { \
|
||||
return type == XLA_FFI_AttrType_SCALAR && \
|
||||
reinterpret_cast<XLA_FFI_Scalar*>(attr)->dtype == TYPE; \
|
||||
} \
|
||||
\
|
||||
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE static std::optional<T> Decode( \
|
||||
XLA_FFI_AttrType type, void* attr, DiagnosticEngine& diagnostic) { \
|
||||
if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_SCALAR)) { \
|
||||
|
|
|
|||
35
third_party/xla/xla/ffi/api/ffi.h
vendored
35
third_party/xla/xla/ffi/api/ffi.h
vendored
|
|
@ -747,6 +747,12 @@ using Token = BufferR0<DataType::TOKEN>; // NOLINT
|
|||
|
||||
namespace internal {
|
||||
|
||||
template <DataType dtype, size_t rank>
|
||||
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE bool IsaBuffer(XLA_FFI_Buffer* buf) {
|
||||
return static_cast<DataType>(buf->dtype) == dtype &&
|
||||
(rank == internal::kDynamicRank || buf->rank == rank);
|
||||
}
|
||||
|
||||
template <DataType dtype, size_t rank>
|
||||
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE std::optional<Buffer<dtype, rank>> DecodeBuffer(
|
||||
XLA_FFI_Buffer* buf, DiagnosticEngine& diagnostic) {
|
||||
|
|
@ -820,6 +826,11 @@ inline std::ostream& operator<<(std::ostream& os, const XLA_FFI_ArgType type) {
|
|||
|
||||
template <>
|
||||
struct ArgDecoding<AnyBuffer> {
|
||||
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
|
||||
static bool Isa(XLA_FFI_ArgType type, void* arg) {
|
||||
return type == XLA_FFI_ArgType_BUFFER;
|
||||
}
|
||||
|
||||
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
|
||||
static std::optional<AnyBuffer> Decode(XLA_FFI_ArgType type, void* arg,
|
||||
DiagnosticEngine& diagnostic) {
|
||||
|
|
@ -833,6 +844,13 @@ struct ArgDecoding<AnyBuffer> {
|
|||
|
||||
template <DataType dtype, size_t rank>
|
||||
struct ArgDecoding<Buffer<dtype, rank>> {
|
||||
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
|
||||
static bool Isa(XLA_FFI_ArgType type, void* arg) {
|
||||
return type == XLA_FFI_ArgType_BUFFER &&
|
||||
internal::IsaBuffer<dtype, rank>(
|
||||
reinterpret_cast<XLA_FFI_Buffer*>(arg));
|
||||
}
|
||||
|
||||
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
|
||||
static std::optional<Buffer<dtype, rank>> Decode(
|
||||
XLA_FFI_ArgType type, void* arg, DiagnosticEngine& diagnostic) {
|
||||
|
|
@ -895,6 +913,11 @@ inline std::ostream& operator<<(std::ostream& os, const XLA_FFI_RetType type) {
|
|||
|
||||
template <>
|
||||
struct RetDecoding<AnyBuffer> {
|
||||
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
|
||||
static bool Isa(XLA_FFI_RetType type, void* ret) {
|
||||
return type == XLA_FFI_RetType_BUFFER;
|
||||
}
|
||||
|
||||
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
|
||||
static std::optional<Result<AnyBuffer>> Decode(XLA_FFI_RetType type,
|
||||
void* ret,
|
||||
|
|
@ -909,6 +932,13 @@ struct RetDecoding<AnyBuffer> {
|
|||
|
||||
template <DataType dtype, size_t rank>
|
||||
struct RetDecoding<Buffer<dtype, rank>> {
|
||||
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
|
||||
static bool Isa(XLA_FFI_RetType type, void* ret) {
|
||||
return type == XLA_FFI_RetType_BUFFER &&
|
||||
internal::IsaBuffer<dtype, rank>(
|
||||
reinterpret_cast<XLA_FFI_Buffer*>(ret));
|
||||
}
|
||||
|
||||
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
|
||||
static std::optional<Result<Buffer<dtype, rank>>> Decode(
|
||||
XLA_FFI_RetType type, void* ret, DiagnosticEngine& diagnostic) {
|
||||
|
|
@ -1001,6 +1031,11 @@ template <>
|
|||
struct AttrDecoding<std::string_view> {
|
||||
using Type = std::string_view;
|
||||
|
||||
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE static bool Isa(XLA_FFI_AttrType type,
|
||||
void* attr) {
|
||||
return type == XLA_FFI_AttrType_STRING;
|
||||
}
|
||||
|
||||
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE static std::optional<std::string_view> Decode(
|
||||
XLA_FFI_AttrType type, void* attr, DiagnosticEngine& diagnostic) {
|
||||
if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_STRING)) {
|
||||
|
|
|
|||
15
third_party/xla/xla/ffi/api/ffi_test.cc
vendored
15
third_party/xla/xla/ffi/api/ffi_test.cc
vendored
|
|
@ -639,6 +639,9 @@ TEST(FfiTest, RemainingArgs) {
|
|||
auto fn = [&](RemainingArgs args) {
|
||||
EXPECT_EQ(args.size(), 1);
|
||||
|
||||
EXPECT_TRUE(args.isa<AnyBuffer>(0));
|
||||
EXPECT_FALSE(args.isa<BufferR2<F64>>(0));
|
||||
|
||||
ErrorOr<AnyBuffer> arg0 = args.get<AnyBuffer>(0);
|
||||
ErrorOr<AnyBuffer> arg1 = args.get<AnyBuffer>(1);
|
||||
|
||||
|
|
@ -666,6 +669,9 @@ TEST(FfiTest, RemainingRets) {
|
|||
auto fn = [&](Result<AnyBuffer> ret, RemainingRets rets) {
|
||||
EXPECT_EQ(rets.size(), 1);
|
||||
|
||||
EXPECT_TRUE(rets.isa<AnyBuffer>(0));
|
||||
EXPECT_FALSE(rets.isa<BufferR2<F64>>(0));
|
||||
|
||||
ErrorOr<Result<AnyBuffer>> ret0 = rets.get<AnyBuffer>(0);
|
||||
ErrorOr<Result<AnyBuffer>> ret1 = rets.get<AnyBuffer>(1);
|
||||
|
||||
|
|
@ -860,8 +866,17 @@ TEST(FfiTest, AutoBindingStructs) {
|
|||
|
||||
TEST(FfiTest, AutoBindingDictionary) {
|
||||
auto handler = Ffi::BindTo(+[](Dictionary attrs) {
|
||||
EXPECT_TRUE(attrs.contains("i32"));
|
||||
EXPECT_TRUE(attrs.contains("f32"));
|
||||
|
||||
EXPECT_TRUE(attrs.contains<int32_t>("i32"));
|
||||
EXPECT_TRUE(attrs.contains<float>("f32"));
|
||||
EXPECT_FALSE(attrs.contains<int64_t>("i32"));
|
||||
EXPECT_FALSE(attrs.contains<int64_t>("f32"));
|
||||
|
||||
EXPECT_EQ(*attrs.get<int32_t>("i32"), 42);
|
||||
EXPECT_EQ(*attrs.get<float>("f32"), 42.0f);
|
||||
|
||||
return Error::Success();
|
||||
});
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user