From 4a42fca8689cc30fb1f10fdb45cd604e13065cae Mon Sep 17 00:00:00 2001 From: Maxim Ermilov Date: Fri, 17 Oct 2025 21:14:41 -0700 Subject: [PATCH] First step to introduce GpuComputeCapability custom class instead of std::variant PiperOrigin-RevId: 820940828 --- .../transforms/gpu_kernel_to_blob_pass.cc | 8 +- .../common_runtime/gpu/gpu_device_test.cc | 12 +- tensorflow/core/kernels/matmul_op_fused.cc | 2 +- tensorflow/core/kernels/matmul_op_impl.h | 2 +- .../gpu/codegen/triton/fusion_emitter.cc | 3 +- .../gpu/codegen/triton/support_legacy.cc | 45 +++---- .../runtime/command_buffer_conversion_pass.cc | 11 +- third_party/xla/xla/service/gpu/BUILD | 1 + .../service/gpu/compile_module_to_llvm_ir.cc | 18 ++- .../gpu/cublas_padding_requirements.cc | 39 +++--- .../xla/xla/service/gpu/float_support_test.cc | 12 +- .../analytical_latency_estimator_test.cc | 31 +++-- .../model/gpu_collective_performance_model.cc | 12 +- .../xla/xla/service/gpu/transforms/BUILD | 1 + .../transforms/command_buffer_scheduling.cc | 14 +-- .../xla/stream_executor/device_description.h | 112 ++++++++++-------- 16 files changed, 157 insertions(+), 166 deletions(-) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc index 7b3625103ef..079500b8cd1 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc @@ -163,9 +163,11 @@ class GpuKernelToBlobPass target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast; }; - TF_ASSIGN_OR_RETURN(std::string ptx, xla::gpu::nvptx::CompileToPtx( - llvm_module_copy.get(), cc, - options, enable_fusion)); + TF_ASSIGN_OR_RETURN( + std::string ptx, + xla::gpu::nvptx::CompileToPtx( + llvm_module_copy.get(), stream_executor::GpuComputeCapability(cc), + options, enable_fusion)); if (print_ptx_) { llvm::dbgs() << "Generated PTX code for module '" << gpu_module.getName() << "' on architecture sm_" << arch diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc index 8eba1fbb914..3aa8fa1003f 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc @@ -68,12 +68,12 @@ se::CudaComputeCapability GetComputeCapability() { } bool IsRocm() { - return std::holds_alternative( - se::GPUMachineManager() - ->ExecutorForDevice(0) - .value() - ->GetDeviceDescription() - .gpu_compute_capability()); + return se::GPUMachineManager() + ->ExecutorForDevice(0) + .value() + ->GetDeviceDescription() + .gpu_compute_capability() + .IsRocm(); } void ExpectErrorMessageSubstr(const Status& s, StringPiece substr) { diff --git a/tensorflow/core/kernels/matmul_op_fused.cc b/tensorflow/core/kernels/matmul_op_fused.cc index 502654ca097..4e6a8d52666 100644 --- a/tensorflow/core/kernels/matmul_op_fused.cc +++ b/tensorflow/core/kernels/matmul_op_fused.cc @@ -516,7 +516,7 @@ struct LaunchFusedMatMulOp { const auto& cc = stream->parent()->GetDeviceDescription().gpu_compute_capability(); - if (auto* procm = std::get_if(&cc)) { + if (auto* procm = cc.rocm_compute_capability()) { use_cudnn = !procm->gfx9_mi200_or_later(); } BlasScratchAllocator scratch_allocator(context); diff --git a/tensorflow/core/kernels/matmul_op_impl.h b/tensorflow/core/kernels/matmul_op_impl.h index 50517dc9a2c..e81cb4e2b8c 100644 --- a/tensorflow/core/kernels/matmul_op_impl.h +++ b/tensorflow/core/kernels/matmul_op_impl.h @@ -604,7 +604,7 @@ struct LaunchBatchMatMul { const auto& cc = stream->parent()->GetDeviceDescription().gpu_compute_capability(); - if (auto* procm = std::get_if(&cc)) { + if (auto* procm = cc.rocm_compute_capability()) { bCublasLtSupport = procm->gfx9_mi200_or_later(); } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc index 681486d275b..3513014133f 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -2000,8 +2000,7 @@ absl::StatusOr CompileTritonToLLVM( mlir::MLIRContext& mlir_context, bool is_xla_fusion, bool emit_kernel) { const auto& gpu_cc = device_info.gpu_compute_capability(); TF_RETURN_IF_ERROR(CheckAtLeastAmpere(gpu_cc)); - std::string arch_name = - std::visit([](auto& cc) { return cc.ToString(); }, gpu_cc); + std::string arch_name = gpu_cc.ToString(); const HloModuleConfig& hlo_config = hlo_module.config(); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy.cc index b0d6c79274e..726530cf917 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_legacy.cc @@ -68,38 +68,25 @@ bool IsTritonSupportedDotOutputType( case F32: return true; case F8E5M2: - return std::visit( - absl::Overload( - [](const se::CudaComputeCapability& cc) { - return cc.IsAtLeastAmpere(); - }, - [](const se::RocmComputeCapability& cc) { return false; }), - gpu_version); - + if (auto ptr = gpu_version.cuda_compute_capability()) { + return ptr->IsAtLeastAmpere(); + } + return false; case F8E4M3FN: - return std::visit( - absl::Overload( - [](const se::CudaComputeCapability& cc) { - return cc.IsAtLeastHopper(); - }, - [](const se::RocmComputeCapability& cc) { return false; }), - gpu_version); + if (auto ptr = gpu_version.cuda_compute_capability()) { + return ptr->IsAtLeastHopper(); + } + return false; case BF16: - return std::visit( - absl::Overload( - [](const se::CudaComputeCapability& cc) { return true; }, - [](const se::RocmComputeCapability& cc) { - return cc.has_bf16_dtype_support(); - }), - gpu_version); + if (auto ptr = gpu_version.rocm_compute_capability()) { + return ptr->has_bf16_dtype_support(); + } + return true; case S32: - return std::visit( - absl::Overload( - [](const se::CudaComputeCapability& cc) { - return cc.IsAtLeastAmpere(); - }, - [](const se::RocmComputeCapability& cc) { return false; }), - gpu_version); + if (auto ptr = gpu_version.cuda_compute_capability()) { + return ptr->IsAtLeastAmpere(); + } + return false; default: return false; } diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_conversion_pass.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_conversion_pass.cc index 952673948d8..f3b558de34e 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_conversion_pass.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_conversion_pass.cc @@ -100,19 +100,16 @@ CommandBufferConfig GetCommandBufferConfig( }; // Check if CUDA/ROCM driver supports required features. - auto erase_cuda = [&](const se::CudaComputeCapability& cuda_comp) { + if (device_info.gpu_compute_capability().IsCuda()) { if (std::min(device_info.runtime_version(), device_info.driver_version()) < se::SemanticVersion{12, 3, 0}) { erase(kRequireTracing); // cuStreamBeginCaptureToGraph erase(kRequireConditionals); // on-device control flow } - }; - auto erase_rocm = [&](const se::RocmComputeCapability& rocm_comp) { + } + if (device_info.gpu_compute_capability().IsRocm()) { erase(kRequireConditionals); // on-device control flow - }; - - std::visit(absl::Overload(erase_cuda, erase_rocm), - device_info.gpu_compute_capability()); + } return config; } diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 1b98aca38c5..305b9bcb723 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1388,6 +1388,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_compute_capability", + "//xla/stream_executor/rocm:rocm_compute_capability", "@com_google_absl//absl/functional:overload", ], ) diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc index f6755f94fb4..fcab22e65cc 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -146,16 +146,14 @@ CompileModuleResults InitializeResults(const HloModule* hlo_module, } std::string GetDumpName(const se::DeviceDescription& device_desc) { - struct GetCcStr { - std::string operator()(const se::CudaComputeCapability& cc) const { - return absl::StrCat("sm_", cc.ToString()); - } - std::string operator()(const se::RocmComputeCapability& cc) const { - return cc.gfx_version(); - } - }; - std::string prefix = - std::visit(GetCcStr(), device_desc.gpu_compute_capability()); + std::string prefix; + if (auto* cc = + device_desc.gpu_compute_capability().cuda_compute_capability()) { + prefix = absl::StrCat("sm_", cc->ToString()); + } else if (auto* cc = device_desc.gpu_compute_capability() + .rocm_compute_capability()) { + prefix = cc->gfx_version(); + } return absl::StrCat(prefix, "_gpu_", kAfterOptimizationsDumpName); } diff --git a/third_party/xla/xla/service/gpu/cublas_padding_requirements.cc b/third_party/xla/xla/service/gpu/cublas_padding_requirements.cc index 5e7800be917..7f538a63124 100644 --- a/third_party/xla/xla/service/gpu/cublas_padding_requirements.cc +++ b/third_party/xla/xla/service/gpu/cublas_padding_requirements.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/rocm/rocm_compute_capability.h" #include "xla/util.h" namespace xla { @@ -33,26 +34,24 @@ namespace { bool DimensionRequiresPadding(const int64_t size, const PrimitiveType data_type, const se::GpuComputeCapability& gpu_cc) { - return std::visit( - absl::Overload( - [&](const se::CudaComputeCapability& cc) { - for (const auto& req : CublasPaddingRequirements) { - if (cc.SupportsAllFeaturesOf(req.min_compute_capability) && - data_type == req.data_type && size % req.multiple_of != 0) { - return true; - } - } - return false; - }, - [&](const se::RocmComputeCapability& cc) { - for (const auto& req : HipblasPaddingRequirements) { - if (data_type == req.data_type && size % req.multiple_of != 0) { - return true; - } - } - return false; - }), - gpu_cc); + if (const se::CudaComputeCapability* cc = gpu_cc.cuda_compute_capability()) { + for (const auto& req : CublasPaddingRequirements) { + if (cc->SupportsAllFeaturesOf(req.min_compute_capability) && + data_type == req.data_type && size % req.multiple_of != 0) { + return true; + } + } + return false; + } + if (const se::RocmComputeCapability* cc = gpu_cc.rocm_compute_capability()) { + for (const auto& req : HipblasPaddingRequirements) { + if (data_type == req.data_type && size % req.multiple_of != 0) { + return true; + } + } + return false; + } + return false; } bool ShapeRequiresPadding(const Shape& shape, int batch_dimensions_size, diff --git a/third_party/xla/xla/service/gpu/float_support_test.cc b/third_party/xla/xla/service/gpu/float_support_test.cc index 4824af66d08..4fc8c1c87d2 100644 --- a/third_party/xla/xla/service/gpu/float_support_test.cc +++ b/third_party/xla/xla/service/gpu/float_support_test.cc @@ -75,15 +75,9 @@ ENTRY e { } TEST_F(FloatSupportTestWithTriton, MixedTypeDotWithBF16IsNotUpcasted) { - bool skip_test = - std::visit(absl::Overload( - [](const se::CudaComputeCapability& cc) { - return !cc.IsAtLeast(se::CudaComputeCapability::kAmpere); - }, - [](const se::RocmComputeCapability&) { return true; }), - GetGpuComputeCapability()); - - if (skip_test) { + if (GetGpuComputeCapability().IsRocm() || + !GetGpuComputeCapability().cuda_compute_capability()->IsAtLeast( + se::CudaComputeCapability::kAmpere)) { GTEST_SKIP() << "Not supported on this GPU architecture"; } diff --git a/third_party/xla/xla/service/gpu/model/analytical_latency_estimator_test.cc b/third_party/xla/xla/service/gpu/model/analytical_latency_estimator_test.cc index 88ebc1ba977..defcce91630 100644 --- a/third_party/xla/xla/service/gpu/model/analytical_latency_estimator_test.cc +++ b/third_party/xla/xla/service/gpu/model/analytical_latency_estimator_test.cc @@ -116,24 +116,21 @@ class AnalyticalLatencyHidingSchedulerTest : public GpuCodegenTest { TEST_F(AnalyticalLatencyHidingSchedulerTest, TestAnalyticalLatencyEstimator) { auto gpu_compute_capability = GetGpuComputeCapability(); - auto visitor = [](const auto& c) { - using cc = std::remove_const_t>; - if constexpr (std::is_same_v) { - if (!c.IsAtLeast(se::CudaComputeCapability::kPascal)) { - GTEST_SKIP() << "This test is for Pascal+ GPUs."; - } - if (c.major == 12 && c.minor == 1) { - // Skip this test for Spark. Because of the AllReduce, the test uses - // gpu_collective_performance_model, which only makes sense in a - // datacenter network setting. - GTEST_SKIP() << "This test is for datacenter GPUs."; - } - } else if (!std::is_same_v) { - GTEST_SKIP() << "This test is for Pascal+ GPUs."; - } - }; + if (gpu_compute_capability.IsRocm()) { + GTEST_SKIP() << "This test is for Pascal+ GPUs."; + } + + auto* c = gpu_compute_capability.cuda_compute_capability(); + if (!c->IsAtLeast(se::CudaComputeCapability::kPascal)) { + GTEST_SKIP() << "This test is for Pascal+ GPUs."; + } + if (c->major == 12 && c->minor == 1) { + // Skip this test for Spark. Because of the AllReduce, the test uses + // gpu_collective_performance_model, which only makes sense in a + // datacenter network setting. + GTEST_SKIP() << "This test is for datacenter GPUs."; + } - std::visit(visitor, gpu_compute_capability); const se::DeviceDescription dev_info = backend().default_stream_executor()->GetDeviceDescription(); diff --git a/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.cc index aa43729fe20..5c5bdeeb3f1 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.cc @@ -440,11 +440,15 @@ GpuPerformanceWithCollectiveModel::ComputeAllreduceTime( const se::DeviceDescription& gpu_device_info) { // We use nccl group call to launch multiple allreduces so launch overhead // only occurs once. - const auto visitor = [&](const auto& cc) { + if (auto ptr = + gpu_device_info.gpu_compute_capability().cuda_compute_capability()) { return ComputeAllreduceTimeImpl(instr, cost_analysis, gpu_device_info, - CreateSettings(cc)); - }; - return std::visit(visitor, gpu_device_info.gpu_compute_capability()); + CreateSettings(*ptr)); + } + return ComputeAllreduceTimeImpl( + instr, cost_analysis, gpu_device_info, + CreateSettings( + *gpu_device_info.gpu_compute_capability().rocm_compute_capability())); } /*static*/ absl::Duration diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index b4e2a37172c..31d5409410a 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -378,6 +378,7 @@ cc_library( "//xla/stream_executor:device_description", "//xla/stream_executor:semantic_version", "//xla/stream_executor/cuda:cuda_compute_capability", + "//xla/stream_executor/rocm:rocm_compute_capability", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc index 672b1248753..96fa801f5b3 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc @@ -56,6 +56,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/rocm/rocm_compute_capability.h" #include "xla/stream_executor/semantic_version.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" @@ -843,20 +844,19 @@ absl::StatusOr CommandBufferScheduling::Run( }; // Check if CUDA/ROCM driver supports required features. - auto erase_cuda = [&](const se::CudaComputeCapability& cuda_comp) { + if (auto* cuda_comp = device_description_.gpu_compute_capability() + .cuda_compute_capability()) { if (std::min(device_description_.runtime_version(), device_description_.driver_version()) < se::SemanticVersion{12, 3, 0}) { erase(kRequireTracing); // cuStreamBeginCaptureToGraph erase(kRequireConditionals); // on-device control flow } - }; - auto erase_rocm = [&](const se::RocmComputeCapability& rocm_comp) { + } else if (const se::RocmComputeCapability* rocm_comp = + device_description_.gpu_compute_capability() + .rocm_compute_capability()) { erase(kRequireConditionals); // on-device control flow - }; - - std::visit(absl::Overload(erase_cuda, erase_rocm), - device_description_.gpu_compute_capability()); + } auto order = module->MakeComputationPostOrder(); std::reverse(order.begin(), order.end()); diff --git a/third_party/xla/xla/stream_executor/device_description.h b/third_party/xla/xla/stream_executor/device_description.h index fa394c2ee23..391cad3f7eb 100644 --- a/third_party/xla/xla/stream_executor/device_description.h +++ b/third_party/xla/xla/stream_executor/device_description.h @@ -36,8 +36,35 @@ limitations under the License. namespace stream_executor { -using GpuComputeCapability = - std::variant; +class GpuComputeCapability + : public std::variant { + public: + using std::variant::variant; + using std::variant::operator=; + + bool IsCuda() const { + return std::holds_alternative(*this); + } + + bool IsRocm() const { + return std::holds_alternative(*this); + } + + const CudaComputeCapability* cuda_compute_capability() const { + return std::get_if(this); + } + + const RocmComputeCapability* rocm_compute_capability() const { + return std::get_if(this); + } + + std::string ToString() const { + if (auto ptr = cuda_compute_capability()) { + return ptr->ToString(); + } + return rocm_compute_capability()->ToString(); + } +}; // Data that describes the execution target of the StreamExecutor, in terms of // important logical parameters. These include dimensionality limits and @@ -193,60 +220,45 @@ class DeviceDescription { // also we do not count what occupies cache, but rather claim that what is // much smaller than the cache size will likely stay in it. constexpr int64_t l1_cache_size_per_SM() const { - return std::visit( - [](const auto& capability) -> int64_t { - if constexpr (std::is_same_v, - RocmComputeCapability>) { - // MI100 and MI200 has 16KB L1 cache per CU. - if (capability.gfx9_mi100() || capability.gfx9_mi200()) { - return 16 * 1024; - } - // MI300 has 32KB L1 cache per CU. - if (capability.gfx9_mi300_series()) { - return 32 * 1024; - } - } - // Default return for other GPUs (e.g., RTX A6000). - return 2 * 1024; - }, - gpu_compute_capability_); + if (auto* capability = gpu_compute_capability_.rocm_compute_capability()) { + // MI100 and MI200 has 16KB L1 cache per CU. + if (capability->gfx9_mi100() || capability->gfx9_mi200()) { + return 16 * 1024; + } + // MI300 has 32KB L1 cache per CU. + if (capability->gfx9_mi300_series()) { + return 32 * 1024; + } + } + // Default return for other GPUs (e.g., RTX A6000). + return 2 * 1024; } constexpr int64_t dram_to_l2_transaction_size_bytes() const { - return std::visit( - [](const auto& capability) -> int { - if constexpr (std::is_same_v, - RocmComputeCapability>) { - // DRAM->L2 bus is 128 Byte width for MI300. - if (capability.gfx9_mi300_series()) { - return 128; - } - } - // Cache line is 128B that is split into 4 sectors of 32B. Default - // transaction size from DRAM -> L2 = 64 Bytes = 2 sectors, since - // V100, but it can be also configured. - // https://developer.download.nvidia.com/video/gputechconf/gtc/2020/presentations/s21819-optimizing-applications-for-nvidia-ampere-gpu-architecture.pdf - // (page 10). - // return 64 Bytes by default. - return 64; - }, - gpu_compute_capability_); + if (auto* capability = gpu_compute_capability_.rocm_compute_capability()) { + // DRAM->L2 bus is 128 Byte width for MI300. + if (capability->gfx9_mi300_series()) { + return 128; + } + } + // Cache line is 128B that is split into 4 sectors of 32B. Default + // transaction size from DRAM -> L2 = 64 Bytes = 2 sectors, since + // V100, but it can be also configured. + // https://developer.download.nvidia.com/video/gputechconf/gtc/2020/presentations/s21819-optimizing-applications-for-nvidia-ampere-gpu-architecture.pdf + // (page 10). + // return 64 Bytes by default. + return 64; } constexpr int64_t memory_transactions_per_clock() const { - return std::visit( - [](const auto& capability) -> int { - if constexpr (std::is_same_v, - RocmComputeCapability>) { - // 16 works well on MI300. - if (capability.gfx9_mi300_series()) { - return 16; - } - } - // Default return for other GPUs. - return 32; - }, - gpu_compute_capability_); + if (auto* capability = gpu_compute_capability_.rocm_compute_capability()) { + // 16 works well on MI300. + if (capability->gfx9_mi300_series()) { + return 16; + } + } + // Default return for other GPUs. + return 32; } GpuDeviceInfoProto ToGpuProto() const;