[XLA:GPU] Introduce CollectiveOpsE2ETestBase for common functionality.

We need a proper base class for common functionality. `CollectiveOpsTestE2E` is not a good base class, because it also holds a lot of fp8-specific helpers that are only used in a few tests.

PiperOrigin-RevId: 826017075
This commit is contained in:
Oleg Shyshkov 2025-10-30 07:09:28 -07:00 committed by TensorFlower Gardener
parent 689bf5ef28
commit fd71e8be05

View File

@ -112,9 +112,16 @@ Type CheckStatus(absl::StatusOr<Type> result) {
return *result;
}
class CollectiveOpsTestE2E : public HloHardwareIndependentTestBase {
bool IsAsync(const HloInstruction* inst) {
return !inst->backend_config<gpu::GpuBackendConfig>()
.value()
.collective_backend_config()
.is_sync();
}
class CollectiveOpsE2ETestBase : public HloHardwareIndependentTestBase {
public:
CollectiveOpsTestE2E() {
CollectiveOpsE2ETestBase() {
se::Platform* platform = CheckStatus(PlatformUtil::GetPlatform("GPU"));
se::Platform* reference_platform =
CheckStatus(PlatformUtil::GetPlatform("GPU"));
@ -144,54 +151,6 @@ class CollectiveOpsTestE2E : public HloHardwareIndependentTestBase {
std::move(allocators)));
reference_hlo_runner_ = std::make_unique<HloRunner>(
reference_platform, /*intra_op_parallelism_threads=*/0);
replacements_[kF8E4M3DatatypePlaceholder] =
Capability().IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz";
replacements_[kF8E5M2DatatypePlaceholder] =
Capability().IsCuda() ? "f8e5m2" : "f8e5m2fnuz";
}
const se::GpuComputeCapability& Capability() {
return hlo_runner_->backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability();
}
bool HasFp8Support() {
if (Capability().IsCuda()) {
return Capability().cuda_compute_capability()->IsAtLeast(8, 9);
}
return Capability().rocm_compute_capability()->has_fp8_support() &&
GetDebugOptionsForTest().xla_gpu_enable_cublaslt();
}
void CollectiveOpsVerifyF8Matmul(absl::string_view hlo_text,
const DebugOptions& options) {
if (!HasFp8Support()) {
return;
}
const int64_t kNumReplicas = 1;
const int64_t kNumPartitions = 4;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
config.set_debug_options(options);
config.set_num_partitions(kNumPartitions);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_text, config));
TF_ASSERT_OK_AND_ASSIGN(auto executable, hlo_runner_->CreateExecutable(
std::move(module),
/*run_hlo_passes=*/true));
TF_ASSERT_OK_AND_ASSIGN(
const HloModule* const hlo_module,
hlo_runner_->HloModuleFromWrapped(executable.get()));
std::vector<HloInstruction*> gemm_ops =
FindInstructions(hlo_module, HloOpcode::kCustomCall);
for (HloInstruction* gemm_op : gemm_ops) {
EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8");
}
}
// TODO(b/449655621) Use absl::AnyInvocable instead of std::function.
@ -263,17 +222,65 @@ class CollectiveOpsTestE2E : public HloHardwareIndependentTestBase {
num_replicas, /*run_hlo_passes=*/false, &device_assignment);
}
bool IsAsync(const HloInstruction* inst) {
return !inst->backend_config<gpu::GpuBackendConfig>()
.value()
.collective_backend_config()
.is_sync();
protected:
std::unique_ptr<HloRunner> hlo_runner_;
std::unique_ptr<HloRunner> reference_hlo_runner_;
};
class CollectiveOpsTestE2E : public CollectiveOpsE2ETestBase {
public:
CollectiveOpsTestE2E() {
replacements_[kF8E4M3DatatypePlaceholder] =
Capability().IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz";
replacements_[kF8E5M2DatatypePlaceholder] =
Capability().IsCuda() ? "f8e5m2" : "f8e5m2fnuz";
}
const se::GpuComputeCapability& Capability() {
return hlo_runner_->backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability();
}
bool HasFp8Support() {
if (Capability().IsCuda()) {
return Capability().cuda_compute_capability()->IsAtLeast(8, 9);
}
return Capability().rocm_compute_capability()->has_fp8_support() &&
GetDebugOptionsForTest().xla_gpu_enable_cublaslt();
}
void CollectiveOpsVerifyF8Matmul(absl::string_view hlo_text,
const DebugOptions& options) {
if (!HasFp8Support()) {
return;
}
const int64_t kNumReplicas = 1;
const int64_t kNumPartitions = 4;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
config.set_debug_options(options);
config.set_num_partitions(kNumPartitions);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_text, config));
TF_ASSERT_OK_AND_ASSIGN(auto executable, hlo_runner_->CreateExecutable(
std::move(module),
/*run_hlo_passes=*/true));
TF_ASSERT_OK_AND_ASSIGN(
const HloModule* const hlo_module,
hlo_runner_->HloModuleFromWrapped(executable.get()));
std::vector<HloInstruction*> gemm_ops =
FindInstructions(hlo_module, HloOpcode::kCustomCall);
for (HloInstruction* gemm_op : gemm_ops) {
EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8");
}
}
protected:
absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;
std::unique_ptr<HloRunner> hlo_runner_;
std::unique_ptr<HloRunner> reference_hlo_runner_;
private:
static constexpr const char* kF8E4M3DatatypePlaceholder{"<<F8E4M3>>"};
@ -287,7 +294,7 @@ class CollectiveOpsTestE2E : public HloHardwareIndependentTestBase {
// E2E test for collectives with flags set. Has constructor arguments specifying
// whether to enable/disable async collectives, and to set the memcpy_local_p2p
// flag. Subclasses pass in constructor arguments based on GetParam().
class CollectiveOpsWithFlagsBase : public CollectiveOpsTestE2E {
class CollectiveOpsWithFlagsBase : public CollectiveOpsE2ETestBase {
public:
CollectiveOpsWithFlagsBase(bool enable_async, bool enable_p2p_memcpy)
: enable_async_(enable_async), enable_p2p_memcpy_(enable_p2p_memcpy) {