mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[tsl:concurrency] Add APIs to create futures by running callbacks on executor
PiperOrigin-RevId: 813953773
This commit is contained in:
parent
c1a1cee310
commit
14ad872740
2
third_party/xla/xla/tsl/concurrency/BUILD
vendored
2
third_party/xla/xla/tsl/concurrency/BUILD
vendored
|
|
@ -144,6 +144,7 @@ cc_library(
|
|||
compatible_with = get_compatible_with_portable(),
|
||||
deps = [
|
||||
":async_value",
|
||||
":executor",
|
||||
":ref_count",
|
||||
"//xla/tsl/platform:logging",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
|
|
@ -162,6 +163,7 @@ tsl_cc_test(
|
|||
name = "future_test",
|
||||
srcs = ["future_test.cc"],
|
||||
deps = [
|
||||
":executor",
|
||||
":future",
|
||||
"//xla/tsl/platform:test",
|
||||
"//xla/tsl/platform:test_benchmark",
|
||||
|
|
|
|||
24
third_party/xla/xla/tsl/concurrency/future.h
vendored
24
third_party/xla/xla/tsl/concurrency/future.h
vendored
|
|
@ -31,6 +31,7 @@ limitations under the License.
|
|||
#include "absl/utility/utility.h"
|
||||
#include "xla/tsl/concurrency/async_value.h"
|
||||
#include "xla/tsl/concurrency/async_value_ref.h"
|
||||
#include "xla/tsl/concurrency/executor.h"
|
||||
#include "xla/tsl/platform/logging.h"
|
||||
|
||||
namespace tsl {
|
||||
|
|
@ -478,6 +479,18 @@ class Future : public internal::FutureBase<absl::StatusOr<T>> {
|
|||
return std::make_pair(std::move(promise), std::move(future));
|
||||
}
|
||||
|
||||
// Returns a future that is constructed from the result of invoking functor
|
||||
// `f` on the given `executor`.
|
||||
template <typename F, typename R = std::invoke_result_t<F>,
|
||||
std::enable_if_t<std::is_constructible_v<absl::StatusOr<T>, R>>* =
|
||||
nullptr>
|
||||
static Future<T> MakeOn(Executor& executor, F&& f) {
|
||||
auto [promise, future] = MakePromise();
|
||||
executor.Execute([promise = std::move(promise),
|
||||
f = std::forward<F>(f)]() mutable { promise.Set(f()); });
|
||||
return std::move(future);
|
||||
}
|
||||
|
||||
using Base::Await;
|
||||
using Base::GetReadyFuture;
|
||||
using Base::OnReady;
|
||||
|
|
@ -768,6 +781,17 @@ class Future<void> : public internal::FutureBase<absl::Status> {
|
|||
return std::make_pair(std::move(promise), std::move(future));
|
||||
}
|
||||
|
||||
// Returns a future that is constructed from the result of invoking functor
|
||||
// `f` on the given `executor`.
|
||||
template <typename F, typename R = std::invoke_result_t<F>,
|
||||
std::enable_if_t<std::is_same_v<R, absl::Status>>* = nullptr>
|
||||
static Future<> MakeOn(Executor& executor, F&& f) {
|
||||
auto [promise, future] = MakePromise();
|
||||
executor.Execute([promise = std::move(promise),
|
||||
f = std::forward<F>(f)]() mutable { promise.Set(f()); });
|
||||
return std::move(future);
|
||||
}
|
||||
|
||||
using Base::Await;
|
||||
using Base::BlockUntilReady;
|
||||
using Base::OnReady;
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||
#include "absl/status/status.h"
|
||||
#include "absl/status/status_matchers.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "xla/tsl/concurrency/executor.h"
|
||||
#include "xla/tsl/platform/test.h"
|
||||
#include "xla/tsl/platform/test_benchmark.h"
|
||||
|
||||
|
|
@ -674,6 +675,62 @@ TEST(FutureTest, MakeSharedPromise) {
|
|||
}
|
||||
}
|
||||
|
||||
struct InlineExecutor : public Executor {
|
||||
void Execute(Task task) final { std::move(task)(); }
|
||||
};
|
||||
|
||||
TEST(FutureTest, MakeOnStateless) {
|
||||
InlineExecutor e;
|
||||
|
||||
{
|
||||
auto future = Future<>::MakeOn(e, [] { return absl::OkStatus(); });
|
||||
EXPECT_TRUE(future.IsReady());
|
||||
EXPECT_EQ(future.Await(), absl::OkStatus());
|
||||
}
|
||||
|
||||
{
|
||||
auto future =
|
||||
Future<>::MakeOn(e, [] { return absl::InternalError("test"); });
|
||||
EXPECT_TRUE(future.IsReady());
|
||||
EXPECT_EQ(future.Await(), absl::InternalError("test"));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FutureTest, MakeOnStateful) {
|
||||
InlineExecutor executor;
|
||||
|
||||
struct Foo {
|
||||
Foo(int32_t value) : value(value) {} // NOLINT
|
||||
int32_t value;
|
||||
};
|
||||
|
||||
{
|
||||
auto future = Future<int32_t>::MakeOn(executor, [] { return 42; });
|
||||
EXPECT_TRUE(future.IsReady());
|
||||
EXPECT_EQ(*future.Await(), 42);
|
||||
}
|
||||
|
||||
{
|
||||
auto future = Future<Foo>::MakeOn(executor, [] { return 42; });
|
||||
EXPECT_TRUE(future.IsReady());
|
||||
EXPECT_EQ(future.Await()->value, 42);
|
||||
}
|
||||
|
||||
{
|
||||
auto future = Future<std::unique_ptr<int32_t>>::MakeOn(
|
||||
executor, [] { return std::make_unique<int32_t>(42); });
|
||||
EXPECT_TRUE(future.IsReady());
|
||||
EXPECT_EQ(**future.Await(), 42);
|
||||
}
|
||||
|
||||
{
|
||||
auto future = Future<int32_t>::MakeOn(
|
||||
executor, [] { return absl::InternalError("test"); });
|
||||
EXPECT_TRUE(future.IsReady());
|
||||
EXPECT_EQ(future.Await().status(), absl::InternalError("test"));
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Performance benchmarks.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user