PR #32836: [GPU] Dispatch S-curve model to single-partition multi-host topology

Imported from GitHub PR https://github.com/openxla/xla/pull/32836

📝 Summary of Changes
Updated SINGLE_HOST communication type to SINGLE_PARTITION (fast-interconnect domain) to meet the need of multi-node NVLink (MNNVL) topology. Piped auto-detected partition size for communication type determination, also exposed partition size in SolGPUCostModel::Config for AOT compilation.

🎯 Justification
S-curve model cannot handle NVLink latency, single fast-interconnect domain including MNNVL topology should use latency table model. This PR updates the routing mechanism so that MNNVL will be treated as a single partition, while previously host is assumed equivalent to partition.

🚀 Kind of Contribution
 New Feature

📊 Benchmark (for Performance Improvements)
N/A

🧪 Unit Tests:
Added unit tests for model dispatching mechanism.

🧪 Execution Tests:
Behavior unchanged for non-MNNVL topology, N/A.

Copybara import of the project:

--
a9544375934873f7b888fdb5ff6c9dc6ee8b0e6c by Terry Sun <tesun@nvidia.com>:

use partition size for static model dispatching

--
e3445a5deb8da10146e90c50da5598f91cfe0a69 by Terry Sun <tesun@nvidia.com>:

expose partition size to config

--
212535ce891b8eb96ebb3c1e215a91d2b5035594 by Terry Sun <tesun@nvidia.com>:

better modularity

--
a9fe8a0f89dea9e2811d76a3570c7398df8dd756 by Terry Sun <tesun@nvidia.com>:

better code structure and doc string

--
a64a2b5ed1d45d815c6a2c47628b4d9ebb8368bd by Terry Sun <tesun@nvidia.com>:

update naming

Merging this change closes #32836

PiperOrigin-RevId: 826697791
This commit is contained in:
Terry Sun 2025-10-31 18:19:29 -07:00 committed by TensorFlower Gardener
parent dad4fb74cd
commit 8134117476
13 changed files with 272 additions and 98 deletions

View File

@ -70,6 +70,11 @@ constexpr char kSolChunkSizeBytes[] = "chunk_size_bytes";
// cost model.
constexpr char kSolGpusPerNode[] = "gpus_per_node";
// Defines the partition size (number of devices per fast-interconnect domain)
// used by the SoL cost model. This is necessary for AOT compilation when the
// partition is larger than a node.
constexpr char kSolPartitionSize[] = "partition_size";
} // namespace xla
#endif // XLA_SERVICE_COLLECTIVE_UTILS_H_

View File

@ -2012,6 +2012,11 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
absl::StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
const CompileOptions& options) {
// TODO rename slice_size to partition_size in CompileOptions
if (options.slice_size > 0) {
module->mutable_config().set_partition_size(options.slice_size);
}
const DebugOptions debug_opts = module->config().debug_options();
TF_RETURN_IF_ERROR(LoadAutotuneResultsFromFile(debug_opts));
bool is_deviceless = options.target_config.has_value() ||

View File

@ -148,7 +148,7 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
int num_hosts) {
IotaReplicaGroupList iota(1, 1);
switch (comm) {
case GPUCommunicationType::SINGLE_HOST:
case GPUCommunicationType::SINGLE_PARTITION:
iota = IotaReplicaGroupList(num_hosts, kNumGpusPerHost);
break;
case GPUCommunicationType::MULTI_HOST_WORLD_LEVEL:
@ -225,28 +225,28 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
},
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/GPUCommunicationType::SINGLE_HOST,
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/2048,
},
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/GPUCommunicationType::SINGLE_HOST,
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/2 * 1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/2 * 2048,
},
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/GPUCommunicationType::SINGLE_HOST,
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024,
/*num_nodes=*/4,
/*network_througput_bytes=*/2048,
},
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/GPUCommunicationType::SINGLE_HOST,
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/2 * 1024,
/*num_nodes=*/4,
/*network_througput_bytes=*/2 * 2048,
@ -309,28 +309,28 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
},
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/GPUCommunicationType::SINGLE_HOST,
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/2048,
},
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/GPUCommunicationType::SINGLE_HOST,
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/2 * 1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/2 * 2048,
},
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/GPUCommunicationType::SINGLE_HOST,
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024,
/*num_nodes=*/4,
/*network_througput_bytes=*/2048,
},
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/GPUCommunicationType::SINGLE_HOST,
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/2 * 1024,
/*num_nodes=*/4,
/*network_througput_bytes=*/2 * 2048,
@ -393,35 +393,35 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
},
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/GPUCommunicationType::SINGLE_HOST,
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/2048,
},
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/GPUCommunicationType::SINGLE_HOST,
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/2 * 1024,
/*num_nodes=*/2,
/*network_througput_bytes=*/2 * 2048,
},
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/GPUCommunicationType::SINGLE_HOST,
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024,
/*num_nodes=*/4,
/*network_througput_bytes=*/2048,
},
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/GPUCommunicationType::SINGLE_HOST,
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/2 * 1024,
/*num_nodes=*/4,
/*network_througput_bytes=*/2 * 2048,
},
{
/*opcode=*/HloOpcode::kAllToAll,
/*comm=*/GPUCommunicationType::SINGLE_HOST,
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024,
/*num_nodes=*/1,
/*network_througput_bytes=*/1024,
@ -574,60 +574,61 @@ INSTANTIATE_TEST_SUITE_P(
/*expected_duration=*/absl::Milliseconds(2500),
},
{
/*test_name=*/"AR_single_host_aligned_extrapolate_nodes",
/*test_name=*/"AR_SINGLE_PARTITION_aligned_extrapolate_nodes",
/*spec=*/
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/
GPUCommunicationType::SINGLE_HOST,
GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024,
/*num_nodes=*/8,
},
/*expected_duration=*/absl::Milliseconds(500),
},
{
/*test_name=*/"AR_single_host_aligned_extrapolate_tensor_size",
/*test_name=*/"AR_SINGLE_PARTITION_aligned_extrapolate_tensor_size",
/*spec=*/
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/
GPUCommunicationType::SINGLE_HOST,
GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/4 * 1024,
/*num_nodes=*/2,
},
/*expected_duration=*/absl::Seconds(1),
},
{
/*test_name=*/"AR_single_host_aligned_interpolate_nodes",
/*test_name=*/"AR_SINGLE_PARTITION_aligned_interpolate_nodes",
/*spec=*/
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/
GPUCommunicationType::SINGLE_HOST,
GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024,
/*num_nodes=*/3,
},
/*expected_duration=*/absl::Milliseconds(500),
},
{
/*test_name=*/"AR_single_host_aligned_interpolate_tensor_size",
/*test_name=*/"AR_SINGLE_PARTITION_aligned_interpolate_tensor_size",
/*spec=*/
{
/*opcode=*/HloOpcode::kAllReduce,
/*comm=*/
GPUCommunicationType::SINGLE_HOST,
GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024 + 256,
/*num_nodes=*/2,
},
/*expected_duration=*/absl::Milliseconds(625),
},
{
/*test_name=*/"ARS_single_host_aligned_interpolate_tensor_size",
/*test_name=*/"ARS_SINGLE_PARTITION_aligned_interpolate_tensor_"
"size",
/*spec=*/
{
/*opcode=*/HloOpcode::kAllReduceStart,
/*comm=*/
GPUCommunicationType::SINGLE_HOST,
GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024 + 256,
/*num_nodes=*/2,
},
@ -754,48 +755,48 @@ INSTANTIATE_TEST_SUITE_P(
/*expected_duration=*/absl::Milliseconds(2500),
},
{
/*test_name=*/"RS_single_host_aligned_extrapolate_nodes",
/*test_name=*/"RS_SINGLE_PARTITION_aligned_extrapolate_nodes",
/*spec=*/
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/
GPUCommunicationType::SINGLE_HOST,
GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024,
/*num_nodes=*/8,
},
/*expected_duration=*/absl::Milliseconds(500),
},
{
/*test_name=*/"RS_single_host_aligned_extrapolate_tensor_size",
/*test_name=*/"RS_SINGLE_PARTITION_aligned_extrapolate_tensor_size",
/*spec=*/
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/
GPUCommunicationType::SINGLE_HOST,
GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/4 * 1024,
/*num_nodes=*/2,
},
/*expected_duration=*/absl::Seconds(1),
},
{
/*test_name=*/"RS_single_host_aligned_interpolate_nodes",
/*test_name=*/"RS_SINGLE_PARTITION_aligned_interpolate_nodes",
/*spec=*/
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/
GPUCommunicationType::SINGLE_HOST,
GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024,
/*num_nodes=*/3,
},
/*expected_duration=*/absl::Milliseconds(500),
},
{
/*test_name=*/"RS_single_host_aligned_interpolate_tensor_size",
/*test_name=*/"RS_SINGLE_PARTITION_aligned_interpolate_tensor_size",
/*spec=*/
{
/*opcode=*/HloOpcode::kReduceScatter,
/*comm=*/
GPUCommunicationType::SINGLE_HOST,
GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024 + 256,
/*num_nodes=*/2,
},
@ -922,60 +923,61 @@ INSTANTIATE_TEST_SUITE_P(
/*expected_duration=*/absl::Milliseconds(2500),
},
{
/*test_name=*/"AG_single_host_aligned_extrapolate_nodes",
/*test_name=*/"AG_SINGLE_PARTITION_aligned_extrapolate_nodes",
/*spec=*/
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/
GPUCommunicationType::SINGLE_HOST,
GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024,
/*num_nodes=*/8,
},
/*expected_duration=*/absl::Milliseconds(500),
},
{
/*test_name=*/"AG_single_host_aligned_extrapolate_tensor_size",
/*test_name=*/"AG_SINGLE_PARTITION_aligned_extrapolate_tensor_size",
/*spec=*/
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/
GPUCommunicationType::SINGLE_HOST,
GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/4 * 1024,
/*num_nodes=*/2,
},
/*expected_duration=*/absl::Seconds(1),
},
{
/*test_name=*/"AG_single_host_aligned_interpolate_nodes",
/*test_name=*/"AG_SINGLE_PARTITION_aligned_interpolate_nodes",
/*spec=*/
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/
GPUCommunicationType::SINGLE_HOST,
GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024,
/*num_nodes=*/3,
},
/*expected_duration=*/absl::Milliseconds(500),
},
{
/*test_name=*/"AG_single_host_aligned_interpolate_tensor_size",
/*test_name=*/"AG_SINGLE_PARTITION_aligned_interpolate_tensor_size",
/*spec=*/
{
/*opcode=*/HloOpcode::kAllGather,
/*comm=*/
GPUCommunicationType::SINGLE_HOST,
GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024 + 256,
/*num_nodes=*/2,
},
/*expected_duration=*/absl::Milliseconds(625),
},
{
/*test_name=*/"AGS_single_host_aligned_interpolate_tensor_size",
/*test_name=*/"AGS_SINGLE_PARTITION_aligned_interpolate_tensor_"
"size",
/*spec=*/
{
/*opcode=*/HloOpcode::kAllGatherStart,
/*comm=*/
GPUCommunicationType::SINGLE_HOST,
GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024 + 256,
/*num_nodes=*/2,
},
@ -1004,11 +1006,11 @@ INSTANTIATE_TEST_SUITE_P(
/*expected_duration=*/absl::Milliseconds(250),
},
{
/*test_name=*/"A2A_single_host_exact_match",
/*test_name=*/"A2A_SINGLE_PARTITION_exact_match",
{
/*opcode=*/HloOpcode::kAllToAll,
/*comm=*/
GPUCommunicationType::SINGLE_HOST,
GPUCommunicationType::SINGLE_PARTITION,
/*tensor_size=*/1024,
/*num_nodes=*/1,
},

View File

@ -138,6 +138,9 @@ SolGPUCostModel::Config GetPlatformConfig(
} else if (option_name == kSolChunkSizeBytes &&
absl::SimpleAtoi(option_value, &value) && value > 0) {
config.chunk_size_bytes = value;
} else if (option_name == kSolPartitionSize &&
absl::SimpleAtoi(option_value, &value) && value > 0) {
config.partition_size = value;
}
}
return config;

View File

@ -40,6 +40,8 @@ class SolGPUCostModel {
absl::Duration rtt;
int64_t gpus_per_node;
int64_t chunk_size_bytes;
// Partition size (devices per fast-interconnect domain). 0 means unset.
int64_t partition_size;
};
enum CollectiveAlgorithmType {

View File

@ -189,6 +189,17 @@ absl::StatusOr<absl::Duration> DCNCollectiveDuration(
return result;
}
int64_t GetPartitionSize(const HloInstruction& instr,
const SolGPUCostModel::Config& sol_flags) {
if (sol_flags.partition_size > 0) {
return sol_flags.partition_size;
}
if (instr.GetModule()->config().partition_size() > 0) {
return instr.GetModule()->config().partition_size();
}
return sol_flags.gpus_per_node;
}
absl::StatusOr<absl::Duration> DispatchEstimation(
const absl::StatusOr<GPUCommunicationType>& communication_type,
const HloCollectiveInstruction& instr,
@ -202,11 +213,12 @@ absl::StatusOr<absl::Duration> DispatchEstimation(
GPUCommunicationType comm = *communication_type;
TF_ASSIGN_OR_RETURN(auto num_groups_and_devices,
GetReplicaGroupCountAndSize(&instr));
int64_t partition_size = GetPartitionSize(instr, sol_flags);
switch (comm) {
case GPUCommunicationType::MULTI_HOST_WORLD_LEVEL: {
return DCNCollectiveDuration(
num_groups_and_devices->second / sol_flags.gpus_per_node,
num_groups_and_devices->second / partition_size,
/*num_communicators=*/num_groups_and_devices->first, instr,
gpu_device_info, sol_flags, analysis, symbolic_expr_context);
}
@ -216,10 +228,11 @@ absl::StatusOr<absl::Duration> DispatchEstimation(
/*num_communicators=*/num_groups_and_devices->first, instr,
gpu_device_info, sol_flags, analysis, symbolic_expr_context);
}
case GPUCommunicationType::SINGLE_HOST: {
case GPUCommunicationType::SINGLE_PARTITION: {
if (collective_interpolator == nullptr) {
return absl::InvalidArgumentError(
"Collective interpolator is required for single host collectives");
"Collective interpolator is required for single partition "
"collectives");
}
return collective_interpolator->EstimatedRuntime(instr);
}
@ -309,9 +322,10 @@ SolLatencyEstimator::ComputeCollectiveTime(
absl::StrCat("Unsupported collective instruction: ", instr.ToString()));
}
int64_t partition_size = GetPartitionSize(*collective_instr, sol_flags);
TF_ASSIGN_OR_RETURN(
GPUCommunicationType communication_type,
CommunicationType(sol_flags.gpus_per_node, *collective_instr,
CommunicationType(partition_size, *collective_instr,
gpu_device_info.gpu_compute_capability()));
TF_ASSIGN_OR_RETURN(
absl::Duration result,

View File

@ -99,7 +99,7 @@ absl::StatusOr<bool> CollectiveBackendAssigner::Run(
<< " slice_size_=" << slice_size_;
bool use_nvshmem =
(num_visible_devices_per_process_ == 1 ||
comm_type == GPUCommunicationType::SINGLE_HOST ||
comm_type == GPUCommunicationType::SINGLE_PARTITION ||
(slice_size_ > 0 &&
IsIntraNVLinkDomain(module->config(), slice_size_))) &&
(!IsAllReduceOp(instr) || shape_size < threshold_in_bytes_);

View File

@ -44,26 +44,26 @@ namespace xla {
namespace gpu {
namespace {
// Computes a map from source node ID to a set of target node IDs for a
// collective-permute instruction. A node ID is computed by dividing the device
// (replica) ID by the number of devices per host.
// Computes a map from source partition ID to a set of target partition IDs for
// a collective-permute instruction. A partition ID is computed by dividing the
// device (replica) ID by the number of devices per host.
absl::flat_hash_map<int64_t, absl::flat_hash_set<int64_t>>
GetSourceToTargetsNodeMap(const HloCollectivePermuteInstruction& instr,
int num_devices_per_host) {
int num_devices_per_partition) {
absl::flat_hash_map<int64_t, absl::flat_hash_set<int64_t>>
source_to_targets_node_map;
source_to_targets_partition_map;
for (const auto& [source, target] : instr.source_target_pairs()) {
int64_t source_node = source / num_devices_per_host;
int64_t target_node = target / num_devices_per_host;
source_to_targets_node_map[source_node].insert(target_node);
int64_t source_partition = source / num_devices_per_partition;
int64_t target_partition = target / num_devices_per_partition;
source_to_targets_partition_map[source_partition].insert(target_partition);
}
return source_to_targets_node_map;
return source_to_targets_partition_map;
}
struct CollectiveMetadata {
// map for ops with `replica_groups`, e.g. all-gather.
absl::flat_hash_map<int64_t, size_t> node_to_participant_count;
int num_devices_per_host;
absl::flat_hash_map<int64_t, size_t> partition_to_participant_count;
int num_devices_per_partition;
int64_t replica_count;
};
@ -85,46 +85,48 @@ bool SameParticipantCounts(const absl::flat_hash_map<int64_t, size_t>& lhs,
}
absl::StatusOr<CollectiveMetadata> CommunicationContext(
const HloCollectiveInstruction& instr, int num_devices_per_host) {
absl::flat_hash_map<int64_t, size_t> node_to_participant_count;
const HloCollectiveInstruction& instr, int num_devices_per_partition) {
absl::flat_hash_map<int64_t, size_t> partition_to_participant_count;
for (const ReplicaGroup& replica_group :
instr.device_list().replica_groups()) {
absl::flat_hash_map<int64_t, size_t> buffer;
for (int64_t rank : replica_group.replica_ids()) {
int64_t node_id = rank / num_devices_per_host;
buffer[node_id]++;
int64_t partition_id = rank / num_devices_per_partition;
buffer[partition_id]++;
}
if (!node_to_participant_count.empty() &&
!SameParticipantCounts(buffer, node_to_participant_count)) {
if (!partition_to_participant_count.empty() &&
!SameParticipantCounts(buffer, partition_to_participant_count)) {
return absl::FailedPreconditionError(absl::StrCat(
"Non homogenous replica group: ", instr.device_list().ToString()));
}
if (node_to_participant_count.empty()) {
node_to_participant_count = buffer;
if (partition_to_participant_count.empty()) {
partition_to_participant_count = buffer;
}
}
return CollectiveMetadata{node_to_participant_count, num_devices_per_host,
return CollectiveMetadata{partition_to_participant_count,
num_devices_per_partition,
instr.GetModule()->config().replica_count()};
}
bool IsSingleHost(const CollectiveMetadata& pattern) {
if (pattern.node_to_participant_count.size() == 1) {
if (pattern.partition_to_participant_count.size() == 1) {
return true;
}
return pattern.replica_count > 0 &&
pattern.node_to_participant_count.empty() &&
pattern.replica_count <= pattern.num_devices_per_host;
pattern.partition_to_participant_count.empty() &&
pattern.replica_count <= pattern.num_devices_per_partition;
}
bool IsWorldLevelCommunication(const CollectiveMetadata& pattern) {
if (!IsSingleHost(pattern) && pattern.node_to_participant_count.empty()) {
if (!IsSingleHost(pattern) &&
pattern.partition_to_participant_count.empty()) {
return true;
}
return absl::c_all_of(
pattern.node_to_participant_count, [&pattern](const auto& elem) {
const auto& [node_id, participant_count] = elem;
return participant_count == pattern.num_devices_per_host;
pattern.partition_to_participant_count, [&pattern](const auto& elem) {
const auto& [partition_id, participant_count] = elem;
return participant_count == pattern.num_devices_per_partition;
});
}
@ -143,7 +145,7 @@ bool IsGPUSyncCollective(const HloInstruction& instr) {
}
absl::StatusOr<GPUCommunicationType> CommunicationType(
int num_devices_per_host, const HloChannelInstruction& instr,
int num_devices_per_partition, const HloChannelInstruction& instr,
const se::GpuComputeCapability& gpu_version) {
if (!gpu_version.IsCuda()) {
return absl::FailedPreconditionError("Only CUDA is supported.");
@ -152,9 +154,9 @@ absl::StatusOr<GPUCommunicationType> CommunicationType(
if (const auto* collective = DynCast<HloCollectiveInstruction>(&instr)) {
TF_ASSIGN_OR_RETURN(
CollectiveMetadata comm,
CommunicationContext(*collective, num_devices_per_host));
CommunicationContext(*collective, num_devices_per_partition));
if (IsSingleHost(comm)) {
return GPUCommunicationType::SINGLE_HOST;
return GPUCommunicationType::SINGLE_PARTITION;
}
if (IsWorldLevelCommunication(comm)) {
return GPUCommunicationType::MULTI_HOST_WORLD_LEVEL;
@ -164,19 +166,19 @@ absl::StatusOr<GPUCommunicationType> CommunicationType(
}
} else if (const auto* collective_permute =
DynCast<HloCollectivePermuteInstruction>(&instr)) {
const auto source_to_targets_node_map =
GetSourceToTargetsNodeMap(*collective_permute, num_devices_per_host);
for (const auto& [source_node, target_node_set] :
source_to_targets_node_map) {
if (target_node_set.size() > 1) {
const auto source_to_targets_partition_map = GetSourceToTargetsNodeMap(
*collective_permute, num_devices_per_partition);
for (const auto& [source_partition, target_partition_set] :
source_to_targets_partition_map) {
if (target_partition_set.size() > 1) {
return GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL;
}
CHECK_EQ(target_node_set.size(), 1);
if (source_node != *target_node_set.begin()) {
CHECK_EQ(target_partition_set.size(), 1);
if (source_partition != *target_partition_set.begin()) {
return GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL;
}
}
return GPUCommunicationType::SINGLE_HOST;
return GPUCommunicationType::SINGLE_PARTITION;
} else {
return absl::FailedPreconditionError(
"Cannot determine communication type for non-collective channel "

View File

@ -35,13 +35,13 @@ enum class GPUCommunicationType {
// the involved hosts has only a subset of its devices participating.
MULTI_HOST_NON_WORLD_LEVEL = 2,
// All devices participating in the collective operation reside on the same
// host machine.
SINGLE_HOST = 3
// fast-interconnect domain.
SINGLE_PARTITION = 3
};
// Returns the type of communication pattern for a channel instruction.
absl::StatusOr<GPUCommunicationType> CommunicationType(
int num_devices_per_host, const HloChannelInstruction& instr,
int partition_size, const HloChannelInstruction& instr,
const se::GpuComputeCapability& gpu_version);
// Returns true if instruction is a synchronous collective op.

View File

@ -62,7 +62,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHost8Devices) {
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
IsOkAndHolds(GPUCommunicationType::SINGLE_PARTITION));
}
TEST_F(CommunicationTypeTest, DetectsSingleHost4Devices) {
@ -85,7 +85,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHost4Devices) {
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
IsOkAndHolds(GPUCommunicationType::SINGLE_PARTITION));
}
TEST_F(CommunicationTypeTest, DetectsSingleHost16Devices) {
@ -106,9 +106,9 @@ TEST_F(CommunicationTypeTest, DetectsSingleHost16Devices) {
HloCollectiveInstruction* instr = Cast<HloCollectiveInstruction>(
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
EXPECT_THAT(CommunicationType(/*partition_size=*/8, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
IsOkAndHolds(GPUCommunicationType::SINGLE_PARTITION));
}
TEST_F(CommunicationTypeTest, DetectWorldLevelAllDevices) {
@ -201,7 +201,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHost16DevicesForEmptyReplicaGroups) {
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/16, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
IsOkAndHolds(GPUCommunicationType::SINGLE_PARTITION));
}
TEST_F(CommunicationTypeTest, DetectWorldLevel8DevicesForEmptyReplicaGroups) {
@ -263,7 +263,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHostCollectivePermute) {
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
IsOkAndHolds(GPUCommunicationType::SINGLE_PARTITION));
}
TEST_F(CommunicationTypeTest, DetectsSingleHostCollectivePermuteSinglePair) {
@ -283,7 +283,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHostCollectivePermuteSinglePair) {
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
IsOkAndHolds(GPUCommunicationType::SINGLE_PARTITION));
}
TEST_F(CommunicationTypeTest, DetectNonWorldLevelCollectivePermute) {
@ -355,5 +355,130 @@ TEST_F(CommunicationTypeTest, DetectsCrossHostCollectivePermuteMixed) {
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL));
}
TEST_F(CommunicationTypeTest, DetectsSinglePartitionMultiHost) {
// 16 devices across 2 hosts with partition_size=16 (single partition spanning
// 2 hosts)
absl::string_view kHlo = R"(
HloModule m, num_partitions=16
ENTRY e {
p = f32[128] parameter(0)
ROOT _ = f32[2048] all-gather(p),
dimensions={0},
use_global_device_ids=true,
channel_id=1,
replica_groups=[1,16]<=[16]
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo));
HloCollectiveInstruction* instr = Cast<HloCollectiveInstruction>(
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*partition_size=*/16, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::SINGLE_PARTITION));
}
TEST_F(CommunicationTypeTest, DetectsMultiPartitionWith8DevicePartitions) {
// 64 devices across 2 partitions with partition_size=32
absl::string_view kHlo = R"(
HloModule m, num_partitions=16
ENTRY e {
p = f32[128] parameter(0)
ROOT _ = f32[2048] all-gather(p),
dimensions={0},
use_global_device_ids=true,
channel_id=1,
replica_groups=[1, 64]<=[64]
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo));
HloCollectiveInstruction* instr = Cast<HloCollectiveInstruction>(
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*partition_size=*/32, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_WORLD_LEVEL));
}
TEST_F(CommunicationTypeTest, DetectsMultiPartitionNonRailAligned) {
// 64 devices with partition_size=36: partition 0 has 36 devices, partition 1
// has 28 devices
absl::string_view kHlo = R"(
HloModule m, num_partitions=12
ENTRY e {
p = f32[128] parameter(0)
ROOT _ = f32[1536] all-gather(p),
dimensions={0},
use_global_device_ids=true,
channel_id=1,
replica_groups=[1, 64]<=[64]
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo));
HloCollectiveInstruction* instr = Cast<HloCollectiveInstruction>(
module->entry_computation()->root_instruction());
// With partition_size=8, spans 2 partitions but not rail-aligned (8 and 4
// devices)
EXPECT_THAT(CommunicationType(/*partition_size=*/36, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL));
}
TEST_F(CommunicationTypeTest, DetectsSinglePartitionSubset) {
// 6 devices within a single partition (partition_size=36)
absl::string_view kHlo = R"(
HloModule m, num_partitions=4
ENTRY e {
p = f32[128] parameter(0)
ROOT _ = f32[512] all-gather(p),
dimensions={0},
use_global_device_ids=true,
channel_id=1,
replica_groups={{0,1,2,3,4,5}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo));
HloCollectiveInstruction* instr = Cast<HloCollectiveInstruction>(
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*partition_size=*/36, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::SINGLE_PARTITION));
}
TEST_F(CommunicationTypeTest, DetectsRailAlignedMultiPartition) {
// 128 devices across 2 partitions with partition_size=8 (rail-aligned: 64
// devices per partition)
absl::string_view kHlo = R"(
HloModule m, num_partitions=32
ENTRY e {
p = f32[128] parameter(0)
ROOT _ = f32[4096] all-gather(p),
dimensions={0},
use_global_device_ids=true,
channel_id=1,
replica_groups=[1,128]<=[128]
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo));
HloCollectiveInstruction* instr = Cast<HloCollectiveInstruction>(
module->entry_computation()->root_instruction());
EXPECT_THAT(CommunicationType(/*partition_size=*/64, *instr,
device_info().gpu_compute_capability()),
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_WORLD_LEVEL));
}
} // namespace
} // namespace xla::gpu

View File

@ -110,6 +110,9 @@ std::string HloModuleConfig::compilation_cache_key() const {
StrAppend(&key, "::device_memory_size=", device_memory_size());
}
StrAppend(&key, "::use_shardy_partitioner=", use_shardy_partitioner());
if (partition_size() != 0) {
StrAppend(&key, "::partition_size=", partition_size());
}
return key;
}
@ -339,6 +342,7 @@ HloModuleConfigProto HloModuleConfig::ToProto() const {
proto.set_fdo_profile(fdo_profile_);
proto.set_device_memory_size(device_memory_size_);
proto.set_use_shardy_partitioner(use_shardy_partitioner_);
proto.set_partition_size(partition_size_);
*proto.mutable_sharding_config() = ShardingConfig::ToProto(sharding_config_);
*proto.mutable_schedule_config() = ScheduleConfig::ToProto(schedule_config_);
return proto;
@ -418,6 +422,7 @@ HloModuleConfig::CreateFromProto(const HloModuleConfigProto& proto) {
config->fdo_profile_ = proto.fdo_profile();
config->device_memory_size_ = proto.device_memory_size();
config->use_shardy_partitioner_ = proto.use_shardy_partitioner();
config->partition_size_ = proto.partition_size();
config->sharding_config_ = ShardingConfig::FromProto(proto.sharding_config());
config->schedule_config_ = ScheduleConfig::FromProto(proto.schedule_config());
return std::move(config);

View File

@ -451,6 +451,12 @@ class HloModuleConfig {
use_shardy_partitioner_ = use_shardy_partitioner;
}
// Number of devices in a fast-interconnect domain.
int64_t partition_size() const { return partition_size_; }
void set_partition_size(int64_t partition_size) {
partition_size_ = partition_size;
}
// Do channel IDs in this module carry semantic information.
bool ChannelIdSensitive() const {
// TODO(b/430952564): Base this on num_partitions / num_replicas instead
@ -625,6 +631,9 @@ class HloModuleConfig {
bool use_shardy_partitioner_ = false;
// Number of devices in a fast-interconnect domain.
int64_t partition_size_ = 0;
// Sharding configuration, where sharding_config_.nodes[v] controls the
// sharding of operation v.
ShardingConfig sharding_config_;

View File

@ -1606,7 +1606,7 @@ message ExecutionOptions {
// Serialization of HloModuleConfig. See the C++ class definition for
// descriptions of each field.
// There are no guarantees of backwards or forwards compatibility.
// Next id: 42.
// Next id: 43.
message HloModuleConfigProto {
enum FusionConfigCollection {
OFF = 0; // Do not collect configuration.
@ -1671,6 +1671,8 @@ message HloModuleConfigProto {
bool use_shardy_partitioner = 34;
ShardingConfigProto sharding_config = 38;
ScheduleConfigProto schedule_config = 41;
// Number of devices in a fast-interconnect domain.
int64 partition_size = 42;
}
message HloModuleProtoWithConfig {