mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Pipe incarnations to jax.live_devices.
PiperOrigin-RevId: 822250955
This commit is contained in:
parent
47cd01d4a5
commit
5776d2771c
4
third_party/xla/xla/pjrt/distributed/BUILD
vendored
4
third_party/xla/xla/pjrt/distributed/BUILD
vendored
|
|
@ -71,6 +71,7 @@ cc_library(
|
|||
deps = [
|
||||
":key_value_store_interface",
|
||||
":util",
|
||||
"//xla/service:global_device_id",
|
||||
"//xla/tsl/distributed_runtime/coordination:coordination_client",
|
||||
"//xla/tsl/distributed_runtime/coordination:coordination_service_agent",
|
||||
"//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_client",
|
||||
|
|
@ -78,6 +79,7 @@ cc_library(
|
|||
"//xla/tsl/platform:statusor",
|
||||
"//xla/tsl/protobuf:coordination_config_proto_cc",
|
||||
"//xla/tsl/protobuf:coordination_service_proto_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
|
|
@ -147,10 +149,12 @@ xla_cc_test(
|
|||
":service",
|
||||
":topology_util",
|
||||
"//xla:status_macros",
|
||||
"//xla/service:global_device_id",
|
||||
"//xla/tsl/distributed_runtime/coordination:coordination_service_agent",
|
||||
"//xla/tsl/lib/core:status_test_util",
|
||||
"//xla/tsl/platform:env",
|
||||
"//xla/tsl/util/proto:proto_matchers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
|
|||
34
third_party/xla/xla/pjrt/distributed/client.cc
vendored
34
third_party/xla/xla/pjrt/distributed/client.cc
vendored
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
|
|
@ -31,6 +32,7 @@ limitations under the License.
|
|||
#include "absl/types/span.h"
|
||||
#include "grpcpp/channel.h"
|
||||
#include "xla/pjrt/distributed/key_value_store_interface.h"
|
||||
#include "xla/service/global_device_id.h"
|
||||
#include "xla/tsl/distributed_runtime/coordination/coordination_client.h"
|
||||
#include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h"
|
||||
#include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.h"
|
||||
|
|
@ -67,6 +69,8 @@ class DistributedRuntimeCoordinationServiceClient
|
|||
absl::Status WaitAtBarrier(
|
||||
std::string barrier_id, absl::Duration timeout,
|
||||
std::optional<absl::Span<const int32_t>> process_ids) override;
|
||||
absl::StatusOr<absl::flat_hash_map<int32_t, IncarnationId>>
|
||||
GetLiveNodesWithIncarnations(absl::Span<const int32_t> nodes) override;
|
||||
absl::StatusOr<std::vector<int32_t>> GetLiveNodes(
|
||||
absl::Span<const int32_t> nodes) override;
|
||||
absl::StatusOr<tsl::CoordinationServiceAgent*> GetCoordinationServiceAgent()
|
||||
|
|
@ -208,8 +212,8 @@ absl::Status DistributedRuntimeCoordinationServiceClient::WaitAtBarrier(
|
|||
return coord_agent_->WaitAtBarrier(barrier_id, timeout, tasks);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<int32_t>>
|
||||
DistributedRuntimeCoordinationServiceClient::GetLiveNodes(
|
||||
absl::StatusOr<absl::flat_hash_map<int32_t, IncarnationId>>
|
||||
DistributedRuntimeCoordinationServiceClient::GetLiveNodesWithIncarnations(
|
||||
absl::Span<const int32_t> nodes) {
|
||||
// Note that jax.distributed uses terms "process" and "node", and the
|
||||
// coordination service uses the term "task". These all refer to the same
|
||||
|
|
@ -227,13 +231,29 @@ DistributedRuntimeCoordinationServiceClient::GetLiveNodes(
|
|||
}
|
||||
|
||||
// Get the set of live tasks.
|
||||
TF_ASSIGN_OR_RETURN(const std::vector<tensorflow::CoordinatedTask> live_tasks,
|
||||
coord_agent_->GetAliveTasks(tasks));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const std::vector<tsl::CoordinationServiceAgent::AliveTask> live_tasks,
|
||||
coord_agent_->GetAliveTasks(tasks));
|
||||
|
||||
// Extract the node ids from the live tasks.
|
||||
std::vector<int32_t> live_nodes(live_tasks.size());
|
||||
for (int i = 0; i < live_tasks.size(); ++i) {
|
||||
live_nodes[i] = live_tasks[i].task_id();
|
||||
absl::flat_hash_map<int32_t, IncarnationId> live_nodes;
|
||||
for (const tsl::CoordinationServiceAgent::AliveTask& task : live_tasks) {
|
||||
live_nodes[task.task_id] = task.incarnation_id;
|
||||
}
|
||||
return live_nodes;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<int32_t>>
|
||||
DistributedRuntimeCoordinationServiceClient::GetLiveNodes(
|
||||
absl::Span<const int32_t> nodes) {
|
||||
absl::StatusOr<absl::flat_hash_map<int32_t, IncarnationId>>
|
||||
live_nodes_with_incarnations = GetLiveNodesWithIncarnations(nodes);
|
||||
if (!live_nodes_with_incarnations.ok()) {
|
||||
return live_nodes_with_incarnations.status();
|
||||
}
|
||||
std::vector<int32_t> live_nodes;
|
||||
for (const auto& [task_id, unused] : *live_nodes_with_incarnations) {
|
||||
live_nodes.push_back(task_id);
|
||||
}
|
||||
return live_nodes;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
|
|
@ -32,6 +33,7 @@ limitations under the License.
|
|||
#include "absl/types/span.h"
|
||||
#include "grpcpp/channel.h"
|
||||
#include "xla/pjrt/distributed/key_value_store_interface.h"
|
||||
#include "xla/service/global_device_id.h"
|
||||
#include "xla/tsl/platform/env.h"
|
||||
|
||||
namespace tsl {
|
||||
|
|
@ -149,8 +151,15 @@ class DistributedRuntimeClient {
|
|||
std::string barrier_id, absl::Duration timeout,
|
||||
std::optional<absl::Span<const int32_t>> nodes) = 0;
|
||||
|
||||
// Returns the subset of live nodes, along with their incarnations. See
|
||||
// CoordinationService.GetAliveTasks for detailed semantics.
|
||||
virtual absl::StatusOr<absl::flat_hash_map<int32_t, IncarnationId>>
|
||||
GetLiveNodesWithIncarnations(absl::Span<const int32_t> nodes) = 0;
|
||||
|
||||
// Returns the subset of live nodes. See CoordinationService.GetAliveTasks for
|
||||
// detailed semantics.
|
||||
//
|
||||
// TODO: mwhittaker - Remove this function.
|
||||
virtual absl::StatusOr<std::vector<int32_t>> GetLiveNodes(
|
||||
absl::Span<const int32_t> nodes) = 0;
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
|
|
@ -42,6 +43,7 @@ limitations under the License.
|
|||
#include "xla/pjrt/distributed/protocol.pb.h"
|
||||
#include "xla/pjrt/distributed/service.h"
|
||||
#include "xla/pjrt/distributed/topology_util.h"
|
||||
#include "xla/service/global_device_id.h"
|
||||
#include "xla/status_macros.h"
|
||||
#include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h"
|
||||
#include "xla/tsl/lib/core/status_test_util.h"
|
||||
|
|
@ -57,6 +59,7 @@ namespace xla {
|
|||
namespace {
|
||||
|
||||
using ::testing::IsEmpty;
|
||||
using ::testing::Key;
|
||||
using ::testing::Matches;
|
||||
using ::testing::Pair;
|
||||
using ::testing::UnorderedElementsAre;
|
||||
|
|
@ -1001,10 +1004,10 @@ TEST_F(ClientServerTest, GetLiveTasksSucceeds) {
|
|||
TF_ASSERT_OK(client->Connect());
|
||||
|
||||
// Get the set of live nodes. All three nodes should be live.
|
||||
absl::StatusOr<std::vector<int32_t>> live_nodes =
|
||||
client->GetLiveNodes(std::vector<int>{0, 1, 2});
|
||||
absl::StatusOr<absl::flat_hash_map<int32_t, IncarnationId>> live_nodes =
|
||||
client->GetLiveNodesWithIncarnations(std::vector<int>{0, 1, 2});
|
||||
TF_ASSERT_OK(live_nodes.status());
|
||||
EXPECT_THAT(*live_nodes, UnorderedElementsAre(0, 1, 2));
|
||||
EXPECT_THAT(*live_nodes, UnorderedElementsAre(Key(0), Key(1), Key(2)));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -1023,7 +1026,7 @@ TEST_F(ClientServerTest, GetLiveTasksWithoutBeingAMember) {
|
|||
// Get the set of live nodes but don't include ourselves.
|
||||
std::vector<int> nodes{0, 1, 2};
|
||||
nodes.erase(nodes.begin() + i);
|
||||
EXPECT_THAT(client->GetLiveNodes(nodes),
|
||||
EXPECT_THAT(client->GetLiveNodesWithIncarnations(nodes),
|
||||
absl_testing::StatusIs(absl::StatusCode::kInvalidArgument));
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1062,8 +1062,8 @@ TEST_F(ClientServerTest, GetAliveTasks_Succeed) {
|
|||
auto thread_fn = [&](int node_id) -> absl::Status {
|
||||
auto client = GetClient(node_id);
|
||||
TF_RETURN_IF_ERROR(client->Connect());
|
||||
absl::StatusOr<std::vector<tensorflow::CoordinatedTask>> alive_tasks =
|
||||
client->GetAliveTasks({GetTask(0), GetTask(1)});
|
||||
absl::StatusOr<std::vector<CoordinationServiceAgent::AliveTask>>
|
||||
alive_tasks = client->GetAliveTasks({GetTask(0), GetTask(1)});
|
||||
if (!alive_tasks.ok()) {
|
||||
return alive_tasks.status();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1566,7 +1566,6 @@ void CoordinationService::RefreshAliveness() {
|
|||
// the same set of alive tasks (alive_tasks) to every task in the barrier.
|
||||
std::vector<CoordinatedTask> v{alive_tasks.begin(), alive_tasks.end()};
|
||||
std::vector<IncarnationId> incarnation_ids = IncarnationIds(v);
|
||||
absl::c_sort(incarnation_ids);
|
||||
for (const GetAliveTasksCallback& done : it->dones) {
|
||||
done(absl::OkStatus(), v, incarnation_ids);
|
||||
}
|
||||
|
|
@ -1618,7 +1617,6 @@ void CoordinationService::GetAliveTasksAsync(
|
|||
if (TaskSetSubset(alive_tasks, it->in_barrier)) {
|
||||
std::vector<CoordinatedTask> v{alive_tasks.begin(), alive_tasks.end()};
|
||||
std::vector<IncarnationId> incarnation_ids = IncarnationIds(v);
|
||||
absl::c_sort(incarnation_ids);
|
||||
for (const GetAliveTasksCallback& done : it->dones) {
|
||||
done(absl::OkStatus(), v, incarnation_ids);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1012,7 +1012,7 @@ void CoordinationServiceAgent::CancelBarrierAsync(absl::string_view barrier_id,
|
|||
});
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<tensorflow::CoordinatedTask>>
|
||||
absl::StatusOr<std::vector<CoordinationServiceAgent::AliveTask>>
|
||||
CoordinationServiceAgent::GetAliveTasks(
|
||||
const std::vector<CoordinatedTask>& tasks) {
|
||||
// Validate the agent.
|
||||
|
|
@ -1036,20 +1036,21 @@ CoordinationServiceAgent::GetAliveTasks(
|
|||
};
|
||||
leader_client_->GetAliveTasksAsync(request.get(), response.get(), done);
|
||||
n.WaitForNotification();
|
||||
|
||||
// Parse the response.
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
{
|
||||
absl::MutexLock lock(incarnations_mu_);
|
||||
for (int i = 0; i < response->alive_tasks_size(); ++i) {
|
||||
incarnations_[response->alive_tasks(i).task_id()] =
|
||||
response->incarnations(i);
|
||||
}
|
||||
|
||||
// Parse the response.
|
||||
absl::MutexLock lock(incarnations_mu_);
|
||||
std::vector<AliveTask> alive_tasks;
|
||||
for (int i = 0; i < response->alive_tasks_size(); ++i) {
|
||||
int task_id = response->alive_tasks(i).task_id();
|
||||
IncarnationId incarnation_id(response->incarnations(i));
|
||||
|
||||
alive_tasks.push_back(AliveTask{task_id, incarnation_id});
|
||||
incarnations_[task_id] = incarnation_id;
|
||||
}
|
||||
return std::vector<tensorflow::CoordinatedTask>(
|
||||
response->alive_tasks().begin(), response->alive_tasks().end());
|
||||
return alive_tasks;
|
||||
}
|
||||
|
||||
// Returns an error if agent is not running.
|
||||
|
|
|
|||
|
|
@ -320,7 +320,11 @@ class CoordinationServiceAgent {
|
|||
// has failed and that every task calls GetAliveTasks([A, B, C, D]). The
|
||||
// invocation will return tasks [A, B, C]. The GetAliveTasks call acts as a
|
||||
// barrier across tasks A, B, and C. Task D, which failed, is ignored.
|
||||
absl::StatusOr<std::vector<tensorflow::CoordinatedTask>> GetAliveTasks(
|
||||
struct AliveTask {
|
||||
int task_id;
|
||||
IncarnationId incarnation_id;
|
||||
};
|
||||
absl::StatusOr<std::vector<AliveTask>> GetAliveTasks(
|
||||
const std::vector<tensorflow::CoordinatedTask>& tasks);
|
||||
|
||||
// Returns the latest known set of incarnation ids for every task. Incarnation
|
||||
|
|
|
|||
|
|
@ -2562,9 +2562,9 @@ TEST_F(GetAliveTasksTest, SuccessfulGetAliveTasks) {
|
|||
const std::vector<IncarnationId>& incarnations) {
|
||||
EXPECT_OK(status);
|
||||
EXPECT_THAT(alive_tasks, UnorderedElementsAreArray(GetTaskMatchers()));
|
||||
EXPECT_EQ(incarnations,
|
||||
(std::vector<IncarnationId>{IncarnationId(0), IncarnationId(1),
|
||||
IncarnationId(2)}));
|
||||
EXPECT_THAT(incarnations,
|
||||
UnorderedElementsAre(IncarnationId(0), IncarnationId(1),
|
||||
IncarnationId(2)));
|
||||
finished.DecrementCount();
|
||||
};
|
||||
GetCoordinationService()->GetAliveTasksAsync(GetTask(0), GetTasks(), done);
|
||||
|
|
@ -2583,8 +2583,8 @@ TEST_F(GetAliveTasksTest, FailedTaskBeforeCallingGetAliveTasks) {
|
|||
EXPECT_OK(status);
|
||||
EXPECT_THAT(alive_tasks, UnorderedElementsAre(EqualsProto(GetTask(0)),
|
||||
EqualsProto(GetTask(1))));
|
||||
EXPECT_EQ(incarnations,
|
||||
(std::vector<IncarnationId>{IncarnationId(0), IncarnationId(1)}));
|
||||
EXPECT_THAT(incarnations,
|
||||
UnorderedElementsAre(IncarnationId(0), IncarnationId(1)));
|
||||
finished.DecrementCount();
|
||||
};
|
||||
ASSERT_OK(GetCoordinationService()->ReportTaskError(
|
||||
|
|
@ -2605,8 +2605,8 @@ TEST_F(GetAliveTasksTest, FailedTaskAfterCallingGetAliveTasks) {
|
|||
EXPECT_OK(status);
|
||||
EXPECT_THAT(alive_tasks, UnorderedElementsAre(EqualsProto(GetTask(0)),
|
||||
EqualsProto(GetTask(1))));
|
||||
EXPECT_EQ(incarnations,
|
||||
(std::vector<IncarnationId>{IncarnationId(0), IncarnationId(1)}));
|
||||
EXPECT_THAT(incarnations,
|
||||
UnorderedElementsAre(IncarnationId(0), IncarnationId(1)));
|
||||
finished.DecrementCount();
|
||||
};
|
||||
GetCoordinationService()->GetAliveTasksAsync(GetTask(0), GetTasks(), done);
|
||||
|
|
@ -2630,8 +2630,8 @@ TEST_F(GetAliveTasksTest, ConcurrentGetAliveTasks) {
|
|||
EXPECT_OK(status);
|
||||
EXPECT_THAT(alive_tasks, UnorderedElementsAre(EqualsProto(tasks_01[0]),
|
||||
EqualsProto(tasks_01[1])));
|
||||
EXPECT_EQ(incarnations,
|
||||
(std::vector<IncarnationId>{IncarnationId(0), IncarnationId(1)}));
|
||||
EXPECT_THAT(incarnations,
|
||||
UnorderedElementsAre(IncarnationId(0), IncarnationId(1)));
|
||||
finished_01.DecrementCount();
|
||||
};
|
||||
|
||||
|
|
@ -2644,8 +2644,8 @@ TEST_F(GetAliveTasksTest, ConcurrentGetAliveTasks) {
|
|||
EXPECT_OK(status);
|
||||
EXPECT_THAT(alive_tasks, UnorderedElementsAre(EqualsProto(tasks_12[0]),
|
||||
EqualsProto(tasks_12[1])));
|
||||
EXPECT_EQ(incarnations,
|
||||
(std::vector<IncarnationId>{IncarnationId(1), IncarnationId(2)}));
|
||||
EXPECT_THAT(incarnations,
|
||||
UnorderedElementsAre(IncarnationId(1), IncarnationId(2)));
|
||||
finished_12.DecrementCount();
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user