diff --git a/third_party/xla/xla/service/gpu/model/collective_interpolator_test.cc b/third_party/xla/xla/service/gpu/model/collective_interpolator_test.cc index 26cfc6a2bbb..8a4348ecf31 100644 --- a/third_party/xla/xla/service/gpu/model/collective_interpolator_test.cc +++ b/third_party/xla/xla/service/gpu/model/collective_interpolator_test.cc @@ -151,10 +151,10 @@ class CollectiveInterpolationTest : public TestWithParam { case GPUCommunicationType::SINGLE_HOST: iota = IotaReplicaGroupList(num_hosts, kNumGpusPerHost); break; - case GPUCommunicationType::RAIL_ALIGNED: + case GPUCommunicationType::MULTI_HOST_WORLD_LEVEL: iota = IotaReplicaGroupList(1, num_hosts * kNumGpusPerHost); break; - case GPUCommunicationType::NON_RAIL_ALIGNED: + case GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL: iota = IotaReplicaGroupList(kNumGpusPerHost, num_hosts, {num_hosts, kNumGpusPerHost}, {1, 0}); break; @@ -169,56 +169,56 @@ class CollectiveInterpolationTest : public TestWithParam { std::vector test_space_ = { { /*opcode=*/HloOpcode::kAllReduce, - /*comm=*/GPUCommunicationType::RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/2, /*network_througput_bytes=*/1024, }, { /*opcode=*/HloOpcode::kAllReduce, - /*comm=*/GPUCommunicationType::RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/2 * 1024, /*num_nodes=*/2, /*network_througput_bytes=*/2 * 1024, }, { /*opcode=*/HloOpcode::kAllReduce, - /*comm=*/GPUCommunicationType::RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/4, /*network_througput_bytes=*/4 * 1024, }, { /*opcode=*/HloOpcode::kAllReduce, - /*comm=*/GPUCommunicationType::RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/2 * 1024, /*num_nodes=*/4, /*network_througput_bytes=*/5 * 1024, }, { /*opcode=*/HloOpcode::kAllReduce, - /*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/2, /*network_througput_bytes=*/512, }, { /*opcode=*/HloOpcode::kAllReduce, - /*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/2 * 1024, /*num_nodes=*/2, /*network_througput_bytes=*/2 * 512, }, { /*opcode=*/HloOpcode::kAllReduce, - /*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/4, /*network_througput_bytes=*/4 * 512, }, { /*opcode=*/HloOpcode::kAllReduce, - /*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/2 * 1024, /*num_nodes=*/4, /*network_througput_bytes=*/5 * 512, @@ -253,56 +253,56 @@ class CollectiveInterpolationTest : public TestWithParam { }, { /*opcode=*/HloOpcode::kReduceScatter, - /*comm=*/GPUCommunicationType::RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/2, /*network_througput_bytes=*/1024, }, { /*opcode=*/HloOpcode::kReduceScatter, - /*comm=*/GPUCommunicationType::RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/2 * 1024, /*num_nodes=*/2, /*network_througput_bytes=*/2 * 1024, }, { /*opcode=*/HloOpcode::kReduceScatter, - /*comm=*/GPUCommunicationType::RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/4, /*network_througput_bytes=*/4 * 1024, }, { /*opcode=*/HloOpcode::kReduceScatter, - /*comm=*/GPUCommunicationType::RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/2 * 1024, /*num_nodes=*/4, /*network_througput_bytes=*/5 * 1024, }, { /*opcode=*/HloOpcode::kReduceScatter, - /*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/2, /*network_througput_bytes=*/512, }, { /*opcode=*/HloOpcode::kReduceScatter, - /*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/2 * 1024, /*num_nodes=*/2, /*network_througput_bytes=*/2 * 512, }, { /*opcode=*/HloOpcode::kReduceScatter, - /*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/4, /*network_througput_bytes=*/4 * 512, }, { /*opcode=*/HloOpcode::kReduceScatter, - /*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/2 * 1024, /*num_nodes=*/4, /*network_througput_bytes=*/5 * 512, @@ -337,56 +337,56 @@ class CollectiveInterpolationTest : public TestWithParam { }, { /*opcode=*/HloOpcode::kAllGather, - /*comm=*/GPUCommunicationType::RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/2, /*network_througput_bytes=*/1024, }, { /*opcode=*/HloOpcode::kAllGather, - /*comm=*/GPUCommunicationType::RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/2 * 1024, /*num_nodes=*/2, /*network_througput_bytes=*/2 * 1024, }, { /*opcode=*/HloOpcode::kAllGather, - /*comm=*/GPUCommunicationType::RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/4, /*network_througput_bytes=*/4 * 1024, }, { /*opcode=*/HloOpcode::kAllGather, - /*comm=*/GPUCommunicationType::RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/2 * 1024, /*num_nodes=*/4, /*network_througput_bytes=*/5 * 1024, }, { /*opcode=*/HloOpcode::kAllGather, - /*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/2, /*network_througput_bytes=*/512, }, { /*opcode=*/HloOpcode::kAllGather, - /*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/2 * 1024, /*num_nodes=*/2, /*network_througput_bytes=*/2 * 512, }, { /*opcode=*/HloOpcode::kAllGather, - /*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/4, /*network_througput_bytes=*/4 * 512, }, { /*opcode=*/HloOpcode::kAllGather, - /*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/2 * 1024, /*num_nodes=*/4, /*network_througput_bytes=*/5 * 512, @@ -428,14 +428,14 @@ class CollectiveInterpolationTest : public TestWithParam { }, { /*opcode=*/HloOpcode::kAllToAll, - /*comm=*/GPUCommunicationType::RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/2, /*network_througput_bytes=*/2048, }, { /*opcode=*/HloOpcode::kAllToAll, - /*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED, + /*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/2, /*network_througput_bytes=*/4096, @@ -459,7 +459,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllReduce, /*comm=*/ - GPUCommunicationType::RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/2, }, @@ -471,7 +471,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllReduce, /*comm=*/ - GPUCommunicationType::RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/8, }, @@ -483,7 +483,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllReduce, /*comm=*/ - GPUCommunicationType::RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/4 * 1024, /*num_nodes=*/2, }, @@ -495,7 +495,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllReduce, /*comm=*/ - GPUCommunicationType::RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/3, }, @@ -507,7 +507,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllReduce, /*comm=*/ - GPUCommunicationType::RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024 + 256, /*num_nodes=*/2, }, @@ -519,7 +519,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllReduce, /*comm=*/ - GPUCommunicationType::NON_RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/2, }, @@ -531,7 +531,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllReduce, /*comm=*/ - GPUCommunicationType::NON_RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/8, }, @@ -543,7 +543,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllReduce, /*comm=*/ - GPUCommunicationType::NON_RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/4 * 1024, /*num_nodes=*/2, }, @@ -555,7 +555,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllReduce, /*comm=*/ - GPUCommunicationType::NON_RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/3, }, @@ -567,7 +567,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllReduce, /*comm=*/ - GPUCommunicationType::NON_RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1024 + 256, /*num_nodes=*/2, }, @@ -639,7 +639,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kReduceScatter, /*comm=*/ - GPUCommunicationType::RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/2, }, @@ -651,7 +651,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kReduceScatter, /*comm=*/ - GPUCommunicationType::RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/8, }, @@ -663,7 +663,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kReduceScatter, /*comm=*/ - GPUCommunicationType::RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/4 * 1024, /*num_nodes=*/2, }, @@ -675,7 +675,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kReduceScatter, /*comm=*/ - GPUCommunicationType::RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/3, }, @@ -687,7 +687,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kReduceScatter, /*comm=*/ - GPUCommunicationType::RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024 + 256, /*num_nodes=*/2, }, @@ -699,7 +699,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kReduceScatter, /*comm=*/ - GPUCommunicationType::NON_RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/2, }, @@ -711,7 +711,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kReduceScatter, /*comm=*/ - GPUCommunicationType::NON_RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/8, }, @@ -723,7 +723,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kReduceScatter, /*comm=*/ - GPUCommunicationType::NON_RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/4 * 1024, /*num_nodes=*/2, }, @@ -735,7 +735,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kReduceScatter, /*comm=*/ - GPUCommunicationType::NON_RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1032, /*num_nodes=*/3, }, @@ -747,7 +747,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kReduceScatter, /*comm=*/ - GPUCommunicationType::NON_RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1024 + 256, /*num_nodes=*/2, }, @@ -807,7 +807,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllGather, /*comm=*/ - GPUCommunicationType::RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/2, }, @@ -819,7 +819,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllGather, /*comm=*/ - GPUCommunicationType::RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/8, }, @@ -831,7 +831,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllGather, /*comm=*/ - GPUCommunicationType::RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/4 * 1024, /*num_nodes=*/2, }, @@ -843,7 +843,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllGather, /*comm=*/ - GPUCommunicationType::RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1056, /*num_nodes=*/3, }, @@ -855,7 +855,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllGather, /*comm=*/ - GPUCommunicationType::RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024 + 256, /*num_nodes=*/2, }, @@ -867,7 +867,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllGather, /*comm=*/ - GPUCommunicationType::NON_RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/2, }, @@ -879,7 +879,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllGather, /*comm=*/ - GPUCommunicationType::NON_RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/8, }, @@ -891,7 +891,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllGather, /*comm=*/ - GPUCommunicationType::NON_RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/4 * 1024, /*num_nodes=*/2, }, @@ -903,7 +903,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllGather, /*comm=*/ - GPUCommunicationType::NON_RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1032, /*num_nodes=*/3, }, @@ -915,7 +915,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllGather, /*comm=*/ - GPUCommunicationType::NON_RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1024 + 256, /*num_nodes=*/2, }, @@ -986,7 +986,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllToAll, /*comm=*/ - GPUCommunicationType::RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/2, }, @@ -997,7 +997,7 @@ INSTANTIATE_TEST_SUITE_P( { /*opcode=*/HloOpcode::kAllToAll, /*comm=*/ - GPUCommunicationType::NON_RAIL_ALIGNED, + GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL, /*tensor_size=*/1024, /*num_nodes=*/2, }, diff --git a/third_party/xla/xla/service/gpu/model/sol_latency_estimator.cc b/third_party/xla/xla/service/gpu/model/sol_latency_estimator.cc index b6c91f08816..ad9a657549c 100644 --- a/third_party/xla/xla/service/gpu/model/sol_latency_estimator.cc +++ b/third_party/xla/xla/service/gpu/model/sol_latency_estimator.cc @@ -204,13 +204,13 @@ absl::StatusOr DispatchEstimation( GetReplicaGroupCountAndSize(&instr)); switch (comm) { - case GPUCommunicationType::RAIL_ALIGNED: { + case GPUCommunicationType::MULTI_HOST_WORLD_LEVEL: { return DCNCollectiveDuration( num_groups_and_devices->second / sol_flags.gpus_per_node, /*num_communicators=*/num_groups_and_devices->first, instr, gpu_device_info, sol_flags, analysis, symbolic_expr_context); } - case GPUCommunicationType::NON_RAIL_ALIGNED: { + case GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL: { return DCNCollectiveDuration( num_groups_and_devices->second, /*num_communicators=*/num_groups_and_devices->first, instr, diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc index 309498780a4..cd4b5e1ef57 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc @@ -115,7 +115,7 @@ bool IsSingleHost(const CommunicationMetadata& pattern) { pattern.replica_count <= pattern.num_devices_per_host; } -bool IsRailAligned(const CommunicationMetadata& pattern) { +bool IsWorldLevelCommunication(const CommunicationMetadata& pattern) { if (!IsSingleHost(pattern) && pattern.node_to_participant_count.empty()) { return true; } @@ -126,8 +126,8 @@ bool IsRailAligned(const CommunicationMetadata& pattern) { }); } -bool IsNonRailAligned(const CommunicationMetadata& pattern) { - return !IsSingleHost(pattern) && !IsRailAligned(pattern); +bool IsNonWorldLevelCommunication(const CommunicationMetadata& pattern) { + return !IsSingleHost(pattern) && !IsWorldLevelCommunication(pattern); } } // namespace @@ -152,11 +152,11 @@ absl::StatusOr CommunicationType( if (IsSingleHost(comm)) { return GPUCommunicationType::SINGLE_HOST; } - if (IsRailAligned(comm)) { - return GPUCommunicationType::RAIL_ALIGNED; + if (IsWorldLevelCommunication(comm)) { + return GPUCommunicationType::MULTI_HOST_WORLD_LEVEL; } - if (IsNonRailAligned(comm)) { - return GPUCommunicationType::NON_RAIL_ALIGNED; + if (IsNonWorldLevelCommunication(comm)) { + return GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL; } return GPUCommunicationType::UNDEFINED; diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.h b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.h index 62161101f5d..719940c9e4a 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.h +++ b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.h @@ -29,10 +29,10 @@ enum class GPUCommunicationType { UNDEFINED = 0, // Communication involves devices from multiple hosts, and every host // involved in the communication pattern has all of its devices participating. - RAIL_ALIGNED = 1, + MULTI_HOST_WORLD_LEVEL = 1, // Communication involves devices from multiple hosts, but at least one of // the involved hosts has only a subset of its devices participating. - NON_RAIL_ALIGNED = 2, + MULTI_HOST_NON_WORLD_LEVEL = 2, // All devices participating in the collective operation reside on the same // host machine. SINGLE_HOST = 3 diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils_test.cc b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils_test.cc index 7ec6d51709d..35b9f1f5fc6 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils_test.cc @@ -111,7 +111,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHost16Devices) { IsOkAndHolds(GPUCommunicationType::SINGLE_HOST)); } -TEST_F(CommunicationTypeTest, DetectRailAlignedAllDevices) { +TEST_F(CommunicationTypeTest, DetectWorldLevelAllDevices) { absl::string_view kHlo = R"( HloModule m, num_partitions=16 @@ -131,10 +131,10 @@ TEST_F(CommunicationTypeTest, DetectRailAlignedAllDevices) { module->entry_computation()->root_instruction()); EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr, device_info().gpu_compute_capability()), - IsOkAndHolds(GPUCommunicationType::RAIL_ALIGNED)); + IsOkAndHolds(GPUCommunicationType::MULTI_HOST_WORLD_LEVEL)); } -TEST_F(CommunicationTypeTest, DetectRailAlignedHalfMesh) { +TEST_F(CommunicationTypeTest, DetectWorldLevelHalfMesh) { absl::string_view kHlo = R"( HloModule m, num_partitions=32 @@ -157,10 +157,10 @@ TEST_F(CommunicationTypeTest, DetectRailAlignedHalfMesh) { module->entry_computation()->root_instruction()); EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr, device_info().gpu_compute_capability()), - IsOkAndHolds(GPUCommunicationType::RAIL_ALIGNED)); + IsOkAndHolds(GPUCommunicationType::MULTI_HOST_WORLD_LEVEL)); } -TEST_F(CommunicationTypeTest, DetectNonRailAligned) { +TEST_F(CommunicationTypeTest, DetectNonWorldLevel) { absl::string_view kHlo = R"( HloModule m, num_partitions=16 @@ -180,7 +180,7 @@ TEST_F(CommunicationTypeTest, DetectNonRailAligned) { module->entry_computation()->root_instruction()); EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr, device_info().gpu_compute_capability()), - IsOkAndHolds(GPUCommunicationType::NON_RAIL_ALIGNED)); + IsOkAndHolds(GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL)); } TEST_F(CommunicationTypeTest, DetectsSingleHost16DevicesForEmptyReplicaGroups) { @@ -204,7 +204,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHost16DevicesForEmptyReplicaGroups) { IsOkAndHolds(GPUCommunicationType::SINGLE_HOST)); } -TEST_F(CommunicationTypeTest, DetectsRailAligned8DevicesForEmptyReplicaGroups) { +TEST_F(CommunicationTypeTest, DetectWorldLevel8DevicesForEmptyReplicaGroups) { absl::string_view kHlo = R"( HloModule m, replica_count=16 @@ -222,10 +222,10 @@ TEST_F(CommunicationTypeTest, DetectsRailAligned8DevicesForEmptyReplicaGroups) { module->entry_computation()->root_instruction()); EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr, device_info().gpu_compute_capability()), - IsOkAndHolds(GPUCommunicationType::RAIL_ALIGNED)); + IsOkAndHolds(GPUCommunicationType::MULTI_HOST_WORLD_LEVEL)); } -TEST_F(CommunicationTypeTest, DetectsNonRailAligned16Devices) { +TEST_F(CommunicationTypeTest, DetectNonWorldLevel16Devices) { absl::string_view kHlo = R"( HloModule m, replica_count=16 @@ -243,7 +243,7 @@ TEST_F(CommunicationTypeTest, DetectsNonRailAligned16Devices) { module->entry_computation()->root_instruction()); EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr, device_info().gpu_compute_capability()), - IsOkAndHolds(GPUCommunicationType::NON_RAIL_ALIGNED)); + IsOkAndHolds(GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL)); } TEST_F(CommunicationTypeTest, DetectsSingleHostCollectivePermute) { @@ -266,7 +266,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHostCollectivePermute) { IsOkAndHolds(GPUCommunicationType::SINGLE_HOST)); } -TEST_F(CommunicationTypeTest, DetectsNonRailAlignedCollectivePermute) { +TEST_F(CommunicationTypeTest, DetectNonWorldLevelCollectivePermute) { absl::string_view kHlo = R"( HloModule m, num_partitions=16 @@ -284,10 +284,10 @@ TEST_F(CommunicationTypeTest, DetectsNonRailAlignedCollectivePermute) { module->entry_computation()->root_instruction()); EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr, device_info().gpu_compute_capability()), - IsOkAndHolds(GPUCommunicationType::NON_RAIL_ALIGNED)); + IsOkAndHolds(GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL)); } -TEST_F(CommunicationTypeTest, DetectsRailAlignedCollectivePermute) { +TEST_F(CommunicationTypeTest, DetectWorldLevelCollectivePermute) { absl::string_view kHlo = R"( HloModule m, num_partitions=16 @@ -304,7 +304,7 @@ TEST_F(CommunicationTypeTest, DetectsRailAlignedCollectivePermute) { module->entry_computation()->root_instruction()); EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr, device_info().gpu_compute_capability()), - IsOkAndHolds(GPUCommunicationType::RAIL_ALIGNED)); + IsOkAndHolds(GPUCommunicationType::MULTI_HOST_WORLD_LEVEL)); } } // namespace