mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
* Minor fixes for finding CUDNN * Minor fixes for comment * Fix lints * Fix naming conflicts * Fix import name
97 lines
3.4 KiB
Python
97 lines
3.4 KiB
Python
import os
|
|
import glob
|
|
|
|
from .env import IS_WINDOWS, IS_CONDA, CONDA_DIR, check_negative_env_flag, gather_paths, lib_paths_from_base
|
|
from .cuda import USE_CUDA, CUDA_HOME
|
|
|
|
|
|
USE_CUDNN = False
|
|
CUDNN_LIB_DIR = None
|
|
CUDNN_INCLUDE_DIR = None
|
|
CUDNN_LIBRARY = None
|
|
WITH_STATIC_CUDNN = os.getenv("USE_STATIC_CUDNN")
|
|
|
|
if USE_CUDA and not check_negative_env_flag('USE_CUDNN'):
|
|
lib_paths = list(filter(bool, [
|
|
os.getenv('CUDNN_LIB_DIR')
|
|
] + lib_paths_from_base(CUDA_HOME) + [
|
|
'/usr/lib/x86_64-linux-gnu/',
|
|
'/usr/lib/powerpc64le-linux-gnu/',
|
|
'/usr/lib/aarch64-linux-gnu/',
|
|
] + gather_paths([
|
|
'LIBRARY_PATH',
|
|
]) + gather_paths([
|
|
'LD_LIBRARY_PATH',
|
|
])))
|
|
include_paths = list(filter(bool, [
|
|
os.getenv('CUDNN_INCLUDE_DIR'),
|
|
os.path.join(CUDA_HOME, 'include'),
|
|
'/usr/include/',
|
|
] + gather_paths([
|
|
'CPATH',
|
|
'C_INCLUDE_PATH',
|
|
'CPLUS_INCLUDE_PATH',
|
|
])))
|
|
# Add CUDA related dirs to candidate list
|
|
if IS_CONDA:
|
|
lib_paths.append(os.path.join(CONDA_DIR, 'lib'))
|
|
include_paths.append(os.path.join(CONDA_DIR, 'include'))
|
|
for path in include_paths:
|
|
if path is None or not os.path.exists(path):
|
|
continue
|
|
include_file_path = os.path.join(path, 'cudnn.h')
|
|
CUDNN_INCLUDE_VERSION = None
|
|
if os.path.exists(include_file_path):
|
|
CUDNN_INCLUDE_DIR = path
|
|
with open(include_file_path) as f:
|
|
for line in f:
|
|
if "#define CUDNN_MAJOR" in line:
|
|
CUDNN_INCLUDE_VERSION = int(line.split()[-1])
|
|
break
|
|
if CUDNN_INCLUDE_VERSION is None:
|
|
raise AssertionError("Could not find #define CUDNN_MAJOR in " + include_file_path)
|
|
break
|
|
|
|
if CUDNN_INCLUDE_VERSION is None:
|
|
pass
|
|
|
|
# Check for standalone cuDNN libraries
|
|
if CUDNN_INCLUDE_DIR is not None:
|
|
cudnn_path = os.path.join(os.path.dirname(CUDNN_INCLUDE_DIR))
|
|
cudnn_lib_paths = lib_paths_from_base(cudnn_path)
|
|
lib_paths.extend(cudnn_lib_paths)
|
|
|
|
for path in lib_paths:
|
|
if path is None or not os.path.exists(path):
|
|
continue
|
|
if IS_WINDOWS:
|
|
library = os.path.join(path, 'cudnn.lib')
|
|
if os.path.exists(library):
|
|
CUDNN_LIBRARY = library
|
|
CUDNN_LIB_DIR = path
|
|
break
|
|
else:
|
|
if WITH_STATIC_CUDNN is not None:
|
|
search_name = 'libcudnn_static.a'
|
|
else:
|
|
search_name = 'libcudnn*' + str(CUDNN_INCLUDE_VERSION) + "*"
|
|
libraries = sorted(glob.glob(os.path.join(path, search_name)))
|
|
if libraries:
|
|
CUDNN_LIBRARY = libraries[0]
|
|
CUDNN_LIB_DIR = path
|
|
break
|
|
# Specifying the library directly will overwrite the lib directory
|
|
library = os.getenv('CUDNN_LIBRARY')
|
|
if library is not None and os.path.exists(library):
|
|
CUDNN_LIBRARY = library
|
|
CUDNN_LIB_DIR = os.path.dirname(CUDNN_LIBRARY)
|
|
|
|
if not all([CUDNN_LIBRARY, CUDNN_LIB_DIR, CUDNN_INCLUDE_DIR]):
|
|
CUDNN_LIBRARY = CUDNN_LIB_DIR = CUDNN_INCLUDE_DIR = None
|
|
else:
|
|
real_cudnn_library = os.path.realpath(CUDNN_LIBRARY)
|
|
real_cudnn_lib_dir = os.path.realpath(CUDNN_LIB_DIR)
|
|
assert os.path.dirname(real_cudnn_library) == real_cudnn_lib_dir, (
|
|
'cudnn library and lib_dir must agree')
|
|
USE_CUDNN = True
|