mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Raise runtime error in setup.py if cudnn version is not supported
This commit is contained in:
parent
1322f9a272
commit
8666be05f5
|
|
@ -159,7 +159,7 @@ Once you have [Anaconda](https://www.continuum.io/downloads) installed, here are
|
||||||
|
|
||||||
If you want to compile with CUDA support, install
|
If you want to compile with CUDA support, install
|
||||||
- [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 7.5 or above
|
- [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 7.5 or above
|
||||||
- [NVIDIA cuDNN](https://developer.nvidia.com/cudnn) v5.x or above
|
- [NVIDIA cuDNN](https://developer.nvidia.com/cudnn) v6.x or above
|
||||||
|
|
||||||
If you want to disable CUDA support, export environment variable `NO_CUDA=1`.
|
If you want to disable CUDA support, export environment variable `NO_CUDA=1`.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,20 +17,20 @@ def find_cudnn_version(cudnn_lib_dir):
|
||||||
candidate_names = [os.path.basename(c) for c in candidate_names]
|
candidate_names = [os.path.basename(c) for c in candidate_names]
|
||||||
|
|
||||||
# suppose version is MAJOR.MINOR.PATCH, all numbers
|
# suppose version is MAJOR.MINOR.PATCH, all numbers
|
||||||
version_regex = re.compile('[0-9]+\.[0-9]+\.[0-9]+')
|
version_regex = re.compile('\d+\.\d+\.\d+')
|
||||||
candidates = [c.group() for c in map(version_regex.search, candidate_names) if c]
|
candidates = [c.group() for c in map(version_regex.search, candidate_names) if c]
|
||||||
if len(candidates) > 0:
|
if len(candidates) > 0:
|
||||||
# normally only one will be retrieved, take the first result
|
# normally only one will be retrieved, take the first result
|
||||||
return candidates[0]
|
return candidates[0]
|
||||||
|
|
||||||
# if no candidates were found, try MAJOR.MINOR
|
# if no candidates were found, try MAJOR.MINOR
|
||||||
version_regex = re.compile('[0-9]+\.[0-9]+')
|
version_regex = re.compile('\d+\.\d+')
|
||||||
candidates = [c.group() for c in map(version_regex.search, candidate_names) if c]
|
candidates = [c.group() for c in map(version_regex.search, candidate_names) if c]
|
||||||
if len(candidates) > 0:
|
if len(candidates) > 0:
|
||||||
return candidates[0]
|
return candidates[0]
|
||||||
|
|
||||||
# if no candidates were found, try MAJOR
|
# if no candidates were found, try MAJOR
|
||||||
version_regex = re.compile('[0-9]+')
|
version_regex = re.compile('\d+')
|
||||||
candidates = [c.group() for c in map(version_regex.search, candidate_names) if c]
|
candidates = [c.group() for c in map(version_regex.search, candidate_names) if c]
|
||||||
if len(candidates) > 0:
|
if len(candidates) > 0:
|
||||||
return candidates[0]
|
return candidates[0]
|
||||||
|
|
@ -38,6 +38,20 @@ def find_cudnn_version(cudnn_lib_dir):
|
||||||
return 'unknown'
|
return 'unknown'
|
||||||
|
|
||||||
|
|
||||||
|
def check_cudnn_version(cudnn_version_string):
|
||||||
|
if cudnn_version_string is 'unknown':
|
||||||
|
return # Assume version is OK and let compilation continue
|
||||||
|
|
||||||
|
cudnn_min_version = 6
|
||||||
|
cudnn_version = int(cudnn_version_string.split('.')[0])
|
||||||
|
if cudnn_version < cudnn_min_version:
|
||||||
|
raise RuntimeError(
|
||||||
|
'CuDNN v%s found, but need at least CuDNN v%s. '
|
||||||
|
'You can get the latest version of CuDNN from '
|
||||||
|
'https://developer.nvidia.com/cudnn' %
|
||||||
|
(cudnn_version_string, cudnn_min_version))
|
||||||
|
|
||||||
|
|
||||||
is_conda = 'conda' in sys.version or 'Continuum' in sys.version
|
is_conda = 'conda' in sys.version or 'Continuum' in sys.version
|
||||||
conda_dir = os.path.join(os.path.dirname(sys.executable), '..')
|
conda_dir = os.path.join(os.path.dirname(sys.executable), '..')
|
||||||
|
|
||||||
|
|
@ -86,4 +100,5 @@ if WITH_CUDA and not check_env_flag('NO_CUDNN'):
|
||||||
CUDNN_LIB_DIR = CUDNN_INCLUDE_DIR = None
|
CUDNN_LIB_DIR = CUDNN_INCLUDE_DIR = None
|
||||||
else:
|
else:
|
||||||
CUDNN_VERSION = find_cudnn_version(CUDNN_LIB_DIR)
|
CUDNN_VERSION = find_cudnn_version(CUDNN_LIB_DIR)
|
||||||
|
check_cudnn_version(CUDNN_VERSION)
|
||||||
WITH_CUDNN = True
|
WITH_CUDNN = True
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user