Rename rail-aligned into world-level in collective_ops_utils.h

Network rail usually refers to a set of  NICs connected by the same fabric/switch, e.g. [Rail-optimized topology](https://developer.nvidia.com/blog/doubling-all2all-performance-with-nvidia-collective-communication-library-2-12/).

PiperOrigin-RevId: 825696577
This commit is contained in:
Felix Wang 2025-10-29 13:59:15 -07:00 committed by TensorFlower Gardener
parent ca3d7d6305
commit cecce70fb2
5 changed files with 85 additions and 85 deletions

View File

@ -151,10 +151,10 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
case GPUCommunicationType::SINGLE_HOST:
iota = IotaReplicaGroupList(num_hosts, kNumGpusPerHost);
break;
case GPUCommunicationType::RAIL_ALIGNED:
case GPUCommunicationType::MULTI_HOST_WORLD_LEVEL:
iota = IotaReplicaGroupList(1, num_hosts * kNumGpusPerHost);
break;
case GPUCommunicationType::NON_RAIL_ALIGNED:
case GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL:
iota = IotaReplicaGroupList(kNumGpusPerHost, num_hosts,
{num_hosts, kNumGpusPerHost}, {1, 0});
break;
@ -169,56 +169,56 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
std::vector<SpaceSpec> test_space_ = {
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/1024,
},
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/2 * 1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/2 * 1024,
},
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/4,
/*network_througput_bytes=*/4 * 1024,
},
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/2 * 1024,
/*num_nodes=*/4,
/*network_througput_bytes=*/5 * 1024,
},
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/512,
},
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/2 * 1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/2 * 512,
},
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/4,
/*network_througput_bytes=*/4 * 512,
},
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/2 * 1024,
/*num_nodes=*/4,
/*network_througput_bytes=*/5 * 512,
@ -253,56 +253,56 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
},
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/1024,
},
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/2 * 1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/2 * 1024,
},
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/4,
/*network_througput_bytes=*/4 * 1024,
},
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/2 * 1024,
/*num_nodes=*/4,
/*network_througput_bytes=*/5 * 1024,
},
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/512,
},
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/2 * 1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/2 * 512,
},
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/4,
/*network_througput_bytes=*/4 * 512,
},
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/2 * 1024,
/*num_nodes=*/4,
/*network_througput_bytes=*/5 * 512,
@ -337,56 +337,56 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
},
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/1024,
},
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/2 * 1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/2 * 1024,
},
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/4,
/*network_througput_bytes=*/4 * 1024,
},
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/2 * 1024,
/*num_nodes=*/4,
/*network_througput_bytes=*/5 * 1024,
},
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/512,
},
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/2 * 1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/2 * 512,
},
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/4,
/*network_througput_bytes=*/4 * 512,
},
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/2 * 1024,
/*num_nodes=*/4,
/*network_througput_bytes=*/5 * 512,
@ -428,14 +428,14 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
},
{
/*opcode=*/HloOpcode::kAllToAll,
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/2048,
},
{
/*opcode=*/HloOpcode::kAllToAll,
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/4096,
@ -459,7 +459,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/
GPUCommunicationType::RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/2,
},
@ -471,7 +471,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/
GPUCommunicationType::RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/8,
},
@ -483,7 +483,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/
GPUCommunicationType::RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/4 * 1024,
/*num_nodes=*/2,
},
@ -495,7 +495,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/
GPUCommunicationType::RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/3,
},
@ -507,7 +507,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/
GPUCommunicationType::RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024 + 256,
/*num_nodes=*/2,
},
@ -519,7 +519,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/
GPUCommunicationType::NON_RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/2,
},
@ -531,7 +531,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/
GPUCommunicationType::NON_RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/8,
},
@ -543,7 +543,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/
GPUCommunicationType::NON_RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/4 * 1024,
/*num_nodes=*/2,
},
@ -555,7 +555,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/
GPUCommunicationType::NON_RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/3,
},
@ -567,7 +567,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/
GPUCommunicationType::NON_RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1024 + 256,
/*num_nodes=*/2,
},
@ -639,7 +639,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/
GPUCommunicationType::RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/2,
},
@ -651,7 +651,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/
GPUCommunicationType::RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/8,
},
@ -663,7 +663,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/
GPUCommunicationType::RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/4 * 1024,
/*num_nodes=*/2,
},
@ -675,7 +675,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/
GPUCommunicationType::RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/3,
},
@ -687,7 +687,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/
GPUCommunicationType::RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024 + 256,
/*num_nodes=*/2,
},
@ -699,7 +699,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/
GPUCommunicationType::NON_RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/2,
},
@ -711,7 +711,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/
GPUCommunicationType::NON_RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/8,
},
@ -723,7 +723,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/
GPUCommunicationType::NON_RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/4 * 1024,
/*num_nodes=*/2,
},
@ -735,7 +735,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/
GPUCommunicationType::NON_RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1032,
/*num_nodes=*/3,
},
@ -747,7 +747,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/
GPUCommunicationType::NON_RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1024 + 256,
/*num_nodes=*/2,
},
@ -807,7 +807,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/
GPUCommunicationType::RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/2,
},
@ -819,7 +819,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/
GPUCommunicationType::RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/8,
},
@ -831,7 +831,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/
GPUCommunicationType::RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/4 * 1024,
/*num_nodes=*/2,
},
@ -843,7 +843,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/
GPUCommunicationType::RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1056,
/*num_nodes=*/3,
},
@ -855,7 +855,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/
GPUCommunicationType::RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024 + 256,
/*num_nodes=*/2,
},
@ -867,7 +867,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/
GPUCommunicationType::NON_RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/2,
},
@ -879,7 +879,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/
GPUCommunicationType::NON_RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/8,
},
@ -891,7 +891,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/
GPUCommunicationType::NON_RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/4 * 1024,
/*num_nodes=*/2,
},
@ -903,7 +903,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/
GPUCommunicationType::NON_RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1032,
/*num_nodes=*/3,
},
@ -915,7 +915,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/
GPUCommunicationType::NON_RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1024 + 256,
/*num_nodes=*/2,
},
@ -986,7 +986,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllToAll,
/*comm=*/
GPUCommunicationType::RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/2,
},
@ -997,7 +997,7 @@ INSTANTIATE_TEST_SUITE_P(
{
/*opcode=*/HloOpcode::kAllToAll,
/*comm=*/
GPUCommunicationType::NON_RAIL_ALIGNED,
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
/*tensor_size=*/1024,
/*num_nodes=*/2,
},

View File

@ -204,13 +204,13 @@ absl::StatusOr<absl::Duration> DispatchEstimation(
GetReplicaGroupCountAndSize(&instr));
switch (comm) {
case GPUCommunicationType::RAIL_ALIGNED: {
case GPUCommunicationType::MULTI_HOST_WORLD_LEVEL: {
return DCNCollectiveDuration(
num_groups_and_devices->second / sol_flags.gpus_per_node,
/*num_communicators=*/num_groups_and_devices->first, instr,
gpu_device_info, sol_flags, analysis, symbolic_expr_context);
}
case GPUCommunicationType::NON_RAIL_ALIGNED: {
case GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL: {
return DCNCollectiveDuration(
num_groups_and_devices->second,
/*num_communicators=*/num_groups_and_devices->first, instr,

View File

@ -115,7 +115,7 @@ bool IsSingleHost(const CommunicationMetadata& pattern) {
pattern.replica_count <= pattern.num_devices_per_host;
}
bool IsRailAligned(const CommunicationMetadata& pattern) {
bool IsWorldLevelCommunication(const CommunicationMetadata& pattern) {
if (!IsSingleHost(pattern) && pattern.node_to_participant_count.empty()) {
return true;
}
@ -126,8 +126,8 @@ bool IsRailAligned(const CommunicationMetadata& pattern) {
});
}
bool IsNonRailAligned(const CommunicationMetadata& pattern) {
return !IsSingleHost(pattern) && !IsRailAligned(pattern);
bool IsNonWorldLevelCommunication(const CommunicationMetadata& pattern) {
return !IsSingleHost(pattern) && !IsWorldLevelCommunication(pattern);
}
} // namespace
@ -152,11 +152,11 @@ absl::StatusOr<GPUCommunicationType> CommunicationType(
if (IsSingleHost(comm)) {
return GPUCommunicationType::SINGLE_HOST;
}
if (IsRailAligned(comm)) {
return GPUCommunicationType::RAIL_ALIGNED;
if (IsWorldLevelCommunication(comm)) {
return GPUCommunicationType::MULTI_HOST_WORLD_LEVEL;
}
if (IsNonRailAligned(comm)) {
return GPUCommunicationType::NON_RAIL_ALIGNED;
if (IsNonWorldLevelCommunication(comm)) {
return GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL;
}
return GPUCommunicationType::UNDEFINED;

View File

@ -29,10 +29,10 @@ enum class GPUCommunicationType {
UNDEFINED = 0,
// Communication involves devices from multiple hosts, and every host
// involved in the communication pattern has all of its devices participating.
RAIL_ALIGNED = 1,
MULTI_HOST_WORLD_LEVEL = 1,
// Communication involves devices from multiple hosts, but at least one of
// the involved hosts has only a subset of its devices participating.
NON_RAIL_ALIGNED = 2,
MULTI_HOST_NON_WORLD_LEVEL = 2,
// All devices participating in the collective operation reside on the same
// host machine.
SINGLE_HOST = 3

View File

@ -111,7 +111,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHost16Devices) {
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
}
TEST_F(CommunicationTypeTest, DetectRailAlignedAllDevices) {
TEST_F(CommunicationTypeTest, DetectWorldLevelAllDevices) {
absl::string_view kHlo = R"(
HloModule m, num_partitions=16
@ -131,10 +131,10 @@ TEST_F(CommunicationTypeTest, DetectRailAlignedAllDevices) {
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::RAIL_ALIGNED));
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_WORLD_LEVEL));
}
TEST_F(CommunicationTypeTest, DetectRailAlignedHalfMesh) {
TEST_F(CommunicationTypeTest, DetectWorldLevelHalfMesh) {
absl::string_view kHlo = R"(
HloModule m, num_partitions=32
@ -157,10 +157,10 @@ TEST_F(CommunicationTypeTest, DetectRailAlignedHalfMesh) {
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::RAIL_ALIGNED));
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_WORLD_LEVEL));
}
TEST_F(CommunicationTypeTest, DetectNonRailAligned) {
TEST_F(CommunicationTypeTest, DetectNonWorldLevel) {
absl::string_view kHlo = R"(
HloModule m, num_partitions=16
@ -180,7 +180,7 @@ TEST_F(CommunicationTypeTest, DetectNonRailAligned) {
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::NON_RAIL_ALIGNED));
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL));
}
TEST_F(CommunicationTypeTest, DetectsSingleHost16DevicesForEmptyReplicaGroups) {
@ -204,7 +204,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHost16DevicesForEmptyReplicaGroups) {
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
}
TEST_F(CommunicationTypeTest, DetectsRailAligned8DevicesForEmptyReplicaGroups) {
TEST_F(CommunicationTypeTest, DetectWorldLevel8DevicesForEmptyReplicaGroups) {
absl::string_view kHlo = R"(
HloModule m, replica_count=16
@ -222,10 +222,10 @@ TEST_F(CommunicationTypeTest, DetectsRailAligned8DevicesForEmptyReplicaGroups) {
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::RAIL_ALIGNED));
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_WORLD_LEVEL));
}
TEST_F(CommunicationTypeTest, DetectsNonRailAligned16Devices) {
TEST_F(CommunicationTypeTest, DetectNonWorldLevel16Devices) {
absl::string_view kHlo = R"(
HloModule m, replica_count=16
@ -243,7 +243,7 @@ TEST_F(CommunicationTypeTest, DetectsNonRailAligned16Devices) {
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::NON_RAIL_ALIGNED));
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL));
}
TEST_F(CommunicationTypeTest, DetectsSingleHostCollectivePermute) {
@ -266,7 +266,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHostCollectivePermute) {
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
}
TEST_F(CommunicationTypeTest, DetectsNonRailAlignedCollectivePermute) {
TEST_F(CommunicationTypeTest, DetectNonWorldLevelCollectivePermute) {
absl::string_view kHlo = R"(
HloModule m, num_partitions=16
@ -284,10 +284,10 @@ TEST_F(CommunicationTypeTest, DetectsNonRailAlignedCollectivePermute) {
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::NON_RAIL_ALIGNED));
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL));
}
TEST_F(CommunicationTypeTest, DetectsRailAlignedCollectivePermute) {
TEST_F(CommunicationTypeTest, DetectWorldLevelCollectivePermute) {
absl::string_view kHlo = R"(
HloModule m, num_partitions=16
@ -304,7 +304,7 @@ TEST_F(CommunicationTypeTest, DetectsRailAlignedCollectivePermute) {
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::RAIL_ALIGNED));
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_WORLD_LEVEL));
}
} // namespace