Port to new GpuComputeCapability API. Last part

PiperOrigin-RevId: 822676102
This commit is contained in:
Maxim Ermilov 2025-10-22 11:41:58 -07:00 committed by TensorFlower Gardener
parent 3503a61282
commit 1b08f96abf
12 changed files with 90 additions and 92 deletions

View File

@ -85,8 +85,8 @@ se::StreamExecutor* GpuExecutor() {
bool IsAtLeastCuda12900(const se::StreamExecutor* stream_executor) {
const auto& device_description = stream_executor->GetDeviceDescription();
const auto* cuda_cc = std::get_if<se::CudaComputeCapability>(
&device_description.gpu_compute_capability());
const auto* cuda_cc =
device_description.gpu_compute_capability().cuda_compute_capability();
if (cuda_cc != nullptr) {
if (device_description.driver_version() >=
stream_executor::SemanticVersion(12, 9, 0) &&

View File

@ -55,12 +55,12 @@ std::string AutotuneCacheKey::HloInstructionToCanonicalString(
std::string AutotuneCacheKey::DeviceDescriptionToCacheKey(
const se::DeviceDescription& device_description) {
std::string compute_capability;
if (auto* ccc = std::get_if<se::CudaComputeCapability>(
&device_description.gpu_compute_capability())) {
if (auto* ccc = device_description.gpu_compute_capability()
.cuda_compute_capability()) {
compute_capability = absl::StrCat("CUDA: ", ccc->major, ".", ccc->minor);
} else {
auto* rcc = std::get_if<se::RocmComputeCapability>(
&device_description.gpu_compute_capability());
auto* rcc =
device_description.gpu_compute_capability().rocm_compute_capability();
CHECK(rcc != nullptr) << "Unknown compute capability type";
compute_capability = absl::StrCat("ROCM: ", rcc->gfx_version());
}

View File

@ -130,9 +130,8 @@ ENTRY main {
// Algorithm 14 is disabled for cuDNN 9 on V100
TF_ASSERT_OK_AND_ASSIGN(auto dnn_version, GetDnnVersionInfo(stream_exec));
if (dnn_version.major_version() >= 9 && dnn_version.major_version() < 10 &&
std::holds_alternative<stream_executor::CudaComputeCapability>(cc) &&
std::get<stream_executor::CudaComputeCapability>(cc).major == 7 &&
std::get<stream_executor::CudaComputeCapability>(cc).minor == 0) {
cc.IsCuda() && cc.cuda_compute_capability()->major == 7 &&
cc.cuda_compute_capability()->minor == 0) {
EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
->has_cudnn_conv_backend_config() &&
conv->backend_config<GpuBackendConfig>()

View File

@ -201,11 +201,6 @@ class GemmFusionAutotunerImpl {
return config_.GetGpuComputeCapability();
}
bool isRocm() const {
return std::holds_alternative<se::RocmComputeCapability>(
GetComputeCapability());
}
bool AddLibConfigs(const HloFusionInstruction& fusion,
const HloDotInstruction* dot,
std::vector<BackendConfig>& configs);

View File

@ -52,7 +52,8 @@ bool GemmFusionAutotunerImpl::AddLibConfigs(
const HloFusionInstruction& fusion, const HloDotInstruction* dot,
std::vector<BackendConfig>& configs) {
// Add cuDNN plans, if available.
auto cc = std::get<se::CudaComputeCapability>(GetComputeCapability());
stream_executor::CudaComputeCapability cc =
*GetComputeCapability().cuda_compute_capability();
bool is_cudnn_enabled =
!config_.IsDeviceless() &&
GetDnnVersionInfoOrDefault(config_.GetExecutor()).major_version() >= 9 &&
@ -81,8 +82,8 @@ bool GemmFusionAutotunerImpl::AddLibConfigs(
std::vector<TritonGemmConfig> GemmFusionAutotunerImpl::GetDefaultTritonConfigs()
const {
auto compute_capability =
std::get<se::CudaComputeCapability>(GetComputeCapability());
stream_executor::CudaComputeCapability compute_capability =
*GetComputeCapability().cuda_compute_capability();
std::vector<TritonGemmConfig> configs;
if (compute_capability.IsAtLeastBlackwell()) {

View File

@ -197,13 +197,13 @@ class StatelessAutotunerTest : public HloTestBase {
SymbolicExprContext* symbolic_expr_context) {
const HloFusionInstruction& fusion = *Cast<HloFusionInstruction>(
module.entry_computation()->root_instruction());
if (!isRocm()) {
auto cu_compute_capability =
std::get<se::CudaComputeCapability>(compute_capability);
if (GpuComputeComp().IsCuda()) {
auto* cu_compute_capability =
compute_capability.cuda_compute_capability();
se::GpuDeviceInfoProto deviceless_proto;
auto ccc = deviceless_proto.mutable_cuda_compute_capability();
ccc->set_major(cu_compute_capability.major);
ccc->set_minor(cu_compute_capability.minor);
ccc->set_major(cu_compute_capability->major);
ccc->set_minor(cu_compute_capability->minor);
}
DeviceConfig test_config{backend().default_stream_executor(),
@ -237,10 +237,6 @@ class StatelessAutotunerTest : public HloTestBase {
.gpu_compute_capability();
}
bool isRocm() {
return std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp());
}
// Returns the config for the current device.
absl::StatusOr<std::vector<GemmFusionAutotunerImpl::BackendConfig>>
GetPossibleMatmulAutotuneConfigs(const HloModule& module) {
@ -321,7 +317,7 @@ TEST_F(StatelessAutotunerTest, CublasFallbackForBf16Bf16F32Algorithm) {
TF_ASSERT_OK_AND_ASSIGN(auto configs,
GetPossibleMatmulAutotuneConfigs(*module));
if (!isRocm()) {
if (!GpuComputeComp().IsRocm()) {
switch (GetCudaComputeCapability().major) {
case se::CudaComputeCapability::kAmpere:
EXPECT_TRUE(hasCublasConfig(configs))
@ -361,8 +357,8 @@ class GemmFusionAutotunerTest : public StatelessAutotunerTest {
}
stream_executor::GpuComputeCapability CudaAmpereOrRocm() {
if (isRocm()) {
return GetRocmComputeCapability();
if (GpuComputeComp().IsRocm()) {
return GpuComputeComp();
} else {
return stream_executor::GpuComputeCapability{
stream_executor::CudaComputeCapability{
@ -427,7 +423,8 @@ GetPossibleMatmulAutotuneTritonConfigs(
TF_ASSIGN_OR_RETURN(se::DeviceDescription device_description,
se::DeviceDescription::FromProto(
se::GpuDeviceInfoProto::default_instance()));
device_description.set_gpu_compute_capability(compute_capability);
device_description.set_gpu_compute_capability(
se::GpuComputeCapability{compute_capability});
// Using H100 numbers as the most relevant example here.
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications-technical-specifications-per-compute-capability
// https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/#nvidia_h100_gpu_architecture_in-depth
@ -446,7 +443,7 @@ GetPossibleMatmulAutotuneTritonConfigs(
}
TEST_F(GemmFusionAutotunerTest, AmpereUsesMoreThanTwoStages) {
if (isRocm()) {
if (GpuComputeComp().IsRocm()) {
GTEST_SKIP() << "Not supported on ROCm.";
}
std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
@ -625,7 +622,7 @@ ENTRY e {
// TODO(b/344770374): Make this test not fragile.
TEST_F(GemmFusionAutotunerTest, DoNotRunAutotuningKernelSpillingRegisters) {
if (isRocm()) {
if (GpuComputeComp().IsRocm()) {
GTEST_SKIP() << "Not supported on ROCm.";
}
const std::string kHloText = R"(
@ -808,11 +805,10 @@ ENTRY main {
TF_EXPECT_OK(HloTestBase::RunHloPass(&pipeline, module.get()));
const bool is_at_least_hopper =
std::holds_alternative<se::CudaComputeCapability>(
autotune_config.GetGpuComputeCapability()) &&
std::get<se::CudaComputeCapability>(
autotune_config.GetGpuComputeCapability())
.IsAtLeastHopper();
autotune_config.GetGpuComputeCapability().IsCuda() &&
autotune_config.GetGpuComputeCapability()
.cuda_compute_capability()
->IsAtLeastHopper();
TF_ASSERT_OK_AND_ASSIGN(
bool filecheck_matches,
RunFileCheck(module->ToString(), is_at_least_hopper
@ -822,8 +818,9 @@ ENTRY main {
}
TEST_F(GemmFusionAutotunerDumpTest, DumpingWorks) {
if (isRocm() || GetDebugOptionsForTest()
.xla_gpu_experimental_disable_binary_libraries()) {
if (GpuComputeComp().IsRocm() ||
GetDebugOptionsForTest()
.xla_gpu_experimental_disable_binary_libraries()) {
GTEST_SKIP() << "Not supported on ROCm or with binary libraries disabled.";
}
HloModuleConfig config;
@ -891,8 +888,9 @@ CHECK: cublas
}
TEST_F(GemmFusionAutotunerTest, AutotuneCuDnnFusion) {
if (isRocm() || GetDebugOptionsForTest()
.xla_gpu_experimental_disable_binary_libraries()) {
if (GpuComputeComp().IsRocm() ||
GetDebugOptionsForTest()
.xla_gpu_experimental_disable_binary_libraries()) {
GTEST_SKIP() << "Not supported on ROCm or with binary libraries disabled.";
}
const std::string kHlo = R"(
@ -1202,7 +1200,7 @@ ENTRY entry {
}
TEST_F(GemmFusionAutotunerTest, CreatesCustomKernelFusionConfigs) {
if (isRocm()) {
if (GpuComputeComp().IsRocm()) {
GTEST_SKIP() << "Not supported on ROCm.";
}
const std::string kHlo = R"(
@ -1241,7 +1239,7 @@ TEST_F(GemmFusionAutotunerTest, CreatesCustomKernelFusionConfigs) {
}
TEST_F(GemmFusionAutotunerTest, GeneratesTwoConfigsForUpcastGemmWithPrologue) {
if (isRocm()) {
if (GpuComputeComp().IsRocm()) {
GTEST_SKIP() << "Not supported on ROCm.";
}
const std::string kHlo = R"(
@ -1273,8 +1271,9 @@ TEST_F(GemmFusionAutotunerTest, GeneratesTwoConfigsForUpcastGemmWithPrologue) {
TF_ASSERT_OK_AND_ASSIGN(
const std::vector<GemmFusionAutotunerImpl::BackendConfig> configs,
GetPossibleMatmulAutotuneConfigs(
*module, compute_capability, GetToolkitVersion(),
GetDebugOptionsForTest(), &symbolic_expr_context_));
*module, se::GpuComputeCapability{compute_capability},
GetToolkitVersion(), GetDebugOptionsForTest(),
&symbolic_expr_context_));
EXPECT_EQ(
2, std::count_if(
configs.begin(), configs.end(),
@ -1287,7 +1286,7 @@ TEST_F(GemmFusionAutotunerTest, GeneratesTwoConfigsForUpcastGemmWithPrologue) {
TEST_F(GemmFusionAutotunerTest, GeneratesOneConfigForUpcastGemmWithPrologue) {
// Same as GeneratesTwoConfigsForUpcastGemmWithPrologue, but with contracting
// dimension size = 128 which is not supported by the SplitK kernel.
if (isRocm()) {
if (GpuComputeComp().IsRocm()) {
GTEST_SKIP() << "Not supported on ROCm.";
}
const std::string kHlo = R"(
@ -1319,8 +1318,9 @@ TEST_F(GemmFusionAutotunerTest, GeneratesOneConfigForUpcastGemmWithPrologue) {
TF_ASSERT_OK_AND_ASSIGN(
const std::vector<GemmFusionAutotunerImpl::BackendConfig> configs,
GetPossibleMatmulAutotuneConfigs(
*module, compute_capability, GetToolkitVersion(),
GetDebugOptionsForTest(), &symbolic_expr_context_));
*module, se::GpuComputeCapability{compute_capability},
GetToolkitVersion(), GetDebugOptionsForTest(),
&symbolic_expr_context_));
EXPECT_EQ(
1, std::count_if(
configs.begin(), configs.end(),
@ -1332,7 +1332,7 @@ TEST_F(GemmFusionAutotunerTest, GeneratesOneConfigForUpcastGemmWithPrologue) {
TEST_F(GemmFusionAutotunerTest,
GeneratesConfigForUpcastGemmWithPrologueAndEpilogue) {
if (isRocm()) {
if (GpuComputeComp().IsRocm()) {
GTEST_SKIP() << "Not supported on ROCm.";
}
const std::string kHlo = R"(
@ -1368,8 +1368,9 @@ TEST_F(GemmFusionAutotunerTest,
TF_ASSERT_OK_AND_ASSIGN(
const std::vector<GemmFusionAutotunerImpl::BackendConfig> configs,
GetPossibleMatmulAutotuneConfigs(
*module, compute_capability, GetToolkitVersion(),
GetDebugOptionsForTest(), &symbolic_expr_context_));
*module, se::GpuComputeCapability{compute_capability},
GetToolkitVersion(), GetDebugOptionsForTest(),
&symbolic_expr_context_));
EXPECT_EQ(
2, std::count_if(
configs.begin(), configs.end(),
@ -1486,7 +1487,7 @@ class GemmFusionShardedAutotunerTest : public GemmFusionAutotunerTest {
TEST_F(
GemmFusionShardedAutotunerTest,
AutotuningSucceedsWhenKeyValueStoreAlreadyContainsAutotuningResultsForTheInputModule) { // NOLINT(whitespace/line_length)
if (isRocm()) {
if (GpuComputeComp().IsRocm()) {
GTEST_SKIP() << "Not supported on ROCm.";
}
const std::string kHlo = R"(
@ -1545,7 +1546,7 @@ TEST_F(
TEST_F(
GemmFusionShardedAutotunerTest,
AutotuningStoresDifferentResultsForTheSameFusionInDifferentModules) { // NOLINT(whitespace/line_length)
if (isRocm()) {
if (GpuComputeComp().IsRocm()) {
GTEST_SKIP() << "Not supported on ROCm.";
}
const std::string kHlo1 = R"(
@ -1626,7 +1627,7 @@ TEST_F(
}
TEST_F(GemmFusionAutotunerTest, RewritesGemmFusionToCustomKernelFusion) {
if (isRocm()) {
if (GpuComputeComp().IsRocm()) {
GTEST_SKIP() << "Not supported on ROCm.";
}
const std::string kHlo = R"(
@ -1711,7 +1712,7 @@ ENTRY e {
}
TEST_F(GemmFusionAutotunerTest, VerifyHopperConfigsAreDifferentFromBlackwell) {
if (isRocm()) {
if (GpuComputeComp().IsRocm()) {
GTEST_SKIP() << "Not supported on ROCm.";
}
@ -1752,7 +1753,7 @@ TEST_F(GemmFusionAutotunerTest, VerifyHopperConfigsAreDifferentFromBlackwell) {
}
TEST_F(GemmFusionAutotunerTest, ScaledDotConfigsAreGenerated) {
if (isRocm()) {
if (GpuComputeComp().IsRocm()) {
GTEST_SKIP() << "Not supported on ROCm.";
}
@ -1781,7 +1782,7 @@ TEST_F(GemmFusionAutotunerTest, ScaledDotConfigsAreGenerated) {
}
TEST_F(GemmFusionAutotunerTest, ScaledDotConfigsHaveCuBlasFallback) {
if (isRocm()) {
if (GpuComputeComp().IsRocm()) {
GTEST_SKIP() << "Not supported on ROCm.";
}
@ -1829,7 +1830,7 @@ class GemmFusionAutotunerEnableTma : public GemmFusionAutotunerTest {
TEST_F(GemmFusionAutotunerEnableTma,
TmaConfigsAreGeneratedOnlyForHopperAndWorkCorrectly) {
if (isRocm()) {
if (GpuComputeComp().IsRocm()) {
GTEST_SKIP() << "Not supported on ROCm.";
}
@ -1885,7 +1886,7 @@ TEST_F(GemmFusionAutotunerEnableTma,
TEST_F(GemmFusionAutotunerEnableTma,
TmaConfigsGeneratedAndRunCorrectlyForDotsOfBroadcasts) {
if (isRocm()) {
if (GpuComputeComp().IsRocm()) {
GTEST_SKIP() << "Not supported on ROCm.";
}

View File

@ -78,11 +78,11 @@ static const std::initializer_list<absl::string_view> kf16f32{"f16", "f32"};
class CudnnFusedConvRewriterHloTest : public HloTestBase {
public:
bool IsCuda() const {
return std::holds_alternative<se::CudaComputeCapability>(
backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability());
return backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability()
.IsCuda();
}
se::CudaComputeCapability GetCudaComputeCapability() const {
return backend()
@ -119,11 +119,11 @@ class CudnnFusedConvRewriterHloTest : public HloTestBase {
class CudnnFusedConvRewriterTest : public GpuCodegenTest {
public:
bool IsCuda() const {
return std::holds_alternative<se::CudaComputeCapability>(
backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability());
return backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability()
.IsCuda();
}
se::CudaComputeCapability GetCudaComputeCapability() const {
return backend()

View File

@ -646,7 +646,9 @@ ENTRY entry {
hlo_module->entry_computation()->ComputeProgramShape());
GpuLayoutAssignment layout_assignment(
&computation_layout, se::RocmComputeCapability::EarliestRDNASupport(),
&computation_layout,
se::GpuComputeCapability{
se::RocmComputeCapability::EarliestRDNASupport()},
GetDnnVersion(), GetDeviceDescription());
EXPECT_THAT(layout_assignment.Run(hlo_module.get()),
@ -683,7 +685,9 @@ ENTRY entry {
hlo_module->entry_computation()->ComputeProgramShape());
GpuLayoutAssignment layout_assignment(
&computation_layout, se::RocmComputeCapability::EarliestRDNASupport(),
&computation_layout,
se::GpuComputeCapability{
se::RocmComputeCapability::EarliestRDNASupport()},
GetDnnVersion(), GetDeviceDescription());
EXPECT_THAT(layout_assignment.Run(hlo_module.get()),
@ -723,7 +727,9 @@ ENTRY entry {
hlo_module->entry_computation()->ComputeProgramShape());
GpuLayoutAssignment layout_assignment(
&computation_layout, se::RocmComputeCapability::EarliestCDNASupport(),
&computation_layout,
se::GpuComputeCapability{
se::RocmComputeCapability::EarliestCDNASupport()},
GetDnnVersion(), GetDeviceDescription());
EXPECT_THAT(layout_assignment.Run(hlo_module.get()),
@ -763,7 +769,9 @@ ENTRY entry {
hlo_module->entry_computation()->ComputeProgramShape());
GpuLayoutAssignment layout_assignment(
&computation_layout, se::RocmComputeCapability::EarliestCDNASupport(),
&computation_layout,
se::GpuComputeCapability{
se::RocmComputeCapability::EarliestCDNASupport()},
GetDnnVersion(), GetDeviceDescription());
EXPECT_THAT(layout_assignment.Run(hlo_module.get()),

View File

@ -69,9 +69,8 @@ TEST(CudaExecutorTest, CreateDeviceDescription) {
EXPECT_THAT(result->model_str(), Not(IsEmpty()));
EXPECT_THAT(result->device_vendor(), "NVIDIA Corporation");
EXPECT_THAT(result->gpu_compute_capability(),
VariantWith<CudaComputeCapability>(::testing::Field(
"major", &CudaComputeCapability::major, Ge(1))));
EXPECT_THAT(*result->gpu_compute_capability().cuda_compute_capability(),
::testing::Field("major", &CudaComputeCapability::major, Ge(1)));
}
TEST(CudaExecutorTest, GetCudaKernel) {

View File

@ -68,12 +68,10 @@ absl::StatusOr<DeviceDescription> DeviceDescription::FromProto(
GpuDeviceInfoProto DeviceDescription::ToGpuProto() const {
stream_executor::GpuDeviceInfoProto proto;
if (auto* ptr = std::get_if<stream_executor::CudaComputeCapability>(
&gpu_compute_capability_)) {
if (auto* ptr = gpu_compute_capability_.cuda_compute_capability()) {
*proto.mutable_cuda_compute_capability() = ptr->ToProto();
}
if (auto* ptr = std::get_if<stream_executor::RocmComputeCapability>(
&gpu_compute_capability_)) {
if (auto* ptr = gpu_compute_capability_.rocm_compute_capability()) {
*proto.mutable_rocm_compute_capability() = ptr->ToProto();
}
@ -106,8 +104,7 @@ const GpuComputeCapability &DeviceDescription::gpu_compute_capability() const {
}
CudaComputeCapability DeviceDescription::cuda_compute_capability() const {
if (auto *ptr =
std::get_if<CudaComputeCapability>(&gpu_compute_capability_)) {
if (auto* ptr = gpu_compute_capability_.cuda_compute_capability()) {
return *ptr;
}
// Fallback for backwards compatibility.
@ -115,8 +112,7 @@ CudaComputeCapability DeviceDescription::cuda_compute_capability() const {
}
RocmComputeCapability DeviceDescription::rocm_compute_capability() const {
if (auto *ptr =
std::get_if<RocmComputeCapability>(&gpu_compute_capability_)) {
if (auto* ptr = gpu_compute_capability_.rocm_compute_capability()) {
return *ptr;
}
return RocmComputeCapability{};

View File

@ -462,8 +462,8 @@ absl::StatusOr<TmaMetadata> TmaMetadata::FromProto(
bool IsTmaAvailableForDevice(
const stream_executor::DeviceDescription& device_info) {
if (auto* cuda_cc = std::get_if<stream_executor::CudaComputeCapability>(
&device_info.gpu_compute_capability())) {
if (auto* cuda_cc =
device_info.gpu_compute_capability().cuda_compute_capability()) {
return cuda_cc->IsAtLeastHopper();
}
return false;

View File

@ -51,7 +51,7 @@ class RocmComputeCapability {
return RocmComputeCapability{"gfx1030"};
}
std::string gcn_arch_name() const { return gcn_arch_name_; }
const std::string& gcn_arch_name() const { return gcn_arch_name_; }
std::string ToString() const { return gcn_arch_name(); }
@ -198,7 +198,7 @@ class RocmComputeCapability {
template <typename... ArrayOfStrings>
bool IsThisGfxInAnyList(ArrayOfStrings&&... arr) const {
static_assert(sizeof...(arr) >= 1);
const auto gfx = gfx_version();
const std::string gfx = gfx_version();
return (implIsThisGfxInAnyList(std::begin(arr), std::end(arr), gfx) || ...);
}
@ -206,10 +206,9 @@ class RocmComputeCapability {
/// \warning Don't use directly!
bool implIsThisGfxInAnyList(const absl::string_view* beg,
const absl::string_view* end,
const std::string& gfx) const {
return std::any_of(beg, end, [&gfx = gfx](const absl::string_view& s) {
return gfx == s;
});
const absl::string_view gfx) const {
return std::any_of(
beg, end, [&gfx = gfx](const absl::string_view s) { return gfx == s; });
}
std::string gcn_arch_name_{kInvalidGfx}; // default to invalid arch.