[tsl:concurrency] Add APIs to create futures by running callbacks on executor

PiperOrigin-RevId: 813953773
This commit is contained in:
Eugene Zhulenev 2025-10-01 16:04:19 -07:00 committed by TensorFlower Gardener
parent c1a1cee310
commit 14ad872740
3 changed files with 83 additions and 0 deletions

View File

@ -144,6 +144,7 @@ cc_library(
compatible_with = get_compatible_with_portable(), compatible_with = get_compatible_with_portable(),
deps = [ deps = [
":async_value", ":async_value",
":executor",
":ref_count", ":ref_count",
"//xla/tsl/platform:logging", "//xla/tsl/platform:logging",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
@ -162,6 +163,7 @@ tsl_cc_test(
name = "future_test", name = "future_test",
srcs = ["future_test.cc"], srcs = ["future_test.cc"],
deps = [ deps = [
":executor",
":future", ":future",
"//xla/tsl/platform:test", "//xla/tsl/platform:test",
"//xla/tsl/platform:test_benchmark", "//xla/tsl/platform:test_benchmark",

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "absl/utility/utility.h" #include "absl/utility/utility.h"
#include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/async_value.h"
#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/concurrency/executor.h"
#include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/logging.h"
namespace tsl { namespace tsl {
@ -478,6 +479,18 @@ class Future : public internal::FutureBase<absl::StatusOr<T>> {
return std::make_pair(std::move(promise), std::move(future)); return std::make_pair(std::move(promise), std::move(future));
} }
// Returns a future that is constructed from the result of invoking functor
// `f` on the given `executor`.
template <typename F, typename R = std::invoke_result_t<F>,
std::enable_if_t<std::is_constructible_v<absl::StatusOr<T>, R>>* =
nullptr>
static Future<T> MakeOn(Executor& executor, F&& f) {
auto [promise, future] = MakePromise();
executor.Execute([promise = std::move(promise),
f = std::forward<F>(f)]() mutable { promise.Set(f()); });
return std::move(future);
}
using Base::Await; using Base::Await;
using Base::GetReadyFuture; using Base::GetReadyFuture;
using Base::OnReady; using Base::OnReady;
@ -768,6 +781,17 @@ class Future<void> : public internal::FutureBase<absl::Status> {
return std::make_pair(std::move(promise), std::move(future)); return std::make_pair(std::move(promise), std::move(future));
} }
// Returns a future that is constructed from the result of invoking functor
// `f` on the given `executor`.
template <typename F, typename R = std::invoke_result_t<F>,
std::enable_if_t<std::is_same_v<R, absl::Status>>* = nullptr>
static Future<> MakeOn(Executor& executor, F&& f) {
auto [promise, future] = MakePromise();
executor.Execute([promise = std::move(promise),
f = std::forward<F>(f)]() mutable { promise.Set(f()); });
return std::move(future);
}
using Base::Await; using Base::Await;
using Base::BlockUntilReady; using Base::BlockUntilReady;
using Base::OnReady; using Base::OnReady;

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/status_matchers.h" #include "absl/status/status_matchers.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "xla/tsl/concurrency/executor.h"
#include "xla/tsl/platform/test.h" #include "xla/tsl/platform/test.h"
#include "xla/tsl/platform/test_benchmark.h" #include "xla/tsl/platform/test_benchmark.h"
@ -674,6 +675,62 @@ TEST(FutureTest, MakeSharedPromise) {
} }
} }
struct InlineExecutor : public Executor {
void Execute(Task task) final { std::move(task)(); }
};
TEST(FutureTest, MakeOnStateless) {
InlineExecutor e;
{
auto future = Future<>::MakeOn(e, [] { return absl::OkStatus(); });
EXPECT_TRUE(future.IsReady());
EXPECT_EQ(future.Await(), absl::OkStatus());
}
{
auto future =
Future<>::MakeOn(e, [] { return absl::InternalError("test"); });
EXPECT_TRUE(future.IsReady());
EXPECT_EQ(future.Await(), absl::InternalError("test"));
}
}
TEST(FutureTest, MakeOnStateful) {
InlineExecutor executor;
struct Foo {
Foo(int32_t value) : value(value) {} // NOLINT
int32_t value;
};
{
auto future = Future<int32_t>::MakeOn(executor, [] { return 42; });
EXPECT_TRUE(future.IsReady());
EXPECT_EQ(*future.Await(), 42);
}
{
auto future = Future<Foo>::MakeOn(executor, [] { return 42; });
EXPECT_TRUE(future.IsReady());
EXPECT_EQ(future.Await()->value, 42);
}
{
auto future = Future<std::unique_ptr<int32_t>>::MakeOn(
executor, [] { return std::make_unique<int32_t>(42); });
EXPECT_TRUE(future.IsReady());
EXPECT_EQ(**future.Await(), 42);
}
{
auto future = Future<int32_t>::MakeOn(
executor, [] { return absl::InternalError("test"); });
EXPECT_TRUE(future.IsReady());
EXPECT_EQ(future.Await().status(), absl::InternalError("test"));
}
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Performance benchmarks. // Performance benchmarks.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//