mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:GPU] Introduce CollectiveOpsE2ETestBase for common functionality.
We need a proper base class for common functionality. `CollectiveOpsTestE2E` is not a good base class, because it also holds a lot of fp8-specific helpers that are only used in a few tests. PiperOrigin-RevId: 826017075
This commit is contained in:
parent
689bf5ef28
commit
fd71e8be05
123
third_party/xla/xla/tests/collective_ops_e2e_test.cc
vendored
123
third_party/xla/xla/tests/collective_ops_e2e_test.cc
vendored
|
|
@ -112,9 +112,16 @@ Type CheckStatus(absl::StatusOr<Type> result) {
|
||||||
return *result;
|
return *result;
|
||||||
}
|
}
|
||||||
|
|
||||||
class CollectiveOpsTestE2E : public HloHardwareIndependentTestBase {
|
bool IsAsync(const HloInstruction* inst) {
|
||||||
|
return !inst->backend_config<gpu::GpuBackendConfig>()
|
||||||
|
.value()
|
||||||
|
.collective_backend_config()
|
||||||
|
.is_sync();
|
||||||
|
}
|
||||||
|
|
||||||
|
class CollectiveOpsE2ETestBase : public HloHardwareIndependentTestBase {
|
||||||
public:
|
public:
|
||||||
CollectiveOpsTestE2E() {
|
CollectiveOpsE2ETestBase() {
|
||||||
se::Platform* platform = CheckStatus(PlatformUtil::GetPlatform("GPU"));
|
se::Platform* platform = CheckStatus(PlatformUtil::GetPlatform("GPU"));
|
||||||
se::Platform* reference_platform =
|
se::Platform* reference_platform =
|
||||||
CheckStatus(PlatformUtil::GetPlatform("GPU"));
|
CheckStatus(PlatformUtil::GetPlatform("GPU"));
|
||||||
|
|
@ -144,54 +151,6 @@ class CollectiveOpsTestE2E : public HloHardwareIndependentTestBase {
|
||||||
std::move(allocators)));
|
std::move(allocators)));
|
||||||
reference_hlo_runner_ = std::make_unique<HloRunner>(
|
reference_hlo_runner_ = std::make_unique<HloRunner>(
|
||||||
reference_platform, /*intra_op_parallelism_threads=*/0);
|
reference_platform, /*intra_op_parallelism_threads=*/0);
|
||||||
|
|
||||||
replacements_[kF8E4M3DatatypePlaceholder] =
|
|
||||||
Capability().IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz";
|
|
||||||
replacements_[kF8E5M2DatatypePlaceholder] =
|
|
||||||
Capability().IsCuda() ? "f8e5m2" : "f8e5m2fnuz";
|
|
||||||
}
|
|
||||||
|
|
||||||
const se::GpuComputeCapability& Capability() {
|
|
||||||
return hlo_runner_->backend()
|
|
||||||
.default_stream_executor()
|
|
||||||
->GetDeviceDescription()
|
|
||||||
.gpu_compute_capability();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool HasFp8Support() {
|
|
||||||
if (Capability().IsCuda()) {
|
|
||||||
return Capability().cuda_compute_capability()->IsAtLeast(8, 9);
|
|
||||||
}
|
|
||||||
return Capability().rocm_compute_capability()->has_fp8_support() &&
|
|
||||||
GetDebugOptionsForTest().xla_gpu_enable_cublaslt();
|
|
||||||
}
|
|
||||||
|
|
||||||
void CollectiveOpsVerifyF8Matmul(absl::string_view hlo_text,
|
|
||||||
const DebugOptions& options) {
|
|
||||||
if (!HasFp8Support()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const int64_t kNumReplicas = 1;
|
|
||||||
const int64_t kNumPartitions = 4;
|
|
||||||
|
|
||||||
HloModuleConfig config =
|
|
||||||
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
|
|
||||||
config.set_debug_options(options);
|
|
||||||
config.set_num_partitions(kNumPartitions);
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
|
||||||
ParseAndReturnVerifiedModule(hlo_text, config));
|
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto executable, hlo_runner_->CreateExecutable(
|
|
||||||
std::move(module),
|
|
||||||
/*run_hlo_passes=*/true));
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
|
||||||
const HloModule* const hlo_module,
|
|
||||||
hlo_runner_->HloModuleFromWrapped(executable.get()));
|
|
||||||
std::vector<HloInstruction*> gemm_ops =
|
|
||||||
FindInstructions(hlo_module, HloOpcode::kCustomCall);
|
|
||||||
for (HloInstruction* gemm_op : gemm_ops) {
|
|
||||||
EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(b/449655621) Use absl::AnyInvocable instead of std::function.
|
// TODO(b/449655621) Use absl::AnyInvocable instead of std::function.
|
||||||
|
|
@ -263,17 +222,65 @@ class CollectiveOpsTestE2E : public HloHardwareIndependentTestBase {
|
||||||
num_replicas, /*run_hlo_passes=*/false, &device_assignment);
|
num_replicas, /*run_hlo_passes=*/false, &device_assignment);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsAsync(const HloInstruction* inst) {
|
protected:
|
||||||
return !inst->backend_config<gpu::GpuBackendConfig>()
|
std::unique_ptr<HloRunner> hlo_runner_;
|
||||||
.value()
|
std::unique_ptr<HloRunner> reference_hlo_runner_;
|
||||||
.collective_backend_config()
|
};
|
||||||
.is_sync();
|
|
||||||
|
class CollectiveOpsTestE2E : public CollectiveOpsE2ETestBase {
|
||||||
|
public:
|
||||||
|
CollectiveOpsTestE2E() {
|
||||||
|
replacements_[kF8E4M3DatatypePlaceholder] =
|
||||||
|
Capability().IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz";
|
||||||
|
replacements_[kF8E5M2DatatypePlaceholder] =
|
||||||
|
Capability().IsCuda() ? "f8e5m2" : "f8e5m2fnuz";
|
||||||
|
}
|
||||||
|
|
||||||
|
const se::GpuComputeCapability& Capability() {
|
||||||
|
return hlo_runner_->backend()
|
||||||
|
.default_stream_executor()
|
||||||
|
->GetDeviceDescription()
|
||||||
|
.gpu_compute_capability();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HasFp8Support() {
|
||||||
|
if (Capability().IsCuda()) {
|
||||||
|
return Capability().cuda_compute_capability()->IsAtLeast(8, 9);
|
||||||
|
}
|
||||||
|
return Capability().rocm_compute_capability()->has_fp8_support() &&
|
||||||
|
GetDebugOptionsForTest().xla_gpu_enable_cublaslt();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CollectiveOpsVerifyF8Matmul(absl::string_view hlo_text,
|
||||||
|
const DebugOptions& options) {
|
||||||
|
if (!HasFp8Support()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const int64_t kNumReplicas = 1;
|
||||||
|
const int64_t kNumPartitions = 4;
|
||||||
|
|
||||||
|
HloModuleConfig config =
|
||||||
|
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
|
||||||
|
config.set_debug_options(options);
|
||||||
|
config.set_num_partitions(kNumPartitions);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_text, config));
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto executable, hlo_runner_->CreateExecutable(
|
||||||
|
std::move(module),
|
||||||
|
/*run_hlo_passes=*/true));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
const HloModule* const hlo_module,
|
||||||
|
hlo_runner_->HloModuleFromWrapped(executable.get()));
|
||||||
|
std::vector<HloInstruction*> gemm_ops =
|
||||||
|
FindInstructions(hlo_module, HloOpcode::kCustomCall);
|
||||||
|
for (HloInstruction* gemm_op : gemm_ops) {
|
||||||
|
EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;
|
absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;
|
||||||
std::unique_ptr<HloRunner> hlo_runner_;
|
|
||||||
std::unique_ptr<HloRunner> reference_hlo_runner_;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static constexpr const char* kF8E4M3DatatypePlaceholder{"<<F8E4M3>>"};
|
static constexpr const char* kF8E4M3DatatypePlaceholder{"<<F8E4M3>>"};
|
||||||
|
|
@ -287,7 +294,7 @@ class CollectiveOpsTestE2E : public HloHardwareIndependentTestBase {
|
||||||
// E2E test for collectives with flags set. Has constructor arguments specifying
|
// E2E test for collectives with flags set. Has constructor arguments specifying
|
||||||
// whether to enable/disable async collectives, and to set the memcpy_local_p2p
|
// whether to enable/disable async collectives, and to set the memcpy_local_p2p
|
||||||
// flag. Subclasses pass in constructor arguments based on GetParam().
|
// flag. Subclasses pass in constructor arguments based on GetParam().
|
||||||
class CollectiveOpsWithFlagsBase : public CollectiveOpsTestE2E {
|
class CollectiveOpsWithFlagsBase : public CollectiveOpsE2ETestBase {
|
||||||
public:
|
public:
|
||||||
CollectiveOpsWithFlagsBase(bool enable_async, bool enable_p2p_memcpy)
|
CollectiveOpsWithFlagsBase(bool enable_async, bool enable_p2p_memcpy)
|
||||||
: enable_async_(enable_async), enable_p2p_memcpy_(enable_p2p_memcpy) {
|
: enable_async_(enable_async), enable_p2p_memcpy_(enable_p2p_memcpy) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user