mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
First step to introduce GpuComputeCapability custom class instead of std::variant
PiperOrigin-RevId: 820940828
This commit is contained in:
parent
4d358b2bac
commit
4a42fca868
|
|
@ -163,8 +163,10 @@ class GpuKernelToBlobPass
|
||||||
target->Options.AllowFPOpFusion =
|
target->Options.AllowFPOpFusion =
|
||||||
llvm::FPOpFusion::FPOpFusionMode::Fast;
|
llvm::FPOpFusion::FPOpFusionMode::Fast;
|
||||||
};
|
};
|
||||||
TF_ASSIGN_OR_RETURN(std::string ptx, xla::gpu::nvptx::CompileToPtx(
|
TF_ASSIGN_OR_RETURN(
|
||||||
llvm_module_copy.get(), cc,
|
std::string ptx,
|
||||||
|
xla::gpu::nvptx::CompileToPtx(
|
||||||
|
llvm_module_copy.get(), stream_executor::GpuComputeCapability(cc),
|
||||||
options, enable_fusion));
|
options, enable_fusion));
|
||||||
if (print_ptx_) {
|
if (print_ptx_) {
|
||||||
llvm::dbgs() << "Generated PTX code for module '"
|
llvm::dbgs() << "Generated PTX code for module '"
|
||||||
|
|
|
||||||
|
|
@ -68,12 +68,12 @@ se::CudaComputeCapability GetComputeCapability() {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsRocm() {
|
bool IsRocm() {
|
||||||
return std::holds_alternative<se::RocmComputeCapability>(
|
return se::GPUMachineManager()
|
||||||
se::GPUMachineManager()
|
|
||||||
->ExecutorForDevice(0)
|
->ExecutorForDevice(0)
|
||||||
.value()
|
.value()
|
||||||
->GetDeviceDescription()
|
->GetDeviceDescription()
|
||||||
.gpu_compute_capability());
|
.gpu_compute_capability()
|
||||||
|
.IsRocm();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExpectErrorMessageSubstr(const Status& s, StringPiece substr) {
|
void ExpectErrorMessageSubstr(const Status& s, StringPiece substr) {
|
||||||
|
|
|
||||||
|
|
@ -516,7 +516,7 @@ struct LaunchFusedMatMulOp<GPUDevice, T> {
|
||||||
|
|
||||||
const auto& cc =
|
const auto& cc =
|
||||||
stream->parent()->GetDeviceDescription().gpu_compute_capability();
|
stream->parent()->GetDeviceDescription().gpu_compute_capability();
|
||||||
if (auto* procm = std::get_if<se::RocmComputeCapability>(&cc)) {
|
if (auto* procm = cc.rocm_compute_capability()) {
|
||||||
use_cudnn = !procm->gfx9_mi200_or_later();
|
use_cudnn = !procm->gfx9_mi200_or_later();
|
||||||
}
|
}
|
||||||
BlasScratchAllocator scratch_allocator(context);
|
BlasScratchAllocator scratch_allocator(context);
|
||||||
|
|
|
||||||
|
|
@ -604,7 +604,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||||
|
|
||||||
const auto& cc =
|
const auto& cc =
|
||||||
stream->parent()->GetDeviceDescription().gpu_compute_capability();
|
stream->parent()->GetDeviceDescription().gpu_compute_capability();
|
||||||
if (auto* procm = std::get_if<se::RocmComputeCapability>(&cc)) {
|
if (auto* procm = cc.rocm_compute_capability()) {
|
||||||
bCublasLtSupport = procm->gfx9_mi200_or_later();
|
bCublasLtSupport = procm->gfx9_mi200_or_later();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2000,8 +2000,7 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
|
||||||
mlir::MLIRContext& mlir_context, bool is_xla_fusion, bool emit_kernel) {
|
mlir::MLIRContext& mlir_context, bool is_xla_fusion, bool emit_kernel) {
|
||||||
const auto& gpu_cc = device_info.gpu_compute_capability();
|
const auto& gpu_cc = device_info.gpu_compute_capability();
|
||||||
TF_RETURN_IF_ERROR(CheckAtLeastAmpere(gpu_cc));
|
TF_RETURN_IF_ERROR(CheckAtLeastAmpere(gpu_cc));
|
||||||
std::string arch_name =
|
std::string arch_name = gpu_cc.ToString();
|
||||||
std::visit([](auto& cc) { return cc.ToString(); }, gpu_cc);
|
|
||||||
|
|
||||||
const HloModuleConfig& hlo_config = hlo_module.config();
|
const HloModuleConfig& hlo_config = hlo_module.config();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -68,38 +68,25 @@ bool IsTritonSupportedDotOutputType(
|
||||||
case F32:
|
case F32:
|
||||||
return true;
|
return true;
|
||||||
case F8E5M2:
|
case F8E5M2:
|
||||||
return std::visit(
|
if (auto ptr = gpu_version.cuda_compute_capability()) {
|
||||||
absl::Overload(
|
return ptr->IsAtLeastAmpere();
|
||||||
[](const se::CudaComputeCapability& cc) {
|
}
|
||||||
return cc.IsAtLeastAmpere();
|
return false;
|
||||||
},
|
|
||||||
[](const se::RocmComputeCapability& cc) { return false; }),
|
|
||||||
gpu_version);
|
|
||||||
|
|
||||||
case F8E4M3FN:
|
case F8E4M3FN:
|
||||||
return std::visit(
|
if (auto ptr = gpu_version.cuda_compute_capability()) {
|
||||||
absl::Overload(
|
return ptr->IsAtLeastHopper();
|
||||||
[](const se::CudaComputeCapability& cc) {
|
}
|
||||||
return cc.IsAtLeastHopper();
|
return false;
|
||||||
},
|
|
||||||
[](const se::RocmComputeCapability& cc) { return false; }),
|
|
||||||
gpu_version);
|
|
||||||
case BF16:
|
case BF16:
|
||||||
return std::visit(
|
if (auto ptr = gpu_version.rocm_compute_capability()) {
|
||||||
absl::Overload(
|
return ptr->has_bf16_dtype_support();
|
||||||
[](const se::CudaComputeCapability& cc) { return true; },
|
}
|
||||||
[](const se::RocmComputeCapability& cc) {
|
return true;
|
||||||
return cc.has_bf16_dtype_support();
|
|
||||||
}),
|
|
||||||
gpu_version);
|
|
||||||
case S32:
|
case S32:
|
||||||
return std::visit(
|
if (auto ptr = gpu_version.cuda_compute_capability()) {
|
||||||
absl::Overload(
|
return ptr->IsAtLeastAmpere();
|
||||||
[](const se::CudaComputeCapability& cc) {
|
}
|
||||||
return cc.IsAtLeastAmpere();
|
return false;
|
||||||
},
|
|
||||||
[](const se::RocmComputeCapability& cc) { return false; }),
|
|
||||||
gpu_version);
|
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -100,19 +100,16 @@ CommandBufferConfig GetCommandBufferConfig(
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check if CUDA/ROCM driver supports required features.
|
// Check if CUDA/ROCM driver supports required features.
|
||||||
auto erase_cuda = [&](const se::CudaComputeCapability& cuda_comp) {
|
if (device_info.gpu_compute_capability().IsCuda()) {
|
||||||
if (std::min(device_info.runtime_version(), device_info.driver_version()) <
|
if (std::min(device_info.runtime_version(), device_info.driver_version()) <
|
||||||
se::SemanticVersion{12, 3, 0}) {
|
se::SemanticVersion{12, 3, 0}) {
|
||||||
erase(kRequireTracing); // cuStreamBeginCaptureToGraph
|
erase(kRequireTracing); // cuStreamBeginCaptureToGraph
|
||||||
erase(kRequireConditionals); // on-device control flow
|
erase(kRequireConditionals); // on-device control flow
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
auto erase_rocm = [&](const se::RocmComputeCapability& rocm_comp) {
|
if (device_info.gpu_compute_capability().IsRocm()) {
|
||||||
erase(kRequireConditionals); // on-device control flow
|
erase(kRequireConditionals); // on-device control flow
|
||||||
};
|
}
|
||||||
|
|
||||||
std::visit(absl::Overload(erase_cuda, erase_rocm),
|
|
||||||
device_info.gpu_compute_capability());
|
|
||||||
|
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
1
third_party/xla/xla/service/gpu/BUILD
vendored
1
third_party/xla/xla/service/gpu/BUILD
vendored
|
|
@ -1388,6 +1388,7 @@ cc_library(
|
||||||
"//xla/hlo/ir:hlo",
|
"//xla/hlo/ir:hlo",
|
||||||
"//xla/stream_executor:device_description",
|
"//xla/stream_executor:device_description",
|
||||||
"//xla/stream_executor/cuda:cuda_compute_capability",
|
"//xla/stream_executor/cuda:cuda_compute_capability",
|
||||||
|
"//xla/stream_executor/rocm:rocm_compute_capability",
|
||||||
"@com_google_absl//absl/functional:overload",
|
"@com_google_absl//absl/functional:overload",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -146,16 +146,14 @@ CompileModuleResults InitializeResults(const HloModule* hlo_module,
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string GetDumpName(const se::DeviceDescription& device_desc) {
|
std::string GetDumpName(const se::DeviceDescription& device_desc) {
|
||||||
struct GetCcStr {
|
std::string prefix;
|
||||||
std::string operator()(const se::CudaComputeCapability& cc) const {
|
if (auto* cc =
|
||||||
return absl::StrCat("sm_", cc.ToString());
|
device_desc.gpu_compute_capability().cuda_compute_capability()) {
|
||||||
|
prefix = absl::StrCat("sm_", cc->ToString());
|
||||||
|
} else if (auto* cc = device_desc.gpu_compute_capability()
|
||||||
|
.rocm_compute_capability()) {
|
||||||
|
prefix = cc->gfx_version();
|
||||||
}
|
}
|
||||||
std::string operator()(const se::RocmComputeCapability& cc) const {
|
|
||||||
return cc.gfx_version();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
std::string prefix =
|
|
||||||
std::visit(GetCcStr(), device_desc.gpu_compute_capability());
|
|
||||||
return absl::StrCat(prefix, "_gpu_", kAfterOptimizationsDumpName);
|
return absl::StrCat(prefix, "_gpu_", kAfterOptimizationsDumpName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||||
#include "xla/shape.h"
|
#include "xla/shape.h"
|
||||||
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
|
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
|
||||||
#include "xla/stream_executor/device_description.h"
|
#include "xla/stream_executor/device_description.h"
|
||||||
|
#include "xla/stream_executor/rocm/rocm_compute_capability.h"
|
||||||
#include "xla/util.h"
|
#include "xla/util.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
@ -33,26 +34,24 @@ namespace {
|
||||||
|
|
||||||
bool DimensionRequiresPadding(const int64_t size, const PrimitiveType data_type,
|
bool DimensionRequiresPadding(const int64_t size, const PrimitiveType data_type,
|
||||||
const se::GpuComputeCapability& gpu_cc) {
|
const se::GpuComputeCapability& gpu_cc) {
|
||||||
return std::visit(
|
if (const se::CudaComputeCapability* cc = gpu_cc.cuda_compute_capability()) {
|
||||||
absl::Overload(
|
|
||||||
[&](const se::CudaComputeCapability& cc) {
|
|
||||||
for (const auto& req : CublasPaddingRequirements) {
|
for (const auto& req : CublasPaddingRequirements) {
|
||||||
if (cc.SupportsAllFeaturesOf(req.min_compute_capability) &&
|
if (cc->SupportsAllFeaturesOf(req.min_compute_capability) &&
|
||||||
data_type == req.data_type && size % req.multiple_of != 0) {
|
data_type == req.data_type && size % req.multiple_of != 0) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
},
|
}
|
||||||
[&](const se::RocmComputeCapability& cc) {
|
if (const se::RocmComputeCapability* cc = gpu_cc.rocm_compute_capability()) {
|
||||||
for (const auto& req : HipblasPaddingRequirements) {
|
for (const auto& req : HipblasPaddingRequirements) {
|
||||||
if (data_type == req.data_type && size % req.multiple_of != 0) {
|
if (data_type == req.data_type && size % req.multiple_of != 0) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}),
|
}
|
||||||
gpu_cc);
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ShapeRequiresPadding(const Shape& shape, int batch_dimensions_size,
|
bool ShapeRequiresPadding(const Shape& shape, int batch_dimensions_size,
|
||||||
|
|
|
||||||
|
|
@ -75,15 +75,9 @@ ENTRY e {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FloatSupportTestWithTriton, MixedTypeDotWithBF16IsNotUpcasted) {
|
TEST_F(FloatSupportTestWithTriton, MixedTypeDotWithBF16IsNotUpcasted) {
|
||||||
bool skip_test =
|
if (GetGpuComputeCapability().IsRocm() ||
|
||||||
std::visit(absl::Overload(
|
!GetGpuComputeCapability().cuda_compute_capability()->IsAtLeast(
|
||||||
[](const se::CudaComputeCapability& cc) {
|
se::CudaComputeCapability::kAmpere)) {
|
||||||
return !cc.IsAtLeast(se::CudaComputeCapability::kAmpere);
|
|
||||||
},
|
|
||||||
[](const se::RocmComputeCapability&) { return true; }),
|
|
||||||
GetGpuComputeCapability());
|
|
||||||
|
|
||||||
if (skip_test) {
|
|
||||||
GTEST_SKIP() << "Not supported on this GPU architecture";
|
GTEST_SKIP() << "Not supported on this GPU architecture";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -116,24 +116,21 @@ class AnalyticalLatencyHidingSchedulerTest : public GpuCodegenTest {
|
||||||
|
|
||||||
TEST_F(AnalyticalLatencyHidingSchedulerTest, TestAnalyticalLatencyEstimator) {
|
TEST_F(AnalyticalLatencyHidingSchedulerTest, TestAnalyticalLatencyEstimator) {
|
||||||
auto gpu_compute_capability = GetGpuComputeCapability();
|
auto gpu_compute_capability = GetGpuComputeCapability();
|
||||||
auto visitor = [](const auto& c) {
|
if (gpu_compute_capability.IsRocm()) {
|
||||||
using cc = std::remove_const_t<std::remove_reference_t<decltype(c)>>;
|
|
||||||
if constexpr (std::is_same_v<stream_executor::CudaComputeCapability, cc>) {
|
|
||||||
if (!c.IsAtLeast(se::CudaComputeCapability::kPascal)) {
|
|
||||||
GTEST_SKIP() << "This test is for Pascal+ GPUs.";
|
GTEST_SKIP() << "This test is for Pascal+ GPUs.";
|
||||||
}
|
}
|
||||||
if (c.major == 12 && c.minor == 1) {
|
|
||||||
|
auto* c = gpu_compute_capability.cuda_compute_capability();
|
||||||
|
if (!c->IsAtLeast(se::CudaComputeCapability::kPascal)) {
|
||||||
|
GTEST_SKIP() << "This test is for Pascal+ GPUs.";
|
||||||
|
}
|
||||||
|
if (c->major == 12 && c->minor == 1) {
|
||||||
// Skip this test for Spark. Because of the AllReduce, the test uses
|
// Skip this test for Spark. Because of the AllReduce, the test uses
|
||||||
// gpu_collective_performance_model, which only makes sense in a
|
// gpu_collective_performance_model, which only makes sense in a
|
||||||
// datacenter network setting.
|
// datacenter network setting.
|
||||||
GTEST_SKIP() << "This test is for datacenter GPUs.";
|
GTEST_SKIP() << "This test is for datacenter GPUs.";
|
||||||
}
|
}
|
||||||
} else if (!std::is_same_v<stream_executor::RocmComputeCapability, cc>) {
|
|
||||||
GTEST_SKIP() << "This test is for Pascal+ GPUs.";
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
std::visit(visitor, gpu_compute_capability);
|
|
||||||
const se::DeviceDescription dev_info =
|
const se::DeviceDescription dev_info =
|
||||||
backend().default_stream_executor()->GetDeviceDescription();
|
backend().default_stream_executor()->GetDeviceDescription();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -440,11 +440,15 @@ GpuPerformanceWithCollectiveModel::ComputeAllreduceTime(
|
||||||
const se::DeviceDescription& gpu_device_info) {
|
const se::DeviceDescription& gpu_device_info) {
|
||||||
// We use nccl group call to launch multiple allreduces so launch overhead
|
// We use nccl group call to launch multiple allreduces so launch overhead
|
||||||
// only occurs once.
|
// only occurs once.
|
||||||
const auto visitor = [&](const auto& cc) {
|
if (auto ptr =
|
||||||
|
gpu_device_info.gpu_compute_capability().cuda_compute_capability()) {
|
||||||
return ComputeAllreduceTimeImpl(instr, cost_analysis, gpu_device_info,
|
return ComputeAllreduceTimeImpl(instr, cost_analysis, gpu_device_info,
|
||||||
CreateSettings(cc));
|
CreateSettings(*ptr));
|
||||||
};
|
}
|
||||||
return std::visit(visitor, gpu_device_info.gpu_compute_capability());
|
return ComputeAllreduceTimeImpl(
|
||||||
|
instr, cost_analysis, gpu_device_info,
|
||||||
|
CreateSettings(
|
||||||
|
*gpu_device_info.gpu_compute_capability().rocm_compute_capability()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static*/ absl::Duration
|
/*static*/ absl::Duration
|
||||||
|
|
|
||||||
|
|
@ -378,6 +378,7 @@ cc_library(
|
||||||
"//xla/stream_executor:device_description",
|
"//xla/stream_executor:device_description",
|
||||||
"//xla/stream_executor:semantic_version",
|
"//xla/stream_executor:semantic_version",
|
||||||
"//xla/stream_executor/cuda:cuda_compute_capability",
|
"//xla/stream_executor/cuda:cuda_compute_capability",
|
||||||
|
"//xla/stream_executor/rocm:rocm_compute_capability",
|
||||||
"//xla/tsl/platform:errors",
|
"//xla/tsl/platform:errors",
|
||||||
"//xla/tsl/platform:statusor",
|
"//xla/tsl/platform:statusor",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,7 @@ limitations under the License.
|
||||||
#include "xla/shape_util.h"
|
#include "xla/shape_util.h"
|
||||||
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
|
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
|
||||||
#include "xla/stream_executor/device_description.h"
|
#include "xla/stream_executor/device_description.h"
|
||||||
|
#include "xla/stream_executor/rocm/rocm_compute_capability.h"
|
||||||
#include "xla/stream_executor/semantic_version.h"
|
#include "xla/stream_executor/semantic_version.h"
|
||||||
#include "xla/tsl/platform/errors.h"
|
#include "xla/tsl/platform/errors.h"
|
||||||
#include "xla/tsl/platform/statusor.h"
|
#include "xla/tsl/platform/statusor.h"
|
||||||
|
|
@ -843,20 +844,19 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check if CUDA/ROCM driver supports required features.
|
// Check if CUDA/ROCM driver supports required features.
|
||||||
auto erase_cuda = [&](const se::CudaComputeCapability& cuda_comp) {
|
if (auto* cuda_comp = device_description_.gpu_compute_capability()
|
||||||
|
.cuda_compute_capability()) {
|
||||||
if (std::min(device_description_.runtime_version(),
|
if (std::min(device_description_.runtime_version(),
|
||||||
device_description_.driver_version()) <
|
device_description_.driver_version()) <
|
||||||
se::SemanticVersion{12, 3, 0}) {
|
se::SemanticVersion{12, 3, 0}) {
|
||||||
erase(kRequireTracing); // cuStreamBeginCaptureToGraph
|
erase(kRequireTracing); // cuStreamBeginCaptureToGraph
|
||||||
erase(kRequireConditionals); // on-device control flow
|
erase(kRequireConditionals); // on-device control flow
|
||||||
}
|
}
|
||||||
};
|
} else if (const se::RocmComputeCapability* rocm_comp =
|
||||||
auto erase_rocm = [&](const se::RocmComputeCapability& rocm_comp) {
|
device_description_.gpu_compute_capability()
|
||||||
|
.rocm_compute_capability()) {
|
||||||
erase(kRequireConditionals); // on-device control flow
|
erase(kRequireConditionals); // on-device control flow
|
||||||
};
|
}
|
||||||
|
|
||||||
std::visit(absl::Overload(erase_cuda, erase_rocm),
|
|
||||||
device_description_.gpu_compute_capability());
|
|
||||||
|
|
||||||
auto order = module->MakeComputationPostOrder();
|
auto order = module->MakeComputationPostOrder();
|
||||||
std::reverse(order.begin(), order.end());
|
std::reverse(order.begin(), order.end());
|
||||||
|
|
|
||||||
|
|
@ -36,8 +36,35 @@ limitations under the License.
|
||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
|
|
||||||
using GpuComputeCapability =
|
class GpuComputeCapability
|
||||||
std::variant<CudaComputeCapability, RocmComputeCapability>;
|
: public std::variant<CudaComputeCapability, RocmComputeCapability> {
|
||||||
|
public:
|
||||||
|
using std::variant<CudaComputeCapability, RocmComputeCapability>::variant;
|
||||||
|
using std::variant<CudaComputeCapability, RocmComputeCapability>::operator=;
|
||||||
|
|
||||||
|
bool IsCuda() const {
|
||||||
|
return std::holds_alternative<CudaComputeCapability>(*this);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsRocm() const {
|
||||||
|
return std::holds_alternative<RocmComputeCapability>(*this);
|
||||||
|
}
|
||||||
|
|
||||||
|
const CudaComputeCapability* cuda_compute_capability() const {
|
||||||
|
return std::get_if<CudaComputeCapability>(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
const RocmComputeCapability* rocm_compute_capability() const {
|
||||||
|
return std::get_if<RocmComputeCapability>(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ToString() const {
|
||||||
|
if (auto ptr = cuda_compute_capability()) {
|
||||||
|
return ptr->ToString();
|
||||||
|
}
|
||||||
|
return rocm_compute_capability()->ToString();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Data that describes the execution target of the StreamExecutor, in terms of
|
// Data that describes the execution target of the StreamExecutor, in terms of
|
||||||
// important logical parameters. These include dimensionality limits and
|
// important logical parameters. These include dimensionality limits and
|
||||||
|
|
@ -193,32 +220,24 @@ class DeviceDescription {
|
||||||
// also we do not count what occupies cache, but rather claim that what is
|
// also we do not count what occupies cache, but rather claim that what is
|
||||||
// much smaller than the cache size will likely stay in it.
|
// much smaller than the cache size will likely stay in it.
|
||||||
constexpr int64_t l1_cache_size_per_SM() const {
|
constexpr int64_t l1_cache_size_per_SM() const {
|
||||||
return std::visit(
|
if (auto* capability = gpu_compute_capability_.rocm_compute_capability()) {
|
||||||
[](const auto& capability) -> int64_t {
|
|
||||||
if constexpr (std::is_same_v<std::decay_t<decltype(capability)>,
|
|
||||||
RocmComputeCapability>) {
|
|
||||||
// MI100 and MI200 has 16KB L1 cache per CU.
|
// MI100 and MI200 has 16KB L1 cache per CU.
|
||||||
if (capability.gfx9_mi100() || capability.gfx9_mi200()) {
|
if (capability->gfx9_mi100() || capability->gfx9_mi200()) {
|
||||||
return 16 * 1024;
|
return 16 * 1024;
|
||||||
}
|
}
|
||||||
// MI300 has 32KB L1 cache per CU.
|
// MI300 has 32KB L1 cache per CU.
|
||||||
if (capability.gfx9_mi300_series()) {
|
if (capability->gfx9_mi300_series()) {
|
||||||
return 32 * 1024;
|
return 32 * 1024;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Default return for other GPUs (e.g., RTX A6000).
|
// Default return for other GPUs (e.g., RTX A6000).
|
||||||
return 2 * 1024;
|
return 2 * 1024;
|
||||||
},
|
|
||||||
gpu_compute_capability_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int64_t dram_to_l2_transaction_size_bytes() const {
|
constexpr int64_t dram_to_l2_transaction_size_bytes() const {
|
||||||
return std::visit(
|
if (auto* capability = gpu_compute_capability_.rocm_compute_capability()) {
|
||||||
[](const auto& capability) -> int {
|
|
||||||
if constexpr (std::is_same_v<std::decay_t<decltype(capability)>,
|
|
||||||
RocmComputeCapability>) {
|
|
||||||
// DRAM->L2 bus is 128 Byte width for MI300.
|
// DRAM->L2 bus is 128 Byte width for MI300.
|
||||||
if (capability.gfx9_mi300_series()) {
|
if (capability->gfx9_mi300_series()) {
|
||||||
return 128;
|
return 128;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -229,24 +248,17 @@ class DeviceDescription {
|
||||||
// (page 10).
|
// (page 10).
|
||||||
// return 64 Bytes by default.
|
// return 64 Bytes by default.
|
||||||
return 64;
|
return 64;
|
||||||
},
|
|
||||||
gpu_compute_capability_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int64_t memory_transactions_per_clock() const {
|
constexpr int64_t memory_transactions_per_clock() const {
|
||||||
return std::visit(
|
if (auto* capability = gpu_compute_capability_.rocm_compute_capability()) {
|
||||||
[](const auto& capability) -> int {
|
|
||||||
if constexpr (std::is_same_v<std::decay_t<decltype(capability)>,
|
|
||||||
RocmComputeCapability>) {
|
|
||||||
// 16 works well on MI300.
|
// 16 works well on MI300.
|
||||||
if (capability.gfx9_mi300_series()) {
|
if (capability->gfx9_mi300_series()) {
|
||||||
return 16;
|
return 16;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Default return for other GPUs.
|
// Default return for other GPUs.
|
||||||
return 32;
|
return 32;
|
||||||
},
|
|
||||||
gpu_compute_capability_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
GpuDeviceInfoProto ToGpuProto() const;
|
GpuDeviceInfoProto ToGpuProto() const;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user