diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc index 4761af16f6e..c2aba137cc5 100644 --- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc +++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc @@ -112,9 +112,16 @@ Type CheckStatus(absl::StatusOr result) { return *result; } -class CollectiveOpsTestE2E : public HloHardwareIndependentTestBase { +bool IsAsync(const HloInstruction* inst) { + return !inst->backend_config() + .value() + .collective_backend_config() + .is_sync(); +} + +class CollectiveOpsE2ETestBase : public HloHardwareIndependentTestBase { public: - CollectiveOpsTestE2E() { + CollectiveOpsE2ETestBase() { se::Platform* platform = CheckStatus(PlatformUtil::GetPlatform("GPU")); se::Platform* reference_platform = CheckStatus(PlatformUtil::GetPlatform("GPU")); @@ -144,54 +151,6 @@ class CollectiveOpsTestE2E : public HloHardwareIndependentTestBase { std::move(allocators))); reference_hlo_runner_ = std::make_unique( reference_platform, /*intra_op_parallelism_threads=*/0); - - replacements_[kF8E4M3DatatypePlaceholder] = - Capability().IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz"; - replacements_[kF8E5M2DatatypePlaceholder] = - Capability().IsCuda() ? "f8e5m2" : "f8e5m2fnuz"; - } - - const se::GpuComputeCapability& Capability() { - return hlo_runner_->backend() - .default_stream_executor() - ->GetDeviceDescription() - .gpu_compute_capability(); - } - - bool HasFp8Support() { - if (Capability().IsCuda()) { - return Capability().cuda_compute_capability()->IsAtLeast(8, 9); - } - return Capability().rocm_compute_capability()->has_fp8_support() && - GetDebugOptionsForTest().xla_gpu_enable_cublaslt(); - } - - void CollectiveOpsVerifyF8Matmul(absl::string_view hlo_text, - const DebugOptions& options) { - if (!HasFp8Support()) { - return; - } - const int64_t kNumReplicas = 1; - const int64_t kNumPartitions = 4; - - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - config.set_debug_options(options); - config.set_num_partitions(kNumPartitions); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_text, config)); - - TF_ASSERT_OK_AND_ASSIGN(auto executable, hlo_runner_->CreateExecutable( - std::move(module), - /*run_hlo_passes=*/true)); - TF_ASSERT_OK_AND_ASSIGN( - const HloModule* const hlo_module, - hlo_runner_->HloModuleFromWrapped(executable.get())); - std::vector gemm_ops = - FindInstructions(hlo_module, HloOpcode::kCustomCall); - for (HloInstruction* gemm_op : gemm_ops) { - EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8"); - } } // TODO(b/449655621) Use absl::AnyInvocable instead of std::function. @@ -263,17 +222,65 @@ class CollectiveOpsTestE2E : public HloHardwareIndependentTestBase { num_replicas, /*run_hlo_passes=*/false, &device_assignment); } - bool IsAsync(const HloInstruction* inst) { - return !inst->backend_config() - .value() - .collective_backend_config() - .is_sync(); + protected: + std::unique_ptr hlo_runner_; + std::unique_ptr reference_hlo_runner_; +}; + +class CollectiveOpsTestE2E : public CollectiveOpsE2ETestBase { + public: + CollectiveOpsTestE2E() { + replacements_[kF8E4M3DatatypePlaceholder] = + Capability().IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz"; + replacements_[kF8E5M2DatatypePlaceholder] = + Capability().IsCuda() ? "f8e5m2" : "f8e5m2fnuz"; + } + + const se::GpuComputeCapability& Capability() { + return hlo_runner_->backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + + bool HasFp8Support() { + if (Capability().IsCuda()) { + return Capability().cuda_compute_capability()->IsAtLeast(8, 9); + } + return Capability().rocm_compute_capability()->has_fp8_support() && + GetDebugOptionsForTest().xla_gpu_enable_cublaslt(); + } + + void CollectiveOpsVerifyF8Matmul(absl::string_view hlo_text, + const DebugOptions& options) { + if (!HasFp8Support()) { + return; + } + const int64_t kNumReplicas = 1; + const int64_t kNumPartitions = 4; + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + config.set_debug_options(options); + config.set_num_partitions(kNumPartitions); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_text, config)); + + TF_ASSERT_OK_AND_ASSIGN(auto executable, hlo_runner_->CreateExecutable( + std::move(module), + /*run_hlo_passes=*/true)); + TF_ASSERT_OK_AND_ASSIGN( + const HloModule* const hlo_module, + hlo_runner_->HloModuleFromWrapped(executable.get())); + std::vector gemm_ops = + FindInstructions(hlo_module, HloOpcode::kCustomCall); + for (HloInstruction* gemm_op : gemm_ops) { + EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8"); + } } protected: absl::flat_hash_map replacements_; - std::unique_ptr hlo_runner_; - std::unique_ptr reference_hlo_runner_; private: static constexpr const char* kF8E4M3DatatypePlaceholder{"<>"}; @@ -287,7 +294,7 @@ class CollectiveOpsTestE2E : public HloHardwareIndependentTestBase { // E2E test for collectives with flags set. Has constructor arguments specifying // whether to enable/disable async collectives, and to set the memcpy_local_p2p // flag. Subclasses pass in constructor arguments based on GetParam(). -class CollectiveOpsWithFlagsBase : public CollectiveOpsTestE2E { +class CollectiveOpsWithFlagsBase : public CollectiveOpsE2ETestBase { public: CollectiveOpsWithFlagsBase(bool enable_async, bool enable_p2p_memcpy) : enable_async_(enable_async), enable_p2p_memcpy_(enable_p2p_memcpy) {