diff --git a/aten/CMakeLists.txt b/aten/CMakeLists.txt index 0908265c2f7..77ed7da4381 100644 --- a/aten/CMakeLists.txt +++ b/aten/CMakeLists.txt @@ -24,6 +24,7 @@ set(ATen_THIRD_PARTY_INCLUDE) set(ATen_CUDA_SRCS) set(ATen_CUDA_TEST_SRCS) set(ATen_CUDA_INCLUDE) +set(ATen_NVRTC_STUB_SRCS) set(ATen_HIP_SRCS) set(ATen_HIP_TEST_SRCS) set(ATen_HIP_INCLUDE) @@ -101,6 +102,7 @@ add_subdirectory(src/ATen) set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE) set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} PARENT_SCOPE) set(ATen_HIP_SRCS ${ATen_HIP_SRCS} PARENT_SCOPE) +set(ATen_NVRTC_STUB_SRCS ${ATen_NVRTC_STUB_SRCS} PARENT_SCOPE) set(ATen_CPU_TEST_SRCS ${ATen_CPU_TEST_SRCS} PARENT_SCOPE) set(ATen_CUDA_TEST_SRCS ${ATen_CUDA_TEST_SRCS} PARENT_SCOPE) set(ATen_HIP_TEST_SRCS ${ATen_HIP_TEST_SRCS} PARENT_SCOPE) diff --git a/aten/src/ATen/ATen.h b/aten/src/ATen/ATen.h index 5fd4588390f..b741126ca16 100644 --- a/aten/src/ATen/ATen.h +++ b/aten/src/ATen/ATen.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #ifdef BUILD_NAMEDTENSOR diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 4f118a40c63..6824b4d5bcc 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -39,6 +39,8 @@ FILE(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp") add_subdirectory(core) FILE(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh") FILE(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp") +FILE(GLOB cuda_nvrtc_stub_h "cuda/nvrtc_stub/*.h") +FILE(GLOB cuda_nvrtc_stub_cpp "cuda/nvrtc_stub/*.cpp") FILE(GLOB cuda_cu "cuda/*.cu" "cuda/detail/*.cu") FILE(GLOB cudnn_h "cudnn/*.h" "cudnn/*.cuh") FILE(GLOB cudnn_cpp "cudnn/*.cpp") @@ -46,6 +48,8 @@ FILE(GLOB cudnn_cpp "cudnn/*.cpp") FILE(GLOB hip_h "hip/*.h" "hip/detail/*.h" "hip/*.cuh" "hip/detail/*.cuh") FILE(GLOB hip_cpp "hip/*.cpp" "hip/detail/*.cpp" "hip/impl/*.cpp") FILE(GLOB hip_hip "hip/*.hip" "hip/detail/*.hip" "hip/impl/*.hip") +FILE(GLOB hip_nvrtc_stub_h "hip/nvrtc_stub/*.h") +FILE(GLOB hip_nvrtc_stub_cpp "hip/nvrtc_stub/*.cpp") FILE(GLOB miopen_h "miopen/*.h") FILE(GLOB miopen_cpp "miopen/*.cpp") @@ -356,6 +360,7 @@ endif() if(USE_CUDA) set(ATen_CUDA_SRCS ${all_cuda_cpp}) + set(ATen_NVRTC_STUB_SRCS ${cuda_nvrtc_stub_cpp}) if(AT_LINK_STYLE STREQUAL "INTERFACE") # Source code can't be added to an interface library, so it is # passed back to be compiled into the containing library @@ -368,6 +373,9 @@ endif() if(USE_ROCM) set(ATen_HIP_SRCS ${all_hip_cpp}) + # caffe2_nvrtc's stubs to driver APIs are useful for HIP. + # See NOTE [ ATen NVRTC Stub and HIP ] + set(ATen_NVRTC_STUB_SRCS ${hip_nvrtc_stub_cpp}) if(AT_LINK_STYLE STREQUAL "INTERFACE") # Source code can't be added to an interface library, so it is # passed back to be compiled into the containing library @@ -439,6 +447,7 @@ endif() set(ATen_CORE_SRCS ${ATen_CORE_SRCS} PARENT_SCOPE) set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE) set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} PARENT_SCOPE) +set(ATen_NVRTC_STUB_SRCS ${ATen_NVRTC_STUB_SRCS} PARENT_SCOPE) set(ATen_HIP_SRCS ${ATen_HIP_SRCS} PARENT_SCOPE) set(ATen_QUANTIZED_SRCS ${ATen_QUANTIZED_SRCS} PARENT_SCOPE) set(ATen_CPU_TEST_SRCS ${ATen_CPU_TEST_SRCS} PARENT_SCOPE) diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index a564b729088..086ed73641e 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -77,7 +77,9 @@ class CAFFE2_API Context { }); return thh_state.get(); } - + const at::cuda::NVRTC& getNVRTC() { + return detail::getCUDAHooks().nvrtc(); + } THCState* getTHCState() { // AT_ASSERT(thc_state); return thc_state.get(); diff --git a/aten/src/ATen/DynamicLibrary.cpp b/aten/src/ATen/DynamicLibrary.cpp new file mode 100644 index 00000000000..72b1a934a1c --- /dev/null +++ b/aten/src/ATen/DynamicLibrary.cpp @@ -0,0 +1,74 @@ +#include +#include +#include + +#ifndef _WIN32 +#include +#include +#else +#include +#endif + +namespace at { + + +#ifndef _WIN32 + +// Unix + +static void* checkDL(void* x) { + if (!x) { + AT_ERROR("Error in dlopen or dlsym: ", dlerror()); + } + + return x; +} +DynamicLibrary::DynamicLibrary(const char* name) { + // NOLINTNEXTLINE(hicpp-signed-bitwise) + handle = checkDL(dlopen(name, RTLD_LOCAL | RTLD_NOW)); +} + +void* DynamicLibrary::sym(const char* name) { + AT_ASSERT(handle); + return checkDL(dlsym(handle, name)); +} + +DynamicLibrary::~DynamicLibrary() { + if (!handle) + return; + dlclose(handle); +} + +#else + +// Windows + +DynamicLibrary::DynamicLibrary(const char* name) { + // NOLINTNEXTLINE(hicpp-signed-bitwise) + HMODULE theModule = LoadLibraryA(name); + if (theModule) { + handle = theModule; + } else { + AT_ERROR("error in LoadLibraryA"); + } +} + +void* DynamicLibrary::sym(const char* name) { + AT_ASSERT(handle); + FARPROC procAddress = GetProcAddress((HMODULE)handle, name); + if (!procAddress) { + AT_ERROR("error in GetProcAddress"); + } + return (void*)procAddress; +} + +DynamicLibrary::~DynamicLibrary() { + if (!handle) { + return; + } + FreeLibrary((HMODULE)handle); +} + +#endif + +} // namespace at diff --git a/aten/src/ATen/DynamicLibrary.h b/aten/src/ATen/DynamicLibrary.h new file mode 100644 index 00000000000..ea919a79d31 --- /dev/null +++ b/aten/src/ATen/DynamicLibrary.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +namespace at { + +struct DynamicLibrary { + AT_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary); + + CAFFE2_API DynamicLibrary(const char* name); + + CAFFE2_API void* sym(const char* name); + + CAFFE2_API ~DynamicLibrary(); + + private: + void* handle = nullptr; +}; + +} // namespace at diff --git a/aten/src/ATen/Utils.h b/aten/src/ATen/Utils.h index 63ed9ea5ea5..be47638394b 100644 --- a/aten/src/ATen/Utils.h +++ b/aten/src/ATen/Utils.h @@ -24,6 +24,10 @@ #define __ubsan_ignore_vptr__ #endif +#define AT_DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName&) = delete; \ + void operator=(const TypeName&) = delete + namespace at { CAFFE2_API int _crash_if_asan(int); diff --git a/aten/src/ATen/cuda/CUDAContext.h b/aten/src/ATen/cuda/CUDAContext.h index 718f5f94382..5afcac93a23 100644 --- a/aten/src/ATen/cuda/CUDAContext.h +++ b/aten/src/ATen/cuda/CUDAContext.h @@ -3,8 +3,8 @@ #include #include #include -#include #include +#include #include diff --git a/aten/src/ATen/cuda/Exceptions.h b/aten/src/ATen/cuda/Exceptions.h index ee90859dcb3..c8e2f2f325e 100644 --- a/aten/src/ATen/cuda/Exceptions.h +++ b/aten/src/ATen/cuda/Exceptions.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -20,3 +21,57 @@ } while (0) #define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR) + +// For CUDA Driver API +// +// This is here instead of in c10 because NVRTC is loaded dynamically via a stub +// in ATen, and we need to use its nvrtcGetErrorString. +// See NOTE [ USE OF NVRTC AND DRIVER API ]. +#ifndef __HIP_PLATFORM_HCC__ + +#define AT_CUDA_DRIVER_CHECK(EXPR) \ + do { \ + CUresult __err = EXPR; \ + if (__err != CUDA_SUCCESS) { \ + const char* err_str; \ + CUresult get_error_str_err C10_UNUSED = at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \ + if (get_error_str_err != CUDA_SUCCESS) { \ + AT_ERROR("CUDA driver error: unknown error"); \ + } else { \ + AT_ERROR("CUDA driver error: ", err_str); \ + } \ + } \ + } while (0) + +#else + +#define AT_CUDA_DRIVER_CHECK(EXPR) \ + do { \ + CUresult __err = EXPR; \ + if (__err != CUDA_SUCCESS) { \ + AT_ERROR("CUDA driver error: ", static_cast(__err)); \ + } \ + } while (0) + +#endif + +// For CUDA NVRTC +// +// Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE, +// incorrectly produces the error string "NVRTC unknown error." +// The following maps it correctly. +// +// This is here instead of in c10 because NVRTC is loaded dynamically via a stub +// in ATen, and we need to use its nvrtcGetErrorString. +// See NOTE [ USE OF NVRTC AND DRIVER API ]. +#define AT_CUDA_NVRTC_CHECK(EXPR) \ + do { \ + nvrtcResult __err = EXPR; \ + if (__err != NVRTC_SUCCESS) { \ + if (static_cast(__err) != 7) { \ + AT_ERROR("CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err)); \ + } else { \ + AT_ERROR("CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \ + } \ + } \ + } while (0) diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index ea5c661c271..23dbc9713be 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -2,9 +2,11 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -77,6 +79,32 @@ bool CUDAHooks::hasCuDNN() const { return AT_CUDNN_ENABLED(); } +#ifdef USE_DIRECT_NVRTC +static std::pair, at::cuda::NVRTC*> load_nvrtc() { + return std::make_pair(nullptr, at::cuda::load_nvrtc()); +} +#else +static std::pair, at::cuda::NVRTC*> load_nvrtc() { +#if defined(_WIN32) + std::string libcaffe2_nvrtc = "caffe2_nvrtc.dll"; +#elif defined(__APPLE__) + std::string libcaffe2_nvrtc = "libcaffe2_nvrtc.dylib"; +#else + std::string libcaffe2_nvrtc = "libcaffe2_nvrtc.so"; +#endif + std::unique_ptr libnvrtc_stub( + new at::DynamicLibrary(libcaffe2_nvrtc.c_str())); + auto fn = (at::cuda::NVRTC * (*)()) libnvrtc_stub->sym("load_nvrtc"); + return std::make_pair(std::move(libnvrtc_stub), fn()); +} +#endif + +const at::cuda::NVRTC& CUDAHooks::nvrtc() const { + // must hold onto DynamicLibrary otherwise it will unload + static auto handle = load_nvrtc(); + return *handle.second; +} + int64_t CUDAHooks::current_device() const { int device; cudaError_t err = cudaGetDevice(&device); diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index f9ef501469d..49a358ff6d1 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -16,6 +16,7 @@ struct CUDAHooks : public at::CUDAHooksInterface { bool hasCUDA() const override; bool hasMAGMA() const override; bool hasCuDNN() const override; + const at::cuda::NVRTC& nvrtc() const override; int64_t current_device() const override; Allocator* getPinnedMemoryAllocator() const override; bool compiledWithCuDNN() const override; diff --git a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.cpp b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.cpp new file mode 100644 index 00000000000..5ecf47a41f8 --- /dev/null +++ b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.cpp @@ -0,0 +1,13 @@ +#include +#include + +namespace at { namespace cuda { + +NVRTC* load_nvrtc() { + auto self = new NVRTC(); +#define CREATE_ASSIGN(name) self->name = name; + AT_FORALL_NVRTC(CREATE_ASSIGN) + return self; +} + +}} // at::cuda diff --git a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h new file mode 100644 index 00000000000..118e40e6720 --- /dev/null +++ b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h @@ -0,0 +1,88 @@ +#pragma once + +#include +#include + +#ifndef __HIP_PLATFORM_HCC__ +#include +#endif + +namespace at { namespace cuda { + + +// NOTE [ USE OF NVRTC AND DRIVER API ] +// +// ATen does not directly link to either libnvrtc or libcuda because they +// require libcuda to be installed, yet we want our GPU build to work on CPU +// machines as long as CUDA is not initialized. +// +// Normal CUDA code in torch uses the cuda runtime libraries which can be +// installed even if the driver is not installed, but sometimes we specifically +// need to use the driver API (e.g., to load JIT compiled code). +// To accomplish this, we lazily link libcaffe2_nvrtc which provides a struct +// at::cuda::NVRTC that contains function pointers to all of the apis we need. +// +// IT IS AN ERROR TO TRY TO CALL ANY nvrtc* or cu* FUNCTION DIRECTLY. +// INSTEAD USE, e.g. +// detail::getCUDAHooks().nvrtc().cuLoadModule(...) +// oe +// globalContext().getNVRTC().cuLoadModule(...) +// +// If a function is missing add it to the list in ATen/cuda/nvrtc_stub/ATenNVRTC.h. + +#ifndef __HIP_PLATFORM_HCC__ + +#define AT_FORALL_NVRTC(_) \ + _(nvrtcVersion) \ + _(nvrtcCreateProgram) \ + _(nvrtcDestroyProgram) \ + _(nvrtcGetPTXSize) \ + _(nvrtcGetPTX) \ + _(nvrtcCompileProgram) \ + _(nvrtcGetErrorString) \ + _(nvrtcGetProgramLogSize) \ + _(nvrtcGetProgramLog) \ + _(cuModuleLoadData) \ + _(cuModuleGetFunction) \ + _(cuOccupancyMaxActiveBlocksPerMultiprocessor) \ + _(cuGetErrorString) \ + _(cuLaunchKernel) \ + _(cuCtxGetCurrent) \ + _(cuModuleUnload) \ + _(cuDevicePrimaryCtxGetState) + +#else + +// NOTE [ ATen NVRTC Stub and HIP ] +// +// ATen's NVRTC stub library, caffe2_nvrtc, provides dynamic loading of both +// NVRTC and driver APIs. While the former is not yet suppoted for HIP, the +// later is supported and needed. +// +// The macro below strips out certain unsupported operations on HIP from the full +// list above. +// +// HIP doesn't have +// nvrtc* +// cuOccupancyMaxActiveBlocksPerMultiprocessor +// cuGetErrorString (maps to non-functional hipGetErrorString___) + +#define AT_FORALL_NVRTC(_) \ + _(cuModuleLoadData) \ + _(cuModuleGetFunction) \ + _(cuLaunchKernel) \ + _(cuCtxGetCurrent) \ + _(cuModuleUnload) \ + _(cuDevicePrimaryCtxGetState) + +#endif + +extern "C" typedef struct NVRTC { +#define CREATE_MEMBER(name) decltype(&name) name; + AT_FORALL_NVRTC(CREATE_MEMBER) +#undef CREATE_MEMBER +} NVRTC; + +extern "C" AT_CUDA_API NVRTC* load_nvrtc(); + +}} // at::cuda diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index 083b32d62a0..e4d5999f7ad 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -13,6 +13,11 @@ // Forward-declares THCState struct THCState; +// Forward-declares at::cuda::NVRTC +namespace at { namespace cuda { +struct NVRTC; +}} // at::cuda + namespace at { class Context; } @@ -78,6 +83,10 @@ struct CAFFE2_API CUDAHooksInterface { return false; } + virtual const at::cuda::NVRTC& nvrtc() const { + AT_ERROR("NVRTC requires CUDA. ", CUDA_HELP); + } + virtual int64_t current_device() const { return -1; } diff --git a/c10/cuda/CUDAException.h b/c10/cuda/CUDAException.h index 9dfdd890012..d7e7ec3f5c1 100644 --- a/c10/cuda/CUDAException.h +++ b/c10/cuda/CUDAException.h @@ -11,6 +11,7 @@ // macro and a function implementation if we pass along __LINE__ // and __FILE__, but no one has found this worth doing. +// For CUDA Runtime API #define C10_CUDA_CHECK(EXPR) \ do { \ cudaError_t __err = EXPR; \ diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index e638f8dff7c..acd13deab6c 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -453,42 +453,27 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_ROOT}/test/cpp/jit/test.cpp ) - if (WIN32) + if (NOT WIN32) list(APPEND TORCH_SRCS - ${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/dynamic_library_win.cpp - ) - if (USE_CUDA AND NOT USE_ROCM) - list(APPEND Caffe2_GPU_SRCS - ${TORCH_SRC_DIR}/csrc/jit/fuser/cuda/fused_kernel.cpp - ) - add_library(thnvrtc SHARED ${TORCH_SRC_DIR}/csrc/jit/fuser/cuda/thnvrtc.cpp) - target_link_libraries(thnvrtc ${CUDA_NVRTC} ${CUDA_CUDA_LIB} ${CUDA_NVRTC_LIB}) - target_include_directories(thnvrtc PRIVATE ${CUDA_INCLUDE_DIRS}) - install(TARGETS thnvrtc DESTINATION "${TORCH_INSTALL_LIB_DIR}") - endif() - else () - list(APPEND TORCH_SRCS - ${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/dynamic_library_unix.cpp ${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/fused_kernel.cpp ) - if (USE_CUDA AND NOT USE_ROCM) - list(APPEND Caffe2_GPU_SRCS - ${TORCH_SRC_DIR}/csrc/jit/fuser/cuda/fused_kernel.cpp - ) - add_library(thnvrtc SHARED ${TORCH_SRC_DIR}/csrc/jit/fuser/cuda/thnvrtc.cpp) - target_link_libraries(thnvrtc ${CUDA_NVRTC} ${CUDA_CUDA_LIB} ${CUDA_NVRTC_LIB}) - target_include_directories(thnvrtc PRIVATE ${CUDA_INCLUDE_DIRS}) - install(TARGETS thnvrtc DESTINATION "${TORCH_INSTALL_LIB_DIR}") - - endif() endif () if (USE_CUDA) + if (NOT USE_ROCM) + list(APPEND Caffe2_GPU_SRCS + ${TORCH_SRC_DIR}/csrc/jit/fuser/cuda/fused_kernel.cpp + ) + endif() list(APPEND Caffe2_GPU_SRCS ${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp ${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp ${TORCH_SRC_DIR}/csrc/cuda/comm.cpp ) + add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS}) + target_link_libraries(caffe2_nvrtc ${CUDA_NVRTC} ${CUDA_CUDA_LIB} ${CUDA_NVRTC_LIB}) + target_include_directories(caffe2_nvrtc PRIVATE ${CUDA_INCLUDE_DIRS}) + install(TARGETS caffe2_nvrtc DESTINATION "${TORCH_INSTALL_LIB_DIR}") endif() if (USE_ROCM) @@ -496,6 +481,13 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp ${TORCH_SRC_DIR}/csrc/cuda/comm.cpp ) + # caffe2_nvrtc's stubs to driver APIs are useful for HIP. + # See NOTE [ ATen NVRTC Stub and HIP ] + add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS}) + target_link_libraries(caffe2_nvrtc ${CUDA_NVRTC} ${CUDA_CUDA_LIB} ${CUDA_NVRTC_LIB}) + target_include_directories(caffe2_nvrtc PRIVATE ${CUDA_INCLUDE_DIRS}) + target_compile_definitions(caffe2_nvrtc PRIVATE USE_ROCM __HIP_PLATFORM_HCC__) + install(TARGETS caffe2_nvrtc DESTINATION "${TORCH_INSTALL_LIB_DIR}") endif() if (NOT NO_API) diff --git a/cmake/TorchConfig.cmake.in b/cmake/TorchConfig.cmake.in index 2c10d0a4ae3..3937de7840b 100644 --- a/cmake/TorchConfig.cmake.in +++ b/cmake/TorchConfig.cmake.in @@ -65,8 +65,8 @@ if (@USE_CUDA@) ${NVTOOLEXT_HOME}/lib/x64/nvToolsExt64_1.lib ${CUDA_LIBRARIES}) list(APPEND TORCH_INCLUDE_DIRS ${NVTOOLEXT_HOME}/include) - find_library(THNVRTC_LIBRARY thnvrtc PATHS "${TORCH_INSTALL_PREFIX}/lib") - list(APPEND TORCH_CUDA_LIBRARIES ${THNVRTC_LIBRARY}) + find_library(CAFFE2_NVRTC_LIBRARY caffe2_nvrtc PATHS "${TORCH_INSTALL_PREFIX}/lib") + list(APPEND TORCH_CUDA_LIBRARIES ${CAFFE2_NVRTC_LIBRARY}) elseif(APPLE) set(TORCH_CUDA_LIBRARIES ${CUDA_TOOLKIT_ROOT_DIR}/lib/libcudart.dylib diff --git a/test/test_cuda_primary_ctx.py b/test/test_cuda_primary_ctx.py index 2211a134871..db2023d8702 100644 --- a/test/test_cuda_primary_ctx.py +++ b/test/test_cuda_primary_ctx.py @@ -19,17 +19,17 @@ if not TEST_CUDA: TestCase = object # noqa: F811 -_thnvrtc = None +_caffe2_nvrtc = None def get_is_primary_context_created(device): flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint)) active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) - global _thnvrtc - if _thnvrtc is None: - path = glob.glob('{}/lib/libthnvrtc.*'.format(os.path.dirname(torch.__file__)))[0] - _thnvrtc = ctypes.cdll.LoadLibrary(path) - result = _thnvrtc.cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active) + global _caffe2_nvrtc + if _caffe2_nvrtc is None: + path = glob.glob('{}/lib/libcaffe2_nvrtc.*'.format(os.path.dirname(torch.__file__)))[0] + _caffe2_nvrtc = ctypes.cdll.LoadLibrary(path) + result = _caffe2_nvrtc.cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active) assert result == 0, 'cuDevicePrimaryCtxGetState failed' return bool(active[0]) diff --git a/tools/build_variables.py b/tools/build_variables.py index 1ea5dbe6258..8ff5d25e9a3 100644 --- a/tools/build_variables.py +++ b/tools/build_variables.py @@ -137,7 +137,6 @@ libtorch_sources = [ "torch/csrc/jit/fuser/codegen.cpp", "torch/csrc/jit/fuser/fallback.cpp", "torch/csrc/jit/fuser/cpu/fused_kernel.cpp", - "torch/csrc/jit/fuser/cpu/dynamic_library_unix.cpp", "torch/csrc/jit/fuser/interface.cpp", "torch/csrc/jit/function.cpp", "test/cpp/jit/test.cpp", @@ -147,7 +146,6 @@ libtorch_cuda_sources = [ "torch/csrc/cuda/comm.cpp", "torch/csrc/cuda/nccl.cpp", "torch/csrc/jit/fuser/cuda/fused_kernel.cpp", - "torch/csrc/jit/fuser/cuda/thnvrtc.cpp", "torch/csrc/autograd/profiler_cuda.cpp", "torch/csrc/autograd/functions/comm.cpp" ] @@ -350,7 +348,6 @@ def add_torch_libs(): # TODO: putting USE_CUDA in propagated_pp_flags is error-prone propagated_pp_flags=propagated_pp_flags + [ "-DUSE_CUDA", - "-DUSE_DIRECT_NVRTC", ], deps=[ ":generated-autograd-headers", diff --git a/torch/csrc/jit/fuser/cpu/dynamic_library.h b/torch/csrc/jit/fuser/cpu/dynamic_library.h deleted file mode 100644 index 1841a805654..00000000000 --- a/torch/csrc/jit/fuser/cpu/dynamic_library.h +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include -#include - -namespace torch { -namespace jit { -namespace fuser { -namespace cpu { - -struct DynamicLibrary { - TH_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary); - - TORCH_API DynamicLibrary(const char* name); - - TORCH_API void* sym(const char* name); - - TORCH_API ~DynamicLibrary(); - - private: - void* handle = nullptr; -}; - -} // namespace cpu -} // namespace fuser -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/fuser/cpu/dynamic_library_unix.cpp b/torch/csrc/jit/fuser/cpu/dynamic_library_unix.cpp deleted file mode 100644 index 84d08e30d73..00000000000 --- a/torch/csrc/jit/fuser/cpu/dynamic_library_unix.cpp +++ /dev/null @@ -1,39 +0,0 @@ -#include -#include -#include - -#include -#include - -namespace torch { -namespace jit { -namespace fuser { -namespace cpu { - -static void* checkDL(void* x) { - if (!x) { - AT_ERROR("error in dlopen or dlsym: ", dlerror()); - } - - return x; -} -DynamicLibrary::DynamicLibrary(const char* name) { - // NOLINTNEXTLINE(hicpp-signed-bitwise) - handle = checkDL(dlopen(name, RTLD_LOCAL | RTLD_NOW)); -} - -void* DynamicLibrary::sym(const char* name) { - AT_ASSERT(handle); - return checkDL(dlsym(handle, name)); -} - -DynamicLibrary::~DynamicLibrary() { - if (!handle) - return; - dlclose(handle); -} - -} // namespace cpu -} // namespace fuser -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/fuser/cpu/dynamic_library_win.cpp b/torch/csrc/jit/fuser/cpu/dynamic_library_win.cpp deleted file mode 100644 index 7fb1483fee8..00000000000 --- a/torch/csrc/jit/fuser/cpu/dynamic_library_win.cpp +++ /dev/null @@ -1,41 +0,0 @@ -#include -#include -#include -#include - -namespace torch { -namespace jit { -namespace fuser { -namespace cpu { - - -DynamicLibrary::DynamicLibrary(const char* name) { - // NOLINTNEXTLINE(hicpp-signed-bitwise) - HMODULE theModule = LoadLibraryA(name); - if (theModule) { - handle = theModule; - } else { - AT_ERROR("error in LoadLibraryA"); - } -} - -void* DynamicLibrary::sym(const char* name) { - AT_ASSERT(handle); - FARPROC procAddress = GetProcAddress((HMODULE)handle, name); - if (!procAddress) { - AT_ERROR("error in GetProcAddress"); - } - return (void*)procAddress; -} - -DynamicLibrary::~DynamicLibrary() { - if (!handle) { - return; - } - FreeLibrary((HMODULE)handle); -} - -} // namespace cpu -} // namespace fuser -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/fuser/cpu/fused_kernel.cpp b/torch/csrc/jit/fuser/cpu/fused_kernel.cpp index c9603f5db3a..ae662ad6a2e 100644 --- a/torch/csrc/jit/fuser/cpu/fused_kernel.cpp +++ b/torch/csrc/jit/fuser/cpu/fused_kernel.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #include #include @@ -123,7 +122,7 @@ FusedKernelCPU::FusedKernelCPU( runCompiler(cpp_file.name(), so_file.name()); if (debugFuser() >= 2) disas(so_file.name()); - so_lib = make_unique(so_file.name().c_str()); + so_lib = make_unique(so_file.name().c_str()); #pragma GCC diagnostic ignored "-Wpedantic" kernel = reinterpret_cast(so_lib->sym(name_.c_str())); diff --git a/torch/csrc/jit/fuser/cpu/fused_kernel.h b/torch/csrc/jit/fuser/cpu/fused_kernel.h index 3116f3c27a9..3ce556740af 100644 --- a/torch/csrc/jit/fuser/cpu/fused_kernel.h +++ b/torch/csrc/jit/fuser/cpu/fused_kernel.h @@ -2,7 +2,6 @@ #include #include -#include #include #include @@ -36,7 +35,7 @@ struct TORCH_API FusedKernelCPU : public ::torch::jit::fuser::FusedKernel { } private: - std::unique_ptr so_lib; + std::unique_ptr so_lib; void (*kernel)(uint32_t, void**) = nullptr; }; diff --git a/torch/csrc/jit/fuser/cuda/fused_kernel.cpp b/torch/csrc/jit/fuser/cuda/fused_kernel.cpp index e4382a70c03..1d7ca3f2e4a 100644 --- a/torch/csrc/jit/fuser/cuda/fused_kernel.cpp +++ b/torch/csrc/jit/fuser/cuda/fused_kernel.cpp @@ -1,12 +1,12 @@ #include #include +#include #include +#include #include #include #include -#include -#include #include #include @@ -23,77 +23,17 @@ namespace jit { namespace fuser { namespace cuda { -// [USE OF NVRTC AND DRIVER API] -// libtorch does not directly link to either libnvrtc or libcuda because -// they require libcuda to be installed. Normal CUDA code in torch uses the cuda -// runtime libraries which can be installed even if the driver is not installed, -// but here we specifically need to use the driver API to load JIT compiled -// code. To accomplish this, we lazily link libthnvrtc which provides a struct -// THNVRTC that contains function pointers to all of the apis we need. -// -// IT IS AN ERROR TO TRY TO CALL ANY nvrtc* or cu* FUNCTION DIRECTLY. -// INSTEAD USE, e.g. nvrtc().cuLoadModule(...) -// If a function is missing add it to the list in thnvrtc. - -#ifdef USE_DIRECT_NVRTC -std::pair, THNVRTC*> loadNVRTC() { - return std::make_pair(nullptr, torch_load_nvrtc()); +// See NOTE [ USE OF NVRTC AND DRIVER API ] +const at::cuda::NVRTC& nvrtc() { + return at::globalContext().getNVRTC(); } -#else -std::pair, THNVRTC*> loadNVRTC() { -#if defined(_WIN32) - std::string libthnvrtc = "thnvrtc.dll"; -#elif defined(__APPLE__) - std::string libthnvrtc = "libthnvrtc.dylib"; -#else - std::string libthnvrtc = "libthnvrtc.so"; -#endif - std::unique_ptr libnvrtc_stub( - new cpu::DynamicLibrary(libthnvrtc.c_str())); - auto fn = (THNVRTC * (*)()) libnvrtc_stub->sym("torch_load_nvrtc"); - return std::make_pair(std::move(libnvrtc_stub), fn()); -} -#endif - -const THNVRTC& nvrtc() { - // must hold onto DynamicLibrary otherwise it will unload - static auto handle = loadNVRTC(); - return *handle.second; -} - -// We're using three CUDA APIs, so define a few helpers for error handling -// Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE, incorrectly produces the error string -// "NVRTC unknown error." The following maps it correctly. -static inline void nvrtcCheck(nvrtcResult result, const char* file, int line) { - if (result != NVRTC_SUCCESS) { - std::stringstream ss; - ss << file << ":" << line << ": "; - if (static_cast(result) != 7) - ss << nvrtc().nvrtcGetErrorString(result); - else - ss << "NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"; - throw std::runtime_error(ss.str()); - } -} -#define TORCH_NVRTC_CHECK(result) nvrtcCheck(result, __FILE__, __LINE__); - -static inline void cuCheck(CUresult result, const char* file, int line) { - if (result != CUDA_SUCCESS) { - const char* str; - nvrtc().cuGetErrorString(result, &str); - std::stringstream ss; - ss << file << ":" << line << ": " << str; - throw std::runtime_error(ss.str()); - } -} -#define TORCH_CU_CHECK(result) cuCheck(result, __FILE__, __LINE__); static void getMajorMinor( const cudaDeviceProp* const prop, int& major, int& minor) { int nvrtc_major, nvrtc_minor; - TORCH_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor)); + AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor)); // Short-circuits if NVRTC version too low AT_ASSERT(nvrtc_major >= 6); @@ -145,7 +85,7 @@ FusedKernelCUDA::FusedKernelCUDA( device_(device) { // Initializes driver's API context (if necessary) CUcontext pctx = 0; - TORCH_CU_CHECK(nvrtc().cuCtxGetCurrent(&pctx)); + AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx)); if (!pctx) { std::unique_lock cudaFreeMutexLock( *(c10::cuda::CUDACachingAllocator::getFreeMutex())); @@ -165,7 +105,7 @@ FusedKernelCUDA::FusedKernelCUDA( // Creates the NVRTC program nvrtcProgram program; - TORCH_NVRTC_CHECK(nvrtc().nvrtcCreateProgram( + AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram( &program, code_.c_str(), nullptr, 0, nullptr, nullptr)); const std::string compute = "--gpu-architecture=compute_" + @@ -184,19 +124,19 @@ FusedKernelCUDA::FusedKernelCUDA( throw std::runtime_error(cu.str()); } ResourceGuard holdProgram( - [&] { TORCH_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); }); - TORCH_NVRTC_CHECK(result); + [&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); }); + AT_CUDA_NVRTC_CHECK(result); size_t ptx_size; - TORCH_NVRTC_CHECK(nvrtc().nvrtcGetPTXSize(program, &ptx_size)); + AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTXSize(program, &ptx_size)); ptx_.resize(ptx_size); - TORCH_NVRTC_CHECK(nvrtc().nvrtcGetPTX(program, ptx_.data())); + AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTX(program, ptx_.data())); - TORCH_CU_CHECK(nvrtc().cuModuleLoadData(&module_, ptx_.data())); - TORCH_CU_CHECK( + AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&module_, ptx_.data())); + AT_CUDA_DRIVER_CHECK( nvrtc().cuModuleGetFunction(&function_, module_, name_.c_str())); // Computes max blocks - TORCH_CU_CHECK(nvrtc().cuOccupancyMaxActiveBlocksPerMultiprocessor( + AT_CUDA_DRIVER_CHECK(nvrtc().cuOccupancyMaxActiveBlocksPerMultiprocessor( &maxBlocks_, function_, 128, 0)); maxBlocks_ *= prop_->multiProcessorCount; @@ -236,7 +176,7 @@ void FusedKernelCUDA::launch_raw( // Launches kernel on current stream (device was set by executor) auto stream = at::cuda::getCurrentCUDAStream(); - TORCH_CU_CHECK(nvrtc().cuLaunchKernel( + AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel( function_, nBlocks, 1, diff --git a/torch/csrc/jit/fuser/cuda/thnvrtc.cpp b/torch/csrc/jit/fuser/cuda/thnvrtc.cpp deleted file mode 100644 index c1f75258d59..00000000000 --- a/torch/csrc/jit/fuser/cuda/thnvrtc.cpp +++ /dev/null @@ -1,9 +0,0 @@ -#include -#include - -THNVRTC* torch_load_nvrtc() { - auto self = new THNVRTC(); -#define CREATE_ASSIGN(name) self->name = name; - TORCH_FORALL_NVRTC(CREATE_ASSIGN) - return self; -} diff --git a/torch/csrc/jit/fuser/cuda/thnvrtc.h b/torch/csrc/jit/fuser/cuda/thnvrtc.h deleted file mode 100644 index db7aa8d1198..00000000000 --- a/torch/csrc/jit/fuser/cuda/thnvrtc.h +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#include -#include -#include - -// See [USE OF NVRTC AND DRIVER API] - -#define TORCH_FORALL_NVRTC(_) \ - _(nvrtcVersion) \ - _(nvrtcCreateProgram) \ - _(nvrtcDestroyProgram) \ - _(nvrtcGetPTXSize) \ - _(nvrtcGetPTX) \ - _(cuModuleLoadData) \ - _(cuModuleGetFunction) \ - _(cuOccupancyMaxActiveBlocksPerMultiprocessor) \ - _(cuGetErrorString) \ - _(nvrtcGetErrorString) \ - _(nvrtcGetProgramLogSize) \ - _(nvrtcGetProgramLog) \ - _(cuLaunchKernel) \ - _(nvrtcCompileProgram) \ - _(cuCtxGetCurrent) \ - _(cuModuleUnload) \ - _(cuDevicePrimaryCtxGetState) - -extern "C" typedef struct THNVRTC { -#define CREATE_MEMBER(name) decltype(&name) name; - TORCH_FORALL_NVRTC(CREATE_MEMBER) -#undef CREATE_MEMBER -} THNVRTC; - -extern "C" TORCH_API THNVRTC* torch_load_nvrtc(); diff --git a/torch/csrc/utils/disallow_copy.h b/torch/csrc/utils/disallow_copy.h index 786b8b7c7f8..5960421d3a4 100644 --- a/torch/csrc/utils/disallow_copy.h +++ b/torch/csrc/utils/disallow_copy.h @@ -1,4 +1,5 @@ #pragma once -#define TH_DISALLOW_COPY_AND_ASSIGN(TypeName) \ - TypeName(const TypeName&) = delete; \ - void operator=(const TypeName&) = delete + +#include + +#define TH_DISALLOW_COPY_AND_ASSIGN AT_DISALLOW_COPY_AND_ASSIGN