mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Port to new GpuComputeCapability API. Last part
PiperOrigin-RevId: 822676102
This commit is contained in:
parent
3503a61282
commit
1b08f96abf
|
|
@ -85,8 +85,8 @@ se::StreamExecutor* GpuExecutor() {
|
|||
|
||||
bool IsAtLeastCuda12900(const se::StreamExecutor* stream_executor) {
|
||||
const auto& device_description = stream_executor->GetDeviceDescription();
|
||||
const auto* cuda_cc = std::get_if<se::CudaComputeCapability>(
|
||||
&device_description.gpu_compute_capability());
|
||||
const auto* cuda_cc =
|
||||
device_description.gpu_compute_capability().cuda_compute_capability();
|
||||
if (cuda_cc != nullptr) {
|
||||
if (device_description.driver_version() >=
|
||||
stream_executor::SemanticVersion(12, 9, 0) &&
|
||||
|
|
|
|||
|
|
@ -55,12 +55,12 @@ std::string AutotuneCacheKey::HloInstructionToCanonicalString(
|
|||
std::string AutotuneCacheKey::DeviceDescriptionToCacheKey(
|
||||
const se::DeviceDescription& device_description) {
|
||||
std::string compute_capability;
|
||||
if (auto* ccc = std::get_if<se::CudaComputeCapability>(
|
||||
&device_description.gpu_compute_capability())) {
|
||||
if (auto* ccc = device_description.gpu_compute_capability()
|
||||
.cuda_compute_capability()) {
|
||||
compute_capability = absl::StrCat("CUDA: ", ccc->major, ".", ccc->minor);
|
||||
} else {
|
||||
auto* rcc = std::get_if<se::RocmComputeCapability>(
|
||||
&device_description.gpu_compute_capability());
|
||||
auto* rcc =
|
||||
device_description.gpu_compute_capability().rocm_compute_capability();
|
||||
CHECK(rcc != nullptr) << "Unknown compute capability type";
|
||||
compute_capability = absl::StrCat("ROCM: ", rcc->gfx_version());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -130,9 +130,8 @@ ENTRY main {
|
|||
// Algorithm 14 is disabled for cuDNN 9 on V100
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto dnn_version, GetDnnVersionInfo(stream_exec));
|
||||
if (dnn_version.major_version() >= 9 && dnn_version.major_version() < 10 &&
|
||||
std::holds_alternative<stream_executor::CudaComputeCapability>(cc) &&
|
||||
std::get<stream_executor::CudaComputeCapability>(cc).major == 7 &&
|
||||
std::get<stream_executor::CudaComputeCapability>(cc).minor == 0) {
|
||||
cc.IsCuda() && cc.cuda_compute_capability()->major == 7 &&
|
||||
cc.cuda_compute_capability()->minor == 0) {
|
||||
EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
|
||||
->has_cudnn_conv_backend_config() &&
|
||||
conv->backend_config<GpuBackendConfig>()
|
||||
|
|
|
|||
|
|
@ -201,11 +201,6 @@ class GemmFusionAutotunerImpl {
|
|||
return config_.GetGpuComputeCapability();
|
||||
}
|
||||
|
||||
bool isRocm() const {
|
||||
return std::holds_alternative<se::RocmComputeCapability>(
|
||||
GetComputeCapability());
|
||||
}
|
||||
|
||||
bool AddLibConfigs(const HloFusionInstruction& fusion,
|
||||
const HloDotInstruction* dot,
|
||||
std::vector<BackendConfig>& configs);
|
||||
|
|
|
|||
|
|
@ -52,7 +52,8 @@ bool GemmFusionAutotunerImpl::AddLibConfigs(
|
|||
const HloFusionInstruction& fusion, const HloDotInstruction* dot,
|
||||
std::vector<BackendConfig>& configs) {
|
||||
// Add cuDNN plans, if available.
|
||||
auto cc = std::get<se::CudaComputeCapability>(GetComputeCapability());
|
||||
stream_executor::CudaComputeCapability cc =
|
||||
*GetComputeCapability().cuda_compute_capability();
|
||||
bool is_cudnn_enabled =
|
||||
!config_.IsDeviceless() &&
|
||||
GetDnnVersionInfoOrDefault(config_.GetExecutor()).major_version() >= 9 &&
|
||||
|
|
@ -81,8 +82,8 @@ bool GemmFusionAutotunerImpl::AddLibConfigs(
|
|||
|
||||
std::vector<TritonGemmConfig> GemmFusionAutotunerImpl::GetDefaultTritonConfigs()
|
||||
const {
|
||||
auto compute_capability =
|
||||
std::get<se::CudaComputeCapability>(GetComputeCapability());
|
||||
stream_executor::CudaComputeCapability compute_capability =
|
||||
*GetComputeCapability().cuda_compute_capability();
|
||||
std::vector<TritonGemmConfig> configs;
|
||||
|
||||
if (compute_capability.IsAtLeastBlackwell()) {
|
||||
|
|
|
|||
|
|
@ -197,13 +197,13 @@ class StatelessAutotunerTest : public HloTestBase {
|
|||
SymbolicExprContext* symbolic_expr_context) {
|
||||
const HloFusionInstruction& fusion = *Cast<HloFusionInstruction>(
|
||||
module.entry_computation()->root_instruction());
|
||||
if (!isRocm()) {
|
||||
auto cu_compute_capability =
|
||||
std::get<se::CudaComputeCapability>(compute_capability);
|
||||
if (GpuComputeComp().IsCuda()) {
|
||||
auto* cu_compute_capability =
|
||||
compute_capability.cuda_compute_capability();
|
||||
se::GpuDeviceInfoProto deviceless_proto;
|
||||
auto ccc = deviceless_proto.mutable_cuda_compute_capability();
|
||||
ccc->set_major(cu_compute_capability.major);
|
||||
ccc->set_minor(cu_compute_capability.minor);
|
||||
ccc->set_major(cu_compute_capability->major);
|
||||
ccc->set_minor(cu_compute_capability->minor);
|
||||
}
|
||||
|
||||
DeviceConfig test_config{backend().default_stream_executor(),
|
||||
|
|
@ -237,10 +237,6 @@ class StatelessAutotunerTest : public HloTestBase {
|
|||
.gpu_compute_capability();
|
||||
}
|
||||
|
||||
bool isRocm() {
|
||||
return std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp());
|
||||
}
|
||||
|
||||
// Returns the config for the current device.
|
||||
absl::StatusOr<std::vector<GemmFusionAutotunerImpl::BackendConfig>>
|
||||
GetPossibleMatmulAutotuneConfigs(const HloModule& module) {
|
||||
|
|
@ -321,7 +317,7 @@ TEST_F(StatelessAutotunerTest, CublasFallbackForBf16Bf16F32Algorithm) {
|
|||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto configs,
|
||||
GetPossibleMatmulAutotuneConfigs(*module));
|
||||
if (!isRocm()) {
|
||||
if (!GpuComputeComp().IsRocm()) {
|
||||
switch (GetCudaComputeCapability().major) {
|
||||
case se::CudaComputeCapability::kAmpere:
|
||||
EXPECT_TRUE(hasCublasConfig(configs))
|
||||
|
|
@ -361,8 +357,8 @@ class GemmFusionAutotunerTest : public StatelessAutotunerTest {
|
|||
}
|
||||
|
||||
stream_executor::GpuComputeCapability CudaAmpereOrRocm() {
|
||||
if (isRocm()) {
|
||||
return GetRocmComputeCapability();
|
||||
if (GpuComputeComp().IsRocm()) {
|
||||
return GpuComputeComp();
|
||||
} else {
|
||||
return stream_executor::GpuComputeCapability{
|
||||
stream_executor::CudaComputeCapability{
|
||||
|
|
@ -427,7 +423,8 @@ GetPossibleMatmulAutotuneTritonConfigs(
|
|||
TF_ASSIGN_OR_RETURN(se::DeviceDescription device_description,
|
||||
se::DeviceDescription::FromProto(
|
||||
se::GpuDeviceInfoProto::default_instance()));
|
||||
device_description.set_gpu_compute_capability(compute_capability);
|
||||
device_description.set_gpu_compute_capability(
|
||||
se::GpuComputeCapability{compute_capability});
|
||||
// Using H100 numbers as the most relevant example here.
|
||||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications-technical-specifications-per-compute-capability
|
||||
// https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/#nvidia_h100_gpu_architecture_in-depth
|
||||
|
|
@ -446,7 +443,7 @@ GetPossibleMatmulAutotuneTritonConfigs(
|
|||
}
|
||||
|
||||
TEST_F(GemmFusionAutotunerTest, AmpereUsesMoreThanTwoStages) {
|
||||
if (isRocm()) {
|
||||
if (GpuComputeComp().IsRocm()) {
|
||||
GTEST_SKIP() << "Not supported on ROCm.";
|
||||
}
|
||||
std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
|
||||
|
|
@ -625,7 +622,7 @@ ENTRY e {
|
|||
|
||||
// TODO(b/344770374): Make this test not fragile.
|
||||
TEST_F(GemmFusionAutotunerTest, DoNotRunAutotuningKernelSpillingRegisters) {
|
||||
if (isRocm()) {
|
||||
if (GpuComputeComp().IsRocm()) {
|
||||
GTEST_SKIP() << "Not supported on ROCm.";
|
||||
}
|
||||
const std::string kHloText = R"(
|
||||
|
|
@ -808,11 +805,10 @@ ENTRY main {
|
|||
|
||||
TF_EXPECT_OK(HloTestBase::RunHloPass(&pipeline, module.get()));
|
||||
const bool is_at_least_hopper =
|
||||
std::holds_alternative<se::CudaComputeCapability>(
|
||||
autotune_config.GetGpuComputeCapability()) &&
|
||||
std::get<se::CudaComputeCapability>(
|
||||
autotune_config.GetGpuComputeCapability())
|
||||
.IsAtLeastHopper();
|
||||
autotune_config.GetGpuComputeCapability().IsCuda() &&
|
||||
autotune_config.GetGpuComputeCapability()
|
||||
.cuda_compute_capability()
|
||||
->IsAtLeastHopper();
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool filecheck_matches,
|
||||
RunFileCheck(module->ToString(), is_at_least_hopper
|
||||
|
|
@ -822,8 +818,9 @@ ENTRY main {
|
|||
}
|
||||
|
||||
TEST_F(GemmFusionAutotunerDumpTest, DumpingWorks) {
|
||||
if (isRocm() || GetDebugOptionsForTest()
|
||||
.xla_gpu_experimental_disable_binary_libraries()) {
|
||||
if (GpuComputeComp().IsRocm() ||
|
||||
GetDebugOptionsForTest()
|
||||
.xla_gpu_experimental_disable_binary_libraries()) {
|
||||
GTEST_SKIP() << "Not supported on ROCm or with binary libraries disabled.";
|
||||
}
|
||||
HloModuleConfig config;
|
||||
|
|
@ -891,8 +888,9 @@ CHECK: cublas
|
|||
}
|
||||
|
||||
TEST_F(GemmFusionAutotunerTest, AutotuneCuDnnFusion) {
|
||||
if (isRocm() || GetDebugOptionsForTest()
|
||||
.xla_gpu_experimental_disable_binary_libraries()) {
|
||||
if (GpuComputeComp().IsRocm() ||
|
||||
GetDebugOptionsForTest()
|
||||
.xla_gpu_experimental_disable_binary_libraries()) {
|
||||
GTEST_SKIP() << "Not supported on ROCm or with binary libraries disabled.";
|
||||
}
|
||||
const std::string kHlo = R"(
|
||||
|
|
@ -1202,7 +1200,7 @@ ENTRY entry {
|
|||
}
|
||||
|
||||
TEST_F(GemmFusionAutotunerTest, CreatesCustomKernelFusionConfigs) {
|
||||
if (isRocm()) {
|
||||
if (GpuComputeComp().IsRocm()) {
|
||||
GTEST_SKIP() << "Not supported on ROCm.";
|
||||
}
|
||||
const std::string kHlo = R"(
|
||||
|
|
@ -1241,7 +1239,7 @@ TEST_F(GemmFusionAutotunerTest, CreatesCustomKernelFusionConfigs) {
|
|||
}
|
||||
|
||||
TEST_F(GemmFusionAutotunerTest, GeneratesTwoConfigsForUpcastGemmWithPrologue) {
|
||||
if (isRocm()) {
|
||||
if (GpuComputeComp().IsRocm()) {
|
||||
GTEST_SKIP() << "Not supported on ROCm.";
|
||||
}
|
||||
const std::string kHlo = R"(
|
||||
|
|
@ -1273,8 +1271,9 @@ TEST_F(GemmFusionAutotunerTest, GeneratesTwoConfigsForUpcastGemmWithPrologue) {
|
|||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
const std::vector<GemmFusionAutotunerImpl::BackendConfig> configs,
|
||||
GetPossibleMatmulAutotuneConfigs(
|
||||
*module, compute_capability, GetToolkitVersion(),
|
||||
GetDebugOptionsForTest(), &symbolic_expr_context_));
|
||||
*module, se::GpuComputeCapability{compute_capability},
|
||||
GetToolkitVersion(), GetDebugOptionsForTest(),
|
||||
&symbolic_expr_context_));
|
||||
EXPECT_EQ(
|
||||
2, std::count_if(
|
||||
configs.begin(), configs.end(),
|
||||
|
|
@ -1287,7 +1286,7 @@ TEST_F(GemmFusionAutotunerTest, GeneratesTwoConfigsForUpcastGemmWithPrologue) {
|
|||
TEST_F(GemmFusionAutotunerTest, GeneratesOneConfigForUpcastGemmWithPrologue) {
|
||||
// Same as GeneratesTwoConfigsForUpcastGemmWithPrologue, but with contracting
|
||||
// dimension size = 128 which is not supported by the SplitK kernel.
|
||||
if (isRocm()) {
|
||||
if (GpuComputeComp().IsRocm()) {
|
||||
GTEST_SKIP() << "Not supported on ROCm.";
|
||||
}
|
||||
const std::string kHlo = R"(
|
||||
|
|
@ -1319,8 +1318,9 @@ TEST_F(GemmFusionAutotunerTest, GeneratesOneConfigForUpcastGemmWithPrologue) {
|
|||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
const std::vector<GemmFusionAutotunerImpl::BackendConfig> configs,
|
||||
GetPossibleMatmulAutotuneConfigs(
|
||||
*module, compute_capability, GetToolkitVersion(),
|
||||
GetDebugOptionsForTest(), &symbolic_expr_context_));
|
||||
*module, se::GpuComputeCapability{compute_capability},
|
||||
GetToolkitVersion(), GetDebugOptionsForTest(),
|
||||
&symbolic_expr_context_));
|
||||
EXPECT_EQ(
|
||||
1, std::count_if(
|
||||
configs.begin(), configs.end(),
|
||||
|
|
@ -1332,7 +1332,7 @@ TEST_F(GemmFusionAutotunerTest, GeneratesOneConfigForUpcastGemmWithPrologue) {
|
|||
|
||||
TEST_F(GemmFusionAutotunerTest,
|
||||
GeneratesConfigForUpcastGemmWithPrologueAndEpilogue) {
|
||||
if (isRocm()) {
|
||||
if (GpuComputeComp().IsRocm()) {
|
||||
GTEST_SKIP() << "Not supported on ROCm.";
|
||||
}
|
||||
const std::string kHlo = R"(
|
||||
|
|
@ -1368,8 +1368,9 @@ TEST_F(GemmFusionAutotunerTest,
|
|||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
const std::vector<GemmFusionAutotunerImpl::BackendConfig> configs,
|
||||
GetPossibleMatmulAutotuneConfigs(
|
||||
*module, compute_capability, GetToolkitVersion(),
|
||||
GetDebugOptionsForTest(), &symbolic_expr_context_));
|
||||
*module, se::GpuComputeCapability{compute_capability},
|
||||
GetToolkitVersion(), GetDebugOptionsForTest(),
|
||||
&symbolic_expr_context_));
|
||||
EXPECT_EQ(
|
||||
2, std::count_if(
|
||||
configs.begin(), configs.end(),
|
||||
|
|
@ -1486,7 +1487,7 @@ class GemmFusionShardedAutotunerTest : public GemmFusionAutotunerTest {
|
|||
TEST_F(
|
||||
GemmFusionShardedAutotunerTest,
|
||||
AutotuningSucceedsWhenKeyValueStoreAlreadyContainsAutotuningResultsForTheInputModule) { // NOLINT(whitespace/line_length)
|
||||
if (isRocm()) {
|
||||
if (GpuComputeComp().IsRocm()) {
|
||||
GTEST_SKIP() << "Not supported on ROCm.";
|
||||
}
|
||||
const std::string kHlo = R"(
|
||||
|
|
@ -1545,7 +1546,7 @@ TEST_F(
|
|||
TEST_F(
|
||||
GemmFusionShardedAutotunerTest,
|
||||
AutotuningStoresDifferentResultsForTheSameFusionInDifferentModules) { // NOLINT(whitespace/line_length)
|
||||
if (isRocm()) {
|
||||
if (GpuComputeComp().IsRocm()) {
|
||||
GTEST_SKIP() << "Not supported on ROCm.";
|
||||
}
|
||||
const std::string kHlo1 = R"(
|
||||
|
|
@ -1626,7 +1627,7 @@ TEST_F(
|
|||
}
|
||||
|
||||
TEST_F(GemmFusionAutotunerTest, RewritesGemmFusionToCustomKernelFusion) {
|
||||
if (isRocm()) {
|
||||
if (GpuComputeComp().IsRocm()) {
|
||||
GTEST_SKIP() << "Not supported on ROCm.";
|
||||
}
|
||||
const std::string kHlo = R"(
|
||||
|
|
@ -1711,7 +1712,7 @@ ENTRY e {
|
|||
}
|
||||
|
||||
TEST_F(GemmFusionAutotunerTest, VerifyHopperConfigsAreDifferentFromBlackwell) {
|
||||
if (isRocm()) {
|
||||
if (GpuComputeComp().IsRocm()) {
|
||||
GTEST_SKIP() << "Not supported on ROCm.";
|
||||
}
|
||||
|
||||
|
|
@ -1752,7 +1753,7 @@ TEST_F(GemmFusionAutotunerTest, VerifyHopperConfigsAreDifferentFromBlackwell) {
|
|||
}
|
||||
|
||||
TEST_F(GemmFusionAutotunerTest, ScaledDotConfigsAreGenerated) {
|
||||
if (isRocm()) {
|
||||
if (GpuComputeComp().IsRocm()) {
|
||||
GTEST_SKIP() << "Not supported on ROCm.";
|
||||
}
|
||||
|
||||
|
|
@ -1781,7 +1782,7 @@ TEST_F(GemmFusionAutotunerTest, ScaledDotConfigsAreGenerated) {
|
|||
}
|
||||
|
||||
TEST_F(GemmFusionAutotunerTest, ScaledDotConfigsHaveCuBlasFallback) {
|
||||
if (isRocm()) {
|
||||
if (GpuComputeComp().IsRocm()) {
|
||||
GTEST_SKIP() << "Not supported on ROCm.";
|
||||
}
|
||||
|
||||
|
|
@ -1829,7 +1830,7 @@ class GemmFusionAutotunerEnableTma : public GemmFusionAutotunerTest {
|
|||
|
||||
TEST_F(GemmFusionAutotunerEnableTma,
|
||||
TmaConfigsAreGeneratedOnlyForHopperAndWorkCorrectly) {
|
||||
if (isRocm()) {
|
||||
if (GpuComputeComp().IsRocm()) {
|
||||
GTEST_SKIP() << "Not supported on ROCm.";
|
||||
}
|
||||
|
||||
|
|
@ -1885,7 +1886,7 @@ TEST_F(GemmFusionAutotunerEnableTma,
|
|||
|
||||
TEST_F(GemmFusionAutotunerEnableTma,
|
||||
TmaConfigsGeneratedAndRunCorrectlyForDotsOfBroadcasts) {
|
||||
if (isRocm()) {
|
||||
if (GpuComputeComp().IsRocm()) {
|
||||
GTEST_SKIP() << "Not supported on ROCm.";
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -78,11 +78,11 @@ static const std::initializer_list<absl::string_view> kf16f32{"f16", "f32"};
|
|||
class CudnnFusedConvRewriterHloTest : public HloTestBase {
|
||||
public:
|
||||
bool IsCuda() const {
|
||||
return std::holds_alternative<se::CudaComputeCapability>(
|
||||
backend()
|
||||
.default_stream_executor()
|
||||
->GetDeviceDescription()
|
||||
.gpu_compute_capability());
|
||||
return backend()
|
||||
.default_stream_executor()
|
||||
->GetDeviceDescription()
|
||||
.gpu_compute_capability()
|
||||
.IsCuda();
|
||||
}
|
||||
se::CudaComputeCapability GetCudaComputeCapability() const {
|
||||
return backend()
|
||||
|
|
@ -119,11 +119,11 @@ class CudnnFusedConvRewriterHloTest : public HloTestBase {
|
|||
class CudnnFusedConvRewriterTest : public GpuCodegenTest {
|
||||
public:
|
||||
bool IsCuda() const {
|
||||
return std::holds_alternative<se::CudaComputeCapability>(
|
||||
backend()
|
||||
.default_stream_executor()
|
||||
->GetDeviceDescription()
|
||||
.gpu_compute_capability());
|
||||
return backend()
|
||||
.default_stream_executor()
|
||||
->GetDeviceDescription()
|
||||
.gpu_compute_capability()
|
||||
.IsCuda();
|
||||
}
|
||||
se::CudaComputeCapability GetCudaComputeCapability() const {
|
||||
return backend()
|
||||
|
|
|
|||
|
|
@ -646,7 +646,9 @@ ENTRY entry {
|
|||
hlo_module->entry_computation()->ComputeProgramShape());
|
||||
|
||||
GpuLayoutAssignment layout_assignment(
|
||||
&computation_layout, se::RocmComputeCapability::EarliestRDNASupport(),
|
||||
&computation_layout,
|
||||
se::GpuComputeCapability{
|
||||
se::RocmComputeCapability::EarliestRDNASupport()},
|
||||
GetDnnVersion(), GetDeviceDescription());
|
||||
|
||||
EXPECT_THAT(layout_assignment.Run(hlo_module.get()),
|
||||
|
|
@ -683,7 +685,9 @@ ENTRY entry {
|
|||
hlo_module->entry_computation()->ComputeProgramShape());
|
||||
|
||||
GpuLayoutAssignment layout_assignment(
|
||||
&computation_layout, se::RocmComputeCapability::EarliestRDNASupport(),
|
||||
&computation_layout,
|
||||
se::GpuComputeCapability{
|
||||
se::RocmComputeCapability::EarliestRDNASupport()},
|
||||
GetDnnVersion(), GetDeviceDescription());
|
||||
|
||||
EXPECT_THAT(layout_assignment.Run(hlo_module.get()),
|
||||
|
|
@ -723,7 +727,9 @@ ENTRY entry {
|
|||
hlo_module->entry_computation()->ComputeProgramShape());
|
||||
|
||||
GpuLayoutAssignment layout_assignment(
|
||||
&computation_layout, se::RocmComputeCapability::EarliestCDNASupport(),
|
||||
&computation_layout,
|
||||
se::GpuComputeCapability{
|
||||
se::RocmComputeCapability::EarliestCDNASupport()},
|
||||
GetDnnVersion(), GetDeviceDescription());
|
||||
|
||||
EXPECT_THAT(layout_assignment.Run(hlo_module.get()),
|
||||
|
|
@ -763,7 +769,9 @@ ENTRY entry {
|
|||
hlo_module->entry_computation()->ComputeProgramShape());
|
||||
|
||||
GpuLayoutAssignment layout_assignment(
|
||||
&computation_layout, se::RocmComputeCapability::EarliestCDNASupport(),
|
||||
&computation_layout,
|
||||
se::GpuComputeCapability{
|
||||
se::RocmComputeCapability::EarliestCDNASupport()},
|
||||
GetDnnVersion(), GetDeviceDescription());
|
||||
|
||||
EXPECT_THAT(layout_assignment.Run(hlo_module.get()),
|
||||
|
|
|
|||
|
|
@ -69,9 +69,8 @@ TEST(CudaExecutorTest, CreateDeviceDescription) {
|
|||
EXPECT_THAT(result->model_str(), Not(IsEmpty()));
|
||||
EXPECT_THAT(result->device_vendor(), "NVIDIA Corporation");
|
||||
|
||||
EXPECT_THAT(result->gpu_compute_capability(),
|
||||
VariantWith<CudaComputeCapability>(::testing::Field(
|
||||
"major", &CudaComputeCapability::major, Ge(1))));
|
||||
EXPECT_THAT(*result->gpu_compute_capability().cuda_compute_capability(),
|
||||
::testing::Field("major", &CudaComputeCapability::major, Ge(1)));
|
||||
}
|
||||
|
||||
TEST(CudaExecutorTest, GetCudaKernel) {
|
||||
|
|
|
|||
|
|
@ -68,12 +68,10 @@ absl::StatusOr<DeviceDescription> DeviceDescription::FromProto(
|
|||
|
||||
GpuDeviceInfoProto DeviceDescription::ToGpuProto() const {
|
||||
stream_executor::GpuDeviceInfoProto proto;
|
||||
if (auto* ptr = std::get_if<stream_executor::CudaComputeCapability>(
|
||||
&gpu_compute_capability_)) {
|
||||
if (auto* ptr = gpu_compute_capability_.cuda_compute_capability()) {
|
||||
*proto.mutable_cuda_compute_capability() = ptr->ToProto();
|
||||
}
|
||||
if (auto* ptr = std::get_if<stream_executor::RocmComputeCapability>(
|
||||
&gpu_compute_capability_)) {
|
||||
if (auto* ptr = gpu_compute_capability_.rocm_compute_capability()) {
|
||||
*proto.mutable_rocm_compute_capability() = ptr->ToProto();
|
||||
}
|
||||
|
||||
|
|
@ -106,8 +104,7 @@ const GpuComputeCapability &DeviceDescription::gpu_compute_capability() const {
|
|||
}
|
||||
|
||||
CudaComputeCapability DeviceDescription::cuda_compute_capability() const {
|
||||
if (auto *ptr =
|
||||
std::get_if<CudaComputeCapability>(&gpu_compute_capability_)) {
|
||||
if (auto* ptr = gpu_compute_capability_.cuda_compute_capability()) {
|
||||
return *ptr;
|
||||
}
|
||||
// Fallback for backwards compatibility.
|
||||
|
|
@ -115,8 +112,7 @@ CudaComputeCapability DeviceDescription::cuda_compute_capability() const {
|
|||
}
|
||||
|
||||
RocmComputeCapability DeviceDescription::rocm_compute_capability() const {
|
||||
if (auto *ptr =
|
||||
std::get_if<RocmComputeCapability>(&gpu_compute_capability_)) {
|
||||
if (auto* ptr = gpu_compute_capability_.rocm_compute_capability()) {
|
||||
return *ptr;
|
||||
}
|
||||
return RocmComputeCapability{};
|
||||
|
|
|
|||
|
|
@ -462,8 +462,8 @@ absl::StatusOr<TmaMetadata> TmaMetadata::FromProto(
|
|||
|
||||
bool IsTmaAvailableForDevice(
|
||||
const stream_executor::DeviceDescription& device_info) {
|
||||
if (auto* cuda_cc = std::get_if<stream_executor::CudaComputeCapability>(
|
||||
&device_info.gpu_compute_capability())) {
|
||||
if (auto* cuda_cc =
|
||||
device_info.gpu_compute_capability().cuda_compute_capability()) {
|
||||
return cuda_cc->IsAtLeastHopper();
|
||||
}
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class RocmComputeCapability {
|
|||
return RocmComputeCapability{"gfx1030"};
|
||||
}
|
||||
|
||||
std::string gcn_arch_name() const { return gcn_arch_name_; }
|
||||
const std::string& gcn_arch_name() const { return gcn_arch_name_; }
|
||||
|
||||
std::string ToString() const { return gcn_arch_name(); }
|
||||
|
||||
|
|
@ -198,7 +198,7 @@ class RocmComputeCapability {
|
|||
template <typename... ArrayOfStrings>
|
||||
bool IsThisGfxInAnyList(ArrayOfStrings&&... arr) const {
|
||||
static_assert(sizeof...(arr) >= 1);
|
||||
const auto gfx = gfx_version();
|
||||
const std::string gfx = gfx_version();
|
||||
return (implIsThisGfxInAnyList(std::begin(arr), std::end(arr), gfx) || ...);
|
||||
}
|
||||
|
||||
|
|
@ -206,10 +206,9 @@ class RocmComputeCapability {
|
|||
/// \warning Don't use directly!
|
||||
bool implIsThisGfxInAnyList(const absl::string_view* beg,
|
||||
const absl::string_view* end,
|
||||
const std::string& gfx) const {
|
||||
return std::any_of(beg, end, [&gfx = gfx](const absl::string_view& s) {
|
||||
return gfx == s;
|
||||
});
|
||||
const absl::string_view gfx) const {
|
||||
return std::any_of(
|
||||
beg, end, [&gfx = gfx](const absl::string_view s) { return gfx == s; });
|
||||
}
|
||||
|
||||
std::string gcn_arch_name_{kInvalidGfx}; // default to invalid arch.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user