First step to introduce GpuComputeCapability custom class instead of std::variant

PiperOrigin-RevId: 820940828
This commit is contained in:
Maxim Ermilov 2025-10-17 21:14:41 -07:00 committed by TensorFlower Gardener
parent 4d358b2bac
commit 4a42fca868
16 changed files with 157 additions and 166 deletions

View File

@ -163,8 +163,10 @@ class GpuKernelToBlobPass
target->Options.AllowFPOpFusion = target->Options.AllowFPOpFusion =
llvm::FPOpFusion::FPOpFusionMode::Fast; llvm::FPOpFusion::FPOpFusionMode::Fast;
}; };
TF_ASSIGN_OR_RETURN(std::string ptx, xla::gpu::nvptx::CompileToPtx( TF_ASSIGN_OR_RETURN(
llvm_module_copy.get(), cc, std::string ptx,
xla::gpu::nvptx::CompileToPtx(
llvm_module_copy.get(), stream_executor::GpuComputeCapability(cc),
options, enable_fusion)); options, enable_fusion));
if (print_ptx_) { if (print_ptx_) {
llvm::dbgs() << "Generated PTX code for module '" llvm::dbgs() << "Generated PTX code for module '"

View File

@ -68,12 +68,12 @@ se::CudaComputeCapability GetComputeCapability() {
} }
bool IsRocm() { bool IsRocm() {
return std::holds_alternative<se::RocmComputeCapability>( return se::GPUMachineManager()
se::GPUMachineManager()
->ExecutorForDevice(0) ->ExecutorForDevice(0)
.value() .value()
->GetDeviceDescription() ->GetDeviceDescription()
.gpu_compute_capability()); .gpu_compute_capability()
.IsRocm();
} }
void ExpectErrorMessageSubstr(const Status& s, StringPiece substr) { void ExpectErrorMessageSubstr(const Status& s, StringPiece substr) {

View File

@ -516,7 +516,7 @@ struct LaunchFusedMatMulOp<GPUDevice, T> {
const auto& cc = const auto& cc =
stream->parent()->GetDeviceDescription().gpu_compute_capability(); stream->parent()->GetDeviceDescription().gpu_compute_capability();
if (auto* procm = std::get_if<se::RocmComputeCapability>(&cc)) { if (auto* procm = cc.rocm_compute_capability()) {
use_cudnn = !procm->gfx9_mi200_or_later(); use_cudnn = !procm->gfx9_mi200_or_later();
} }
BlasScratchAllocator scratch_allocator(context); BlasScratchAllocator scratch_allocator(context);

View File

@ -604,7 +604,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
const auto& cc = const auto& cc =
stream->parent()->GetDeviceDescription().gpu_compute_capability(); stream->parent()->GetDeviceDescription().gpu_compute_capability();
if (auto* procm = std::get_if<se::RocmComputeCapability>(&cc)) { if (auto* procm = cc.rocm_compute_capability()) {
bCublasLtSupport = procm->gfx9_mi200_or_later(); bCublasLtSupport = procm->gfx9_mi200_or_later();
} }

View File

@ -2000,8 +2000,7 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
mlir::MLIRContext& mlir_context, bool is_xla_fusion, bool emit_kernel) { mlir::MLIRContext& mlir_context, bool is_xla_fusion, bool emit_kernel) {
const auto& gpu_cc = device_info.gpu_compute_capability(); const auto& gpu_cc = device_info.gpu_compute_capability();
TF_RETURN_IF_ERROR(CheckAtLeastAmpere(gpu_cc)); TF_RETURN_IF_ERROR(CheckAtLeastAmpere(gpu_cc));
std::string arch_name = std::string arch_name = gpu_cc.ToString();
std::visit([](auto& cc) { return cc.ToString(); }, gpu_cc);
const HloModuleConfig& hlo_config = hlo_module.config(); const HloModuleConfig& hlo_config = hlo_module.config();

View File

@ -68,38 +68,25 @@ bool IsTritonSupportedDotOutputType(
case F32: case F32:
return true; return true;
case F8E5M2: case F8E5M2:
return std::visit( if (auto ptr = gpu_version.cuda_compute_capability()) {
absl::Overload( return ptr->IsAtLeastAmpere();
[](const se::CudaComputeCapability& cc) { }
return cc.IsAtLeastAmpere(); return false;
},
[](const se::RocmComputeCapability& cc) { return false; }),
gpu_version);
case F8E4M3FN: case F8E4M3FN:
return std::visit( if (auto ptr = gpu_version.cuda_compute_capability()) {
absl::Overload( return ptr->IsAtLeastHopper();
[](const se::CudaComputeCapability& cc) { }
return cc.IsAtLeastHopper(); return false;
},
[](const se::RocmComputeCapability& cc) { return false; }),
gpu_version);
case BF16: case BF16:
return std::visit( if (auto ptr = gpu_version.rocm_compute_capability()) {
absl::Overload( return ptr->has_bf16_dtype_support();
[](const se::CudaComputeCapability& cc) { return true; }, }
[](const se::RocmComputeCapability& cc) { return true;
return cc.has_bf16_dtype_support();
}),
gpu_version);
case S32: case S32:
return std::visit( if (auto ptr = gpu_version.cuda_compute_capability()) {
absl::Overload( return ptr->IsAtLeastAmpere();
[](const se::CudaComputeCapability& cc) { }
return cc.IsAtLeastAmpere(); return false;
},
[](const se::RocmComputeCapability& cc) { return false; }),
gpu_version);
default: default:
return false; return false;
} }

View File

@ -100,19 +100,16 @@ CommandBufferConfig GetCommandBufferConfig(
}; };
// Check if CUDA/ROCM driver supports required features. // Check if CUDA/ROCM driver supports required features.
auto erase_cuda = [&](const se::CudaComputeCapability& cuda_comp) { if (device_info.gpu_compute_capability().IsCuda()) {
if (std::min(device_info.runtime_version(), device_info.driver_version()) < if (std::min(device_info.runtime_version(), device_info.driver_version()) <
se::SemanticVersion{12, 3, 0}) { se::SemanticVersion{12, 3, 0}) {
erase(kRequireTracing); // cuStreamBeginCaptureToGraph erase(kRequireTracing); // cuStreamBeginCaptureToGraph
erase(kRequireConditionals); // on-device control flow erase(kRequireConditionals); // on-device control flow
} }
}; }
auto erase_rocm = [&](const se::RocmComputeCapability& rocm_comp) { if (device_info.gpu_compute_capability().IsRocm()) {
erase(kRequireConditionals); // on-device control flow erase(kRequireConditionals); // on-device control flow
}; }
std::visit(absl::Overload(erase_cuda, erase_rocm),
device_info.gpu_compute_capability());
return config; return config;
} }

View File

@ -1388,6 +1388,7 @@ cc_library(
"//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo",
"//xla/stream_executor:device_description", "//xla/stream_executor:device_description",
"//xla/stream_executor/cuda:cuda_compute_capability", "//xla/stream_executor/cuda:cuda_compute_capability",
"//xla/stream_executor/rocm:rocm_compute_capability",
"@com_google_absl//absl/functional:overload", "@com_google_absl//absl/functional:overload",
], ],
) )

View File

@ -146,16 +146,14 @@ CompileModuleResults InitializeResults(const HloModule* hlo_module,
} }
std::string GetDumpName(const se::DeviceDescription& device_desc) { std::string GetDumpName(const se::DeviceDescription& device_desc) {
struct GetCcStr { std::string prefix;
std::string operator()(const se::CudaComputeCapability& cc) const { if (auto* cc =
return absl::StrCat("sm_", cc.ToString()); device_desc.gpu_compute_capability().cuda_compute_capability()) {
prefix = absl::StrCat("sm_", cc->ToString());
} else if (auto* cc = device_desc.gpu_compute_capability()
.rocm_compute_capability()) {
prefix = cc->gfx_version();
} }
std::string operator()(const se::RocmComputeCapability& cc) const {
return cc.gfx_version();
}
};
std::string prefix =
std::visit(GetCcStr(), device_desc.gpu_compute_capability());
return absl::StrCat(prefix, "_gpu_", kAfterOptimizationsDumpName); return absl::StrCat(prefix, "_gpu_", kAfterOptimizationsDumpName);
} }

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "xla/shape.h" #include "xla/shape.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h"
#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/rocm/rocm_compute_capability.h"
#include "xla/util.h" #include "xla/util.h"
namespace xla { namespace xla {
@ -33,26 +34,24 @@ namespace {
bool DimensionRequiresPadding(const int64_t size, const PrimitiveType data_type, bool DimensionRequiresPadding(const int64_t size, const PrimitiveType data_type,
const se::GpuComputeCapability& gpu_cc) { const se::GpuComputeCapability& gpu_cc) {
return std::visit( if (const se::CudaComputeCapability* cc = gpu_cc.cuda_compute_capability()) {
absl::Overload(
[&](const se::CudaComputeCapability& cc) {
for (const auto& req : CublasPaddingRequirements) { for (const auto& req : CublasPaddingRequirements) {
if (cc.SupportsAllFeaturesOf(req.min_compute_capability) && if (cc->SupportsAllFeaturesOf(req.min_compute_capability) &&
data_type == req.data_type && size % req.multiple_of != 0) { data_type == req.data_type && size % req.multiple_of != 0) {
return true; return true;
} }
} }
return false; return false;
}, }
[&](const se::RocmComputeCapability& cc) { if (const se::RocmComputeCapability* cc = gpu_cc.rocm_compute_capability()) {
for (const auto& req : HipblasPaddingRequirements) { for (const auto& req : HipblasPaddingRequirements) {
if (data_type == req.data_type && size % req.multiple_of != 0) { if (data_type == req.data_type && size % req.multiple_of != 0) {
return true; return true;
} }
} }
return false; return false;
}), }
gpu_cc); return false;
} }
bool ShapeRequiresPadding(const Shape& shape, int batch_dimensions_size, bool ShapeRequiresPadding(const Shape& shape, int batch_dimensions_size,

View File

@ -75,15 +75,9 @@ ENTRY e {
} }
TEST_F(FloatSupportTestWithTriton, MixedTypeDotWithBF16IsNotUpcasted) { TEST_F(FloatSupportTestWithTriton, MixedTypeDotWithBF16IsNotUpcasted) {
bool skip_test = if (GetGpuComputeCapability().IsRocm() ||
std::visit(absl::Overload( !GetGpuComputeCapability().cuda_compute_capability()->IsAtLeast(
[](const se::CudaComputeCapability& cc) { se::CudaComputeCapability::kAmpere)) {
return !cc.IsAtLeast(se::CudaComputeCapability::kAmpere);
},
[](const se::RocmComputeCapability&) { return true; }),
GetGpuComputeCapability());
if (skip_test) {
GTEST_SKIP() << "Not supported on this GPU architecture"; GTEST_SKIP() << "Not supported on this GPU architecture";
} }

View File

@ -116,24 +116,21 @@ class AnalyticalLatencyHidingSchedulerTest : public GpuCodegenTest {
TEST_F(AnalyticalLatencyHidingSchedulerTest, TestAnalyticalLatencyEstimator) { TEST_F(AnalyticalLatencyHidingSchedulerTest, TestAnalyticalLatencyEstimator) {
auto gpu_compute_capability = GetGpuComputeCapability(); auto gpu_compute_capability = GetGpuComputeCapability();
auto visitor = [](const auto& c) { if (gpu_compute_capability.IsRocm()) {
using cc = std::remove_const_t<std::remove_reference_t<decltype(c)>>;
if constexpr (std::is_same_v<stream_executor::CudaComputeCapability, cc>) {
if (!c.IsAtLeast(se::CudaComputeCapability::kPascal)) {
GTEST_SKIP() << "This test is for Pascal+ GPUs."; GTEST_SKIP() << "This test is for Pascal+ GPUs.";
} }
if (c.major == 12 && c.minor == 1) {
auto* c = gpu_compute_capability.cuda_compute_capability();
if (!c->IsAtLeast(se::CudaComputeCapability::kPascal)) {
GTEST_SKIP() << "This test is for Pascal+ GPUs.";
}
if (c->major == 12 && c->minor == 1) {
// Skip this test for Spark. Because of the AllReduce, the test uses // Skip this test for Spark. Because of the AllReduce, the test uses
// gpu_collective_performance_model, which only makes sense in a // gpu_collective_performance_model, which only makes sense in a
// datacenter network setting. // datacenter network setting.
GTEST_SKIP() << "This test is for datacenter GPUs."; GTEST_SKIP() << "This test is for datacenter GPUs.";
} }
} else if (!std::is_same_v<stream_executor::RocmComputeCapability, cc>) {
GTEST_SKIP() << "This test is for Pascal+ GPUs.";
}
};
std::visit(visitor, gpu_compute_capability);
const se::DeviceDescription dev_info = const se::DeviceDescription dev_info =
backend().default_stream_executor()->GetDeviceDescription(); backend().default_stream_executor()->GetDeviceDescription();

View File

@ -440,11 +440,15 @@ GpuPerformanceWithCollectiveModel::ComputeAllreduceTime(
const se::DeviceDescription& gpu_device_info) { const se::DeviceDescription& gpu_device_info) {
// We use nccl group call to launch multiple allreduces so launch overhead // We use nccl group call to launch multiple allreduces so launch overhead
// only occurs once. // only occurs once.
const auto visitor = [&](const auto& cc) { if (auto ptr =
gpu_device_info.gpu_compute_capability().cuda_compute_capability()) {
return ComputeAllreduceTimeImpl(instr, cost_analysis, gpu_device_info, return ComputeAllreduceTimeImpl(instr, cost_analysis, gpu_device_info,
CreateSettings(cc)); CreateSettings(*ptr));
}; }
return std::visit(visitor, gpu_device_info.gpu_compute_capability()); return ComputeAllreduceTimeImpl(
instr, cost_analysis, gpu_device_info,
CreateSettings(
*gpu_device_info.gpu_compute_capability().rocm_compute_capability()));
} }
/*static*/ absl::Duration /*static*/ absl::Duration

View File

@ -378,6 +378,7 @@ cc_library(
"//xla/stream_executor:device_description", "//xla/stream_executor:device_description",
"//xla/stream_executor:semantic_version", "//xla/stream_executor:semantic_version",
"//xla/stream_executor/cuda:cuda_compute_capability", "//xla/stream_executor/cuda:cuda_compute_capability",
"//xla/stream_executor/rocm:rocm_compute_capability",
"//xla/tsl/platform:errors", "//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor", "//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",

View File

@ -56,6 +56,7 @@ limitations under the License.
#include "xla/shape_util.h" #include "xla/shape_util.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h"
#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/rocm/rocm_compute_capability.h"
#include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/semantic_version.h"
#include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/statusor.h"
@ -843,20 +844,19 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(
}; };
// Check if CUDA/ROCM driver supports required features. // Check if CUDA/ROCM driver supports required features.
auto erase_cuda = [&](const se::CudaComputeCapability& cuda_comp) { if (auto* cuda_comp = device_description_.gpu_compute_capability()
.cuda_compute_capability()) {
if (std::min(device_description_.runtime_version(), if (std::min(device_description_.runtime_version(),
device_description_.driver_version()) < device_description_.driver_version()) <
se::SemanticVersion{12, 3, 0}) { se::SemanticVersion{12, 3, 0}) {
erase(kRequireTracing); // cuStreamBeginCaptureToGraph erase(kRequireTracing); // cuStreamBeginCaptureToGraph
erase(kRequireConditionals); // on-device control flow erase(kRequireConditionals); // on-device control flow
} }
}; } else if (const se::RocmComputeCapability* rocm_comp =
auto erase_rocm = [&](const se::RocmComputeCapability& rocm_comp) { device_description_.gpu_compute_capability()
.rocm_compute_capability()) {
erase(kRequireConditionals); // on-device control flow erase(kRequireConditionals); // on-device control flow
}; }
std::visit(absl::Overload(erase_cuda, erase_rocm),
device_description_.gpu_compute_capability());
auto order = module->MakeComputationPostOrder(); auto order = module->MakeComputationPostOrder();
std::reverse(order.begin(), order.end()); std::reverse(order.begin(), order.end());

View File

@ -36,8 +36,35 @@ limitations under the License.
namespace stream_executor { namespace stream_executor {
using GpuComputeCapability = class GpuComputeCapability
std::variant<CudaComputeCapability, RocmComputeCapability>; : public std::variant<CudaComputeCapability, RocmComputeCapability> {
public:
using std::variant<CudaComputeCapability, RocmComputeCapability>::variant;
using std::variant<CudaComputeCapability, RocmComputeCapability>::operator=;
bool IsCuda() const {
return std::holds_alternative<CudaComputeCapability>(*this);
}
bool IsRocm() const {
return std::holds_alternative<RocmComputeCapability>(*this);
}
const CudaComputeCapability* cuda_compute_capability() const {
return std::get_if<CudaComputeCapability>(this);
}
const RocmComputeCapability* rocm_compute_capability() const {
return std::get_if<RocmComputeCapability>(this);
}
std::string ToString() const {
if (auto ptr = cuda_compute_capability()) {
return ptr->ToString();
}
return rocm_compute_capability()->ToString();
}
};
// Data that describes the execution target of the StreamExecutor, in terms of // Data that describes the execution target of the StreamExecutor, in terms of
// important logical parameters. These include dimensionality limits and // important logical parameters. These include dimensionality limits and
@ -193,32 +220,24 @@ class DeviceDescription {
// also we do not count what occupies cache, but rather claim that what is // also we do not count what occupies cache, but rather claim that what is
// much smaller than the cache size will likely stay in it. // much smaller than the cache size will likely stay in it.
constexpr int64_t l1_cache_size_per_SM() const { constexpr int64_t l1_cache_size_per_SM() const {
return std::visit( if (auto* capability = gpu_compute_capability_.rocm_compute_capability()) {
[](const auto& capability) -> int64_t {
if constexpr (std::is_same_v<std::decay_t<decltype(capability)>,
RocmComputeCapability>) {
// MI100 and MI200 has 16KB L1 cache per CU. // MI100 and MI200 has 16KB L1 cache per CU.
if (capability.gfx9_mi100() || capability.gfx9_mi200()) { if (capability->gfx9_mi100() || capability->gfx9_mi200()) {
return 16 * 1024; return 16 * 1024;
} }
// MI300 has 32KB L1 cache per CU. // MI300 has 32KB L1 cache per CU.
if (capability.gfx9_mi300_series()) { if (capability->gfx9_mi300_series()) {
return 32 * 1024; return 32 * 1024;
} }
} }
// Default return for other GPUs (e.g., RTX A6000). // Default return for other GPUs (e.g., RTX A6000).
return 2 * 1024; return 2 * 1024;
},
gpu_compute_capability_);
} }
constexpr int64_t dram_to_l2_transaction_size_bytes() const { constexpr int64_t dram_to_l2_transaction_size_bytes() const {
return std::visit( if (auto* capability = gpu_compute_capability_.rocm_compute_capability()) {
[](const auto& capability) -> int {
if constexpr (std::is_same_v<std::decay_t<decltype(capability)>,
RocmComputeCapability>) {
// DRAM->L2 bus is 128 Byte width for MI300. // DRAM->L2 bus is 128 Byte width for MI300.
if (capability.gfx9_mi300_series()) { if (capability->gfx9_mi300_series()) {
return 128; return 128;
} }
} }
@ -229,24 +248,17 @@ class DeviceDescription {
// (page 10). // (page 10).
// return 64 Bytes by default. // return 64 Bytes by default.
return 64; return 64;
},
gpu_compute_capability_);
} }
constexpr int64_t memory_transactions_per_clock() const { constexpr int64_t memory_transactions_per_clock() const {
return std::visit( if (auto* capability = gpu_compute_capability_.rocm_compute_capability()) {
[](const auto& capability) -> int {
if constexpr (std::is_same_v<std::decay_t<decltype(capability)>,
RocmComputeCapability>) {
// 16 works well on MI300. // 16 works well on MI300.
if (capability.gfx9_mi300_series()) { if (capability->gfx9_mi300_series()) {
return 16; return 16;
} }
} }
// Default return for other GPUs. // Default return for other GPUs.
return 32; return 32;
},
gpu_compute_capability_);
} }
GpuDeviceInfoProto ToGpuProto() const; GpuDeviceInfoProto ToGpuProto() const;