PR #31375: [XLA:GPU] Add NVLink domain check to CollectiveBackendAssigner

Imported from GitHub PR https://github.com/openxla/xla/pull/31375

📝 Summary of Changes
This PR updates the CollectiveBackendAssigner pass to account for NVLink domain connectivity when deciding between NVSHMEM and DEFAULT backends. It does this by adding a slice_size parameter to the compilation pipeline and introducing an IsIntraNVLinkDomain check.

🎯 Justification
The CollectiveBackendAssigner now uses NVSHMEM not only for single-host scenarios, but also when all devices are within the same NVLink domain.

🚀 Kind of Contribution
️ Performance Improvement, 🧪 Tests

📊 Benchmark (for Performance Improvements)
H100
|  | NVSHMEM enabled | NVSHMEM disabled |
|----------|----------|----------|
| llama31_8b_fp8_1x8    | 1095330 us   | 1093816 us    |
| llama31_8b_bf16_2x8    | 1368948 us   | 1370896 us   |
| llama31_8b_fp8_2x8    | 1096447 us   | 1092437 us   |
| llama31_70b_fp8_16x8    | 9723821 us   | 9707544 us    |

🧪 Unit Tests:
Added unit tests to xla/service/gpu/transforms/collectives/collective_backend_assigner_test.cc

🧪 Execution Tests:
Tested with llama3-8b on 2 GB200 nodes (fsdp = 8). The average step time in NVSHMEM case was 3.69s (vs. 3.76s in the default case).
Copybara import of the project:

--
a02b77cec9622314af01ae481d0fb28b149f1b45 by Sevin Varoglu <svaroglu@nvidia.com>:

Add NVLink domain check to CollectiveBackendAssigner

Merging this change closes #31375

PiperOrigin-RevId: 826649437
This commit is contained in:
Sevin Fide Varoglu 2025-10-31 15:24:38 -07:00 committed by TensorFlower Gardener
parent bf84442f21
commit c655468288
8 changed files with 143 additions and 23 deletions

View File

@ -975,7 +975,7 @@ absl::Status RunCollectiveOptimizationPasses(
if (debug_options.xla_gpu_experimental_enable_nvshmem()) {
collectives_pipeline.AddPass<CollectiveBackendAssigner>(
gpu_version, num_visible_devices_per_process);
gpu_version, num_visible_devices_per_process, options.slice_size);
}
if (debug_options.xla_gpu_unsupported_enable_ragged_all_to_all_decomposer()) {

View File

@ -136,6 +136,7 @@ cc_library(
srcs = ["gpu_collective_combiner_utils.cc"],
hdrs = ["gpu_collective_combiner_utils.h"],
deps = [
":collective_ops_utils",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/service:collective_ops_utils",

View File

@ -92,14 +92,17 @@ absl::StatusOr<bool> CollectiveBackendAssigner::Run(
GPUCommunicationType comm_type,
GetCommunicationType(instr, num_visible_devices_per_process_,
gpu_version_));
int64_t shape_size = GetShapeSize(instr->shape());
VLOG(1) << "CollectiveBackendAssigner: comm_type="
<< static_cast<int>(comm_type)
<< " shape_size=" << GetShapeSize(instr->shape())
<< " threshold_in_bytes_=" << threshold_in_bytes_;
bool use_nvshmem = (num_visible_devices_per_process_ == 1 ||
comm_type == GPUCommunicationType::SINGLE_HOST) &&
(!IsAllReduceOp(instr) ||
GetShapeSize(instr->shape()) < threshold_in_bytes_);
<< static_cast<int>(comm_type) << " shape_size=" << shape_size
<< " threshold_in_bytes_=" << threshold_in_bytes_
<< " slice_size_=" << slice_size_;
bool use_nvshmem =
(num_visible_devices_per_process_ == 1 ||
comm_type == GPUCommunicationType::SINGLE_HOST ||
(slice_size_ > 0 &&
IsIntraNVLinkDomain(module->config(), slice_size_))) &&
(!IsAllReduceOp(instr) || shape_size < threshold_in_bytes_);
if (!use_nvshmem) {
continue;
}

View File

@ -43,11 +43,12 @@ class CollectiveBackendAssigner : public HloModulePass {
public:
explicit CollectiveBackendAssigner(
const se::GpuComputeCapability& gpu_version,
int num_visible_devices_per_process,
int num_visible_devices_per_process, int64_t slice_size = 0,
int64_t threshold_in_bytes = kDefaultThresholdInBytes)
: gpu_version_(gpu_version),
num_visible_devices_per_process_(num_visible_devices_per_process),
threshold_in_bytes_(threshold_in_bytes) {}
threshold_in_bytes_(threshold_in_bytes),
slice_size_(slice_size) {}
absl::string_view name() const override {
return "collective-backend-assigner";
@ -61,6 +62,7 @@ class CollectiveBackendAssigner : public HloModulePass {
se::GpuComputeCapability gpu_version_;
int num_visible_devices_per_process_;
int64_t threshold_in_bytes_;
int64_t slice_size_;
};
} // namespace gpu

View File

@ -38,10 +38,12 @@ using ::tsl::testing::IsOkAndHolds;
class CollectiveBackendAssignerTest : public HloHardwareIndependentTestBase {
protected:
absl::StatusOr<bool> RunCollectiveBackendAssigner(HloModule* module) {
absl::StatusOr<bool> RunCollectiveBackendAssigner(HloModule* module,
int num_devices_per_host,
int64_t slice_size = 0) {
se::GpuComputeCapability gpu_version = se::CudaComputeCapability(8, 0);
return RunHloPass(
CollectiveBackendAssigner(gpu_version, /*num_devices_per_host=*/1),
return RunHloPass(CollectiveBackendAssigner(
gpu_version, num_devices_per_host, slice_size),
module);
}
@ -70,7 +72,9 @@ TEST_F(CollectiveBackendAssignerTest, SmallAllReduceUsesNvshmem) {
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
EXPECT_THAT(RunCollectiveBackendAssigner(module.get()),
EXPECT_THAT(RunCollectiveBackendAssigner(
module.get(), /*num_devices_per_host=*/1, /*slice_size=*/0),
absl_testing::IsOkAndHolds(true));
const HloInstruction* all_reduce =
@ -96,7 +100,9 @@ TEST_F(CollectiveBackendAssignerTest, LargeAllReduceUsesDefault) {
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
EXPECT_THAT(RunCollectiveBackendAssigner(module.get()),
EXPECT_THAT(RunCollectiveBackendAssigner(
module.get(), /*num_devices_per_host=*/1, /*slice_size=*/0),
absl_testing::IsOkAndHolds(false));
const HloInstruction* all_reduce =
@ -117,7 +123,9 @@ TEST_F(CollectiveBackendAssignerTest, SmallCollectivePermuteUsesNvshmem) {
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
EXPECT_THAT(RunCollectiveBackendAssigner(module.get()),
EXPECT_THAT(RunCollectiveBackendAssigner(
module.get(), /*num_devices_per_host=*/1, /*slice_size=*/0),
absl_testing::IsOkAndHolds(true));
const HloInstruction* permute =
@ -138,7 +146,9 @@ TEST_F(CollectiveBackendAssignerTest, LargeCollectivePermuteUsesNvshmem) {
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
EXPECT_THAT(RunCollectiveBackendAssigner(module.get()),
EXPECT_THAT(RunCollectiveBackendAssigner(
module.get(), /*num_devices_per_host=*/1, /*slice_size=*/0),
absl_testing::IsOkAndHolds(true));
const HloInstruction* permute =
@ -147,6 +157,97 @@ TEST_F(CollectiveBackendAssignerTest, LargeCollectivePermuteUsesNvshmem) {
absl_testing::IsOkAndHolds(CollectiveBackendConfig::NVSHMEM));
}
TEST_F(CollectiveBackendAssignerTest, IntraNvlinkDomainUsesNvshmem) {
absl::string_view kHloText = R"(
HloModule m
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY main {
p0 = f32[1024,1024] parameter(0)
ROOT result = f32[1024,1024] all-reduce(p0), to_apply=add, replica_groups={{0,1}}, channel_id=5
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
module->mutable_config().set_num_partitions(2);
module->mutable_config().set_replica_count(2);
EXPECT_THAT(RunCollectiveBackendAssigner(
module.get(), /*num_devices_per_host=*/2, /*slice_size=*/4),
absl_testing::IsOkAndHolds(true));
const HloInstruction* all_reduce =
module->entry_computation()->root_instruction();
EXPECT_THAT(GetCollectiveBackendConfig(all_reduce),
absl_testing::IsOkAndHolds(CollectiveBackendConfig::NVSHMEM));
}
TEST_F(CollectiveBackendAssignerTest,
IntraNvlinkDomainLargeAllReduceUsesDefault) {
absl::string_view kHloText = R"(
HloModule m
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY main {
p0 = f32[8192,8192] parameter(0)
ROOT result = f32[8192,8192] all-reduce(p0), to_apply=add, replica_groups={{0,1}}, channel_id=8
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
module->mutable_config().set_num_partitions(2);
module->mutable_config().set_replica_count(2);
EXPECT_THAT(RunCollectiveBackendAssigner(
module.get(), /*num_devices_per_host=*/2, /*slice_size=*/4),
absl_testing::IsOkAndHolds(false));
const HloInstruction* all_reduce =
module->entry_computation()->root_instruction();
EXPECT_THAT(GetCollectiveBackendConfig(all_reduce),
absl_testing::IsOkAndHolds(CollectiveBackendConfig::DEFAULT));
}
TEST_F(CollectiveBackendAssignerTest, NonIntraNvlinkDomainUsesDefault) {
absl::string_view kHloText = R"(
HloModule m
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY main {
p0 = f32[1024,1024] parameter(0)
ROOT result = f32[1024,1024] all-reduce(p0), to_apply=add, channel_id=13
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
module->mutable_config().set_num_partitions(1);
module->mutable_config().set_replica_count(4);
EXPECT_THAT(RunCollectiveBackendAssigner(
module.get(), /*num_devices_per_host=*/2, /*slice_size=*/2),
absl_testing::IsOkAndHolds(false));
const HloInstruction* all_reduce =
module->entry_computation()->root_instruction();
EXPECT_THAT(GetCollectiveBackendConfig(all_reduce),
absl_testing::IsOkAndHolds(CollectiveBackendConfig::DEFAULT));
}
} // namespace
} // namespace gpu
} // namespace xla

View File

@ -186,5 +186,13 @@ absl::StatusOr<GPUCommunicationType> CommunicationType(
return GPUCommunicationType::UNDEFINED;
}
bool IsIntraNVLinkDomain(const HloModuleConfig& config, int64_t slice_size) {
int device_count = config.num_partitions() * config.replica_count();
bool is_intra = device_count <= slice_size;
VLOG(1) << "IsIntraNVLinkDomain: device_count=" << device_count
<< " slice_size=" << slice_size << " is_intra=" << is_intra;
return is_intra;
}
} // namespace gpu
} // namespace xla

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/service/hlo_module_config.h"
#include "xla/stream_executor/device_description.h"
namespace xla {
@ -46,6 +47,9 @@ absl::StatusOr<GPUCommunicationType> CommunicationType(
// Returns true if instruction is a synchronous collective op.
bool IsGPUSyncCollective(const HloInstruction& instr);
// Returns true if all devices are within the same NVLink domain (slice).
bool IsIntraNVLinkDomain(const HloModuleConfig& config, int64_t slice_size);
} // namespace gpu
} // namespace xla

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/transforms/collectives/collective_ops_utils.h"
#include "xla/service/hlo_module_config.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
#include "xla/stream_executor/device_description.h"
@ -76,17 +77,17 @@ bool EnableHeuristicCollectiveCombining(
if (!cc.IsAtLeastAmpere()) {
return false;
}
int hlo_device_count = config.num_partitions() * config.replica_count();
if (hlo_device_count <= nvlink_slice_size) {
if (IsIntraNVLinkDomain(config, nvlink_slice_size)) {
VLOG(1) << "Disabled heuristic collective combining for intra-NVLink "
"domain communication: HLO device count "
<< hlo_device_count << " <= NVLink slice size "
<< nvlink_slice_size;
<< (config.num_partitions() * config.replica_count())
<< " <= NVLink slice size " << nvlink_slice_size;
return false;
}
VLOG(1) << "Enabled heuristic collective combining for inter-NVLink domain "
"communication: HLO device count "
<< hlo_device_count << " > NVLink slice size " << nvlink_slice_size;
<< (config.num_partitions() * config.replica_count())
<< " > NVLink slice size " << nvlink_slice_size;
return true;
}