Pipe incarnations to jax.live_devices.

PiperOrigin-RevId: 822250955
This commit is contained in:
Michael Whittaker 2025-10-21 13:21:41 -07:00 committed by TensorFlower Gardener
parent 47cd01d4a5
commit 5776d2771c
9 changed files with 77 additions and 38 deletions

View File

@ -71,6 +71,7 @@ cc_library(
deps = [
":key_value_store_interface",
":util",
"//xla/service:global_device_id",
"//xla/tsl/distributed_runtime/coordination:coordination_client",
"//xla/tsl/distributed_runtime/coordination:coordination_service_agent",
"//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_client",
@ -78,6 +79,7 @@ cc_library(
"//xla/tsl/platform:statusor",
"//xla/tsl/protobuf:coordination_config_proto_cc",
"//xla/tsl/protobuf:coordination_service_proto_cc",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
@ -147,10 +149,12 @@ xla_cc_test(
":service",
":topology_util",
"//xla:status_macros",
"//xla/service:global_device_id",
"//xla/tsl/distributed_runtime/coordination:coordination_service_agent",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/platform:env",
"//xla/tsl/util/proto:proto_matchers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
@ -31,6 +32,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "grpcpp/channel.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/service/global_device_id.h"
#include "xla/tsl/distributed_runtime/coordination/coordination_client.h"
#include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h"
#include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.h"
@ -67,6 +69,8 @@ class DistributedRuntimeCoordinationServiceClient
absl::Status WaitAtBarrier(
std::string barrier_id, absl::Duration timeout,
std::optional<absl::Span<const int32_t>> process_ids) override;
absl::StatusOr<absl::flat_hash_map<int32_t, IncarnationId>>
GetLiveNodesWithIncarnations(absl::Span<const int32_t> nodes) override;
absl::StatusOr<std::vector<int32_t>> GetLiveNodes(
absl::Span<const int32_t> nodes) override;
absl::StatusOr<tsl::CoordinationServiceAgent*> GetCoordinationServiceAgent()
@ -208,8 +212,8 @@ absl::Status DistributedRuntimeCoordinationServiceClient::WaitAtBarrier(
return coord_agent_->WaitAtBarrier(barrier_id, timeout, tasks);
}
absl::StatusOr<std::vector<int32_t>>
DistributedRuntimeCoordinationServiceClient::GetLiveNodes(
absl::StatusOr<absl::flat_hash_map<int32_t, IncarnationId>>
DistributedRuntimeCoordinationServiceClient::GetLiveNodesWithIncarnations(
absl::Span<const int32_t> nodes) {
// Note that jax.distributed uses terms "process" and "node", and the
// coordination service uses the term "task". These all refer to the same
@ -227,13 +231,29 @@ DistributedRuntimeCoordinationServiceClient::GetLiveNodes(
}
// Get the set of live tasks.
TF_ASSIGN_OR_RETURN(const std::vector<tensorflow::CoordinatedTask> live_tasks,
coord_agent_->GetAliveTasks(tasks));
TF_ASSIGN_OR_RETURN(
const std::vector<tsl::CoordinationServiceAgent::AliveTask> live_tasks,
coord_agent_->GetAliveTasks(tasks));
// Extract the node ids from the live tasks.
std::vector<int32_t> live_nodes(live_tasks.size());
for (int i = 0; i < live_tasks.size(); ++i) {
live_nodes[i] = live_tasks[i].task_id();
absl::flat_hash_map<int32_t, IncarnationId> live_nodes;
for (const tsl::CoordinationServiceAgent::AliveTask& task : live_tasks) {
live_nodes[task.task_id] = task.incarnation_id;
}
return live_nodes;
}
absl::StatusOr<std::vector<int32_t>>
DistributedRuntimeCoordinationServiceClient::GetLiveNodes(
absl::Span<const int32_t> nodes) {
absl::StatusOr<absl::flat_hash_map<int32_t, IncarnationId>>
live_nodes_with_incarnations = GetLiveNodesWithIncarnations(nodes);
if (!live_nodes_with_incarnations.ok()) {
return live_nodes_with_incarnations.status();
}
std::vector<int32_t> live_nodes;
for (const auto& [task_id, unused] : *live_nodes_with_incarnations) {
live_nodes.push_back(task_id);
}
return live_nodes;
}

View File

@ -24,6 +24,7 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
@ -32,6 +33,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "grpcpp/channel.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/service/global_device_id.h"
#include "xla/tsl/platform/env.h"
namespace tsl {
@ -149,8 +151,15 @@ class DistributedRuntimeClient {
std::string barrier_id, absl::Duration timeout,
std::optional<absl::Span<const int32_t>> nodes) = 0;
// Returns the subset of live nodes, along with their incarnations. See
// CoordinationService.GetAliveTasks for detailed semantics.
virtual absl::StatusOr<absl::flat_hash_map<int32_t, IncarnationId>>
GetLiveNodesWithIncarnations(absl::Span<const int32_t> nodes) = 0;
// Returns the subset of live nodes. See CoordinationService.GetAliveTasks for
// detailed semantics.
//
// TODO: mwhittaker - Remove this function.
virtual absl::StatusOr<std::vector<int32_t>> GetLiveNodes(
absl::Span<const int32_t> nodes) = 0;

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include <gmock/gmock.h>
#include "absl/container/flat_hash_map.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
@ -42,6 +43,7 @@ limitations under the License.
#include "xla/pjrt/distributed/protocol.pb.h"
#include "xla/pjrt/distributed/service.h"
#include "xla/pjrt/distributed/topology_util.h"
#include "xla/service/global_device_id.h"
#include "xla/status_macros.h"
#include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h"
#include "xla/tsl/lib/core/status_test_util.h"
@ -57,6 +59,7 @@ namespace xla {
namespace {
using ::testing::IsEmpty;
using ::testing::Key;
using ::testing::Matches;
using ::testing::Pair;
using ::testing::UnorderedElementsAre;
@ -1001,10 +1004,10 @@ TEST_F(ClientServerTest, GetLiveTasksSucceeds) {
TF_ASSERT_OK(client->Connect());
// Get the set of live nodes. All three nodes should be live.
absl::StatusOr<std::vector<int32_t>> live_nodes =
client->GetLiveNodes(std::vector<int>{0, 1, 2});
absl::StatusOr<absl::flat_hash_map<int32_t, IncarnationId>> live_nodes =
client->GetLiveNodesWithIncarnations(std::vector<int>{0, 1, 2});
TF_ASSERT_OK(live_nodes.status());
EXPECT_THAT(*live_nodes, UnorderedElementsAre(0, 1, 2));
EXPECT_THAT(*live_nodes, UnorderedElementsAre(Key(0), Key(1), Key(2)));
});
}
}
@ -1023,7 +1026,7 @@ TEST_F(ClientServerTest, GetLiveTasksWithoutBeingAMember) {
// Get the set of live nodes but don't include ourselves.
std::vector<int> nodes{0, 1, 2};
nodes.erase(nodes.begin() + i);
EXPECT_THAT(client->GetLiveNodes(nodes),
EXPECT_THAT(client->GetLiveNodesWithIncarnations(nodes),
absl_testing::StatusIs(absl::StatusCode::kInvalidArgument));
});
}

View File

@ -1062,8 +1062,8 @@ TEST_F(ClientServerTest, GetAliveTasks_Succeed) {
auto thread_fn = [&](int node_id) -> absl::Status {
auto client = GetClient(node_id);
TF_RETURN_IF_ERROR(client->Connect());
absl::StatusOr<std::vector<tensorflow::CoordinatedTask>> alive_tasks =
client->GetAliveTasks({GetTask(0), GetTask(1)});
absl::StatusOr<std::vector<CoordinationServiceAgent::AliveTask>>
alive_tasks = client->GetAliveTasks({GetTask(0), GetTask(1)});
if (!alive_tasks.ok()) {
return alive_tasks.status();
}

View File

@ -1566,7 +1566,6 @@ void CoordinationService::RefreshAliveness() {
// the same set of alive tasks (alive_tasks) to every task in the barrier.
std::vector<CoordinatedTask> v{alive_tasks.begin(), alive_tasks.end()};
std::vector<IncarnationId> incarnation_ids = IncarnationIds(v);
absl::c_sort(incarnation_ids);
for (const GetAliveTasksCallback& done : it->dones) {
done(absl::OkStatus(), v, incarnation_ids);
}
@ -1618,7 +1617,6 @@ void CoordinationService::GetAliveTasksAsync(
if (TaskSetSubset(alive_tasks, it->in_barrier)) {
std::vector<CoordinatedTask> v{alive_tasks.begin(), alive_tasks.end()};
std::vector<IncarnationId> incarnation_ids = IncarnationIds(v);
absl::c_sort(incarnation_ids);
for (const GetAliveTasksCallback& done : it->dones) {
done(absl::OkStatus(), v, incarnation_ids);
}

View File

@ -1012,7 +1012,7 @@ void CoordinationServiceAgent::CancelBarrierAsync(absl::string_view barrier_id,
});
}
absl::StatusOr<std::vector<tensorflow::CoordinatedTask>>
absl::StatusOr<std::vector<CoordinationServiceAgent::AliveTask>>
CoordinationServiceAgent::GetAliveTasks(
const std::vector<CoordinatedTask>& tasks) {
// Validate the agent.
@ -1036,20 +1036,21 @@ CoordinationServiceAgent::GetAliveTasks(
};
leader_client_->GetAliveTasksAsync(request.get(), response.get(), done);
n.WaitForNotification();
// Parse the response.
if (!status.ok()) {
return status;
}
{
absl::MutexLock lock(incarnations_mu_);
for (int i = 0; i < response->alive_tasks_size(); ++i) {
incarnations_[response->alive_tasks(i).task_id()] =
response->incarnations(i);
}
// Parse the response.
absl::MutexLock lock(incarnations_mu_);
std::vector<AliveTask> alive_tasks;
for (int i = 0; i < response->alive_tasks_size(); ++i) {
int task_id = response->alive_tasks(i).task_id();
IncarnationId incarnation_id(response->incarnations(i));
alive_tasks.push_back(AliveTask{task_id, incarnation_id});
incarnations_[task_id] = incarnation_id;
}
return std::vector<tensorflow::CoordinatedTask>(
response->alive_tasks().begin(), response->alive_tasks().end());
return alive_tasks;
}
// Returns an error if agent is not running.

View File

@ -320,7 +320,11 @@ class CoordinationServiceAgent {
// has failed and that every task calls GetAliveTasks([A, B, C, D]). The
// invocation will return tasks [A, B, C]. The GetAliveTasks call acts as a
// barrier across tasks A, B, and C. Task D, which failed, is ignored.
absl::StatusOr<std::vector<tensorflow::CoordinatedTask>> GetAliveTasks(
struct AliveTask {
int task_id;
IncarnationId incarnation_id;
};
absl::StatusOr<std::vector<AliveTask>> GetAliveTasks(
const std::vector<tensorflow::CoordinatedTask>& tasks);
// Returns the latest known set of incarnation ids for every task. Incarnation

View File

@ -2562,9 +2562,9 @@ TEST_F(GetAliveTasksTest, SuccessfulGetAliveTasks) {
const std::vector<IncarnationId>& incarnations) {
EXPECT_OK(status);
EXPECT_THAT(alive_tasks, UnorderedElementsAreArray(GetTaskMatchers()));
EXPECT_EQ(incarnations,
(std::vector<IncarnationId>{IncarnationId(0), IncarnationId(1),
IncarnationId(2)}));
EXPECT_THAT(incarnations,
UnorderedElementsAre(IncarnationId(0), IncarnationId(1),
IncarnationId(2)));
finished.DecrementCount();
};
GetCoordinationService()->GetAliveTasksAsync(GetTask(0), GetTasks(), done);
@ -2583,8 +2583,8 @@ TEST_F(GetAliveTasksTest, FailedTaskBeforeCallingGetAliveTasks) {
EXPECT_OK(status);
EXPECT_THAT(alive_tasks, UnorderedElementsAre(EqualsProto(GetTask(0)),
EqualsProto(GetTask(1))));
EXPECT_EQ(incarnations,
(std::vector<IncarnationId>{IncarnationId(0), IncarnationId(1)}));
EXPECT_THAT(incarnations,
UnorderedElementsAre(IncarnationId(0), IncarnationId(1)));
finished.DecrementCount();
};
ASSERT_OK(GetCoordinationService()->ReportTaskError(
@ -2605,8 +2605,8 @@ TEST_F(GetAliveTasksTest, FailedTaskAfterCallingGetAliveTasks) {
EXPECT_OK(status);
EXPECT_THAT(alive_tasks, UnorderedElementsAre(EqualsProto(GetTask(0)),
EqualsProto(GetTask(1))));
EXPECT_EQ(incarnations,
(std::vector<IncarnationId>{IncarnationId(0), IncarnationId(1)}));
EXPECT_THAT(incarnations,
UnorderedElementsAre(IncarnationId(0), IncarnationId(1)));
finished.DecrementCount();
};
GetCoordinationService()->GetAliveTasksAsync(GetTask(0), GetTasks(), done);
@ -2630,8 +2630,8 @@ TEST_F(GetAliveTasksTest, ConcurrentGetAliveTasks) {
EXPECT_OK(status);
EXPECT_THAT(alive_tasks, UnorderedElementsAre(EqualsProto(tasks_01[0]),
EqualsProto(tasks_01[1])));
EXPECT_EQ(incarnations,
(std::vector<IncarnationId>{IncarnationId(0), IncarnationId(1)}));
EXPECT_THAT(incarnations,
UnorderedElementsAre(IncarnationId(0), IncarnationId(1)));
finished_01.DecrementCount();
};
@ -2644,8 +2644,8 @@ TEST_F(GetAliveTasksTest, ConcurrentGetAliveTasks) {
EXPECT_OK(status);
EXPECT_THAT(alive_tasks, UnorderedElementsAre(EqualsProto(tasks_12[0]),
EqualsProto(tasks_12[1])));
EXPECT_EQ(incarnations,
(std::vector<IncarnationId>{IncarnationId(1), IncarnationId(2)}));
EXPECT_THAT(incarnations,
UnorderedElementsAre(IncarnationId(1), IncarnationId(2)));
finished_12.DecrementCount();
};