mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add 9.0a to cpp_extension supported compute archs (#110587)
There's an extended compute capability 9.0a for Hopper that was introduced in Cuda 12.0: https://docs.nvidia.com/cuda/archive/12.0.0/cuda-compiler-driver-nvcc/index.html#gpu-feature-list
E.g. Cutlass leverages it: 5f13dcad78/python/cutlass/emit/pytorch.py (L684)
This adds it to the list of permitted architectures to use in `cpp_extension` directly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110587
Approved by: https://github.com/ezyang
This commit is contained in:
parent
c89d35adfe
commit
a0cea517e7
|
|
@ -72,6 +72,11 @@ if(NOT CUDA_VERSION VERSION_LESS "11.8")
|
|||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT CUDA_VERSION VERSION_LESS "12.0")
|
||||
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0a")
|
||||
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "9.0a")
|
||||
endif()
|
||||
|
||||
################################################################################################
|
||||
# A function for automatic detection of GPUs installed (if autodetection is enabled)
|
||||
# Usage:
|
||||
|
|
|
|||
|
|
@ -1949,7 +1949,7 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
|
|||
])
|
||||
|
||||
supported_arches = ['3.5', '3.7', '5.0', '5.2', '5.3', '6.0', '6.1', '6.2',
|
||||
'7.0', '7.2', '7.5', '8.0', '8.6', '8.7', '8.9', '9.0']
|
||||
'7.0', '7.2', '7.5', '8.0', '8.6', '8.7', '8.9', '9.0', '9.0a']
|
||||
valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches]
|
||||
|
||||
# The default is sm_30 for CUDA 9.x and 10.x
|
||||
|
|
@ -1992,7 +1992,7 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
|
|||
if arch not in valid_arch_strings:
|
||||
raise ValueError(f"Unknown CUDA arch ({arch}) or GPU not supported")
|
||||
else:
|
||||
num = arch[0] + arch[2]
|
||||
num = arch[0] + arch[2:].split("+")[0]
|
||||
flags.append(f'-gencode=arch=compute_{num},code=sm_{num}')
|
||||
if arch.endswith('+PTX'):
|
||||
flags.append(f'-gencode=arch=compute_{num},code=compute_{num}')
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user