Adjust the collective-permute cross host type to MULTI_HOST_NON_WORLD_LEVEL only.

PiperOrigin-RevId: 826327580
This commit is contained in:
Felix Wang 2025-10-30 22:33:10 -07:00 committed by TensorFlower Gardener
parent d90723f48e
commit d9c76aafeb
3 changed files with 122 additions and 50 deletions

View File

@ -105,7 +105,9 @@ cc_library(
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",

View File

@ -18,11 +18,12 @@ limitations under the License.
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <variant>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
@ -43,7 +44,24 @@ namespace xla {
namespace gpu {
namespace {
struct CommunicationMetadata {
// Computes a map from source node ID to a set of target node IDs for a
// collective-permute instruction. A node ID is computed by dividing the device
// (replica) ID by the number of devices per host.
absl::flat_hash_map<int64_t, absl::flat_hash_set<int64_t>>
GetSourceToTargetsNodeMap(const HloCollectivePermuteInstruction& instr,
int num_devices_per_host) {
absl::flat_hash_map<int64_t, absl::flat_hash_set<int64_t>>
source_to_targets_node_map;
for (const auto& [source, target] : instr.source_target_pairs()) {
int64_t source_node = source / num_devices_per_host;
int64_t target_node = target / num_devices_per_host;
source_to_targets_node_map[source_node].insert(target_node);
}
return source_to_targets_node_map;
}
struct CollectiveMetadata {
// map for ops with `replica_groups`, e.g. all-gather.
absl::flat_hash_map<int64_t, size_t> node_to_participant_count;
int num_devices_per_host;
int64_t replica_count;
@ -66,14 +84,12 @@ bool SameParticipantCounts(const absl::flat_hash_map<int64_t, size_t>& lhs,
return lhs_counts == rhs_counts;
}
absl::StatusOr<CommunicationMetadata> CommunicationContext(
const HloChannelInstruction& instr, int num_devices_per_host) {
absl::StatusOr<CollectiveMetadata> CommunicationContext(
const HloCollectiveInstruction& instr, int num_devices_per_host) {
absl::flat_hash_map<int64_t, size_t> node_to_participant_count;
if (const HloCollectiveInstruction* collective =
DynCast<HloCollectiveInstruction>(&instr)) {
for (const ReplicaGroup& replica_group :
collective->device_list().replica_groups()) {
instr.device_list().replica_groups()) {
absl::flat_hash_map<int64_t, size_t> buffer;
for (int64_t rank : replica_group.replica_ids()) {
int64_t node_id = rank / num_devices_per_host;
@ -81,34 +97,18 @@ absl::StatusOr<CommunicationMetadata> CommunicationContext(
}
if (!node_to_participant_count.empty() &&
!SameParticipantCounts(buffer, node_to_participant_count)) {
return absl::FailedPreconditionError(
absl::StrCat("Non homogenous replica group: ",
collective->device_list().ToString()));
return absl::FailedPreconditionError(absl::StrCat(
"Non homogenous replica group: ", instr.device_list().ToString()));
}
if (node_to_participant_count.empty()) {
node_to_participant_count = buffer;
}
}
} else if (const HloCollectivePermuteInstruction* collective_permute =
DynCast<HloCollectivePermuteInstruction>(&instr)) {
for (const auto& [source, target] :
collective_permute->source_target_pairs()) {
int64_t source_node = source / num_devices_per_host;
int64_t target_node = target / num_devices_per_host;
node_to_participant_count[source_node]++;
node_to_participant_count[target_node]++;
}
} else {
return absl::FailedPreconditionError(
"Cannot determine communication context for non-collective channel "
"instruction");
}
return CommunicationMetadata{node_to_participant_count, num_devices_per_host,
return CollectiveMetadata{node_to_participant_count, num_devices_per_host,
instr.GetModule()->config().replica_count()};
}
bool IsSingleHost(const CommunicationMetadata& pattern) {
bool IsSingleHost(const CollectiveMetadata& pattern) {
if (pattern.node_to_participant_count.size() == 1) {
return true;
}
@ -117,7 +117,7 @@ bool IsSingleHost(const CommunicationMetadata& pattern) {
pattern.replica_count <= pattern.num_devices_per_host;
}
bool IsWorldLevelCommunication(const CommunicationMetadata& pattern) {
bool IsWorldLevelCommunication(const CollectiveMetadata& pattern) {
if (!IsSingleHost(pattern) && pattern.node_to_participant_count.empty()) {
return true;
}
@ -128,7 +128,7 @@ bool IsWorldLevelCommunication(const CommunicationMetadata& pattern) {
});
}
bool IsNonWorldLevelCommunication(const CommunicationMetadata& pattern) {
bool IsNonWorldLevelCommunication(const CollectiveMetadata& pattern) {
return !IsSingleHost(pattern) && !IsWorldLevelCommunication(pattern);
}
@ -149,8 +149,10 @@ absl::StatusOr<GPUCommunicationType> CommunicationType(
return absl::FailedPreconditionError("Only CUDA is supported.");
}
TF_ASSIGN_OR_RETURN(CommunicationMetadata comm,
CommunicationContext(instr, num_devices_per_host));
if (const auto* collective = DynCast<HloCollectiveInstruction>(&instr)) {
TF_ASSIGN_OR_RETURN(
CollectiveMetadata comm,
CommunicationContext(*collective, num_devices_per_host));
if (IsSingleHost(comm)) {
return GPUCommunicationType::SINGLE_HOST;
}
@ -160,6 +162,26 @@ absl::StatusOr<GPUCommunicationType> CommunicationType(
if (IsNonWorldLevelCommunication(comm)) {
return GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL;
}
} else if (const auto* collective_permute =
DynCast<HloCollectivePermuteInstruction>(&instr)) {
const auto source_to_targets_node_map =
GetSourceToTargetsNodeMap(*collective_permute, num_devices_per_host);
for (const auto& [source_node, target_node_set] :
source_to_targets_node_map) {
if (target_node_set.size() > 1) {
return GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL;
}
CHECK_EQ(target_node_set.size(), 1);
if (source_node != *target_node_set.begin()) {
return GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL;
}
}
return GPUCommunicationType::SINGLE_HOST;
} else {
return absl::FailedPreconditionError(
"Cannot determine communication type for non-collective channel "
"instruction");
}
return GPUCommunicationType::UNDEFINED;
}

View File

@ -266,6 +266,26 @@ TEST_F(CommunicationTypeTest, DetectsSingleHostCollectivePermute) {
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
}
TEST_F(CommunicationTypeTest, DetectsSingleHostCollectivePermuteSinglePair) {
absl::string_view kHlo = R"(
HloModule m, num_partitions=8
ENTRY e {
p = f32[128] parameter(0)
ROOT _ = f32[128] collective-permute(p),
source_target_pairs={{0,7},{7,0}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo));
HloChannelInstruction* instr = Cast<HloChannelInstruction>(
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
}
TEST_F(CommunicationTypeTest, DetectNonWorldLevelCollectivePermute) {
absl::string_view kHlo = R"(
HloModule m, num_partitions=16
@ -304,7 +324,35 @@ TEST_F(CommunicationTypeTest, DetectWorldLevelCollectivePermute) {
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_WORLD_LEVEL));
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL));
}
TEST_F(CommunicationTypeTest, DetectsCrossHostCollectivePermuteMixed) {
absl::string_view kHlo = R"(
HloModule m, num_partitions=16
ENTRY e {
p = f32[128] parameter(0)
ROOT _ = f32[128] collective-permute(p),
source_target_pairs={{0,7},
{0,8},
{1,9},
{2,10},
{3,11},
{4,12},
{5,13},
{6,14},
{7,15}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo));
HloChannelInstruction* instr = Cast<HloChannelInstruction>(
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL));
}
} // namespace