[xla:ffi] Remove unused context decoding for C API internals

PiperOrigin-RevId: 824792000
This commit is contained in:
Eugene Zhulenev 2025-10-27 19:44:03 -07:00 committed by TensorFlower Gardener
parent b6f66e3e01
commit c09d68c588
4 changed files with 6 additions and 73 deletions

View File

@ -1499,35 +1499,6 @@ inline ThreadPool::ThreadPool(const XLA_FFI_Api* api,
DiagnosticEngine& diagnostic)
: api_(api), ctx_(ctx), diagnostic_(diagnostic) {}
//===----------------------------------------------------------------------===//
// Context decoding for FFI internals
//===----------------------------------------------------------------------===//
struct FfiApi {};
struct FfiExecutionContext {};
template <>
struct CtxDecoding<FfiApi> {
using Type = const XLA_FFI_Api*;
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE static std::optional<Type> Decode(
const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx,
DiagnosticEngine& diagnostic) {
return api;
}
};
template <>
struct CtxDecoding<FfiExecutionContext> {
using Type = XLA_FFI_ExecutionContext*;
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE static std::optional<Type> Decode(
const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx,
DiagnosticEngine& diagnostic) {
return ctx;
}
};
//===----------------------------------------------------------------------===//
// Type Registration
//===----------------------------------------------------------------------===//

View File

@ -1427,13 +1427,6 @@ TEST(FfiTest, ScratchAllocatorUnimplemented) {
TF_ASSERT_OK(status);
}
TEST(FfiTest, BindFfiInternals) {
(void)Ffi::Bind().Ctx<FfiApi>().Ctx<FfiExecutionContext>().To(
+[](const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx) {
return Error::Success();
});
}
TEST(FfiTest, ThreadPool) {
tsl::thread::ThreadPool pool(tsl::Env::Default(), "ffi-test", 2);
Eigen::ThreadPoolDevice device(pool.AsEigenThreadPool(), pool.NumThreads());

View File

@ -69,8 +69,6 @@ struct Allocator {}; // binds `se::DeviceMemoryAllocator*`
struct ScratchAllocator {}; // binds `se::OwningScratchAllocator`
struct CalledComputation {}; // binds `HloComputation*`
struct IntraOpThreadPool {}; // binds `const Eigen::ThreadPoolDevice*`
struct FfiApi {}; // binds `const XLA_FFI_Api*`
struct FfiExecutionContext {}; // binds `XLA_FFI_ExecutionContext*`
template <typename T>
struct PlatformStream {}; // binds a platform stream, e.g. `cudaStream_t`
@ -622,28 +620,6 @@ struct CtxDecoding<IntraOpThreadPool> {
}
};
template <>
struct CtxDecoding<FfiApi> {
using Type = const XLA_FFI_Api*;
static std::optional<Type> Decode(const XLA_FFI_Api* api,
XLA_FFI_ExecutionContext* ctx,
DiagnosticEngine&) {
return api;
}
};
template <>
struct CtxDecoding<FfiExecutionContext> {
using Type = XLA_FFI_ExecutionContext*;
static std::optional<Type> Decode(const XLA_FFI_Api* api,
XLA_FFI_ExecutionContext* ctx,
DiagnosticEngine&) {
return ctx;
}
};
template <typename T>
struct CtxDecoding<PlatformStream<T>> {
using Type = T;

View File

@ -1142,13 +1142,6 @@ TEST(FfiTest, PlatformStream) {
(void)Ffi::BindTo(+[](TestStream stream) { return absl::OkStatus(); });
}
TEST(FfiTest, BindFfiInternals) {
(void)Ffi::Bind().Ctx<FfiApi>().Ctx<FfiExecutionContext>().To(
+[](const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx) {
return absl::OkStatus();
});
}
//===----------------------------------------------------------------------===//
// Performance benchmarks are below.
//===----------------------------------------------------------------------===//