mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:GPU] Introduce CollectiveOpsE2ETestBase for common functionality.
We need a proper base class for common functionality. `CollectiveOpsTestE2E` is not a good base class, because it also holds a lot of fp8-specific helpers that are only used in a few tests. PiperOrigin-RevId: 826017075
This commit is contained in:
parent
689bf5ef28
commit
fd71e8be05
123
third_party/xla/xla/tests/collective_ops_e2e_test.cc
vendored
123
third_party/xla/xla/tests/collective_ops_e2e_test.cc
vendored
|
|
@ -112,9 +112,16 @@ Type CheckStatus(absl::StatusOr<Type> result) {
|
|||
return *result;
|
||||
}
|
||||
|
||||
class CollectiveOpsTestE2E : public HloHardwareIndependentTestBase {
|
||||
bool IsAsync(const HloInstruction* inst) {
|
||||
return !inst->backend_config<gpu::GpuBackendConfig>()
|
||||
.value()
|
||||
.collective_backend_config()
|
||||
.is_sync();
|
||||
}
|
||||
|
||||
class CollectiveOpsE2ETestBase : public HloHardwareIndependentTestBase {
|
||||
public:
|
||||
CollectiveOpsTestE2E() {
|
||||
CollectiveOpsE2ETestBase() {
|
||||
se::Platform* platform = CheckStatus(PlatformUtil::GetPlatform("GPU"));
|
||||
se::Platform* reference_platform =
|
||||
CheckStatus(PlatformUtil::GetPlatform("GPU"));
|
||||
|
|
@ -144,54 +151,6 @@ class CollectiveOpsTestE2E : public HloHardwareIndependentTestBase {
|
|||
std::move(allocators)));
|
||||
reference_hlo_runner_ = std::make_unique<HloRunner>(
|
||||
reference_platform, /*intra_op_parallelism_threads=*/0);
|
||||
|
||||
replacements_[kF8E4M3DatatypePlaceholder] =
|
||||
Capability().IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz";
|
||||
replacements_[kF8E5M2DatatypePlaceholder] =
|
||||
Capability().IsCuda() ? "f8e5m2" : "f8e5m2fnuz";
|
||||
}
|
||||
|
||||
const se::GpuComputeCapability& Capability() {
|
||||
return hlo_runner_->backend()
|
||||
.default_stream_executor()
|
||||
->GetDeviceDescription()
|
||||
.gpu_compute_capability();
|
||||
}
|
||||
|
||||
bool HasFp8Support() {
|
||||
if (Capability().IsCuda()) {
|
||||
return Capability().cuda_compute_capability()->IsAtLeast(8, 9);
|
||||
}
|
||||
return Capability().rocm_compute_capability()->has_fp8_support() &&
|
||||
GetDebugOptionsForTest().xla_gpu_enable_cublaslt();
|
||||
}
|
||||
|
||||
void CollectiveOpsVerifyF8Matmul(absl::string_view hlo_text,
|
||||
const DebugOptions& options) {
|
||||
if (!HasFp8Support()) {
|
||||
return;
|
||||
}
|
||||
const int64_t kNumReplicas = 1;
|
||||
const int64_t kNumPartitions = 4;
|
||||
|
||||
HloModuleConfig config =
|
||||
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
|
||||
config.set_debug_options(options);
|
||||
config.set_num_partitions(kNumPartitions);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_text, config));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto executable, hlo_runner_->CreateExecutable(
|
||||
std::move(module),
|
||||
/*run_hlo_passes=*/true));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
const HloModule* const hlo_module,
|
||||
hlo_runner_->HloModuleFromWrapped(executable.get()));
|
||||
std::vector<HloInstruction*> gemm_ops =
|
||||
FindInstructions(hlo_module, HloOpcode::kCustomCall);
|
||||
for (HloInstruction* gemm_op : gemm_ops) {
|
||||
EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8");
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(b/449655621) Use absl::AnyInvocable instead of std::function.
|
||||
|
|
@ -263,17 +222,65 @@ class CollectiveOpsTestE2E : public HloHardwareIndependentTestBase {
|
|||
num_replicas, /*run_hlo_passes=*/false, &device_assignment);
|
||||
}
|
||||
|
||||
bool IsAsync(const HloInstruction* inst) {
|
||||
return !inst->backend_config<gpu::GpuBackendConfig>()
|
||||
.value()
|
||||
.collective_backend_config()
|
||||
.is_sync();
|
||||
protected:
|
||||
std::unique_ptr<HloRunner> hlo_runner_;
|
||||
std::unique_ptr<HloRunner> reference_hlo_runner_;
|
||||
};
|
||||
|
||||
class CollectiveOpsTestE2E : public CollectiveOpsE2ETestBase {
|
||||
public:
|
||||
CollectiveOpsTestE2E() {
|
||||
replacements_[kF8E4M3DatatypePlaceholder] =
|
||||
Capability().IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz";
|
||||
replacements_[kF8E5M2DatatypePlaceholder] =
|
||||
Capability().IsCuda() ? "f8e5m2" : "f8e5m2fnuz";
|
||||
}
|
||||
|
||||
const se::GpuComputeCapability& Capability() {
|
||||
return hlo_runner_->backend()
|
||||
.default_stream_executor()
|
||||
->GetDeviceDescription()
|
||||
.gpu_compute_capability();
|
||||
}
|
||||
|
||||
bool HasFp8Support() {
|
||||
if (Capability().IsCuda()) {
|
||||
return Capability().cuda_compute_capability()->IsAtLeast(8, 9);
|
||||
}
|
||||
return Capability().rocm_compute_capability()->has_fp8_support() &&
|
||||
GetDebugOptionsForTest().xla_gpu_enable_cublaslt();
|
||||
}
|
||||
|
||||
void CollectiveOpsVerifyF8Matmul(absl::string_view hlo_text,
|
||||
const DebugOptions& options) {
|
||||
if (!HasFp8Support()) {
|
||||
return;
|
||||
}
|
||||
const int64_t kNumReplicas = 1;
|
||||
const int64_t kNumPartitions = 4;
|
||||
|
||||
HloModuleConfig config =
|
||||
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
|
||||
config.set_debug_options(options);
|
||||
config.set_num_partitions(kNumPartitions);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_text, config));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto executable, hlo_runner_->CreateExecutable(
|
||||
std::move(module),
|
||||
/*run_hlo_passes=*/true));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
const HloModule* const hlo_module,
|
||||
hlo_runner_->HloModuleFromWrapped(executable.get()));
|
||||
std::vector<HloInstruction*> gemm_ops =
|
||||
FindInstructions(hlo_module, HloOpcode::kCustomCall);
|
||||
for (HloInstruction* gemm_op : gemm_ops) {
|
||||
EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8");
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;
|
||||
std::unique_ptr<HloRunner> hlo_runner_;
|
||||
std::unique_ptr<HloRunner> reference_hlo_runner_;
|
||||
|
||||
private:
|
||||
static constexpr const char* kF8E4M3DatatypePlaceholder{"<<F8E4M3>>"};
|
||||
|
|
@ -287,7 +294,7 @@ class CollectiveOpsTestE2E : public HloHardwareIndependentTestBase {
|
|||
// E2E test for collectives with flags set. Has constructor arguments specifying
|
||||
// whether to enable/disable async collectives, and to set the memcpy_local_p2p
|
||||
// flag. Subclasses pass in constructor arguments based on GetParam().
|
||||
class CollectiveOpsWithFlagsBase : public CollectiveOpsTestE2E {
|
||||
class CollectiveOpsWithFlagsBase : public CollectiveOpsE2ETestBase {
|
||||
public:
|
||||
CollectiveOpsWithFlagsBase(bool enable_async, bool enable_p2p_memcpy)
|
||||
: enable_async_(enable_async), enable_p2p_memcpy_(enable_p2p_memcpy) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user