mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[tsl:concurrency] Add APIs to create futures by running callbacks on executor
PiperOrigin-RevId: 813953773
This commit is contained in:
parent
c1a1cee310
commit
14ad872740
2
third_party/xla/xla/tsl/concurrency/BUILD
vendored
2
third_party/xla/xla/tsl/concurrency/BUILD
vendored
|
|
@ -144,6 +144,7 @@ cc_library(
|
||||||
compatible_with = get_compatible_with_portable(),
|
compatible_with = get_compatible_with_portable(),
|
||||||
deps = [
|
deps = [
|
||||||
":async_value",
|
":async_value",
|
||||||
|
":executor",
|
||||||
":ref_count",
|
":ref_count",
|
||||||
"//xla/tsl/platform:logging",
|
"//xla/tsl/platform:logging",
|
||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
|
|
@ -162,6 +163,7 @@ tsl_cc_test(
|
||||||
name = "future_test",
|
name = "future_test",
|
||||||
srcs = ["future_test.cc"],
|
srcs = ["future_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":executor",
|
||||||
":future",
|
":future",
|
||||||
"//xla/tsl/platform:test",
|
"//xla/tsl/platform:test",
|
||||||
"//xla/tsl/platform:test_benchmark",
|
"//xla/tsl/platform:test_benchmark",
|
||||||
|
|
|
||||||
24
third_party/xla/xla/tsl/concurrency/future.h
vendored
24
third_party/xla/xla/tsl/concurrency/future.h
vendored
|
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||||
#include "absl/utility/utility.h"
|
#include "absl/utility/utility.h"
|
||||||
#include "xla/tsl/concurrency/async_value.h"
|
#include "xla/tsl/concurrency/async_value.h"
|
||||||
#include "xla/tsl/concurrency/async_value_ref.h"
|
#include "xla/tsl/concurrency/async_value_ref.h"
|
||||||
|
#include "xla/tsl/concurrency/executor.h"
|
||||||
#include "xla/tsl/platform/logging.h"
|
#include "xla/tsl/platform/logging.h"
|
||||||
|
|
||||||
namespace tsl {
|
namespace tsl {
|
||||||
|
|
@ -478,6 +479,18 @@ class Future : public internal::FutureBase<absl::StatusOr<T>> {
|
||||||
return std::make_pair(std::move(promise), std::move(future));
|
return std::make_pair(std::move(promise), std::move(future));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns a future that is constructed from the result of invoking functor
|
||||||
|
// `f` on the given `executor`.
|
||||||
|
template <typename F, typename R = std::invoke_result_t<F>,
|
||||||
|
std::enable_if_t<std::is_constructible_v<absl::StatusOr<T>, R>>* =
|
||||||
|
nullptr>
|
||||||
|
static Future<T> MakeOn(Executor& executor, F&& f) {
|
||||||
|
auto [promise, future] = MakePromise();
|
||||||
|
executor.Execute([promise = std::move(promise),
|
||||||
|
f = std::forward<F>(f)]() mutable { promise.Set(f()); });
|
||||||
|
return std::move(future);
|
||||||
|
}
|
||||||
|
|
||||||
using Base::Await;
|
using Base::Await;
|
||||||
using Base::GetReadyFuture;
|
using Base::GetReadyFuture;
|
||||||
using Base::OnReady;
|
using Base::OnReady;
|
||||||
|
|
@ -768,6 +781,17 @@ class Future<void> : public internal::FutureBase<absl::Status> {
|
||||||
return std::make_pair(std::move(promise), std::move(future));
|
return std::make_pair(std::move(promise), std::move(future));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns a future that is constructed from the result of invoking functor
|
||||||
|
// `f` on the given `executor`.
|
||||||
|
template <typename F, typename R = std::invoke_result_t<F>,
|
||||||
|
std::enable_if_t<std::is_same_v<R, absl::Status>>* = nullptr>
|
||||||
|
static Future<> MakeOn(Executor& executor, F&& f) {
|
||||||
|
auto [promise, future] = MakePromise();
|
||||||
|
executor.Execute([promise = std::move(promise),
|
||||||
|
f = std::forward<F>(f)]() mutable { promise.Set(f()); });
|
||||||
|
return std::move(future);
|
||||||
|
}
|
||||||
|
|
||||||
using Base::Await;
|
using Base::Await;
|
||||||
using Base::BlockUntilReady;
|
using Base::BlockUntilReady;
|
||||||
using Base::OnReady;
|
using Base::OnReady;
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/status/status_matchers.h"
|
#include "absl/status/status_matchers.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
|
#include "xla/tsl/concurrency/executor.h"
|
||||||
#include "xla/tsl/platform/test.h"
|
#include "xla/tsl/platform/test.h"
|
||||||
#include "xla/tsl/platform/test_benchmark.h"
|
#include "xla/tsl/platform/test_benchmark.h"
|
||||||
|
|
||||||
|
|
@ -674,6 +675,62 @@ TEST(FutureTest, MakeSharedPromise) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct InlineExecutor : public Executor {
|
||||||
|
void Execute(Task task) final { std::move(task)(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(FutureTest, MakeOnStateless) {
|
||||||
|
InlineExecutor e;
|
||||||
|
|
||||||
|
{
|
||||||
|
auto future = Future<>::MakeOn(e, [] { return absl::OkStatus(); });
|
||||||
|
EXPECT_TRUE(future.IsReady());
|
||||||
|
EXPECT_EQ(future.Await(), absl::OkStatus());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto future =
|
||||||
|
Future<>::MakeOn(e, [] { return absl::InternalError("test"); });
|
||||||
|
EXPECT_TRUE(future.IsReady());
|
||||||
|
EXPECT_EQ(future.Await(), absl::InternalError("test"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FutureTest, MakeOnStateful) {
|
||||||
|
InlineExecutor executor;
|
||||||
|
|
||||||
|
struct Foo {
|
||||||
|
Foo(int32_t value) : value(value) {} // NOLINT
|
||||||
|
int32_t value;
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
auto future = Future<int32_t>::MakeOn(executor, [] { return 42; });
|
||||||
|
EXPECT_TRUE(future.IsReady());
|
||||||
|
EXPECT_EQ(*future.Await(), 42);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto future = Future<Foo>::MakeOn(executor, [] { return 42; });
|
||||||
|
EXPECT_TRUE(future.IsReady());
|
||||||
|
EXPECT_EQ(future.Await()->value, 42);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto future = Future<std::unique_ptr<int32_t>>::MakeOn(
|
||||||
|
executor, [] { return std::make_unique<int32_t>(42); });
|
||||||
|
EXPECT_TRUE(future.IsReady());
|
||||||
|
EXPECT_EQ(**future.Await(), 42);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto future = Future<int32_t>::MakeOn(
|
||||||
|
executor, [] { return absl::InternalError("test"); });
|
||||||
|
EXPECT_TRUE(future.IsReady());
|
||||||
|
EXPECT_EQ(future.Await().status(), absl::InternalError("test"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Performance benchmarks.
|
// Performance benchmarks.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user