First step to introduce GpuComputeCapability custom class instead of std::variant

PiperOrigin-RevId: 820940828
This commit is contained in:
Maxim Ermilov 2025-10-17 21:14:41 -07:00 committed by TensorFlower Gardener
parent 4d358b2bac
commit 4a42fca868
16 changed files with 157 additions and 166 deletions

View File

@ -163,8 +163,10 @@ class GpuKernelToBlobPass
target->Options.AllowFPOpFusion =
llvm::FPOpFusion::FPOpFusionMode::Fast;
};
TF_ASSIGN_OR_RETURN(std::string ptx, xla::gpu::nvptx::CompileToPtx(
llvm_module_copy.get(), cc,
TF_ASSIGN_OR_RETURN(
std::string ptx,
xla::gpu::nvptx::CompileToPtx(
llvm_module_copy.get(), stream_executor::GpuComputeCapability(cc),
options, enable_fusion));
if (print_ptx_) {
llvm::dbgs() << "Generated PTX code for module '"

View File

@ -68,12 +68,12 @@ se::CudaComputeCapability GetComputeCapability() {
}
bool IsRocm() {
return std::holds_alternative<se::RocmComputeCapability>(
se::GPUMachineManager()
return se::GPUMachineManager()
->ExecutorForDevice(0)
.value()
->GetDeviceDescription()
.gpu_compute_capability());
.gpu_compute_capability()
.IsRocm();
}
void ExpectErrorMessageSubstr(const Status& s, StringPiece substr) {

View File

@ -516,7 +516,7 @@ struct LaunchFusedMatMulOp<GPUDevice, T> {
const auto& cc =
stream->parent()->GetDeviceDescription().gpu_compute_capability();
if (auto* procm = std::get_if<se::RocmComputeCapability>(&cc)) {
if (auto* procm = cc.rocm_compute_capability()) {
use_cudnn = !procm->gfx9_mi200_or_later();
}
BlasScratchAllocator scratch_allocator(context);

View File

@ -604,7 +604,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
const auto& cc =
stream->parent()->GetDeviceDescription().gpu_compute_capability();
if (auto* procm = std::get_if<se::RocmComputeCapability>(&cc)) {
if (auto* procm = cc.rocm_compute_capability()) {
bCublasLtSupport = procm->gfx9_mi200_or_later();
}

View File

@ -2000,8 +2000,7 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
mlir::MLIRContext& mlir_context, bool is_xla_fusion, bool emit_kernel) {
const auto& gpu_cc = device_info.gpu_compute_capability();
TF_RETURN_IF_ERROR(CheckAtLeastAmpere(gpu_cc));
std::string arch_name =
std::visit([](auto& cc) { return cc.ToString(); }, gpu_cc);
std::string arch_name = gpu_cc.ToString();
const HloModuleConfig& hlo_config = hlo_module.config();

View File

@ -68,38 +68,25 @@ bool IsTritonSupportedDotOutputType(
case F32:
return true;
case F8E5M2:
return std::visit(
absl::Overload(
[](const se::CudaComputeCapability& cc) {
return cc.IsAtLeastAmpere();
},
[](const se::RocmComputeCapability& cc) { return false; }),
gpu_version);
if (auto ptr = gpu_version.cuda_compute_capability()) {
return ptr->IsAtLeastAmpere();
}
return false;
case F8E4M3FN:
return std::visit(
absl::Overload(
[](const se::CudaComputeCapability& cc) {
return cc.IsAtLeastHopper();
},
[](const se::RocmComputeCapability& cc) { return false; }),
gpu_version);
if (auto ptr = gpu_version.cuda_compute_capability()) {
return ptr->IsAtLeastHopper();
}
return false;
case BF16:
return std::visit(
absl::Overload(
[](const se::CudaComputeCapability& cc) { return true; },
[](const se::RocmComputeCapability& cc) {
return cc.has_bf16_dtype_support();
}),
gpu_version);
if (auto ptr = gpu_version.rocm_compute_capability()) {
return ptr->has_bf16_dtype_support();
}
return true;
case S32:
return std::visit(
absl::Overload(
[](const se::CudaComputeCapability& cc) {
return cc.IsAtLeastAmpere();
},
[](const se::RocmComputeCapability& cc) { return false; }),
gpu_version);
if (auto ptr = gpu_version.cuda_compute_capability()) {
return ptr->IsAtLeastAmpere();
}
return false;
default:
return false;
}

View File

@ -100,19 +100,16 @@ CommandBufferConfig GetCommandBufferConfig(
};
// Check if CUDA/ROCM driver supports required features.
auto erase_cuda = [&](const se::CudaComputeCapability& cuda_comp) {
if (device_info.gpu_compute_capability().IsCuda()) {
if (std::min(device_info.runtime_version(), device_info.driver_version()) <
se::SemanticVersion{12, 3, 0}) {
erase(kRequireTracing); // cuStreamBeginCaptureToGraph
erase(kRequireConditionals); // on-device control flow
}
};
auto erase_rocm = [&](const se::RocmComputeCapability& rocm_comp) {
}
if (device_info.gpu_compute_capability().IsRocm()) {
erase(kRequireConditionals); // on-device control flow
};
std::visit(absl::Overload(erase_cuda, erase_rocm),
device_info.gpu_compute_capability());
}
return config;
}

View File

@ -1388,6 +1388,7 @@ cc_library(
"//xla/hlo/ir:hlo",
"//xla/stream_executor:device_description",
"//xla/stream_executor/cuda:cuda_compute_capability",
"//xla/stream_executor/rocm:rocm_compute_capability",
"@com_google_absl//absl/functional:overload",
],
)

View File

@ -146,16 +146,14 @@ CompileModuleResults InitializeResults(const HloModule* hlo_module,
}
std::string GetDumpName(const se::DeviceDescription& device_desc) {
struct GetCcStr {
std::string operator()(const se::CudaComputeCapability& cc) const {
return absl::StrCat("sm_", cc.ToString());
std::string prefix;
if (auto* cc =
device_desc.gpu_compute_capability().cuda_compute_capability()) {
prefix = absl::StrCat("sm_", cc->ToString());
} else if (auto* cc = device_desc.gpu_compute_capability()
.rocm_compute_capability()) {
prefix = cc->gfx_version();
}
std::string operator()(const se::RocmComputeCapability& cc) const {
return cc.gfx_version();
}
};
std::string prefix =
std::visit(GetCcStr(), device_desc.gpu_compute_capability());
return absl::StrCat(prefix, "_gpu_", kAfterOptimizationsDumpName);
}

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/rocm/rocm_compute_capability.h"
#include "xla/util.h"
namespace xla {
@ -33,26 +34,24 @@ namespace {
bool DimensionRequiresPadding(const int64_t size, const PrimitiveType data_type,
const se::GpuComputeCapability& gpu_cc) {
return std::visit(
absl::Overload(
[&](const se::CudaComputeCapability& cc) {
if (const se::CudaComputeCapability* cc = gpu_cc.cuda_compute_capability()) {
for (const auto& req : CublasPaddingRequirements) {
if (cc.SupportsAllFeaturesOf(req.min_compute_capability) &&
if (cc->SupportsAllFeaturesOf(req.min_compute_capability) &&
data_type == req.data_type && size % req.multiple_of != 0) {
return true;
}
}
return false;
},
[&](const se::RocmComputeCapability& cc) {
}
if (const se::RocmComputeCapability* cc = gpu_cc.rocm_compute_capability()) {
for (const auto& req : HipblasPaddingRequirements) {
if (data_type == req.data_type && size % req.multiple_of != 0) {
return true;
}
}
return false;
}),
gpu_cc);
}
return false;
}
bool ShapeRequiresPadding(const Shape& shape, int batch_dimensions_size,

View File

@ -75,15 +75,9 @@ ENTRY e {
}
TEST_F(FloatSupportTestWithTriton, MixedTypeDotWithBF16IsNotUpcasted) {
bool skip_test =
std::visit(absl::Overload(
[](const se::CudaComputeCapability& cc) {
return !cc.IsAtLeast(se::CudaComputeCapability::kAmpere);
},
[](const se::RocmComputeCapability&) { return true; }),
GetGpuComputeCapability());
if (skip_test) {
if (GetGpuComputeCapability().IsRocm() ||
!GetGpuComputeCapability().cuda_compute_capability()->IsAtLeast(
se::CudaComputeCapability::kAmpere)) {
GTEST_SKIP() << "Not supported on this GPU architecture";
}

View File

@ -116,24 +116,21 @@ class AnalyticalLatencyHidingSchedulerTest : public GpuCodegenTest {
TEST_F(AnalyticalLatencyHidingSchedulerTest, TestAnalyticalLatencyEstimator) {
auto gpu_compute_capability = GetGpuComputeCapability();
auto visitor = [](const auto& c) {
using cc = std::remove_const_t<std::remove_reference_t<decltype(c)>>;
if constexpr (std::is_same_v<stream_executor::CudaComputeCapability, cc>) {
if (!c.IsAtLeast(se::CudaComputeCapability::kPascal)) {
if (gpu_compute_capability.IsRocm()) {
GTEST_SKIP() << "This test is for Pascal+ GPUs.";
}
if (c.major == 12 && c.minor == 1) {
auto* c = gpu_compute_capability.cuda_compute_capability();
if (!c->IsAtLeast(se::CudaComputeCapability::kPascal)) {
GTEST_SKIP() << "This test is for Pascal+ GPUs.";
}
if (c->major == 12 && c->minor == 1) {
// Skip this test for Spark. Because of the AllReduce, the test uses
// gpu_collective_performance_model, which only makes sense in a
// datacenter network setting.
GTEST_SKIP() << "This test is for datacenter GPUs.";
}
} else if (!std::is_same_v<stream_executor::RocmComputeCapability, cc>) {
GTEST_SKIP() << "This test is for Pascal+ GPUs.";
}
};
std::visit(visitor, gpu_compute_capability);
const se::DeviceDescription dev_info =
backend().default_stream_executor()->GetDeviceDescription();

View File

@ -440,11 +440,15 @@ GpuPerformanceWithCollectiveModel::ComputeAllreduceTime(
const se::DeviceDescription& gpu_device_info) {
// We use nccl group call to launch multiple allreduces so launch overhead
// only occurs once.
const auto visitor = [&](const auto& cc) {
if (auto ptr =
gpu_device_info.gpu_compute_capability().cuda_compute_capability()) {
return ComputeAllreduceTimeImpl(instr, cost_analysis, gpu_device_info,
CreateSettings(cc));
};
return std::visit(visitor, gpu_device_info.gpu_compute_capability());
CreateSettings(*ptr));
}
return ComputeAllreduceTimeImpl(
instr, cost_analysis, gpu_device_info,
CreateSettings(
*gpu_device_info.gpu_compute_capability().rocm_compute_capability()));
}
/*static*/ absl::Duration

View File

@ -378,6 +378,7 @@ cc_library(
"//xla/stream_executor:device_description",
"//xla/stream_executor:semantic_version",
"//xla/stream_executor/cuda:cuda_compute_capability",
"//xla/stream_executor/rocm:rocm_compute_capability",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",

View File

@ -56,6 +56,7 @@ limitations under the License.
#include "xla/shape_util.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/rocm/rocm_compute_capability.h"
#include "xla/stream_executor/semantic_version.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
@ -843,20 +844,19 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(
};
// Check if CUDA/ROCM driver supports required features.
auto erase_cuda = [&](const se::CudaComputeCapability& cuda_comp) {
if (auto* cuda_comp = device_description_.gpu_compute_capability()
.cuda_compute_capability()) {
if (std::min(device_description_.runtime_version(),
device_description_.driver_version()) <
se::SemanticVersion{12, 3, 0}) {
erase(kRequireTracing); // cuStreamBeginCaptureToGraph
erase(kRequireConditionals); // on-device control flow
}
};
auto erase_rocm = [&](const se::RocmComputeCapability& rocm_comp) {
} else if (const se::RocmComputeCapability* rocm_comp =
device_description_.gpu_compute_capability()
.rocm_compute_capability()) {
erase(kRequireConditionals); // on-device control flow
};
std::visit(absl::Overload(erase_cuda, erase_rocm),
device_description_.gpu_compute_capability());
}
auto order = module->MakeComputationPostOrder();
std::reverse(order.begin(), order.end());

View File

@ -36,8 +36,35 @@ limitations under the License.
namespace stream_executor {
using GpuComputeCapability =
std::variant<CudaComputeCapability, RocmComputeCapability>;
class GpuComputeCapability
: public std::variant<CudaComputeCapability, RocmComputeCapability> {
public:
using std::variant<CudaComputeCapability, RocmComputeCapability>::variant;
using std::variant<CudaComputeCapability, RocmComputeCapability>::operator=;
bool IsCuda() const {
return std::holds_alternative<CudaComputeCapability>(*this);
}
bool IsRocm() const {
return std::holds_alternative<RocmComputeCapability>(*this);
}
const CudaComputeCapability* cuda_compute_capability() const {
return std::get_if<CudaComputeCapability>(this);
}
const RocmComputeCapability* rocm_compute_capability() const {
return std::get_if<RocmComputeCapability>(this);
}
std::string ToString() const {
if (auto ptr = cuda_compute_capability()) {
return ptr->ToString();
}
return rocm_compute_capability()->ToString();
}
};
// Data that describes the execution target of the StreamExecutor, in terms of
// important logical parameters. These include dimensionality limits and
@ -193,32 +220,24 @@ class DeviceDescription {
// also we do not count what occupies cache, but rather claim that what is
// much smaller than the cache size will likely stay in it.
constexpr int64_t l1_cache_size_per_SM() const {
return std::visit(
[](const auto& capability) -> int64_t {
if constexpr (std::is_same_v<std::decay_t<decltype(capability)>,
RocmComputeCapability>) {
if (auto* capability = gpu_compute_capability_.rocm_compute_capability()) {
// MI100 and MI200 has 16KB L1 cache per CU.
if (capability.gfx9_mi100() || capability.gfx9_mi200()) {
if (capability->gfx9_mi100() || capability->gfx9_mi200()) {
return 16 * 1024;
}
// MI300 has 32KB L1 cache per CU.
if (capability.gfx9_mi300_series()) {
if (capability->gfx9_mi300_series()) {
return 32 * 1024;
}
}
// Default return for other GPUs (e.g., RTX A6000).
return 2 * 1024;
},
gpu_compute_capability_);
}
constexpr int64_t dram_to_l2_transaction_size_bytes() const {
return std::visit(
[](const auto& capability) -> int {
if constexpr (std::is_same_v<std::decay_t<decltype(capability)>,
RocmComputeCapability>) {
if (auto* capability = gpu_compute_capability_.rocm_compute_capability()) {
// DRAM->L2 bus is 128 Byte width for MI300.
if (capability.gfx9_mi300_series()) {
if (capability->gfx9_mi300_series()) {
return 128;
}
}
@ -229,24 +248,17 @@ class DeviceDescription {
// (page 10).
// return 64 Bytes by default.
return 64;
},
gpu_compute_capability_);
}
constexpr int64_t memory_transactions_per_clock() const {
return std::visit(
[](const auto& capability) -> int {
if constexpr (std::is_same_v<std::decay_t<decltype(capability)>,
RocmComputeCapability>) {
if (auto* capability = gpu_compute_capability_.rocm_compute_capability()) {
// 16 works well on MI300.
if (capability.gfx9_mi300_series()) {
if (capability->gfx9_mi300_series()) {
return 16;
}
}
// Default return for other GPUs.
return 32;
},
gpu_compute_capability_);
}
GpuDeviceInfoProto ToGpuProto() const;