Add 9.0a to cpp_extension supported compute archs (#110587)

There's an extended compute capability 9.0a for Hopper that was introduced in Cuda 12.0: https://docs.nvidia.com/cuda/archive/12.0.0/cuda-compiler-driver-nvcc/index.html#gpu-feature-list

E.g. Cutlass leverages it: 5f13dcad78/python/cutlass/emit/pytorch.py (L684)

This adds it to the list of permitted architectures to use in `cpp_extension` directly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110587
Approved by: https://github.com/ezyang
This commit is contained in:
Dmytro Dzhulgakov 2023-10-05 17:41:03 +00:00 committed by PyTorch MergeBot
parent c89d35adfe
commit a0cea517e7
2 changed files with 7 additions and 2 deletions

View File

@ -72,6 +72,11 @@ if(NOT CUDA_VERSION VERSION_LESS "11.8")
endif()
endif()
if(NOT CUDA_VERSION VERSION_LESS "12.0")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0a")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "9.0a")
endif()
################################################################################################
# A function for automatic detection of GPUs installed (if autodetection is enabled)
# Usage:

View File

@ -1949,7 +1949,7 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
])
supported_arches = ['3.5', '3.7', '5.0', '5.2', '5.3', '6.0', '6.1', '6.2',
'7.0', '7.2', '7.5', '8.0', '8.6', '8.7', '8.9', '9.0']
'7.0', '7.2', '7.5', '8.0', '8.6', '8.7', '8.9', '9.0', '9.0a']
valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches]
# The default is sm_30 for CUDA 9.x and 10.x
@ -1992,7 +1992,7 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
if arch not in valid_arch_strings:
raise ValueError(f"Unknown CUDA arch ({arch}) or GPU not supported")
else:
num = arch[0] + arch[2]
num = arch[0] + arch[2:].split("+")[0]
flags.append(f'-gencode=arch=compute_{num},code=sm_{num}')
if arch.endswith('+PTX'):
flags.append(f'-gencode=arch=compute_{num},code=compute_{num}')