diff --git a/.gitmodules b/.gitmodules index 08868cba712..8f870ad4bfa 100644 --- a/.gitmodules +++ b/.gitmodules @@ -130,6 +130,9 @@ ignore = dirty path = third_party/tensorpipe url = https://github.com/pytorch/tensorpipe.git +[submodule "third_party/cudnn_frontend"] + path = third_party/cudnn_frontend + url = https://github.com/NVIDIA/cudnn-frontend.git [submodule "third_party/kineto"] path = third_party/kineto url = https://github.com/pytorch/kineto diff --git a/CMakeLists.txt b/CMakeLists.txt index 4ad16a8a301..9786d36b854 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -194,6 +194,9 @@ cmake_dependent_option( cmake_dependent_option( USE_STATIC_CUDNN "Use cuDNN static libraries" OFF "USE_CUDNN" OFF) +cmake_dependent_option( + USE_EXPERIMENTAL_CUDNN_V8_API "Use experimental cuDNN v8 API" OFF + "USE_CUDNN" OFF) option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON) option(USE_KINETO "Use Kineto profiling library" ON) option(USE_CUPTI_SO "Use CUPTI as a shared library" OFF) diff --git a/aten/src/ATen/native/cudnn/Conv_v7.cpp b/aten/src/ATen/native/cudnn/Conv_v7.cpp index dc61faea53d..7d16f0a9a91 100644 --- a/aten/src/ATen/native/cudnn/Conv_v7.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v7.cpp @@ -2,6 +2,8 @@ #if AT_CUDNN_ENABLED() +#include + #include #include #include @@ -614,6 +616,8 @@ if (args.params.dataType == CUDNN_DATA_FLOAT) { // // --------------------------------------------------------------------- +#if !HAS_CUDNN_V8() + void raw_cudnn_convolution_forward_out_32bit( const Tensor& output, const Tensor& input, const Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, @@ -665,6 +669,8 @@ void raw_cudnn_convolution_forward_out( split_batch_dim_to_32bit_out(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, 1024 * 1024 * 256, raw_cudnn_convolution_forward_out_32bit); } +#endif // !HAS_CUDNN_V8() + // --------------------------------------------------------------------- // // Convolution backward / Transposed convolution forward diff --git a/aten/src/ATen/native/cudnn/Conv_v8.cpp b/aten/src/ATen/native/cudnn/Conv_v8.cpp index 53f8c37f5e6..9ba1775988b 100644 --- a/aten/src/ATen/native/cudnn/Conv_v8.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v8.cpp @@ -1,5 +1,177 @@ #include // for the definition of AT_CUDNN_ENABLED -#if AT_CUDNN_ENABLED() && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 -// Coming soon -#endif // AT_CUDNN_ENABLED and CUDNN_VERSION +#if AT_CUDNN_ENABLED() + +#include + +#if HAS_CUDNN_V8() + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { namespace native{ + +namespace { + +uint8_t getAlignment(const Tensor &t) { + // alignment are in bytes + uint8_t alignment = 1; + uint64_t address = reinterpret_cast(t.data_ptr()); + while (address % alignment == 0 && alignment < 16) alignment *= 2; + return alignment; +} + +cudnn_frontend::Tensor getTensorDescriptor(const Tensor &t, int64_t id, uint8_t alignment) { + auto shape = t.sizes(); + auto strides = t.strides(); + return cudnn_frontend::TensorBuilder() + .setDim(shape.size(), shape.data()) + .setStrides(strides.size(), strides.data()) + .setId(id) + .setAlignment(alignment) + .setDataType(getCudnnDataType(t)) + .build(); +} + +cudnn_frontend::ConvDesc_v8 getConvDescriptor(cudnnDataType_t dataType, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation) { + uint64_t convDim = stride.size(); + return cudnn_frontend::ConvDescBuilder() + .setDataType(dataType) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, stride.data()) + .setPrePadding(convDim, padding.data()) + .setPostPadding(convDim, padding.data()) + .setDilation(convDim, dilation.data()) + .build(); +} + +void filterEngineConfigs( + cudnn_frontend::EngineConfigList &from, + cudnn_frontend::EngineConfigList &to, + bool deterministic, bool allow_tf32, c10::ScalarType scalar_type) +{ + auto filter = [=](cudnnBackendDescriptor_t c) { + if (deterministic) { + if (cudnn_frontend::hasNumericalNote(c)) return true; + } + if (scalar_type == kFloat || !allow_tf32) { + if (cudnn_frontend::hasNumericalNote(c)) return true; + if (cudnn_frontend::hasNumericalNote(c)) return true; + } + return false; + }; + cudnn_frontend::filter(from, to, filter); +} + +struct CacheKey { + ConvolutionParams params; + uint8_t input_alignment; + uint8_t weight_alignment; + uint8_t output_alignment; +}; + +// FIXME: make this thread-safe by reusing the benchmark cache in Conv_v7.cpp +std::unordered_map, ParamsEqual> engine_cache; + +} + +void raw_cudnn_convolution_forward_out( + const Tensor& output, const Tensor& input, const Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) +{ + TORCH_CHECK(!benchmark, "not supported yet"); + if (output.numel() == 0) { + return; + } + + cudnnHandle_t handle = getCudnnHandle(); + + CacheKey key; + setConvolutionParams(&key.params, input, weight, padding, stride, dilation, groups, deterministic, allow_tf32); + key.input_alignment = getAlignment(input); + key.output_alignment = getAlignment(output); + key.weight_alignment = getAlignment(weight); + + auto run = [&](cudnn_frontend::ManagedOpaqueDescriptor cfg) { + auto plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle) + .setEngineConfig(cfg) + .build(); + + auto workspace_size = plan.getWorkspaceSize(); + auto workspace = at::empty({workspace_size}, input.options().dtype(kByte)); + void *data_ptrs[] = {input.data_ptr(), output.data_ptr(), weight.data_ptr()}; + // std::cout << plan.describe() << " requires workspace " << workspace_size << std::endl; + int64_t uids[] = {'x', 'y', 'w'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace.data_ptr()) + .setDataPointers(3, data_ptrs) + .setUids(3, uids) + .build(); + AT_CUDNN_CHECK(cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc())); + }; + + auto search = engine_cache.find(key); + if (search != engine_cache.end()) { + run(search->second); + return; + } + + auto op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) + .setxDesc(getTensorDescriptor(input, 'x', key.input_alignment)) + .setyDesc(getTensorDescriptor(output, 'y', key.output_alignment)) + .setwDesc(getTensorDescriptor(weight, 'w', key.weight_alignment)) + .setcDesc(getConvDescriptor(key.params.dataType, padding, stride, dilation)) + .build(); + // std::cout << op.describe() << std::endl; + + std::array ops = {&op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle) + .setOperationGraph(1, ops.data()) + .build(); + // std::cout << opGraph.describe() << std::endl; + + auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() + .setOperationGraph(opGraph) + .setHeurMode(CUDNN_HEUR_MODE_INSTANT) + .build(); + auto fallback = cudnn_frontend::EngineFallbackListBuilder() + .setOperationGraph(opGraph) + .setOperation(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) + .build(); + + auto& engine_configs = heuristics.getEngineConfig(heuristics.getEngineConfigCount()); + auto& fallback_list = fallback.getFallbackList(); + + cudnn_frontend::EngineConfigList filtered_configs; + filterEngineConfigs(engine_configs, filtered_configs, deterministic, allow_tf32, input.scalar_type()); + filterEngineConfigs(fallback_list, filtered_configs, deterministic, allow_tf32, input.scalar_type()); + + for (auto &cfg : filtered_configs) { + try { + run(cfg); + engine_cache[key] = cfg; + return; + } catch (cudnn_frontend::cudnnException &e) {} catch(CuDNNError &e) {} + } + TORCH_CHECK(false, "Unable to find an engine to execute this computation"); +} + +}} // at::native + +#endif // HAS_CUDNN_V8 +#endif // AT_CUDNN_ENABLED diff --git a/aten/src/ATen/native/cudnn/Macros.h b/aten/src/ATen/native/cudnn/Macros.h new file mode 100644 index 00000000000..fdc65524328 --- /dev/null +++ b/aten/src/ATen/native/cudnn/Macros.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +// Note: The version below should not actually be 8000. Instead, it should +// be whatever version of cuDNN that v8 API work with PyTorch correctly. +// The version is set to 8000 today for convenience of debugging. +#if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 +#define HAS_CUDNN_V8() true +#else +#define HAS_CUDNN_V8() false +#endif diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 9b5e4bd285f..2cd1015e84a 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1304,6 +1304,15 @@ elseif(USE_ROCM) target_compile_definitions(torch_hip PRIVATE "-DTORCH_HIP_BUILD_MAIN_LIB") endif() +if(USE_EXPERIMENTAL_CUDNN_V8_API) + if(BUILD_SPLIT_CUDA) + target_compile_definitions(torch_cuda_cu PRIVATE "-DUSE_EXPERIMENTAL_CUDNN_V8_API") + target_compile_definitions(torch_cuda_cpp PRIVATE "-DUSE_EXPERIMENTAL_CUDNN_V8_API") + elseif(USE_CUDA) + target_compile_definitions(torch_cuda PRIVATE "-DUSE_EXPERIMENTAL_CUDNN_V8_API") + endif() +endif() + set(EXPERIMENTAL_SINGLE_THREAD_POOL "0" CACHE STRING "Experimental option to use a single thread pool for inter- and intra-op parallelism") if("${EXPERIMENTAL_SINGLE_THREAD_POOL}") diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 4a34d669304..c7fe9b7d4bd 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1189,6 +1189,12 @@ if(USE_CUDA) endif() endif() +# ---[ cuDNN +if(USE_CUDNN) + set(CUDNN_FRONTEND_INCLUDE_DIR ${CMAKE_CURRENT_LIST_DIR}/../third_party/cudnn_frontend/include) + include_directories(${CUDNN_FRONTEND_INCLUDE_DIR}) +endif() + # ---[ HIP if(USE_ROCM) include(${CMAKE_CURRENT_LIST_DIR}/public/LoadHIP.cmake) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 4c402894a7a..795da7cc428 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -74,6 +74,7 @@ function(caffe2_print_configuration_summary) message(STATUS " Split CUDA : ${BUILD_SPLIT_CUDA}") message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}") message(STATUS " USE_CUDNN : ${USE_CUDNN}") + message(STATUS " USE_EXPERIMENTAL_CUDNN_V8_API: ${USE_EXPERIMENTAL_CUDNN_V8_API}") message(STATUS " CUDA version : ${CUDA_VERSION}") if(${USE_CUDNN}) message(STATUS " cuDNN version : ${CUDNN_VERSION}") diff --git a/setup.py b/setup.py index d29a6c3cb96..62db0aa863e 100644 --- a/setup.py +++ b/setup.py @@ -332,7 +332,7 @@ def check_submodules(): print('Please run:\n\tgit submodule update --init --recursive') sys.exit(1) for folder in folders: - check_for_files(folder, ["CMakeLists.txt", "Makefile", "setup.py", "LICENSE"]) + check_for_files(folder, ["CMakeLists.txt", "Makefile", "setup.py", "LICENSE", "LICENSE.txt"]) check_for_files(os.path.join(third_party_path, 'fbgemm', 'third_party', 'asmjit'), ['CMakeLists.txt']) check_for_files(os.path.join(third_party_path, 'onnx', 'third_party', diff --git a/third_party/cudnn_frontend b/third_party/cudnn_frontend new file mode 160000 index 00000000000..51e60d891b6 --- /dev/null +++ b/third_party/cudnn_frontend @@ -0,0 +1 @@ +Subproject commit 51e60d891b689d618e7a623509a779c422a420f7