mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Remove NO_* and WITH_* across codebase, except in setup.py (#8555)
* remove legacy options from CMakeLists * codemod WITH_ to USE_ for WITH_CUDA, WITH_CUDNN, WITH_DISTRIBUTED, WITH_DISTRIBUTED_MW, WITH_GLOO_IBVERBS, WITH_NCCL, WITH_ROCM, WITH_NUMPY * cover SYSTEM_NCCL, MKLDNN, NNPACK, C10D, NINJA * removed NO_* variables and hotpatch them only in setup.py * fix lint
This commit is contained in:
parent
d7690742d5
commit
dc186cc9fe
|
|
@ -36,7 +36,7 @@ if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
|
|||
sudo chown -R jenkins:jenkins /usr/local
|
||||
rm -rf "$(dirname "${BASH_SOURCE[0]}")/../../../pytorch_amd/" || true
|
||||
python "$(dirname "${BASH_SOURCE[0]}")/../../tools/amd_build/build_pytorch_amd.py"
|
||||
WITH_ROCM=1 python setup.py install
|
||||
USE_ROCM=1 python setup.py install
|
||||
exit
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -138,31 +138,6 @@ cmake_dependent_option(
|
|||
option(USE_DISTRIBUTED "Use THD (distributed)" OFF)
|
||||
option(USE_DISTRIBUTED_MW "Use THD (distributed) master worker" OFF)
|
||||
|
||||
# Legacy options, which we will eventually remove
|
||||
cmake_dependent_option(
|
||||
WITH_CUDA "Legacy CUDA" ON
|
||||
"USE_CUDA" OFF)
|
||||
cmake_dependent_option(
|
||||
NO_PYTHON "Legacy Python" OFF
|
||||
"BUILD_PYTHON" ON)
|
||||
cmake_dependent_option(
|
||||
WITH_CUDNN "Legacy cuDNN" ON
|
||||
"USE_CUDNN" OFF)
|
||||
cmake_dependent_option(
|
||||
WITH_NCCL "Legacy NCCL" ON
|
||||
"USE_NCCL" OFF)
|
||||
cmake_dependent_option(
|
||||
NO_MKLDNN "Legacy no MKLDNN" OFF
|
||||
"USE_MKLDNN" ON)
|
||||
cmake_dependent_option(
|
||||
WITH_DISTRIBUTED "Legacy THD (distributed)" ON
|
||||
"USE_DISTRIBUTED" OFF)
|
||||
cmake_dependent_option(
|
||||
WITH_DISTRIBUTED_MW "Legacy THD (distributed) MW" ON
|
||||
"USE_DISTRIBUTED_MW" OFF)
|
||||
cmake_dependent_option(
|
||||
WITH_GLOO_IBVERBS "Legacy Gloo IB verbs for distributed support" ON
|
||||
"USE_GLOO_IBVERBS" OFF)
|
||||
if (USE_ATEN)
|
||||
set(BUILD_ATEN ${USE_ATEN})
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -14,14 +14,6 @@ else()
|
|||
"USE_CUDA" OFF)
|
||||
option(ATEN_NO_TEST "Do not build ATen test binaries" OFF)
|
||||
|
||||
# Legacy options, which we will eventually remove
|
||||
cmake_dependent_option(
|
||||
WITH_CUDNN "Legacy cuDNN" ON
|
||||
"USE_CUDNN" OFF)
|
||||
cmake_dependent_option(
|
||||
NO_MKLDNN "Legacy no MKLDNN" OFF
|
||||
"USE_MKLDNN" ON)
|
||||
|
||||
# Flag for shared dependencies
|
||||
set(BUILD_ATEN ON)
|
||||
endif()
|
||||
|
|
|
|||
151
setup.py
151
setup.py
|
|
@ -40,7 +40,7 @@
|
|||
# disables use of system-wide nccl (we will use our submoduled
|
||||
# copy in third_party/nccl)
|
||||
#
|
||||
# WITH_GLOO_IBVERBS
|
||||
# USE_GLOO_IBVERBS
|
||||
# toggle features related to distributed support
|
||||
#
|
||||
# PYTORCH_BUILD_VERSION
|
||||
|
|
@ -106,22 +106,39 @@ import json
|
|||
import glob
|
||||
import importlib
|
||||
|
||||
from tools.setup_helpers.env import check_env_flag
|
||||
from tools.setup_helpers.cuda import WITH_CUDA, CUDA_HOME, CUDA_VERSION
|
||||
from tools.setup_helpers.rocm import WITH_ROCM, ROCM_HOME, ROCM_VERSION
|
||||
from tools.setup_helpers.cudnn import (WITH_CUDNN, CUDNN_LIBRARY,
|
||||
from tools.setup_helpers.env import check_env_flag, check_negative_env_flag
|
||||
|
||||
# Before we run the setup_helpers, let's look for NO_* and WITH_*
|
||||
# variables and hotpatch the environment with the USE_* equivalent
|
||||
config_env_vars = ['CUDA', 'CUDNN', 'MKLDNN', 'NNPACK', 'DISTRIBUTED', 'DISTRIBUTED_MW',
|
||||
'SYSTEM_NCCL', 'GLOO_IBVERBS']
|
||||
|
||||
def hotpatch_var(var):
|
||||
if check_env_flag('NO_' + var):
|
||||
os.environ['USE_' + var] = '0'
|
||||
elif check_negative_env_flag('NO_' + var):
|
||||
os.environ['USE_' + var] = '1'
|
||||
elif check_env_flag('WITH_' + var):
|
||||
os.environ['USE_' + var] = '1'
|
||||
elif check_negative_env_flag('WITH_' + var):
|
||||
os.environ['USE_' + var] = '0'
|
||||
|
||||
list(map(hotpatch_var, config_env_vars))
|
||||
|
||||
from tools.setup_helpers.cuda import USE_CUDA, CUDA_HOME, CUDA_VERSION
|
||||
from tools.setup_helpers.rocm import USE_ROCM, ROCM_HOME, ROCM_VERSION
|
||||
from tools.setup_helpers.cudnn import (USE_CUDNN, CUDNN_LIBRARY,
|
||||
CUDNN_LIB_DIR, CUDNN_INCLUDE_DIR)
|
||||
from tools.setup_helpers.nccl import WITH_NCCL, WITH_SYSTEM_NCCL, NCCL_LIB_DIR, \
|
||||
from tools.setup_helpers.nccl import USE_NCCL, USE_SYSTEM_NCCL, NCCL_LIB_DIR, \
|
||||
NCCL_INCLUDE_DIR, NCCL_ROOT_DIR, NCCL_SYSTEM_LIB
|
||||
from tools.setup_helpers.mkldnn import (WITH_MKLDNN, MKLDNN_LIBRARY,
|
||||
from tools.setup_helpers.mkldnn import (USE_MKLDNN, MKLDNN_LIBRARY,
|
||||
MKLDNN_LIB_DIR, MKLDNN_INCLUDE_DIR)
|
||||
from tools.setup_helpers.nnpack import WITH_NNPACK
|
||||
from tools.setup_helpers.nnpack import USE_NNPACK
|
||||
from tools.setup_helpers.nvtoolext import NVTOOLEXT_HOME
|
||||
from tools.setup_helpers.generate_code import generate_code
|
||||
from tools.setup_helpers.ninja_builder import NinjaBuilder, ninja_build_ext
|
||||
from tools.setup_helpers.dist_check import WITH_DISTRIBUTED, \
|
||||
WITH_DISTRIBUTED_MW, WITH_GLOO_IBVERBS, WITH_C10D
|
||||
|
||||
from tools.setup_helpers.dist_check import USE_DISTRIBUTED, \
|
||||
USE_DISTRIBUTED_MW, USE_GLOO_IBVERBS, USE_C10D
|
||||
|
||||
################################################################################
|
||||
# Parameters parsed from environment
|
||||
|
|
@ -147,10 +164,10 @@ if not ONNX_NAMESPACE:
|
|||
# Ninja
|
||||
try:
|
||||
import ninja
|
||||
WITH_NINJA = True
|
||||
USE_NINJA = True
|
||||
ninja_global = NinjaBuilder('global')
|
||||
except ImportError:
|
||||
WITH_NINJA = False
|
||||
USE_NINJA = False
|
||||
ninja_global = None
|
||||
|
||||
# Constant known variables used throughout this file
|
||||
|
|
@ -178,7 +195,7 @@ class PytorchCommand(setuptools.Command):
|
|||
# Patches and workarounds
|
||||
################################################################################
|
||||
# Monkey-patch setuptools to compile in parallel
|
||||
if not WITH_NINJA:
|
||||
if not USE_NINJA:
|
||||
def parallelCCompile(self, sources, output_dir=None, macros=None,
|
||||
include_dirs=None, debug=0, extra_preargs=None,
|
||||
extra_postargs=None, depends=None):
|
||||
|
|
@ -287,34 +304,34 @@ def build_libs(libs):
|
|||
my_env["NUM_JOBS"] = str(NUM_JOBS)
|
||||
my_env["ONNX_NAMESPACE"] = ONNX_NAMESPACE
|
||||
if not IS_WINDOWS:
|
||||
if WITH_NINJA:
|
||||
if USE_NINJA:
|
||||
my_env["CMAKE_GENERATOR"] = '-GNinja'
|
||||
my_env["CMAKE_INSTALL"] = 'ninja install'
|
||||
else:
|
||||
my_env['CMAKE_GENERATOR'] = ''
|
||||
my_env['CMAKE_INSTALL'] = 'make install'
|
||||
if WITH_SYSTEM_NCCL:
|
||||
if USE_SYSTEM_NCCL:
|
||||
my_env["NCCL_ROOT_DIR"] = NCCL_ROOT_DIR
|
||||
if WITH_CUDA:
|
||||
if USE_CUDA:
|
||||
my_env["CUDA_BIN_PATH"] = CUDA_HOME
|
||||
build_libs_cmd += ['--with-cuda']
|
||||
if WITH_ROCM:
|
||||
build_libs_cmd += ['--with-rocm']
|
||||
if WITH_NNPACK:
|
||||
build_libs_cmd += ['--with-nnpack']
|
||||
if WITH_CUDNN:
|
||||
build_libs_cmd += ['--use-cuda']
|
||||
if USE_ROCM:
|
||||
build_libs_cmd += ['--use-rocm']
|
||||
if USE_NNPACK:
|
||||
build_libs_cmd += ['--use-nnpack']
|
||||
if USE_CUDNN:
|
||||
my_env["CUDNN_LIB_DIR"] = CUDNN_LIB_DIR
|
||||
my_env["CUDNN_LIBRARY"] = CUDNN_LIBRARY
|
||||
my_env["CUDNN_INCLUDE_DIR"] = CUDNN_INCLUDE_DIR
|
||||
if WITH_MKLDNN:
|
||||
if USE_MKLDNN:
|
||||
my_env["MKLDNN_LIB_DIR"] = MKLDNN_LIB_DIR
|
||||
my_env["MKLDNN_LIBRARY"] = MKLDNN_LIBRARY
|
||||
my_env["MKLDNN_INCLUDE_DIR"] = MKLDNN_INCLUDE_DIR
|
||||
build_libs_cmd += ['--with-mkldnn']
|
||||
if WITH_GLOO_IBVERBS:
|
||||
build_libs_cmd += ['--with-gloo-ibverbs']
|
||||
if WITH_DISTRIBUTED_MW:
|
||||
build_libs_cmd += ['--with-distributed-mw']
|
||||
build_libs_cmd += ['--use-mkldnn']
|
||||
if USE_GLOO_IBVERBS:
|
||||
build_libs_cmd += ['--use-gloo-ibverbs']
|
||||
if USE_DISTRIBUTED_MW:
|
||||
build_libs_cmd += ['--use-distributed-mw']
|
||||
|
||||
if FULL_CAFFE2:
|
||||
build_libs_cmd += ['--full-caffe2']
|
||||
|
|
@ -345,18 +362,18 @@ class build_deps(PytorchCommand):
|
|||
check_pydep('typing', 'typing')
|
||||
|
||||
libs = []
|
||||
if WITH_NCCL and not WITH_SYSTEM_NCCL:
|
||||
if USE_NCCL and not USE_SYSTEM_NCCL:
|
||||
libs += ['nccl']
|
||||
libs += ['caffe2', 'nanopb']
|
||||
if IS_WINDOWS:
|
||||
libs += ['libshm_windows']
|
||||
else:
|
||||
libs += ['libshm']
|
||||
if WITH_DISTRIBUTED:
|
||||
if USE_DISTRIBUTED:
|
||||
if sys.platform.startswith('linux'):
|
||||
libs += ['gloo']
|
||||
libs += ['THD']
|
||||
if WITH_C10D:
|
||||
if USE_C10D:
|
||||
libs += ['c10d']
|
||||
build_libs(libs)
|
||||
|
||||
|
|
@ -427,7 +444,7 @@ class develop(setuptools.command.develop.develop):
|
|||
for entry in load(f)]
|
||||
with open('compile_commands.json', 'w') as f:
|
||||
json.dump(all_commands, f, indent=2)
|
||||
if not WITH_NINJA:
|
||||
if not USE_NINJA:
|
||||
print("WARNING: 'develop' is not building C++ code incrementally")
|
||||
print("because ninja is not installed. Run this to enable it:")
|
||||
print(" > pip install ninja")
|
||||
|
|
@ -450,7 +467,7 @@ def monkey_patch_THD_link_flags():
|
|||
C.extra_link_args += thd_deps
|
||||
|
||||
|
||||
build_ext_parent = ninja_build_ext if WITH_NINJA \
|
||||
build_ext_parent = ninja_build_ext if USE_NINJA \
|
||||
else setuptools.command.build_ext.build_ext
|
||||
|
||||
|
||||
|
|
@ -459,30 +476,30 @@ class build_ext(build_ext_parent):
|
|||
def run(self):
|
||||
|
||||
# Print build options
|
||||
if WITH_NUMPY:
|
||||
if USE_NUMPY:
|
||||
print('-- Building with NumPy bindings')
|
||||
else:
|
||||
print('-- NumPy not found')
|
||||
if WITH_CUDNN:
|
||||
if USE_CUDNN:
|
||||
print('-- Detected cuDNN at ' + CUDNN_LIBRARY + ', ' + CUDNN_INCLUDE_DIR)
|
||||
else:
|
||||
print('-- Not using cuDNN')
|
||||
if WITH_CUDA:
|
||||
if USE_CUDA:
|
||||
print('-- Detected CUDA at ' + CUDA_HOME)
|
||||
else:
|
||||
print('-- Not using CUDA')
|
||||
if WITH_MKLDNN:
|
||||
if USE_MKLDNN:
|
||||
print('-- Detected MKLDNN at ' + MKLDNN_LIBRARY + ', ' + MKLDNN_INCLUDE_DIR)
|
||||
else:
|
||||
print('-- Not using MKLDNN')
|
||||
if WITH_NCCL and WITH_SYSTEM_NCCL:
|
||||
if USE_NCCL and USE_SYSTEM_NCCL:
|
||||
print('-- Using system provided NCCL library at ' +
|
||||
NCCL_SYSTEM_LIB + ', ' + NCCL_INCLUDE_DIR)
|
||||
elif WITH_NCCL:
|
||||
elif USE_NCCL:
|
||||
print('-- Building NCCL library')
|
||||
else:
|
||||
print('-- Not using NCCL')
|
||||
if WITH_DISTRIBUTED:
|
||||
if USE_DISTRIBUTED:
|
||||
print('-- Building with distributed package ')
|
||||
monkey_patch_THD_link_flags()
|
||||
else:
|
||||
|
|
@ -490,7 +507,7 @@ class build_ext(build_ext_parent):
|
|||
|
||||
generate_code(ninja_global)
|
||||
|
||||
if WITH_NINJA:
|
||||
if USE_NINJA:
|
||||
# before we start the normal build make sure all generated code
|
||||
# gets built
|
||||
ninja_global.run()
|
||||
|
|
@ -648,9 +665,9 @@ library_dirs.append(lib_path)
|
|||
|
||||
# we specify exact lib names to avoid conflict with lua-torch installs
|
||||
CAFFE2_LIBS = [os.path.join(lib_path, 'libcaffe2.so')]
|
||||
if WITH_CUDA:
|
||||
if USE_CUDA:
|
||||
CAFFE2_LIBS.extend(['-Wl,--no-as-needed', os.path.join(lib_path, 'libcaffe2_gpu.so'), '-Wl,--as-needed'])
|
||||
if WITH_ROCM:
|
||||
if USE_ROCM:
|
||||
CAFFE2_LIBS.extend(['-Wl,--no-as-needed', os.path.join(lib_path, 'libcaffe2_hip.so'), '-Wl,--as-needed'])
|
||||
THD_LIB = os.path.join(lib_path, 'libTHD.a')
|
||||
NCCL_LIB = os.path.join(lib_path, 'libnccl.so.1')
|
||||
|
|
@ -666,17 +683,17 @@ else:
|
|||
|
||||
if IS_DARWIN:
|
||||
CAFFE2_LIBS = [os.path.join(lib_path, 'libcaffe2.dylib')]
|
||||
if WITH_CUDA:
|
||||
if USE_CUDA:
|
||||
CAFFE2_LIBS.append(os.path.join(lib_path, 'libcaffe2_gpu.dylib'))
|
||||
if WITH_ROCM:
|
||||
if USE_ROCM:
|
||||
CAFFE2_LIBS.append(os.path.join(lib_path, 'libcaffe2_hip.dylib'))
|
||||
NCCL_LIB = os.path.join(lib_path, 'libnccl.1.dylib')
|
||||
|
||||
if IS_WINDOWS:
|
||||
CAFFE2_LIBS = [os.path.join(lib_path, 'caffe2.lib')]
|
||||
if WITH_CUDA:
|
||||
if USE_CUDA:
|
||||
CAFFE2_LIBS.append(os.path.join(lib_path, 'caffe2_gpu.lib'))
|
||||
if WITH_ROCM:
|
||||
if USE_ROCM:
|
||||
CAFFE2_LIBS.append(os.path.join(lib_path, 'caffe2_hip.lib'))
|
||||
# Windows needs direct access to ONNX libraries as well
|
||||
# as through Caffe2 library
|
||||
|
|
@ -809,31 +826,31 @@ main_sources = [
|
|||
try:
|
||||
import numpy as np
|
||||
include_dirs.append(np.get_include())
|
||||
extra_compile_args.append('-DWITH_NUMPY')
|
||||
WITH_NUMPY = True
|
||||
extra_compile_args.append('-DUSE_NUMPY')
|
||||
USE_NUMPY = True
|
||||
except ImportError:
|
||||
WITH_NUMPY = False
|
||||
USE_NUMPY = False
|
||||
|
||||
if WITH_DISTRIBUTED:
|
||||
extra_compile_args += ['-DWITH_DISTRIBUTED']
|
||||
if USE_DISTRIBUTED:
|
||||
extra_compile_args += ['-DUSE_DISTRIBUTED']
|
||||
main_sources += [
|
||||
"torch/csrc/distributed/Module.cpp",
|
||||
]
|
||||
if WITH_DISTRIBUTED_MW:
|
||||
if USE_DISTRIBUTED_MW:
|
||||
main_sources += [
|
||||
"torch/csrc/distributed/Tensor.cpp",
|
||||
"torch/csrc/distributed/Storage.cpp",
|
||||
]
|
||||
extra_compile_args += ['-DWITH_DISTRIBUTED_MW']
|
||||
extra_compile_args += ['-DUSE_DISTRIBUTED_MW']
|
||||
include_dirs += [tmp_install_path + "/include/THD"]
|
||||
main_link_args += [THD_LIB]
|
||||
|
||||
if WITH_C10D:
|
||||
extra_compile_args += ['-DWITH_C10D']
|
||||
if USE_C10D:
|
||||
extra_compile_args += ['-DUSE_C10D']
|
||||
main_sources += ['torch/csrc/distributed/c10d/init.cpp']
|
||||
main_link_args += [C10D_GLOO_LIB, C10D_LIB]
|
||||
|
||||
if WITH_CUDA:
|
||||
if USE_CUDA:
|
||||
nvtoolext_lib_name = None
|
||||
if IS_WINDOWS:
|
||||
cuda_lib_path = CUDA_HOME + '/lib/x64/'
|
||||
|
|
@ -862,7 +879,7 @@ if WITH_CUDA:
|
|||
cuda_include_path = os.path.join(CUDA_HOME, 'include')
|
||||
include_dirs.append(cuda_include_path)
|
||||
include_dirs.append(tmp_install_path + "/include/THCUNN")
|
||||
extra_compile_args += ['-DWITH_CUDA']
|
||||
extra_compile_args += ['-DUSE_CUDA']
|
||||
extra_compile_args += ['-DCUDA_LIB_PATH=' + cuda_lib_path]
|
||||
main_libraries += ['cudart', nvtoolext_lib_name]
|
||||
main_sources += [
|
||||
|
|
@ -876,7 +893,7 @@ if WITH_CUDA:
|
|||
"torch/csrc/nn/THCUNN.cpp",
|
||||
]
|
||||
|
||||
if WITH_ROCM:
|
||||
if USE_ROCM:
|
||||
rocm_include_path = '/opt/rocm/include'
|
||||
hcc_include_path = '/opt/rocm/hcc/include'
|
||||
hipblas_include_path = '/opt/rocm/hipblas/include'
|
||||
|
|
@ -890,7 +907,7 @@ if WITH_ROCM:
|
|||
include_dirs.append(tmp_install_path + "/include/THCUNN")
|
||||
extra_link_args.append('-L' + hip_lib_path)
|
||||
extra_link_args.append('-Wl,-rpath,' + hip_lib_path)
|
||||
extra_compile_args += ['-DWITH_ROCM']
|
||||
extra_compile_args += ['-DUSE_ROCM']
|
||||
extra_compile_args += ['-D__HIP_PLATFORM_HCC__']
|
||||
|
||||
main_sources += [
|
||||
|
|
@ -904,24 +921,24 @@ if WITH_ROCM:
|
|||
"torch/csrc/nn/THCUNN.cpp",
|
||||
]
|
||||
|
||||
if WITH_NCCL:
|
||||
if WITH_SYSTEM_NCCL:
|
||||
if USE_NCCL:
|
||||
if USE_SYSTEM_NCCL:
|
||||
main_link_args += [NCCL_SYSTEM_LIB]
|
||||
include_dirs.append(NCCL_INCLUDE_DIR)
|
||||
else:
|
||||
main_link_args += [NCCL_LIB]
|
||||
extra_compile_args += ['-DWITH_NCCL']
|
||||
extra_compile_args += ['-DUSE_NCCL']
|
||||
main_sources += [
|
||||
"torch/csrc/cuda/nccl.cpp",
|
||||
"torch/csrc/cuda/python_nccl.cpp",
|
||||
]
|
||||
if WITH_CUDNN:
|
||||
if USE_CUDNN:
|
||||
main_libraries += [CUDNN_LIBRARY]
|
||||
# NOTE: these are at the front, in case there's another cuDNN in CUDA path
|
||||
include_dirs.insert(0, CUDNN_INCLUDE_DIR)
|
||||
if not IS_WINDOWS:
|
||||
extra_link_args.insert(0, '-Wl,-rpath,' + CUDNN_LIB_DIR)
|
||||
extra_compile_args += ['-DWITH_CUDNN']
|
||||
extra_compile_args += ['-DUSE_CUDNN']
|
||||
|
||||
if DEBUG:
|
||||
if IS_WINDOWS:
|
||||
|
|
@ -964,7 +981,7 @@ if not IS_WINDOWS:
|
|||
extensions.append(DL)
|
||||
|
||||
|
||||
if WITH_CUDA:
|
||||
if USE_CUDA:
|
||||
thnvrtc_link_flags = extra_link_args + [make_relative_rpath('lib')]
|
||||
if IS_LINUX:
|
||||
thnvrtc_link_flags = thnvrtc_link_flags + ['-Wl,--no-as-needed']
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ for root, _directories, files in os.walk(os.path.join(proj_dir, "torch")):
|
|||
# Update contents.
|
||||
with open(source, "r+") as f:
|
||||
contents = f.read()
|
||||
contents = contents.replace("WITH_CUDA", "WITH_ROCM")
|
||||
contents = contents.replace("USE_CUDA", "USE_ROCM")
|
||||
contents = contents.replace("CUDA_VERSION", "0")
|
||||
f.seek(0)
|
||||
f.write(contents)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
#include "torch/csrc/autograd/utils/wrap_outputs.h"
|
||||
#include "torch/csrc/autograd/utils/python_arg_parsing.h"
|
||||
#include "torch/csrc/jit/tracer.h"
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include "torch/csrc/cuda/Stream.h"
|
||||
#endif
|
||||
#include "torch/csrc/utils/cuda_lazy_init.h"
|
||||
|
|
@ -399,7 +399,7 @@ static PyObject * THPVariable_numpy(PyObject* self, PyObject* arg)
|
|||
static PyObject * THPVariable_record_stream(PyObject* self, PyObject* arg)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
|
||||
if (!THCPStream_Check(arg)) {
|
||||
return PyErr_Format(PyExc_TypeError, "expected Stream object");
|
||||
|
|
|
|||
|
|
@ -17,21 +17,21 @@ set LINK_FLAGS=/DEBUG:FULL
|
|||
|
||||
mkdir torch/lib/tmp_install
|
||||
|
||||
IF "%~1"=="--with-cuda" (
|
||||
IF "%~1"=="--use-cuda" (
|
||||
set /a USE_CUDA=1
|
||||
shift
|
||||
) ELSE (
|
||||
set /a USE_CUDA=0
|
||||
)
|
||||
|
||||
IF "%~1"=="--with-rocm" (
|
||||
set /a WITH_ROCM=1
|
||||
IF "%~1"=="--use-rocm" (
|
||||
set /a USE_ROCM=1
|
||||
shift
|
||||
) ELSE (
|
||||
set /a WITH_ROCM=0
|
||||
set /a USE_ROCM=0
|
||||
)
|
||||
|
||||
IF "%~1"=="--with-nnpack" (
|
||||
IF "%~1"=="--use-nnpack" (
|
||||
set /a NO_NNPACK=0
|
||||
set /a USE_NNPACK=1
|
||||
shift
|
||||
|
|
@ -40,27 +40,27 @@ IF "%~1"=="--with-nnpack" (
|
|||
set /a USE_NNPACK=0
|
||||
)
|
||||
|
||||
IF "%~1"=="--with-mkldnn" (
|
||||
IF "%~1"=="--use-mkldnn" (
|
||||
set /a NO_MKLDNN=0
|
||||
shift
|
||||
) ELSE (
|
||||
set /a NO_MKLDNN=1
|
||||
)
|
||||
|
||||
IF "%~1"=="--with-gloo-ibverbs" (
|
||||
set /a WITH_GLOO_IBVERBS=1
|
||||
IF "%~1"=="--use-gloo-ibverbs" (
|
||||
set /a USE_GLOO_IBVERBS=1
|
||||
echo Warning: gloo iverbs is enabled but build is not yet implemented 1>&2
|
||||
shift
|
||||
) ELSE (
|
||||
set /a WITH_GLOO_IBVERBS=0
|
||||
set /a USE_GLOO_IBVERBS=0
|
||||
)
|
||||
|
||||
IF "%~1"=="--with-distributed-mw" (
|
||||
set /a WITH_DISTRIBUTED_MW=1
|
||||
IF "%~1"=="--use-distributed-mw" (
|
||||
set /a USE_DISTRIBUTED_MW=1
|
||||
echo Warning: distributed mw is enabled but build is not yet implemented 1>&2
|
||||
shift
|
||||
) ELSE (
|
||||
set /a WITH_DISTRIBUTED_MW=0
|
||||
set /a USE_DISTRIBUTED_MW=0
|
||||
)
|
||||
|
||||
set BUILD_TYPE=Release
|
||||
|
|
@ -195,7 +195,7 @@ goto:eof
|
|||
-DCMAKE_CXX_FLAGS="/EHa %USER_CFLAGS%" ^
|
||||
-DCMAKE_EXE_LINKER_FLAGS="%USER_LDFLAGS%" ^
|
||||
-DCMAKE_SHARED_LINKER_FLAGS="%USER_LDFLAGS%" ^
|
||||
-DWITH_ROCM=%WITH_ROCM%
|
||||
-DUSE_ROCM=%USE_ROCM%
|
||||
|
||||
%MAKE_COMMAND%
|
||||
IF ERRORLEVEL 1 exit 1
|
||||
|
|
|
|||
|
|
@ -11,32 +11,32 @@
|
|||
set -ex
|
||||
|
||||
# Options for building only a subset of the libraries
|
||||
WITH_CUDA=0
|
||||
WITH_ROCM=0
|
||||
WITH_NNPACK=0
|
||||
WITH_MKLDNN=0
|
||||
WITH_GLOO_IBVERBS=0
|
||||
WITH_DISTRIBUTED_MW=0
|
||||
USE_CUDA=0
|
||||
USE_ROCM=0
|
||||
USE_NNPACK=0
|
||||
USE_MKLDNN=0
|
||||
USE_GLOO_IBVERBS=0
|
||||
USE_DISTRIBUTED_MW=0
|
||||
FULL_CAFFE2=0
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--with-cuda)
|
||||
WITH_CUDA=1
|
||||
--use-cuda)
|
||||
USE_CUDA=1
|
||||
;;
|
||||
--with-rocm)
|
||||
WITH_ROCM=1
|
||||
--use-rocm)
|
||||
USE_ROCM=1
|
||||
;;
|
||||
--with-nnpack)
|
||||
WITH_NNPACK=1
|
||||
--use-nnpack)
|
||||
USE_NNPACK=1
|
||||
;;
|
||||
--with-mkldnn)
|
||||
WITH_MKLDNN=1
|
||||
--use-mkldnn)
|
||||
USE_MKLDNN=1
|
||||
;;
|
||||
--with-gloo-ibverbs)
|
||||
WITH_GLOO_IBVERBS=1
|
||||
--use-gloo-ibverbs)
|
||||
USE_GLOO_IBVERBS=1
|
||||
;;
|
||||
--with-distributed-mw)
|
||||
WITH_DISTRIBUTED_MW=1
|
||||
--use-distributed-mw)
|
||||
USE_DISTRIBUTED_MW=1
|
||||
;;
|
||||
--full-caffe2)
|
||||
FULL_CAFFE2=1
|
||||
|
|
@ -96,16 +96,16 @@ CPP_FLAGS=" -std=c++11 "
|
|||
GLOO_FLAGS=""
|
||||
THD_FLAGS=""
|
||||
NCCL_ROOT_DIR=${NCCL_ROOT_DIR:-$INSTALL_DIR}
|
||||
if [[ $WITH_CUDA -eq 1 ]]; then
|
||||
if [[ $USE_CUDA -eq 1 ]]; then
|
||||
GLOO_FLAGS="-DUSE_CUDA=1 -DNCCL_ROOT_DIR=$NCCL_ROOT_DIR"
|
||||
fi
|
||||
# Gloo infiniband support
|
||||
if [[ $WITH_GLOO_IBVERBS -eq 1 ]]; then
|
||||
if [[ $USE_GLOO_IBVERBS -eq 1 ]]; then
|
||||
GLOO_FLAGS+=" -DUSE_IBVERBS=1 -DBUILD_SHARED_LIBS=1"
|
||||
THD_FLAGS="-DWITH_GLOO_IBVERBS=1"
|
||||
THD_FLAGS="-DUSE_GLOO_IBVERBS=1"
|
||||
fi
|
||||
if [[ $WITH_DISTRIBUTED_MW -eq 1 ]]; then
|
||||
THD_FLAGS+="-DWITH_DISTRIBUTED_MW=1"
|
||||
if [[ $USE_DISTRIBUTED_MW -eq 1 ]]; then
|
||||
THD_FLAGS+="-DUSE_DISTRIBUTED_MW=1"
|
||||
fi
|
||||
CWRAP_FILES="\
|
||||
$BASE_DIR/torch/lib/ATen/Declarations.cwrap;\
|
||||
|
|
@ -171,8 +171,8 @@ function build() {
|
|||
-DTHNN_SO_VERSION=1 \
|
||||
-DTHCUNN_SO_VERSION=1 \
|
||||
-DTHD_SO_VERSION=1 \
|
||||
-DUSE_CUDA=$WITH_CUDA \
|
||||
-DNO_NNPACK=$((1-$WITH_NNPACK)) \
|
||||
-DUSE_CUDA=$USE_CUDA \
|
||||
-DNO_NNPACK=$((1-$USE_NNPACK)) \
|
||||
-DNCCL_EXTERNAL=1 \
|
||||
-Dnanopb_BUILD_GENERATOR=0 \
|
||||
-DCMAKE_DEBUG_POSTFIX="" \
|
||||
|
|
@ -233,13 +233,13 @@ function build_caffe2() {
|
|||
-DBUILD_BINARY=OFF \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DONNX_NAMESPACE=$ONNX_NAMESPACE \
|
||||
-DUSE_CUDA=$WITH_CUDA \
|
||||
-DUSE_ROCM=$WITH_ROCM \
|
||||
-DUSE_NNPACK=$WITH_NNPACK \
|
||||
-DUSE_CUDA=$USE_CUDA \
|
||||
-DUSE_ROCM=$USE_ROCM \
|
||||
-DUSE_NNPACK=$USE_NNPACK \
|
||||
-DCUDNN_INCLUDE_DIR=$CUDNN_INCLUDE_DIR \
|
||||
-DCUDNN_LIB_DIR=$CUDNN_LIB_DIR \
|
||||
-DCUDNN_LIBRARY=$CUDNN_LIBRARY \
|
||||
-DUSE_MKLDNN=$WITH_MKLDNN \
|
||||
-DUSE_MKLDNN=$USE_MKLDNN \
|
||||
-DMKLDNN_INCLUDE_DIR=$MKLDNN_INCLUDE_DIR \
|
||||
-DMKLDNN_LIB_DIR=$MKLDNN_LIB_DIR \
|
||||
-DMKLDNN_LIBRARY=$MKLDNN_LIBRARY \
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import re
|
|||
import ctypes.util
|
||||
from subprocess import Popen, PIPE
|
||||
|
||||
from .env import IS_WINDOWS, IS_LINUX, IS_DARWIN, check_env_flag
|
||||
from .env import IS_WINDOWS, IS_LINUX, IS_DARWIN, check_env_flag, check_negative_env_flag
|
||||
|
||||
LINUX_HOME = '/usr/local/cuda'
|
||||
WINDOWS_HOME = glob.glob('C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*')
|
||||
|
|
@ -58,8 +58,8 @@ def find_cuda_version(cuda_home):
|
|||
if len(candidates) > 0:
|
||||
return candidates[0]
|
||||
|
||||
if check_env_flag('NO_CUDA') or check_env_flag('WITH_ROCM'):
|
||||
WITH_CUDA = False
|
||||
if check_negative_env_flag('USE_CUDA') or check_env_flag('USE_ROCM'):
|
||||
USE_CUDA = False
|
||||
CUDA_HOME = None
|
||||
CUDA_VERSION = None
|
||||
else:
|
||||
|
|
@ -84,4 +84,4 @@ else:
|
|||
else:
|
||||
CUDA_HOME = None
|
||||
CUDA_VERSION = find_cuda_version(CUDA_HOME)
|
||||
WITH_CUDA = CUDA_HOME is not None
|
||||
USE_CUDA = CUDA_HOME is not None
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
import os
|
||||
import glob
|
||||
|
||||
from .env import IS_WINDOWS, IS_CONDA, CONDA_DIR, check_env_flag, gather_paths
|
||||
from .cuda import WITH_CUDA, CUDA_HOME
|
||||
from .env import IS_WINDOWS, IS_CONDA, CONDA_DIR, check_negative_env_flag, gather_paths
|
||||
from .cuda import USE_CUDA, CUDA_HOME
|
||||
|
||||
|
||||
WITH_CUDNN = False
|
||||
USE_CUDNN = False
|
||||
CUDNN_LIB_DIR = None
|
||||
CUDNN_INCLUDE_DIR = None
|
||||
CUDNN_LIBRARY = None
|
||||
WITH_STATIC_CUDNN = os.getenv("USE_STATIC_CUDNN")
|
||||
|
||||
if WITH_CUDA and not check_env_flag('NO_CUDNN'):
|
||||
if USE_CUDA and not check_negative_env_flag('USE_CUDNN'):
|
||||
lib_paths = list(filter(bool, [
|
||||
os.getenv('CUDNN_LIB_DIR'),
|
||||
os.path.join(CUDA_HOME, 'lib/x64'),
|
||||
|
|
@ -87,4 +87,4 @@ if WITH_CUDA and not check_env_flag('NO_CUDNN'):
|
|||
real_cudnn_lib_dir = os.path.realpath(CUDNN_LIB_DIR)
|
||||
assert os.path.dirname(real_cudnn_library) == real_cudnn_lib_dir, (
|
||||
'cudnn library and lib_dir must agree')
|
||||
WITH_CUDNN = True
|
||||
USE_CUDNN = True
|
||||
|
|
|
|||
|
|
@ -2,14 +2,14 @@ import os
|
|||
import subprocess
|
||||
import glob
|
||||
|
||||
from .env import IS_CONDA, IS_LINUX, IS_WINDOWS, CONDA_DIR, check_env_flag, gather_paths
|
||||
from .cuda import WITH_CUDA
|
||||
from .env import IS_CONDA, IS_LINUX, IS_WINDOWS, CONDA_DIR, check_env_flag, check_negative_env_flag, gather_paths
|
||||
from .cuda import USE_CUDA
|
||||
|
||||
# On ROCm, RCCL development isn't complete. https://github.com/ROCmSoftwarePlatform/rccl
|
||||
WITH_DISTRIBUTED = not check_env_flag("NO_DISTRIBUTED") and not IS_WINDOWS and not check_env_flag("WITH_ROCM")
|
||||
WITH_DISTRIBUTED_MW = WITH_DISTRIBUTED and check_env_flag("WITH_DISTRIBUTED_MW")
|
||||
WITH_GLOO_IBVERBS = False
|
||||
WITH_C10D = WITH_DISTRIBUTED and WITH_CUDA and IS_LINUX
|
||||
USE_DISTRIBUTED = not check_negative_env_flag("USE_DISTRIBUTED") and not IS_WINDOWS and not check_env_flag("USE_ROCM")
|
||||
USE_DISTRIBUTED_MW = USE_DISTRIBUTED and check_env_flag("USE_DISTRIBUTED_MW")
|
||||
USE_GLOO_IBVERBS = False
|
||||
USE_C10D = USE_DISTRIBUTED and USE_CUDA and IS_LINUX
|
||||
|
||||
IB_DEVINFO_CMD = "ibv_devinfo"
|
||||
|
||||
|
|
@ -104,10 +104,10 @@ def should_build_ib():
|
|||
|
||||
return ib_util_found and ib_lib_found and ib_lib_found
|
||||
|
||||
if WITH_DISTRIBUTED:
|
||||
if USE_DISTRIBUTED:
|
||||
# If the env variable is specified, use the value,
|
||||
# otherwise only build with IB when IB support is detected on the system
|
||||
if "WITH_GLOO_IBVERBS" in os.environ:
|
||||
WITH_GLOO_IBVERBS = check_env_flag("WITH_GLOO_IBVERBS")
|
||||
if "USE_GLOO_IBVERBS" in os.environ:
|
||||
USE_GLOO_IBVERBS = check_env_flag("USE_GLOO_IBVERBS")
|
||||
else:
|
||||
WITH_GLOO_IBVERBS = should_build_ib()
|
||||
USE_GLOO_IBVERBS = should_build_ib()
|
||||
|
|
|
|||
|
|
@ -17,5 +17,9 @@ def check_env_flag(name, default=''):
|
|||
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']
|
||||
|
||||
|
||||
def check_negative_env_flag(name, default=''):
|
||||
return os.getenv(name, default).upper() in ['OFF', '0', 'NO', 'FALSE', 'N']
|
||||
|
||||
|
||||
def gather_paths(env_vars):
|
||||
return list(chain(*(os.getenv(v, '').split(':') for v in env_vars)))
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ def gather_paths(env_vars):
|
|||
|
||||
MKLDNN_HOME = os.getenv('MKLDNN_HOME', '/usr/local/mkl-dnn')
|
||||
|
||||
WITH_MKLDNN = False
|
||||
USE_MKLDNN = False
|
||||
MKLDNN_LIB_DIR = None
|
||||
MKLDNN_INCLUDE_DIR = None
|
||||
MKLDNN_LIBRARY = None
|
||||
|
|
@ -84,4 +84,4 @@ if (IS_LINUX or IS_WINDOWS) and not check_env_flag('NO_MKLDNN'):
|
|||
real_mkldnn_lib_dir = os.path.realpath(MKLDNN_LIB_DIR)
|
||||
assert os.path.dirname(real_mkldnn_library) == real_mkldnn_lib_dir, (
|
||||
'mkldnn library and lib_dir must agree')
|
||||
WITH_MKLDNN = True
|
||||
USE_MKLDNN = True
|
||||
|
|
|
|||
|
|
@ -3,24 +3,24 @@ import glob
|
|||
import warnings
|
||||
from itertools import chain
|
||||
|
||||
from .env import IS_WINDOWS, IS_DARWIN, IS_CONDA, CONDA_DIR, check_env_flag, \
|
||||
from .env import IS_WINDOWS, IS_DARWIN, IS_CONDA, CONDA_DIR, check_negative_env_flag, \
|
||||
gather_paths
|
||||
|
||||
from .cuda import WITH_CUDA, CUDA_HOME
|
||||
from .cuda import USE_CUDA, CUDA_HOME
|
||||
|
||||
|
||||
WITH_NCCL = WITH_CUDA and not IS_DARWIN and not IS_WINDOWS
|
||||
WITH_SYSTEM_NCCL = False
|
||||
USE_NCCL = USE_CUDA and not IS_DARWIN and not IS_WINDOWS
|
||||
USE_SYSTEM_NCCL = False
|
||||
NCCL_LIB_DIR = None
|
||||
NCCL_SYSTEM_LIB = None
|
||||
NCCL_INCLUDE_DIR = None
|
||||
NCCL_ROOT_DIR = None
|
||||
WITH_STATIC_NCCL = os.getenv("USE_STATIC_NCCL")
|
||||
USE_STATIC_NCCL = os.getenv("USE_STATIC_NCCL")
|
||||
LIBNCCL_PREFIX = "libnccl"
|
||||
if WITH_STATIC_NCCL is not None:
|
||||
if USE_STATIC_NCCL is not None:
|
||||
LIBNCCL_PREFIX = "libnccl_static"
|
||||
|
||||
if WITH_CUDA and not check_env_flag('NO_SYSTEM_NCCL'):
|
||||
if USE_CUDA and not check_negative_env_flag('USE_SYSTEM_NCCL'):
|
||||
ENV_ROOT = os.getenv('NCCL_ROOT_DIR', None)
|
||||
LIB_DIR = os.getenv('NCCL_LIB_DIR', None)
|
||||
INCLUDE_DIR = os.getenv('NCCL_INCLUDE_DIR', None)
|
||||
|
|
@ -71,5 +71,5 @@ if WITH_CUDA and not check_env_flag('NO_SYSTEM_NCCL'):
|
|||
NCCL_INCLUDE_DIR = path
|
||||
break
|
||||
if NCCL_LIB_DIR is not None and NCCL_INCLUDE_DIR is not None:
|
||||
WITH_SYSTEM_NCCL = True
|
||||
USE_SYSTEM_NCCL = True
|
||||
NCCL_ROOT_DIR = os.path.commonprefix((NCCL_LIB_DIR, NCCL_INCLUDE_DIR))
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from .env import check_env_flag
|
||||
|
||||
if check_env_flag('NO_NNPACK'):
|
||||
WITH_NNPACK = False
|
||||
USE_NNPACK = False
|
||||
else:
|
||||
WITH_NNPACK = True
|
||||
USE_NNPACK = True
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ import platform
|
|||
import ctypes.util
|
||||
from subprocess import Popen, PIPE
|
||||
|
||||
from .cuda import WITH_CUDA
|
||||
from .cuda import USE_CUDA
|
||||
|
||||
WINDOWS_HOME = 'C:/Program Files/NVIDIA Corporation/NvToolsExt'
|
||||
|
||||
if not WITH_CUDA:
|
||||
if not USE_CUDA:
|
||||
NVTOOLEXT_HOME = None
|
||||
else:
|
||||
# We use nvcc path on Linux and cudart path on macOS
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from .env import check_env_flag
|
||||
# Check if ROCM is enabled
|
||||
WITH_ROCM = check_env_flag('WITH_ROCM')
|
||||
USE_ROCM = check_env_flag('USE_ROCM')
|
||||
ROCM_HOME = "/opt/rocm"
|
||||
ROCM_VERSION = ""
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ if(NOT TORCH_INSTALL_LIB_DIR)
|
|||
endif()
|
||||
|
||||
if(USE_CUDA)
|
||||
add_definitions(-DWITH_CUDA)
|
||||
add_definitions(-DUSE_CUDA)
|
||||
|
||||
set(TORCH_CUDA_LIBRARIES
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libcuda.so
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
#include <unordered_map>
|
||||
#include <sstream>
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include <THC/THC.h>
|
||||
#include <THCS/THCS.h>
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -38,11 +38,11 @@
|
|||
#include "torch/csrc/jit/python_ir.h"
|
||||
#include "torch/csrc/onnx/init.h"
|
||||
|
||||
#ifdef WITH_CUDNN
|
||||
#ifdef USE_CUDNN
|
||||
#include "cudnn.h"
|
||||
#endif
|
||||
|
||||
#ifdef WITH_C10D
|
||||
#ifdef USE_C10D
|
||||
#include "torch/csrc/distributed/c10d/c10d.h"
|
||||
#endif
|
||||
|
||||
|
|
@ -267,7 +267,7 @@ static PyObject *THPModule_getBackcompatKeepdimWarn(PyObject *module)
|
|||
|
||||
PyObject *THPModule_hasDistributed(PyObject *_unused)
|
||||
{
|
||||
#ifdef WITH_DISTRIBUTED
|
||||
#ifdef USE_DISTRIBUTED
|
||||
Py_RETURN_TRUE;
|
||||
#else
|
||||
Py_RETURN_FALSE;
|
||||
|
|
@ -423,7 +423,7 @@ bool THCPByteStorage_init(PyObject *module);
|
|||
|
||||
bool THCPStream_init(PyObject *module);
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
PyMethodDef* THCPModule_methods();
|
||||
namespace torch { namespace cuda {
|
||||
|
||||
|
|
@ -435,7 +435,7 @@ void initModule(PyObject *module);
|
|||
namespace torch { namespace nn {
|
||||
|
||||
void init__THNN(PyObject*);
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
void init__THCUNN(PyObject*);
|
||||
#endif
|
||||
|
||||
|
|
@ -452,12 +452,12 @@ bool THDPByteStorage_init(PyObject *module);
|
|||
|
||||
static std::vector<PyMethodDef> methods;
|
||||
|
||||
#ifdef WITH_DISTRIBUTED
|
||||
#ifdef USE_DISTRIBUTED
|
||||
PyMethodDef* THDPModule_methods();
|
||||
#endif
|
||||
|
||||
// TODO: Refactor this in some less manual way
|
||||
#ifdef WITH_CUDNN
|
||||
#ifdef USE_CUDNN
|
||||
static PyObject * THCUDNN_cudnn_version(PyObject *self, PyObject *args)
|
||||
{
|
||||
return PyLong_FromLong(CUDNN_VERSION);
|
||||
|
|
@ -482,16 +482,16 @@ static PyObject* initModule() {
|
|||
THPUtils_addPyMethodDefs(methods, TorchMethods);
|
||||
THPUtils_addPyMethodDefs(methods, DataLoaderMethods);
|
||||
THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions());
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
THPUtils_addPyMethodDefs(methods, THCPModule_methods());
|
||||
#endif
|
||||
#ifdef WITH_CUDNN
|
||||
#ifdef USE_CUDNN
|
||||
THPUtils_addPyMethodDefs(methods, THCUDNN_methods());
|
||||
#endif
|
||||
#ifdef WITH_DISTRIBUTED
|
||||
#ifdef USE_DISTRIBUTED
|
||||
THPUtils_addPyMethodDefs(methods, THDPModule_methods());
|
||||
#endif
|
||||
#ifdef WITH_C10D
|
||||
#ifdef USE_C10D
|
||||
THPUtils_addPyMethodDefs(methods, torch::distributed::c10d::python_functions());
|
||||
#endif
|
||||
|
||||
|
|
@ -524,7 +524,7 @@ static PyObject* initModule() {
|
|||
torch::jit::initJITBindings(module);
|
||||
torch::autograd::initNNFunctions(module);
|
||||
torch::autograd::init_legacy_variable(module);
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
torch::cuda::initModule(module);
|
||||
#endif
|
||||
ASSERT_TRUE(THPDoubleStorage_init(module));
|
||||
|
|
@ -536,7 +536,7 @@ static PyObject* initModule() {
|
|||
ASSERT_TRUE(THPCharStorage_init(module));
|
||||
ASSERT_TRUE(THPByteStorage_init(module));
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
// This will only initialise base classes and attach them to library namespace
|
||||
// They won't be ready for real usage until importing cuda module, that will
|
||||
// complete the process (but it defines Python classes before calling back into
|
||||
|
|
@ -553,7 +553,7 @@ static PyObject* initModule() {
|
|||
ASSERT_TRUE(THCPStream_init(module));
|
||||
#endif
|
||||
|
||||
#ifdef WITH_CUDNN
|
||||
#ifdef USE_CUDNN
|
||||
PyObject *has_cudnn = Py_True;
|
||||
#else
|
||||
PyObject *has_cudnn = Py_False;
|
||||
|
|
@ -561,7 +561,7 @@ static PyObject* initModule() {
|
|||
Py_INCREF(has_cudnn);
|
||||
ASSERT_TRUE(PyModule_AddObject(module, "has_cudnn", has_cudnn) == 0);
|
||||
|
||||
#ifdef WITH_DISTRIBUTED_MW
|
||||
#ifdef USE_DISTRIBUTED_MW
|
||||
// See comment on CUDA objects
|
||||
ASSERT_TRUE(THDPDoubleStorage_init(module));
|
||||
ASSERT_TRUE(THDPFloatStorage_init(module));
|
||||
|
|
@ -584,12 +584,12 @@ static PyObject* initModule() {
|
|||
defaultGenerator);
|
||||
ASSERT_TRUE(PyModule_AddObject(module, "default_generator", (PyObject*)THPDefaultGenerator) == 0);
|
||||
|
||||
#ifdef WITH_NUMPY
|
||||
#ifdef USE_NUMPY
|
||||
if (_import_array() < 0) return NULL;
|
||||
#endif
|
||||
|
||||
torch::nn::init__THNN(module);
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
torch::nn::init__THCUNN(module);
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
be defined only when compiling the core torch package.
|
||||
#endif
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include "cuda/THCP.h"
|
||||
#include "cuda/undef_macros.h"
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ THAllocator THStorageWeakRefAllocator = {
|
|||
free_wrapper<StorageWeakRefAllocator>,
|
||||
};
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
cudaError_t CudaStorageWeakRefAllocator::malloc(void** ptr, size_t size, cudaStream_t stream) {
|
||||
THError("CudaStorageWeakRefAllocator: malloc not supported");
|
||||
return cudaSuccess;
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
#include <memory>
|
||||
|
||||
#include <TH/TH.h>
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include <THC/THC.h>
|
||||
#endif
|
||||
|
||||
|
|
@ -41,7 +41,7 @@ public:
|
|||
void free(void* ptr);
|
||||
};
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
class CudaStorageWeakRefAllocator {
|
||||
public:
|
||||
CudaStorageWeakRefAllocator(PyObject *wrapped_object, THCDeviceAllocator *alloc, void *ctx) {
|
||||
|
|
@ -63,6 +63,6 @@ public:
|
|||
|
||||
extern THAllocator THObjectPtrAllocator;
|
||||
extern THAllocator THStorageWeakRefAllocator;
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
extern THCDeviceAllocator THCStorageWeakRefAllocator;
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
#include <queue>
|
||||
#include <TH/TH.h>
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include <cuda.h>
|
||||
#include <THC/THC.h>
|
||||
#endif
|
||||
|
|
@ -575,7 +575,7 @@ auto Engine::ready_queue(int device) -> ReadyQueue& {
|
|||
|
||||
auto Engine::start_threads() -> void {
|
||||
int num_devices = 0;
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
// check for case of compiled with CUDA but no available devices
|
||||
if (cudaGetDeviceCount(&num_devices) != cudaSuccess) {
|
||||
cudaGetLastError();
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ void RecordFunction::pushFunctionRange(Function* fn) {
|
|||
pushRange(fn->name());
|
||||
}
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
static void onEachDevice(std::function<void(int)> op) {
|
||||
AutoGPU gpu_guard;
|
||||
int count;
|
||||
|
|
@ -28,7 +28,7 @@ static void onEachDevice(std::function<void(int)> op) {
|
|||
|
||||
void enableProfiler(ProfilerState new_state) {
|
||||
TORCH_ASSERT(new_state != ProfilerState::Disabled);
|
||||
#ifndef WITH_CUDA
|
||||
#ifndef USE_CUDA
|
||||
if (new_state == ProfilerState::NVTX)
|
||||
throw std::runtime_error("Can't use NVTX profiler - PyTorch was compiled without CUDA");
|
||||
#endif
|
||||
|
|
@ -37,7 +37,7 @@ void enableProfiler(ProfilerState new_state) {
|
|||
}
|
||||
state = new_state;
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
if(state == ProfilerState::CUDA) {
|
||||
// event recording appears to have some startup overhead, so we need to
|
||||
// to generate some dummy events first before recording syncrhonization events
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include <nvToolsExt.h>
|
||||
#endif
|
||||
#include <thread>
|
||||
|
|
@ -16,7 +16,7 @@
|
|||
#include <tuple>
|
||||
#include "ATen/ATen.h"
|
||||
#include "torch/csrc/cuda/cuda_check.h"
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
|
||||
|
|
@ -47,7 +47,7 @@ struct Event {
|
|||
: kind_(kind)
|
||||
, name_(std::move(name))
|
||||
, thread_id_(thread_id) {
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
if(record_cuda) {
|
||||
TORCH_CUDA_CHECK(cudaGetDevice(&device_));
|
||||
TORCH_CUDA_CHECK(cudaEventCreate(&event));
|
||||
|
|
@ -79,7 +79,7 @@ struct Event {
|
|||
return (e.cpu_ns_ - cpu_ns_)/(1000.0);
|
||||
}
|
||||
double cuda_elapsed_us(const Event & e) {
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
if(!e.has_cuda() || !has_cuda()) {
|
||||
throw std::logic_error("Events were not recorded for CUDA");
|
||||
}
|
||||
|
|
@ -96,7 +96,7 @@ struct Event {
|
|||
#endif
|
||||
}
|
||||
bool has_cuda() const {
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
return event != nullptr;
|
||||
#else
|
||||
return false;
|
||||
|
|
@ -110,7 +110,7 @@ private:
|
|||
std::string name_;
|
||||
uint32_t thread_id_;
|
||||
int64_t cpu_ns_; // signed to allow for negative intervals
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
cudaEvent_t event = nullptr;
|
||||
#endif
|
||||
int device_ = -1;
|
||||
|
|
@ -182,7 +182,7 @@ inline RangeEventList& getEventList() {
|
|||
|
||||
inline void mark(std::string name, bool include_cuda = true) {
|
||||
if (state == ProfilerState::NVTX) {
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
nvtxMarkA(name.c_str());
|
||||
#else
|
||||
throw std::logic_error("mark called with NVTX tracing, but compiled without CUDA");
|
||||
|
|
@ -194,7 +194,7 @@ inline void mark(std::string name, bool include_cuda = true) {
|
|||
|
||||
inline void pushRange(std::string name) {
|
||||
if (state == ProfilerState::NVTX) {
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
nvtxRangePushA(name.c_str());
|
||||
#else
|
||||
throw std::logic_error("pushRange called with NVTX tracing, but compiled without CUDA");
|
||||
|
|
@ -206,7 +206,7 @@ inline void pushRange(std::string name) {
|
|||
|
||||
inline void popRange() {
|
||||
if (state == ProfilerState::NVTX) {
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
nvtxRangePop();
|
||||
#else
|
||||
throw std::logic_error("popRange called with NVTX tracing, but compiled without CUDA");
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
#include <TH/TH.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <THC/THCCachingAllocator.h>
|
||||
#ifdef WITH_NCCL
|
||||
#ifdef USE_NCCL
|
||||
#include <nccl.h>
|
||||
#endif
|
||||
|
||||
|
|
@ -365,7 +365,7 @@ static PyObject * THCPModule_initExtension(PyObject *self)
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
#ifdef WITH_NCCL
|
||||
#ifdef USE_NCCL
|
||||
#include "python_nccl.h"
|
||||
|
||||
void THCPModule_useNccl()
|
||||
|
|
@ -412,7 +412,7 @@ static struct PyMethodDef _THCPModule_methods[] = {
|
|||
{"_cuda_sleep", (PyCFunction)THCPModule_cudaSleep, METH_O, NULL},
|
||||
{"_cuda_lock_mutex", (PyCFunction)THCPModule_cudaLockMutex, METH_NOARGS, NULL},
|
||||
{"_cuda_unlock_mutex", (PyCFunction)THCPModule_cudaUnlockMutex, METH_NOARGS, NULL},
|
||||
#ifdef WITH_NCCL
|
||||
#ifdef USE_NCCL
|
||||
{"_nccl_version", (PyCFunction)THCPModule_nccl_version, METH_NOARGS, NULL},
|
||||
{"_nccl_unique_id", (PyCFunction)THCPModule_nccl_unique_id, METH_NOARGS, NULL},
|
||||
{"_nccl_init_rank", (PyCFunction)THCPModule_nccl_init_rank, METH_VARARGS, NULL},
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
#include "torch/csrc/utils/tensor_flatten.h"
|
||||
#include "torch/csrc/utils/auto_gpu.h"
|
||||
#include "torch/csrc/cuda/device_set.h"
|
||||
#ifdef WITH_NCCL
|
||||
#ifdef USE_NCCL
|
||||
#include "torch/csrc/cuda/nccl.h"
|
||||
#endif
|
||||
|
||||
|
|
@ -34,7 +34,7 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntList devices) {
|
|||
"first on devices list");
|
||||
std::vector<Tensor> tensors;
|
||||
tensors.reserve(devices.size());
|
||||
#ifdef WITH_NCCL
|
||||
#ifdef USE_NCCL
|
||||
if (nccl::is_available({tensor})) {
|
||||
tensors.push_back(tensor);
|
||||
for (auto device : devices.slice(1)) {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <nvrtc.h>
|
||||
|
|
|
|||
|
|
@ -152,7 +152,7 @@ void _check_inputs(TensorList inputs, TensorList outputs, int input_multiplier,
|
|||
} // namespace detail
|
||||
|
||||
bool is_available(TensorList tensors) {
|
||||
#ifdef WITH_NCCL
|
||||
#ifdef USE_NCCL
|
||||
device_set devices;
|
||||
for (auto & tensor : tensors) {
|
||||
auto & type = tensor.type();
|
||||
|
|
@ -174,7 +174,7 @@ bool is_available(TensorList tensors) {
|
|||
std::uint64_t version() {
|
||||
#if defined(NCCL_MAJOR)
|
||||
return NCCL_MAJOR * 1000 + NCCL_MINOR * 100 + NCCL_PATCH;
|
||||
#elif defined(WITH_NCCL)
|
||||
#elif defined(USE_NCCL)
|
||||
return 1000;
|
||||
#else
|
||||
return 0;
|
||||
|
|
@ -182,7 +182,7 @@ std::uint64_t version() {
|
|||
}
|
||||
|
||||
void broadcast(TensorList tensors, const stream_list& streams, const comm_list& user_comms) {
|
||||
#ifdef WITH_NCCL
|
||||
#ifdef USE_NCCL
|
||||
using namespace torch::cuda::nccl::detail;
|
||||
_check_inputs(tensors, tensors, 1, 1);
|
||||
ncclDataType_t data_type = _get_data_type(tensors[0].type());
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
#define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
|
||||
#include <THC/THCGenerateAllTypes.h>
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
std::vector <THCStream*> THPUtils_PySequence_to_THCStreamList(PyObject *obj) {
|
||||
if (!PySequence_Check(obj)) {
|
||||
throw std::runtime_error("Expected a sequence in THPUtils_PySequence_to_THCStreamList");
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
#include "torch/csrc/PythonTypes.h"
|
||||
#include "torch/csrc/autograd/python_variable.h"
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include "torch/csrc/cuda/Stream.h"
|
||||
#endif
|
||||
|
||||
|
|
@ -23,7 +23,7 @@ static std::unordered_map<std::string, THDChannelType> name2channel_type = {
|
|||
|
||||
static bool THDPModule_loadClasses(PyObject *self)
|
||||
{
|
||||
#ifdef WITH_DISTRIBUTED_MW
|
||||
#ifdef USE_DISTRIBUTED_MW
|
||||
#define ASSERT_NOT_NULL(ptr) if (!(ptr)) { THPUtils_setError("couldn't load classes"); return false; }
|
||||
PyObject *torch_module = PyImport_ImportModule("torch.distributed");
|
||||
if (!torch_module) {
|
||||
|
|
@ -56,7 +56,7 @@ static bool THDPModule_loadClasses(PyObject *self)
|
|||
|
||||
static bool THDPModule_assignStateless(PyObject *self)
|
||||
{
|
||||
#ifdef WITH_DISTRIBUTED_MW
|
||||
#ifdef USE_DISTRIBUTED_MW
|
||||
#define INIT_STATELESS(type) \
|
||||
stateless = PyObject_CallFunctionObjArgs((PyObject*)&TH_CONCAT_3(THDP, type, TensorStatelessType), NULL); \
|
||||
if (!stateless) { \
|
||||
|
|
@ -82,7 +82,7 @@ static bool THDPModule_assignStateless(PyObject *self)
|
|||
static std::unordered_map<PyObject*, THDReduceOp> obj2reduceop;
|
||||
static std::unordered_map<PyObject*, THDGroup> obj2group;
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
extern THCState* state;
|
||||
#endif
|
||||
|
||||
|
|
@ -109,7 +109,7 @@ PyObject* THDPModule_initProcessGroup(PyObject *_unused, PyObject *args)
|
|||
AutoNoGIL nogil;
|
||||
THDProcessGroupInit(channel_type, init_method, world_size, group_name, rank);
|
||||
}
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
THDSetCudaStatePtr(&state);
|
||||
#endif
|
||||
Py_RETURN_NONE;
|
||||
|
|
@ -126,7 +126,7 @@ PyObject* THDPModule_destroyProcessGroup(PyObject *_unused) {
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
#ifdef WITH_DISTRIBUTED_MW
|
||||
#ifdef USE_DISTRIBUTED_MW
|
||||
PyObject* THDPModule_initMasterWorker(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
|
|
@ -150,7 +150,7 @@ PyObject* THDPModule_initMasterWorker(PyObject *_unused, PyObject *args)
|
|||
AutoNoGIL nogil;
|
||||
THDMasterWorkerInit(channel_type, init_method, world_size, group_name, rank);
|
||||
}
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
THDSetCudaStatePtr(&state);
|
||||
#endif
|
||||
Py_RETURN_NONE;
|
||||
|
|
@ -158,7 +158,7 @@ PyObject* THDPModule_initMasterWorker(PyObject *_unused, PyObject *args)
|
|||
}
|
||||
#endif
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
PyObject* THDPModule_registerStream(PyObject *_unused, PyObject *_stream)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
|
|
@ -185,7 +185,7 @@ PyObject* THDPModule_getNumProcesses(PyObject *_unused)
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
extern PyObject* THCPDoubleTensorClass;
|
||||
extern PyObject* THCPFloatTensorClass;
|
||||
extern PyObject* THCPHalfTensorClass;
|
||||
|
|
@ -982,10 +982,10 @@ static struct PyMethodDef _THDPModule_methods[] = {
|
|||
{"_dist_init_process_group", (PyCFunction)THDPModule_initProcessGroup, METH_VARARGS, NULL},
|
||||
{"_dist_destroy_process_group", (PyCFunction)THDPModule_destroyProcessGroup, METH_NOARGS, NULL},
|
||||
{"_dist_clear_group_cache", (PyCFunction)THDPModule_clearGroupCache, METH_VARARGS, NULL},
|
||||
#ifdef WITH_DISTRIBUTED_MW
|
||||
#ifdef USE_DISTRIBUTED_MW
|
||||
{"_dist_init_master_worker", (PyCFunction)THDPModule_initMasterWorker, METH_VARARGS, NULL},
|
||||
#endif
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
{"_dist_register_stream", (PyCFunction)THDPModule_registerStream, METH_O, NULL},
|
||||
#endif
|
||||
{"_dist_get_rank", (PyCFunction)THDPModule_getRank, METH_NOARGS, NULL},
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
#include "torch/csrc/THP.h"
|
||||
#include "Module.h"
|
||||
#ifdef WITH_DISTRIBUTED_MW
|
||||
#ifdef USE_DISTRIBUTED_MW
|
||||
#include "Storage.h"
|
||||
#include "../PtrWrapper.h"
|
||||
#ifdef _THP_CORE
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
|
||||
|
|
@ -29,7 +29,7 @@ static PyObject * THPStorage_(copy_)(PyObject *self, PyObject *args, PyObject *k
|
|||
static PyObject * THPStorage_(isPinned)(THPStorage *self)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
#if defined(WITH_CUDA)
|
||||
#if defined(USE_CUDA)
|
||||
cudaPointerAttributes attr;
|
||||
cudaError_t err = cudaPointerGetAttributes(&attr, THWStorage_(data)(LIBRARY_STATE self->cdata));
|
||||
if (err != cudaSuccess) {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
#include "torch/csrc/variable_tensor_functions.h"
|
||||
|
||||
#include "ATen/ATen.h"
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include "THC/THC.h"
|
||||
#include "torch/csrc/cuda/cuda_check.h"
|
||||
#include <nvrtc.h>
|
||||
|
|
@ -39,7 +39,7 @@ std::vector<bool> TensorDesc::findContiguous(
|
|||
|
||||
namespace {
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
|
||||
static int ceilDiv(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
|
|
@ -492,7 +492,7 @@ void CompiledFusionFunction::launch(at::ArrayRef<at::Tensor> inputs, std::vector
|
|||
launch_with_tensors(inputs, outputs);
|
||||
}
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
|
||||
void checkCUDAVersion(const cudaDeviceProp & prop) {
|
||||
if ((prop.major >= 6 && CUDA_VERSION < 8000) ||
|
||||
|
|
@ -747,7 +747,7 @@ std::shared_ptr<CompiledFusionFunction> FusionCompiler::getOrCompile(AnnotatedGr
|
|||
std::string name = "kernel_" + std::to_string(cache.size());
|
||||
CompiledFusionFunction * raw_func;
|
||||
if(agraph.device != kCPUDevice) {
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
raw_func = new CUDAFusionFunction(name, agraph);
|
||||
#else
|
||||
throw std::runtime_error("cannot compile a CUDA fusion group, CUDA is not enabled.");
|
||||
|
|
@ -834,7 +834,7 @@ FusionCompiler & sharedFusionCompiler() {
|
|||
#include "torch/csrc/jit/resource_guard.h"
|
||||
#include "torch/csrc/utils/disallow_copy.h"
|
||||
#include "ATen/ATen.h"
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include "torch/csrc/cuda/cuda_check.h"
|
||||
#include <nvrtc.h>
|
||||
#include <cuda.h>
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ static inline THIntTensor* THNN_IntTensor_Unpack(PyObject* obj) {
|
|||
return torch::nn::unpack<THIntTensor>(obj);
|
||||
}
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
|
||||
static inline bool THNN_CudaHalfTensor_Check(PyObject* obj) {
|
||||
return torch::nn::check_type(obj, at::TypeID::CUDAHalf);
|
||||
|
|
@ -102,4 +102,4 @@ static inline THCudaLongTensor* THNN_CudaLongTensor_Unpack(PyObject* obj) {
|
|||
return torch::nn::unpack<THCudaLongTensor>(obj);
|
||||
}
|
||||
|
||||
#endif // WITH_CUDA
|
||||
#endif // USE_CUDA
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
#include "torch/csrc/utils/python_numbers.h"
|
||||
#include "torch/csrc/utils/python_compat.h"
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include <THC/THC.h>
|
||||
#endif
|
||||
|
||||
|
|
@ -176,7 +176,7 @@ void setBackCompatKeepdimWarn(bool warn);
|
|||
bool getBackCompatKeepdimWarn();
|
||||
bool maybeThrowBackCompatKeepdimWarn(char *func);
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
std::vector <THCStream*> THPUtils_PySequence_to_THCStreamList(PyObject *obj);
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
|
|
@ -29,7 +29,7 @@ struct AutoGPU {
|
|||
}
|
||||
|
||||
~AutoGPU() {
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
if (original_device != -1) {
|
||||
cudaSetDevice(original_device);
|
||||
}
|
||||
|
|
@ -37,7 +37,7 @@ struct AutoGPU {
|
|||
}
|
||||
|
||||
inline void setDevice(int device) {
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
if (device == -1) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -55,7 +55,7 @@ struct AutoGPU {
|
|||
int original_device = -1;
|
||||
|
||||
private:
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
static void cudaCheck(cudaError_t err) {
|
||||
if (err != cudaSuccess) {
|
||||
std::string msg = "CUDA error (";
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@
|
|||
|
||||
// RAII structs to set CUDA stream
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include <THC/THC.h>
|
||||
extern THCState* state;
|
||||
#endif
|
||||
|
||||
struct AutoStream {
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
explicit AutoStream(THCStream* stream)
|
||||
: original_stream(THCState_getStream(state))
|
||||
{
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ namespace torch {
|
|||
namespace utils {
|
||||
|
||||
static inline bool cuda_enabled() {
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
#include "torch/csrc/python_headers.h"
|
||||
|
||||
#ifdef WITH_NUMPY
|
||||
#ifdef USE_NUMPY
|
||||
|
||||
#if !defined(NO_IMPORT_ARRAY) && !defined(WITH_NUMPY_IMPORT_ARRAY)
|
||||
#define NO_IMPORT_ARRAY
|
||||
|
|
@ -18,4 +18,4 @@
|
|||
|
||||
#include <numpy/arrayobject.h>
|
||||
|
||||
#endif // WITH_NUMPY
|
||||
#endif // USE_NUMPY
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ static ScalarType infer_scalar_type(PyObject *obj) {
|
|||
auto var = reinterpret_cast<THPVariable*>(obj)->cdata;
|
||||
return var.type().scalarType();
|
||||
}
|
||||
#ifdef WITH_NUMPY
|
||||
#ifdef USE_NUMPY
|
||||
if (PyArray_Check(obj)) {
|
||||
auto array = (PyArrayObject*)obj;
|
||||
return numpy_dtype_to_aten(PyArray_TYPE(array));
|
||||
|
|
@ -198,7 +198,7 @@ static Tensor internal_new_from_data(const Type & type, at::optional<Device> dev
|
|||
new_with_type_conversion(type_to_use, var, device);
|
||||
}
|
||||
|
||||
#ifdef WITH_NUMPY
|
||||
#ifdef USE_NUMPY
|
||||
if (PyArray_Check(data)) {
|
||||
auto tensor = autograd::make_variable(tensor_from_numpy(data), /*requires_grad=*/false);
|
||||
const auto& type_to_use = type_inference ? type.toScalarType(tensor.type().scalarType()) : type;
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
#include "torch/csrc/utils/numpy_stub.h"
|
||||
|
||||
#ifndef WITH_NUMPY
|
||||
#ifndef USE_NUMPY
|
||||
namespace torch { namespace utils {
|
||||
PyObject* tensor_to_numpy(const at::Tensor& tensor) {
|
||||
throw std::runtime_error("PyTorch was compiled without NumPy support");
|
||||
|
|
@ -176,4 +176,4 @@ ScalarType numpy_dtype_to_aten(int dtype) {
|
|||
|
||||
}} // namespace torch::utils
|
||||
|
||||
#endif // WITH_NUMPY
|
||||
#endif // USE_NUMPY
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ ELSE()
|
|||
INCLUDE_DIRECTORIES(${CUDA_INCLUDE_DIRS})
|
||||
LINK_DIRECTORIES("${CUDA_TOOLKIT_ROOT_DIR}/lib" "${CUDA_TOOLKIT_ROOT_DIR}/lib64")
|
||||
|
||||
ADD_DEFINITIONS(-DWITH_CUDA=1)
|
||||
ADD_DEFINITIONS(-DUSE_CUDA=1)
|
||||
ENDIF()
|
||||
|
||||
FIND_PACKAGE(NCCL)
|
||||
|
|
@ -70,9 +70,9 @@ ENDIF()
|
|||
IF(GLOO_FOUND)
|
||||
ADD_DEFINITIONS(-DWITH_GLOO=1)
|
||||
MESSAGE(STATUS "Found Gloo, will compile with Gloo distributed backend")
|
||||
IF(WITH_GLOO_IBVERBS)
|
||||
IF(USE_GLOO_IBVERBS)
|
||||
MESSAGE(STATUS "Building the gloo backend with both TCP and infiniband support")
|
||||
ADD_DEFINITIONS(-DWITH_GLOO_IBVERBS=1)
|
||||
ADD_DEFINITIONS(-DUSE_GLOO_IBVERBS=1)
|
||||
ELSE()
|
||||
MESSAGE(STATUS "Building the gloo backend with TCP support only")
|
||||
ENDIF()
|
||||
|
|
@ -84,7 +84,7 @@ IF(NCCL_FOUND)
|
|||
MESSAGE(STATUS "NCCL Version 2 or higher found, will "
|
||||
"compile with NCCL distributed backend")
|
||||
SET(DISTRIBUTED_NCCL_FOUND TRUE)
|
||||
ADD_DEFINITIONS(-DWITH_DISTRIBUTED_NCCL=1)
|
||||
ADD_DEFINITIONS(-DUSE_DISTRIBUTED_NCCL=1)
|
||||
ELSE()
|
||||
MESSAGE(STATUS "Found NCCL, but the NCCL version is either not 2+ or not "
|
||||
"determinable, will not compile with NCCL distributed "
|
||||
|
|
@ -133,8 +133,8 @@ EXCLUDE_DIR(master_worker_cpp ".*/dispatch/.*\\.cpp$")
|
|||
SET(all_cpp ${base_cpp} ${process_group_cpp})
|
||||
SET(all_h THD.h ${base_h} ${process_group_h})
|
||||
|
||||
IF(WITH_DISTRIBUTED_MW)
|
||||
ADD_DEFINITIONS(-DWITH_DISTRIBUTED_MW=1)
|
||||
IF(USE_DISTRIBUTED_MW)
|
||||
ADD_DEFINITIONS(-DUSE_DISTRIBUTED_MW=1)
|
||||
SET(all_cpp ${all_cpp} ${master_worker_cpp})
|
||||
SET(all_h THD.h ${all_h} ${master_worker_h})
|
||||
ENDIF()
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@
|
|||
#include "process_group/General.h"
|
||||
#include "process_group/Collectives.h"
|
||||
|
||||
#ifdef WITH_DISTRIBUTED_MW
|
||||
#ifdef USE_DISTRIBUTED_MW
|
||||
#include "master_worker/master/Master.h"
|
||||
#include "master_worker/master/State.h"
|
||||
#include "master_worker/master/THDRandom.h"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#include "Cuda.hpp"
|
||||
#include <unordered_map>
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
THCState** _THDCudaState;
|
||||
|
||||
void THDSetCudaStatePtr(THCState **state) {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include "../THD.h"
|
||||
|
||||
#include <THC/THC.h>
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include <THC/THC.h>
|
||||
#include "Cuda.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@
|
|||
#ifdef WITH_MPI
|
||||
#include "data_channels/DataChannelMPI.hpp"
|
||||
#endif // WITH_MPI
|
||||
#if defined(WITH_CUDA) && defined(WITH_DISTRIBUTED_NCCL)
|
||||
#if defined(USE_CUDA) && defined(USE_DISTRIBUTED_NCCL)
|
||||
#include "data_channels/DataChannelNccl.hpp"
|
||||
#endif // WITH_DISTRIBUTED_NCCL
|
||||
#endif // USE_DISTRIBUTED_NCCL
|
||||
#include "data_channels/DataChannelTCP.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
|
|
@ -43,7 +43,7 @@ DataChannel* DataChannel::newChannel(THDChannelType type, std::string init_metho
|
|||
);
|
||||
|
||||
case THDChannelNccl:
|
||||
#if defined(WITH_CUDA) && defined(WITH_DISTRIBUTED_NCCL)
|
||||
#if defined(USE_CUDA) && defined(USE_DISTRIBUTED_NCCL)
|
||||
return new DataChannelNccl(GET_CONFIG);
|
||||
#endif
|
||||
throw std::runtime_error(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
#include "../THD.h"
|
||||
#include <TH/TH.h>
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include <THC/THC.h>
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
#include "GlooCache.hpp"
|
||||
#include "Store.hpp"
|
||||
|
||||
#if defined(WITH_GLOO_IBVERBS) && WITH_GLOO_IBVERBS
|
||||
#if defined(USE_GLOO_IBVERBS) && USE_GLOO_IBVERBS
|
||||
#include "gloo/transport/ibverbs/device.h"
|
||||
#endif
|
||||
|
||||
|
|
@ -81,7 +81,7 @@ DataChannelGloo::DataChannelGloo(InitMethod::Config config)
|
|||
{
|
||||
_num_processes = config.world_size;
|
||||
|
||||
#if defined(WITH_GLOO_IBVERBS) && WITH_GLOO_IBVERBS
|
||||
#if defined(USE_GLOO_IBVERBS) && USE_GLOO_IBVERBS
|
||||
|
||||
// This helper function automatically detects the IB device in the system
|
||||
auto ibDeviceNames = ::gloo::transport::ibverbs::getDeviceNames();
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
#include <unordered_map>
|
||||
#include <iostream>
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
|
||||
|
|
@ -138,7 +138,7 @@ rank_type DataChannelMPI::getNumProcesses() {
|
|||
struct AutoGPU {
|
||||
AutoGPU(int new_device) {
|
||||
if (new_device == -1) return;
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
cudaGetDevice(&device_);
|
||||
cudaSetDevice(new_device);
|
||||
#endif
|
||||
|
|
@ -146,7 +146,7 @@ struct AutoGPU {
|
|||
|
||||
~AutoGPU() {
|
||||
if (device_ == -1) return;
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
cudaSetDevice(device_);
|
||||
#endif
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
#include "gloo/allreduce_ring.h"
|
||||
#include "gloo/barrier_all_to_all.h"
|
||||
#include "gloo/broadcast_one_to_all.h"
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include "gloo/cuda_allreduce_ring.h"
|
||||
#include "gloo/cuda_allreduce_halving_doubling.h"
|
||||
#include "gloo/cuda_allreduce_halving_doubling_pipelined.h"
|
||||
|
|
@ -19,7 +19,7 @@
|
|||
#include "gloo/rendezvous/store.h"
|
||||
#include "gloo/rendezvous/prefix_store.h"
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
#include <cuda.h>
|
||||
#include <THC/THC.h>
|
||||
#endif
|
||||
|
|
@ -141,7 +141,7 @@ struct GlooCache {
|
|||
if (device == DeviceType::CPU) {
|
||||
return std::shared_ptr<buffer_type>(new char[bytes],
|
||||
std::default_delete<char[]>());
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
} else if (device == DeviceType::CUDA) {
|
||||
buffer_type *buf;
|
||||
THCudaCheck(THCudaMalloc(THDGetCudaState(), (void**)&buf, bytes));
|
||||
|
|
@ -184,7 +184,7 @@ struct GlooCache {
|
|||
|
||||
if (t_dev == DeviceType::CPU) {
|
||||
std::memcpy(input_buffer, t.data_ptr(), tensor_bytes);
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
} else if (t_dev == DeviceType::CUDA) {
|
||||
auto stream = THCState_getCurrentStream(THDGetCudaState());
|
||||
THCudaCheck(cudaMemcpyAsync(input_buffer, t.data_ptr(), tensor_bytes,
|
||||
|
|
@ -202,7 +202,7 @@ struct GlooCache {
|
|||
|
||||
if (t_dev == DeviceType::CPU) {
|
||||
std::memcpy(t.data_ptr(), output_buffer, tensor_bytes);
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
} else if (t_dev == DeviceType::CUDA) {
|
||||
auto stream = THCState_getCurrentStream(THDGetCudaState());
|
||||
THCudaCheck(cudaMemcpyAsync(t.data_ptr(), output_buffer, tensor_bytes,
|
||||
|
|
@ -318,14 +318,14 @@ struct algorithm_spec<CollectiveType::ALL_REDUCE, T> {
|
|||
std::initializer_list<T*>{reinterpret_cast<T*>(input_buffer.get())},
|
||||
count,
|
||||
THDToGlooReduceOp<T>(op));
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
} else if (device == DeviceType::CUDA) {
|
||||
if (op != THDReduceSUM) {
|
||||
throw std::runtime_error("Gloo backend only supports sum op for CUDA all reduce");
|
||||
}
|
||||
auto stream = THCState_getCurrentStream(THDGetCudaState());
|
||||
|
||||
#if defined(WITH_GLOO_IBVERBS) && WITH_GLOO_IBVERBS
|
||||
#if defined(USE_GLOO_IBVERBS) && USE_GLOO_IBVERBS
|
||||
// Only enable GPU direct if the device supports it
|
||||
if (context->getDevice()->hasGPUDirect()) {
|
||||
algo = std::make_shared<::gloo::CudaAllreduceHalvingDoublingPipelined<T,
|
||||
|
|
@ -388,11 +388,11 @@ struct algorithm_spec<CollectiveType::BROADCAST, T> {
|
|||
std::initializer_list<T*>{reinterpret_cast<T*>(input_buffer.get())},
|
||||
count,
|
||||
src_rank);
|
||||
#ifdef WITH_CUDA
|
||||
#ifdef USE_CUDA
|
||||
} else if (device == DeviceType::CUDA) {
|
||||
auto stream = THCState_getCurrentStream(THDGetCudaState());
|
||||
|
||||
#if defined(WITH_GLOO_IBVERBS) && WITH_GLOO_IBVERBS
|
||||
#if defined(USE_GLOO_IBVERBS) && USE_GLOO_IBVERBS
|
||||
// Only enable GPU direct if the device supports it
|
||||
if (context->getDevice()->hasGPUDirect()) {
|
||||
algo = std::make_shared<::gloo::CudaBroadcastOneToAll<T,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user