From a0cea517e770fa04f1f33bf6847c6852ca9abc06 Mon Sep 17 00:00:00 2001 From: Dmytro Dzhulgakov Date: Thu, 5 Oct 2023 17:41:03 +0000 Subject: [PATCH] Add 9.0a to cpp_extension supported compute archs (#110587) There's an extended compute capability 9.0a for Hopper that was introduced in Cuda 12.0: https://docs.nvidia.com/cuda/archive/12.0.0/cuda-compiler-driver-nvcc/index.html#gpu-feature-list E.g. Cutlass leverages it: https://github.com/NVIDIA/cutlass/blob/5f13dcad781284678edafa3b8d108120cfc6a6e4/python/cutlass/emit/pytorch.py#L684 This adds it to the list of permitted architectures to use in `cpp_extension` directly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110587 Approved by: https://github.com/ezyang --- .../upstream/FindCUDA/select_compute_arch.cmake | 5 +++++ torch/utils/cpp_extension.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake b/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake index d917738a5c7..769ddacfcf2 100644 --- a/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake +++ b/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake @@ -72,6 +72,11 @@ if(NOT CUDA_VERSION VERSION_LESS "11.8") endif() endif() +if(NOT CUDA_VERSION VERSION_LESS "12.0") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0a") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "9.0a") +endif() + ################################################################################################ # A function for automatic detection of GPUs installed (if autodetection is enabled) # Usage: diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 5fe55256cfe..6ec39b6817d 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -1949,7 +1949,7 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]: ]) supported_arches = ['3.5', '3.7', '5.0', '5.2', '5.3', '6.0', '6.1', '6.2', - '7.0', '7.2', '7.5', '8.0', '8.6', '8.7', '8.9', '9.0'] + '7.0', '7.2', '7.5', '8.0', '8.6', '8.7', '8.9', '9.0', '9.0a'] valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches] # The default is sm_30 for CUDA 9.x and 10.x @@ -1992,7 +1992,7 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]: if arch not in valid_arch_strings: raise ValueError(f"Unknown CUDA arch ({arch}) or GPU not supported") else: - num = arch[0] + arch[2] + num = arch[0] + arch[2:].split("+")[0] flags.append(f'-gencode=arch=compute_{num},code=sm_{num}') if arch.endswith('+PTX'): flags.append(f'-gencode=arch=compute_{num},code=compute_{num}')