mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] gfx940 and gfx941 cleanup (#147394)
Removing gfx architectures not supported by ROCm. NOTE: For users wanting to build PyTorch for gfx archs that are *not* supported by the official wheels on download.pytorch.org, you can build PyTorch from source for your desired gfx arch [using the PYTORCH_ROCM_ARCH env var](https://github.com/pytorch/pytorch/blob/main/README.md#amd-rocm-support). Pull Request resolved: https://github.com/pytorch/pytorch/pull/147394 Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
This commit is contained in:
parent
c0ee62573a
commit
fd8ae1aa04
|
|
@ -330,7 +330,7 @@ at::BlasBackend Context::blasPreferredBackend() {
|
||||||
if (blas_preferred_backend == at::BlasBackend::Cublaslt) {
|
if (blas_preferred_backend == at::BlasBackend::Cublaslt) {
|
||||||
static const bool hipblaslt_unsupported = []() {
|
static const bool hipblaslt_unsupported = []() {
|
||||||
static const std::vector<std::string> archs = {
|
static const std::vector<std::string> archs = {
|
||||||
"gfx90a", "gfx940", "gfx941", "gfx942",
|
"gfx90a", "gfx942",
|
||||||
#if ROCM_VERSION >= 60300
|
#if ROCM_VERSION >= 60300
|
||||||
"gfx1100", "gfx1101"
|
"gfx1100", "gfx1101"
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -191,7 +191,7 @@ static bool isSupportedHipLtROCmArch(int index) {
|
||||||
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
|
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
|
||||||
std::string device_arch = prop->gcnArchName;
|
std::string device_arch = prop->gcnArchName;
|
||||||
static const std::vector<std::string> archs = {
|
static const std::vector<std::string> archs = {
|
||||||
"gfx90a", "gfx940", "gfx941", "gfx942",
|
"gfx90a", "gfx942",
|
||||||
#if ROCM_VERSION >= 60300
|
#if ROCM_VERSION >= 60300
|
||||||
"gfx1100", "gfx1101"
|
"gfx1100", "gfx1101"
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -862,7 +862,7 @@ static bool _scaled_mm_allowed_device() {
|
||||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
std::string device_arch = dprops->gcnArchName;
|
std::string device_arch = dprops->gcnArchName;
|
||||||
static const std::vector<std::string> archs = {"gfx940", "gfx941", "gfx942"};
|
static const std::vector<std::string> archs = {"gfx942"};
|
||||||
for (std::string arch : archs) {
|
for (std::string arch : archs) {
|
||||||
size_t substring = device_arch.find(arch);
|
size_t substring = device_arch.find(arch);
|
||||||
if (substring != std::string::npos) {
|
if (substring != std::string::npos) {
|
||||||
|
|
@ -879,7 +879,7 @@ static bool _scaled_mm_allowed_device() {
|
||||||
static bool _scaled_mm_is_fnuz() {
|
static bool _scaled_mm_is_fnuz() {
|
||||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
std::string device_arch = dprops->gcnArchName;
|
std::string device_arch = dprops->gcnArchName;
|
||||||
static const std::vector<std::string> archs = {"gfx940", "gfx941", "gfx942"};
|
static const std::vector<std::string> archs = {"gfx942"};
|
||||||
for (std::string arch : archs) {
|
for (std::string arch : archs) {
|
||||||
size_t substring = device_arch.find(arch);
|
size_t substring = device_arch.find(arch);
|
||||||
if (substring != std::string::npos) {
|
if (substring != std::string::npos) {
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@
|
||||||
#include <hip/hip_bf16.h>
|
#include <hip/hip_bf16.h>
|
||||||
|
|
||||||
__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) {
|
__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) {
|
||||||
#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && \
|
#if (defined(__gfx942__)) && \
|
||||||
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16)
|
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16)
|
||||||
typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2;
|
typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2;
|
||||||
static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw));
|
static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw));
|
||||||
|
|
@ -39,7 +39,7 @@ __device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* addre
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) {
|
__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) {
|
||||||
#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && \
|
#if (defined(__gfx942__)) && \
|
||||||
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16)
|
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16)
|
||||||
// The api expects an ext_vector_type of half
|
// The api expects an ext_vector_type of half
|
||||||
typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162;
|
typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162;
|
||||||
|
|
|
||||||
|
|
@ -137,7 +137,7 @@ using VecT = T __attribute__((ext_vector_type(Rank)));
|
||||||
static bool isCDNA2orLater(int index) {
|
static bool isCDNA2orLater(int index) {
|
||||||
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
|
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
|
||||||
std::string device_arch = prop->gcnArchName;
|
std::string device_arch = prop->gcnArchName;
|
||||||
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
|
static const std::vector<std::string> archs = {"gfx90a", "gfx942"};
|
||||||
for (std::string arch : archs) {
|
for (std::string arch : archs) {
|
||||||
size_t substring = device_arch.find(arch);
|
size_t substring = device_arch.find(arch);
|
||||||
if (substring != std::string::npos) {
|
if (substring != std::string::npos) {
|
||||||
|
|
@ -151,7 +151,7 @@ static bool isCDNA2orLater(int index) {
|
||||||
constexpr int32_t kWarpSize = 32;
|
constexpr int32_t kWarpSize = 32;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined (__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
#if defined (__gfx90a__) || defined(__gfx942__)
|
||||||
#define CDNA2_OR_LATER 1
|
#define CDNA2_OR_LATER 1
|
||||||
#else
|
#else
|
||||||
#define CDNA2_OR_LATER 0
|
#define CDNA2_OR_LATER 0
|
||||||
|
|
|
||||||
|
|
@ -1308,13 +1308,13 @@ class cuda:
|
||||||
|
|
||||||
|
|
||||||
class rocm:
|
class rocm:
|
||||||
# Offload arch list for device code compilation, e.g. ["gfx941", "gfx942"].
|
# Offload arch list for device code compilation, e.g. ["gfx90a", "gfx942"].
|
||||||
# If empty, the `native` arch is used
|
# If empty, the `native` arch is used
|
||||||
arch: list[str] = []
|
arch: list[str] = []
|
||||||
|
|
||||||
# Enable the CK backend for CDNA2 and CDNA3 only (for now)
|
# Enable the CK backend for CDNA2 and CDNA3 only (for now)
|
||||||
# Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors
|
# Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors
|
||||||
ck_supported_arch: list[str] = ["gfx90a", "gfx940", "gfx941", "gfx942"]
|
ck_supported_arch: list[str] = ["gfx90a", "gfx942"]
|
||||||
|
|
||||||
# Optimization level, use to balance compilation speed and runtime performance.
|
# Optimization level, use to balance compilation speed and runtime performance.
|
||||||
# The type will not necessarily be comprehensive and won't be enforced at runtime.
|
# The type will not necessarily be comprehensive and won't be enforced at runtime.
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ IS_SM89 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_ca
|
||||||
def CDNA2OrLater():
|
def CDNA2OrLater():
|
||||||
if TEST_WITH_ROCM:
|
if TEST_WITH_ROCM:
|
||||||
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
|
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
|
||||||
return any(arch in gcn_arch_name for arch in {"gfx90a", "gfx940", "gfx941", "gfx942"})
|
return any(arch in gcn_arch_name for arch in {"gfx90a", "gfx942"})
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def evaluate_gfx_arch_exact(matching_arch):
|
def evaluate_gfx_arch_exact(matching_arch):
|
||||||
|
|
|
||||||
|
|
@ -102,7 +102,7 @@ except ImportError:
|
||||||
has_pytest = False
|
has_pytest = False
|
||||||
|
|
||||||
|
|
||||||
MI300_ARCH = ("gfx940", "gfx941", "gfx942")
|
MI300_ARCH = ("gfx942",)
|
||||||
|
|
||||||
|
|
||||||
def freeze_rng_state(*args, **kwargs):
|
def freeze_rng_state(*args, **kwargs):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user