diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 44cdb0899e8..fee384e0dab 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -975,7 +975,7 @@ absl::Status RunCollectiveOptimizationPasses( if (debug_options.xla_gpu_experimental_enable_nvshmem()) { collectives_pipeline.AddPass( - gpu_version, num_visible_devices_per_process); + gpu_version, num_visible_devices_per_process, options.slice_size); } if (debug_options.xla_gpu_unsupported_enable_ragged_all_to_all_decomposer()) { diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/BUILD b/third_party/xla/xla/service/gpu/transforms/collectives/BUILD index f088b3fce0f..caab9b91f0e 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/collectives/BUILD @@ -136,6 +136,7 @@ cc_library( srcs = ["gpu_collective_combiner_utils.cc"], hdrs = ["gpu_collective_combiner_utils.h"], deps = [ + ":collective_ops_utils", "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/collective_backend_assigner.cc b/third_party/xla/xla/service/gpu/transforms/collectives/collective_backend_assigner.cc index 344e399448c..6cb55d45d12 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/collective_backend_assigner.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/collective_backend_assigner.cc @@ -92,14 +92,17 @@ absl::StatusOr CollectiveBackendAssigner::Run( GPUCommunicationType comm_type, GetCommunicationType(instr, num_visible_devices_per_process_, gpu_version_)); + int64_t shape_size = GetShapeSize(instr->shape()); VLOG(1) << "CollectiveBackendAssigner: comm_type=" - << static_cast(comm_type) - << " shape_size=" << GetShapeSize(instr->shape()) - << " threshold_in_bytes_=" << threshold_in_bytes_; - bool use_nvshmem = (num_visible_devices_per_process_ == 1 || - comm_type == GPUCommunicationType::SINGLE_HOST) && - (!IsAllReduceOp(instr) || - GetShapeSize(instr->shape()) < threshold_in_bytes_); + << static_cast(comm_type) << " shape_size=" << shape_size + << " threshold_in_bytes_=" << threshold_in_bytes_ + << " slice_size_=" << slice_size_; + bool use_nvshmem = + (num_visible_devices_per_process_ == 1 || + comm_type == GPUCommunicationType::SINGLE_HOST || + (slice_size_ > 0 && + IsIntraNVLinkDomain(module->config(), slice_size_))) && + (!IsAllReduceOp(instr) || shape_size < threshold_in_bytes_); if (!use_nvshmem) { continue; } diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/collective_backend_assigner.h b/third_party/xla/xla/service/gpu/transforms/collectives/collective_backend_assigner.h index 2f0838ce535..a47a00e4cbb 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/collective_backend_assigner.h +++ b/third_party/xla/xla/service/gpu/transforms/collectives/collective_backend_assigner.h @@ -43,11 +43,12 @@ class CollectiveBackendAssigner : public HloModulePass { public: explicit CollectiveBackendAssigner( const se::GpuComputeCapability& gpu_version, - int num_visible_devices_per_process, + int num_visible_devices_per_process, int64_t slice_size = 0, int64_t threshold_in_bytes = kDefaultThresholdInBytes) : gpu_version_(gpu_version), num_visible_devices_per_process_(num_visible_devices_per_process), - threshold_in_bytes_(threshold_in_bytes) {} + threshold_in_bytes_(threshold_in_bytes), + slice_size_(slice_size) {} absl::string_view name() const override { return "collective-backend-assigner"; @@ -61,6 +62,7 @@ class CollectiveBackendAssigner : public HloModulePass { se::GpuComputeCapability gpu_version_; int num_visible_devices_per_process_; int64_t threshold_in_bytes_; + int64_t slice_size_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/collective_backend_assigner_test.cc b/third_party/xla/xla/service/gpu/transforms/collectives/collective_backend_assigner_test.cc index fef57081238..3a8b20d1d68 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/collective_backend_assigner_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/collective_backend_assigner_test.cc @@ -38,11 +38,13 @@ using ::tsl::testing::IsOkAndHolds; class CollectiveBackendAssignerTest : public HloHardwareIndependentTestBase { protected: - absl::StatusOr RunCollectiveBackendAssigner(HloModule* module) { + absl::StatusOr RunCollectiveBackendAssigner(HloModule* module, + int num_devices_per_host, + int64_t slice_size = 0) { se::GpuComputeCapability gpu_version = se::CudaComputeCapability(8, 0); - return RunHloPass( - CollectiveBackendAssigner(gpu_version, /*num_devices_per_host=*/1), - module); + return RunHloPass(CollectiveBackendAssigner( + gpu_version, num_devices_per_host, slice_size), + module); } absl::StatusOr @@ -70,7 +72,9 @@ TEST_F(CollectiveBackendAssignerTest, SmallAllReduceUsesNvshmem) { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); - EXPECT_THAT(RunCollectiveBackendAssigner(module.get()), + + EXPECT_THAT(RunCollectiveBackendAssigner( + module.get(), /*num_devices_per_host=*/1, /*slice_size=*/0), absl_testing::IsOkAndHolds(true)); const HloInstruction* all_reduce = @@ -96,7 +100,9 @@ TEST_F(CollectiveBackendAssignerTest, LargeAllReduceUsesDefault) { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); - EXPECT_THAT(RunCollectiveBackendAssigner(module.get()), + + EXPECT_THAT(RunCollectiveBackendAssigner( + module.get(), /*num_devices_per_host=*/1, /*slice_size=*/0), absl_testing::IsOkAndHolds(false)); const HloInstruction* all_reduce = @@ -117,7 +123,9 @@ TEST_F(CollectiveBackendAssignerTest, SmallCollectivePermuteUsesNvshmem) { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); - EXPECT_THAT(RunCollectiveBackendAssigner(module.get()), + + EXPECT_THAT(RunCollectiveBackendAssigner( + module.get(), /*num_devices_per_host=*/1, /*slice_size=*/0), absl_testing::IsOkAndHolds(true)); const HloInstruction* permute = @@ -138,7 +146,9 @@ TEST_F(CollectiveBackendAssignerTest, LargeCollectivePermuteUsesNvshmem) { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); - EXPECT_THAT(RunCollectiveBackendAssigner(module.get()), + + EXPECT_THAT(RunCollectiveBackendAssigner( + module.get(), /*num_devices_per_host=*/1, /*slice_size=*/0), absl_testing::IsOkAndHolds(true)); const HloInstruction* permute = @@ -147,6 +157,97 @@ TEST_F(CollectiveBackendAssignerTest, LargeCollectivePermuteUsesNvshmem) { absl_testing::IsOkAndHolds(CollectiveBackendConfig::NVSHMEM)); } +TEST_F(CollectiveBackendAssignerTest, IntraNvlinkDomainUsesNvshmem) { + absl::string_view kHloText = R"( + HloModule m + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY main { + p0 = f32[1024,1024] parameter(0) + ROOT result = f32[1024,1024] all-reduce(p0), to_apply=add, replica_groups={{0,1}}, channel_id=5 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + module->mutable_config().set_num_partitions(2); + module->mutable_config().set_replica_count(2); + + EXPECT_THAT(RunCollectiveBackendAssigner( + module.get(), /*num_devices_per_host=*/2, /*slice_size=*/4), + absl_testing::IsOkAndHolds(true)); + + const HloInstruction* all_reduce = + module->entry_computation()->root_instruction(); + EXPECT_THAT(GetCollectiveBackendConfig(all_reduce), + absl_testing::IsOkAndHolds(CollectiveBackendConfig::NVSHMEM)); +} + +TEST_F(CollectiveBackendAssignerTest, + IntraNvlinkDomainLargeAllReduceUsesDefault) { + absl::string_view kHloText = R"( + HloModule m + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY main { + p0 = f32[8192,8192] parameter(0) + ROOT result = f32[8192,8192] all-reduce(p0), to_apply=add, replica_groups={{0,1}}, channel_id=8 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + module->mutable_config().set_num_partitions(2); + module->mutable_config().set_replica_count(2); + + EXPECT_THAT(RunCollectiveBackendAssigner( + module.get(), /*num_devices_per_host=*/2, /*slice_size=*/4), + absl_testing::IsOkAndHolds(false)); + + const HloInstruction* all_reduce = + module->entry_computation()->root_instruction(); + EXPECT_THAT(GetCollectiveBackendConfig(all_reduce), + absl_testing::IsOkAndHolds(CollectiveBackendConfig::DEFAULT)); +} + +TEST_F(CollectiveBackendAssignerTest, NonIntraNvlinkDomainUsesDefault) { + absl::string_view kHloText = R"( + HloModule m + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY main { + p0 = f32[1024,1024] parameter(0) + ROOT result = f32[1024,1024] all-reduce(p0), to_apply=add, channel_id=13 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + module->mutable_config().set_num_partitions(1); + module->mutable_config().set_replica_count(4); + + EXPECT_THAT(RunCollectiveBackendAssigner( + module.get(), /*num_devices_per_host=*/2, /*slice_size=*/2), + absl_testing::IsOkAndHolds(false)); + + const HloInstruction* all_reduce = + module->entry_computation()->root_instruction(); + EXPECT_THAT(GetCollectiveBackendConfig(all_reduce), + absl_testing::IsOkAndHolds(CollectiveBackendConfig::DEFAULT)); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc index 54fc006c3ad..f4bdf504e8e 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.cc @@ -186,5 +186,13 @@ absl::StatusOr CommunicationType( return GPUCommunicationType::UNDEFINED; } +bool IsIntraNVLinkDomain(const HloModuleConfig& config, int64_t slice_size) { + int device_count = config.num_partitions() * config.replica_count(); + bool is_intra = device_count <= slice_size; + VLOG(1) << "IsIntraNVLinkDomain: device_count=" << device_count + << " slice_size=" << slice_size << " is_intra=" << is_intra; + return is_intra; +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.h b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.h index 719940c9e4a..c4ed3803a57 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.h +++ b/third_party/xla/xla/service/gpu/transforms/collectives/collective_ops_utils.h @@ -19,6 +19,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/hlo_module_config.h" #include "xla/stream_executor/device_description.h" namespace xla { @@ -46,6 +47,9 @@ absl::StatusOr CommunicationType( // Returns true if instruction is a synchronous collective op. bool IsGPUSyncCollective(const HloInstruction& instr); +// Returns true if all devices are within the same NVLink domain (slice). +bool IsIntraNVLinkDomain(const HloModuleConfig& config, int64_t slice_size); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.cc b/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.cc index 1c99bd414fb..52ea29f9a7a 100644 --- a/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.cc +++ b/third_party/xla/xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/transforms/collectives/collective_ops_utils.h" #include "xla/service/hlo_module_config.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" @@ -76,17 +77,17 @@ bool EnableHeuristicCollectiveCombining( if (!cc.IsAtLeastAmpere()) { return false; } - int hlo_device_count = config.num_partitions() * config.replica_count(); - if (hlo_device_count <= nvlink_slice_size) { + if (IsIntraNVLinkDomain(config, nvlink_slice_size)) { VLOG(1) << "Disabled heuristic collective combining for intra-NVLink " "domain communication: HLO device count " - << hlo_device_count << " <= NVLink slice size " - << nvlink_slice_size; + << (config.num_partitions() * config.replica_count()) + << " <= NVLink slice size " << nvlink_slice_size; return false; } VLOG(1) << "Enabled heuristic collective combining for inter-NVLink domain " "communication: HLO device count " - << hlo_device_count << " > NVLink slice size " << nvlink_slice_size; + << (config.num_partitions() * config.replica_count()) + << " > NVLink slice size " << nvlink_slice_size; return true; }