mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Rename rail-aligned into world-level in collective_ops_utils.h
Network rail usually refers to a set of NICs connected by the same fabric/switch, e.g. [Rail-optimized topology](https://developer.nvidia.com/blog/doubling-all2all-performance-with-nvidia-collective-communication-library-2-12/). PiperOrigin-RevId: 825696577
This commit is contained in:
parent
ca3d7d6305
commit
cecce70fb2
|
|
@ -151,10 +151,10 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
|
|||
case GPUCommunicationType::SINGLE_HOST:
|
||||
iota = IotaReplicaGroupList(num_hosts, kNumGpusPerHost);
|
||||
break;
|
||||
case GPUCommunicationType::RAIL_ALIGNED:
|
||||
case GPUCommunicationType::MULTI_HOST_WORLD_LEVEL:
|
||||
iota = IotaReplicaGroupList(1, num_hosts * kNumGpusPerHost);
|
||||
break;
|
||||
case GPUCommunicationType::NON_RAIL_ALIGNED:
|
||||
case GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL:
|
||||
iota = IotaReplicaGroupList(kNumGpusPerHost, num_hosts,
|
||||
{num_hosts, kNumGpusPerHost}, {1, 0});
|
||||
break;
|
||||
|
|
@ -169,56 +169,56 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
|
|||
std::vector<SpaceSpec> test_space_ = {
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/1024,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/2 * 1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/2 * 1024,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/4,
|
||||
/*network_througput_bytes=*/4 * 1024,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/2 * 1024,
|
||||
/*num_nodes=*/4,
|
||||
/*network_througput_bytes=*/5 * 1024,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/512,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/2 * 1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/2 * 512,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/4,
|
||||
/*network_througput_bytes=*/4 * 512,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/2 * 1024,
|
||||
/*num_nodes=*/4,
|
||||
/*network_througput_bytes=*/5 * 512,
|
||||
|
|
@ -253,56 +253,56 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
|
|||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/1024,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/2 * 1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/2 * 1024,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/4,
|
||||
/*network_througput_bytes=*/4 * 1024,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/2 * 1024,
|
||||
/*num_nodes=*/4,
|
||||
/*network_througput_bytes=*/5 * 1024,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/512,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/2 * 1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/2 * 512,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/4,
|
||||
/*network_througput_bytes=*/4 * 512,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/2 * 1024,
|
||||
/*num_nodes=*/4,
|
||||
/*network_througput_bytes=*/5 * 512,
|
||||
|
|
@ -337,56 +337,56 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
|
|||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/1024,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/2 * 1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/2 * 1024,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/4,
|
||||
/*network_througput_bytes=*/4 * 1024,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/2 * 1024,
|
||||
/*num_nodes=*/4,
|
||||
/*network_througput_bytes=*/5 * 1024,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/512,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/2 * 1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/2 * 512,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/4,
|
||||
/*network_througput_bytes=*/4 * 512,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/2 * 1024,
|
||||
/*num_nodes=*/4,
|
||||
/*network_througput_bytes=*/5 * 512,
|
||||
|
|
@ -428,14 +428,14 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
|
|||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllToAll,
|
||||
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/2048,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllToAll,
|
||||
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
/*comm=*/GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/4096,
|
||||
|
|
@ -459,7 +459,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -471,7 +471,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/8,
|
||||
},
|
||||
|
|
@ -483,7 +483,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/4 * 1024,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -495,7 +495,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/3,
|
||||
},
|
||||
|
|
@ -507,7 +507,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024 + 256,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -519,7 +519,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -531,7 +531,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/8,
|
||||
},
|
||||
|
|
@ -543,7 +543,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/4 * 1024,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -555,7 +555,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/3,
|
||||
},
|
||||
|
|
@ -567,7 +567,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024 + 256,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -639,7 +639,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -651,7 +651,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/8,
|
||||
},
|
||||
|
|
@ -663,7 +663,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/4 * 1024,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -675,7 +675,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/3,
|
||||
},
|
||||
|
|
@ -687,7 +687,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024 + 256,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -699,7 +699,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -711,7 +711,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/8,
|
||||
},
|
||||
|
|
@ -723,7 +723,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/4 * 1024,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -735,7 +735,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1032,
|
||||
/*num_nodes=*/3,
|
||||
},
|
||||
|
|
@ -747,7 +747,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024 + 256,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -807,7 +807,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -819,7 +819,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/8,
|
||||
},
|
||||
|
|
@ -831,7 +831,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/4 * 1024,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -843,7 +843,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1056,
|
||||
/*num_nodes=*/3,
|
||||
},
|
||||
|
|
@ -855,7 +855,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024 + 256,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -867,7 +867,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -879,7 +879,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/8,
|
||||
},
|
||||
|
|
@ -891,7 +891,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/4 * 1024,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -903,7 +903,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1032,
|
||||
/*num_nodes=*/3,
|
||||
},
|
||||
|
|
@ -915,7 +915,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024 + 256,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -986,7 +986,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllToAll,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -997,7 +997,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
{
|
||||
/*opcode=*/HloOpcode::kAllToAll,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::NON_RAIL_ALIGNED,
|
||||
GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -204,13 +204,13 @@ absl::StatusOr<absl::Duration> DispatchEstimation(
|
|||
GetReplicaGroupCountAndSize(&instr));
|
||||
|
||||
switch (comm) {
|
||||
case GPUCommunicationType::RAIL_ALIGNED: {
|
||||
case GPUCommunicationType::MULTI_HOST_WORLD_LEVEL: {
|
||||
return DCNCollectiveDuration(
|
||||
num_groups_and_devices->second / sol_flags.gpus_per_node,
|
||||
/*num_communicators=*/num_groups_and_devices->first, instr,
|
||||
gpu_device_info, sol_flags, analysis, symbolic_expr_context);
|
||||
}
|
||||
case GPUCommunicationType::NON_RAIL_ALIGNED: {
|
||||
case GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL: {
|
||||
return DCNCollectiveDuration(
|
||||
num_groups_and_devices->second,
|
||||
/*num_communicators=*/num_groups_and_devices->first, instr,
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ bool IsSingleHost(const CommunicationMetadata& pattern) {
|
|||
pattern.replica_count <= pattern.num_devices_per_host;
|
||||
}
|
||||
|
||||
bool IsRailAligned(const CommunicationMetadata& pattern) {
|
||||
bool IsWorldLevelCommunication(const CommunicationMetadata& pattern) {
|
||||
if (!IsSingleHost(pattern) && pattern.node_to_participant_count.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
|
@ -126,8 +126,8 @@ bool IsRailAligned(const CommunicationMetadata& pattern) {
|
|||
});
|
||||
}
|
||||
|
||||
bool IsNonRailAligned(const CommunicationMetadata& pattern) {
|
||||
return !IsSingleHost(pattern) && !IsRailAligned(pattern);
|
||||
bool IsNonWorldLevelCommunication(const CommunicationMetadata& pattern) {
|
||||
return !IsSingleHost(pattern) && !IsWorldLevelCommunication(pattern);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
@ -152,11 +152,11 @@ absl::StatusOr<GPUCommunicationType> CommunicationType(
|
|||
if (IsSingleHost(comm)) {
|
||||
return GPUCommunicationType::SINGLE_HOST;
|
||||
}
|
||||
if (IsRailAligned(comm)) {
|
||||
return GPUCommunicationType::RAIL_ALIGNED;
|
||||
if (IsWorldLevelCommunication(comm)) {
|
||||
return GPUCommunicationType::MULTI_HOST_WORLD_LEVEL;
|
||||
}
|
||||
if (IsNonRailAligned(comm)) {
|
||||
return GPUCommunicationType::NON_RAIL_ALIGNED;
|
||||
if (IsNonWorldLevelCommunication(comm)) {
|
||||
return GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL;
|
||||
}
|
||||
|
||||
return GPUCommunicationType::UNDEFINED;
|
||||
|
|
|
|||
|
|
@ -29,10 +29,10 @@ enum class GPUCommunicationType {
|
|||
UNDEFINED = 0,
|
||||
// Communication involves devices from multiple hosts, and every host
|
||||
// involved in the communication pattern has all of its devices participating.
|
||||
RAIL_ALIGNED = 1,
|
||||
MULTI_HOST_WORLD_LEVEL = 1,
|
||||
// Communication involves devices from multiple hosts, but at least one of
|
||||
// the involved hosts has only a subset of its devices participating.
|
||||
NON_RAIL_ALIGNED = 2,
|
||||
MULTI_HOST_NON_WORLD_LEVEL = 2,
|
||||
// All devices participating in the collective operation reside on the same
|
||||
// host machine.
|
||||
SINGLE_HOST = 3
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHost16Devices) {
|
|||
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectRailAlignedAllDevices) {
|
||||
TEST_F(CommunicationTypeTest, DetectWorldLevelAllDevices) {
|
||||
absl::string_view kHlo = R"(
|
||||
HloModule m, num_partitions=16
|
||||
|
||||
|
|
@ -131,10 +131,10 @@ TEST_F(CommunicationTypeTest, DetectRailAlignedAllDevices) {
|
|||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::RAIL_ALIGNED));
|
||||
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_WORLD_LEVEL));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectRailAlignedHalfMesh) {
|
||||
TEST_F(CommunicationTypeTest, DetectWorldLevelHalfMesh) {
|
||||
absl::string_view kHlo = R"(
|
||||
HloModule m, num_partitions=32
|
||||
|
||||
|
|
@ -157,10 +157,10 @@ TEST_F(CommunicationTypeTest, DetectRailAlignedHalfMesh) {
|
|||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::RAIL_ALIGNED));
|
||||
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_WORLD_LEVEL));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectNonRailAligned) {
|
||||
TEST_F(CommunicationTypeTest, DetectNonWorldLevel) {
|
||||
absl::string_view kHlo = R"(
|
||||
HloModule m, num_partitions=16
|
||||
|
||||
|
|
@ -180,7 +180,7 @@ TEST_F(CommunicationTypeTest, DetectNonRailAligned) {
|
|||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::NON_RAIL_ALIGNED));
|
||||
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectsSingleHost16DevicesForEmptyReplicaGroups) {
|
||||
|
|
@ -204,7 +204,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHost16DevicesForEmptyReplicaGroups) {
|
|||
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectsRailAligned8DevicesForEmptyReplicaGroups) {
|
||||
TEST_F(CommunicationTypeTest, DetectWorldLevel8DevicesForEmptyReplicaGroups) {
|
||||
absl::string_view kHlo = R"(
|
||||
HloModule m, replica_count=16
|
||||
|
||||
|
|
@ -222,10 +222,10 @@ TEST_F(CommunicationTypeTest, DetectsRailAligned8DevicesForEmptyReplicaGroups) {
|
|||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::RAIL_ALIGNED));
|
||||
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_WORLD_LEVEL));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectsNonRailAligned16Devices) {
|
||||
TEST_F(CommunicationTypeTest, DetectNonWorldLevel16Devices) {
|
||||
absl::string_view kHlo = R"(
|
||||
HloModule m, replica_count=16
|
||||
|
||||
|
|
@ -243,7 +243,7 @@ TEST_F(CommunicationTypeTest, DetectsNonRailAligned16Devices) {
|
|||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::NON_RAIL_ALIGNED));
|
||||
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectsSingleHostCollectivePermute) {
|
||||
|
|
@ -266,7 +266,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHostCollectivePermute) {
|
|||
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectsNonRailAlignedCollectivePermute) {
|
||||
TEST_F(CommunicationTypeTest, DetectNonWorldLevelCollectivePermute) {
|
||||
absl::string_view kHlo = R"(
|
||||
HloModule m, num_partitions=16
|
||||
|
||||
|
|
@ -284,10 +284,10 @@ TEST_F(CommunicationTypeTest, DetectsNonRailAlignedCollectivePermute) {
|
|||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::NON_RAIL_ALIGNED));
|
||||
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectsRailAlignedCollectivePermute) {
|
||||
TEST_F(CommunicationTypeTest, DetectWorldLevelCollectivePermute) {
|
||||
absl::string_view kHlo = R"(
|
||||
HloModule m, num_partitions=16
|
||||
|
||||
|
|
@ -304,7 +304,7 @@ TEST_F(CommunicationTypeTest, DetectsRailAlignedCollectivePermute) {
|
|||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::RAIL_ALIGNED));
|
||||
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_WORLD_LEVEL));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user