Raise runtime error in setup.py if cudnn version is not supported

This commit is contained in:
Richard Zou 2017-10-13 14:08:13 -07:00 committed by Soumith Chintala
parent 1322f9a272
commit 8666be05f5
2 changed files with 19 additions and 4 deletions

View File

@ -159,7 +159,7 @@ Once you have [Anaconda](https://www.continuum.io/downloads) installed, here are
If you want to compile with CUDA support, install
- [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 7.5 or above
- [NVIDIA cuDNN](https://developer.nvidia.com/cudnn) v5.x or above
- [NVIDIA cuDNN](https://developer.nvidia.com/cudnn) v6.x or above
If you want to disable CUDA support, export environment variable `NO_CUDA=1`.

View File

@ -17,20 +17,20 @@ def find_cudnn_version(cudnn_lib_dir):
candidate_names = [os.path.basename(c) for c in candidate_names]
# suppose version is MAJOR.MINOR.PATCH, all numbers
version_regex = re.compile('[0-9]+\.[0-9]+\.[0-9]+')
version_regex = re.compile('\d+\.\d+\.\d+')
candidates = [c.group() for c in map(version_regex.search, candidate_names) if c]
if len(candidates) > 0:
# normally only one will be retrieved, take the first result
return candidates[0]
# if no candidates were found, try MAJOR.MINOR
version_regex = re.compile('[0-9]+\.[0-9]+')
version_regex = re.compile('\d+\.\d+')
candidates = [c.group() for c in map(version_regex.search, candidate_names) if c]
if len(candidates) > 0:
return candidates[0]
# if no candidates were found, try MAJOR
version_regex = re.compile('[0-9]+')
version_regex = re.compile('\d+')
candidates = [c.group() for c in map(version_regex.search, candidate_names) if c]
if len(candidates) > 0:
return candidates[0]
@ -38,6 +38,20 @@ def find_cudnn_version(cudnn_lib_dir):
return 'unknown'
def check_cudnn_version(cudnn_version_string):
if cudnn_version_string is 'unknown':
return # Assume version is OK and let compilation continue
cudnn_min_version = 6
cudnn_version = int(cudnn_version_string.split('.')[0])
if cudnn_version < cudnn_min_version:
raise RuntimeError(
'CuDNN v%s found, but need at least CuDNN v%s. '
'You can get the latest version of CuDNN from '
'https://developer.nvidia.com/cudnn' %
(cudnn_version_string, cudnn_min_version))
is_conda = 'conda' in sys.version or 'Continuum' in sys.version
conda_dir = os.path.join(os.path.dirname(sys.executable), '..')
@ -86,4 +100,5 @@ if WITH_CUDA and not check_env_flag('NO_CUDNN'):
CUDNN_LIB_DIR = CUDNN_INCLUDE_DIR = None
else:
CUDNN_VERSION = find_cudnn_version(CUDNN_LIB_DIR)
check_cudnn_version(CUDNN_VERSION)
WITH_CUDNN = True