mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
As the title stated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160959 Approved by: https://github.com/Skylion007
1244 lines
39 KiB
C++
1244 lines
39 KiB
C++
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
|
#include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
|
|
|
|
#if AT_CUDNN_ENABLED()
|
|
|
|
#include <ATen/core/Tensor.h>
|
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS
|
|
#include <ATen/Functions.h>
|
|
#else
|
|
#include <ATen/ops/empty.h>
|
|
#include <ATen/ops/empty_like.h>
|
|
#include <ATen/ops/zeros.h>
|
|
#endif
|
|
|
|
#include <ATen/Config.h>
|
|
#include <ATen/cuda/Exceptions.h>
|
|
#include <ATen/native/cudnn/ConvShared.h>
|
|
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
|
#include <limits>
|
|
#include <vector>
|
|
|
|
#include <ATen/cudnn/Types.h>
|
|
#include <ATen/cudnn/Utils.h>
|
|
#include <ATen/native/utils/ParamsHash.h>
|
|
|
|
#include <ATen/TensorUtils.h>
|
|
#include <c10/util/irange.h>
|
|
|
|
#include <stdint.h>
|
|
#include <algorithm>
|
|
#include <functional>
|
|
#include <iterator>
|
|
#include <memory>
|
|
#include <mutex>
|
|
#include <sstream>
|
|
#include <unordered_map>
|
|
|
|
// Note [behavior of cudnnFind and cudnnGet]
|
|
// You'll notice that by default, in the ConvolutionDescriptor, we do the
|
|
// following:
|
|
//
|
|
// AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(),
|
|
// CUDNN_DEFAULT_MATH)); if(dataType == CUDNN_DATA_HALF)
|
|
// AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(),
|
|
// CUDNN_TENSOR_OP_MATH));
|
|
//
|
|
// Update: AT_CUDNN_CHECK is updated with AT_CUDNN_CHECK_WITH_SHAPES, which
|
|
// automatically prints tensor shapes and convolution parameters if there
|
|
// is a cuDNN exception thrown.
|
|
//
|
|
// When cudnnSetConvolutionMathType is called before cudnnGet/cudnnFind, it
|
|
// informs cudnnGet/cudnnFind to iterate/take into account both tensor core and
|
|
// non-tensor-core algos. If you don't call cudnnSetConvolutionMathType before
|
|
// calling cudnnGet/cudnnFind, cudnnGet/cudnnFind may not pick tensor core
|
|
// algos.
|
|
//
|
|
// Now after its run, cudnnGet/cudnnFind comes up with the best pair of
|
|
// algo+mathType with all the initial knowledge its given. It then becomes the
|
|
// user's responsibility to update mathType of the convolution descriptor and
|
|
// call the subsequent cudnn calls with the best algo and the updated
|
|
// descriptor. If we don't update the descriptor but just run with the best
|
|
// algo, under the hood, cudnn will run with the slower kernel since it sees
|
|
// fastest algorithm combination with a sub optimal mathType.
|
|
|
|
constexpr size_t operator"" _TiB(unsigned long long n) {
|
|
return size_t(n) * 1024 * 1024 * 1024 * 1024;
|
|
}
|
|
|
|
namespace at {
|
|
namespace native {
|
|
|
|
// Convenience struct for passing around descriptors and data
|
|
// pointers
|
|
struct ConvolutionArgs {
|
|
cudnnHandle_t handle;
|
|
ConvolutionParams params;
|
|
TensorDescriptor idesc, odesc;
|
|
FilterDescriptor wdesc;
|
|
const Tensor &input, output, weight;
|
|
ConvolutionDescriptor cdesc;
|
|
|
|
ConvolutionArgs(
|
|
const Tensor& input,
|
|
const Tensor& output,
|
|
const Tensor& weight)
|
|
: input(input), output(output), weight(weight) {}
|
|
};
|
|
|
|
std::ostream& operator<<(std::ostream& out, const ConvolutionArgs& args) {
|
|
out << repro_from_args(args.params) // already has a trailing newline
|
|
<< args.params // already has a trailing newline
|
|
<< "input: " << args.idesc // already has a trailing newline
|
|
<< "output: " << args.odesc // already has a trailing newline
|
|
<< "weight: " << args.wdesc // already has a trailing newline
|
|
<< "Pointer addresses: "
|
|
<< "\n"
|
|
<< " input: " << args.input.const_data_ptr() << "\n"
|
|
<< " output: " << args.output.const_data_ptr() << "\n"
|
|
<< " weight: " << args.weight.const_data_ptr() << "\n";
|
|
|
|
return out;
|
|
}
|
|
|
|
// ---------------------------------------------------------------------
|
|
//
|
|
// Benchmarking
|
|
//
|
|
// ---------------------------------------------------------------------
|
|
|
|
// TODO: Use something less heavy duty than a big honking mutex
|
|
template <typename T>
|
|
struct BenchmarkCache {
|
|
std::mutex mutex;
|
|
std::unordered_map<
|
|
ConvolutionParams,
|
|
T,
|
|
ParamsHash<ConvolutionParams>,
|
|
ParamsEqual<ConvolutionParams>>
|
|
map;
|
|
|
|
bool find(const ConvolutionParams& params, T* results) {
|
|
std::lock_guard<std::mutex> guard(mutex);
|
|
auto it = map.find(params);
|
|
if (it == map.end()) {
|
|
return false;
|
|
}
|
|
*results = it->second;
|
|
return true;
|
|
}
|
|
|
|
void insert(const ConvolutionParams& params, const T& results) {
|
|
std::lock_guard<std::mutex> guard(mutex);
|
|
map[params] = results;
|
|
}
|
|
};
|
|
|
|
BenchmarkCache<cudnnConvolutionFwdAlgoPerf_t> fwd_algos;
|
|
BenchmarkCache<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_algos;
|
|
BenchmarkCache<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filter_algos;
|
|
|
|
// TODO: Stop manually allocating CUDA memory; allocate an ATen byte
|
|
// tensor instead.
|
|
struct Workspace {
|
|
Workspace(size_t size) : size(size), data(NULL) {
|
|
// Sometimes cuDNN returns a workspace size > 2^63, this could makes the
|
|
// allocation of workspace fail with some 64bit indexing error instead of an
|
|
// OOM error. In such case, we manually fail with OOM.
|
|
TORCH_CHECK_WITH(
|
|
OutOfMemoryError, size < 1_TiB, "Not enough memory for workspace!");
|
|
data = c10::cuda::CUDACachingAllocator::raw_alloc(size);
|
|
}
|
|
Workspace(const Workspace&) = delete;
|
|
Workspace(Workspace&&) = default;
|
|
Workspace& operator=(Workspace&&) = default;
|
|
~Workspace() {
|
|
if (data) {
|
|
c10::cuda::CUDACachingAllocator::raw_delete(data);
|
|
}
|
|
}
|
|
|
|
size_t size;
|
|
void* data;
|
|
};
|
|
|
|
template <typename perf_t>
|
|
struct algorithm_search {};
|
|
|
|
cudnnStatus_t getWorkspaceSize(
|
|
const ConvolutionArgs& args,
|
|
cudnnConvolutionFwdAlgo_t algo,
|
|
size_t* sz) {
|
|
return cudnnGetConvolutionForwardWorkspaceSize(
|
|
args.handle,
|
|
args.idesc.desc(),
|
|
args.wdesc.desc(),
|
|
args.cdesc.desc(),
|
|
args.odesc.desc(),
|
|
algo,
|
|
sz);
|
|
}
|
|
cudnnStatus_t getWorkspaceSize(
|
|
const ConvolutionArgs& args,
|
|
cudnnConvolutionBwdDataAlgo_t algo,
|
|
size_t* sz) {
|
|
return cudnnGetConvolutionBackwardDataWorkspaceSize(
|
|
args.handle,
|
|
args.wdesc.desc(),
|
|
args.odesc.desc(),
|
|
args.cdesc.desc(),
|
|
args.idesc.desc(),
|
|
algo,
|
|
sz);
|
|
}
|
|
cudnnStatus_t getWorkspaceSize(
|
|
const ConvolutionArgs& args,
|
|
cudnnConvolutionBwdFilterAlgo_t algo,
|
|
size_t* sz) {
|
|
return cudnnGetConvolutionBackwardFilterWorkspaceSize(
|
|
args.handle,
|
|
args.idesc.desc(),
|
|
args.odesc.desc(),
|
|
args.cdesc.desc(),
|
|
args.wdesc.desc(),
|
|
algo,
|
|
sz);
|
|
}
|
|
|
|
template <typename algo_t>
|
|
size_t getMaxWorkspaceSize(
|
|
const ConvolutionArgs& args,
|
|
const algo_t* algo,
|
|
int n_algo) {
|
|
size_t max_ws_size = 0;
|
|
size_t max_block_size = 0;
|
|
|
|
const auto device = c10::cuda::current_device();
|
|
// For the native allocator, retrieves the size of the largest unused block.
|
|
// For cudaMallocAsync, see c10/cuda/CUDAMallocAsync.cpp:cacheInfo for
|
|
// details.
|
|
c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size);
|
|
|
|
for (const auto i : c10::irange(n_algo)) {
|
|
cudnnStatus_t err;
|
|
size_t sz;
|
|
err = getWorkspaceSize(args, algo[i], &sz);
|
|
if (CUDNN_STATUS_SUCCESS != err || sz == 0 || sz < max_ws_size ||
|
|
sz > max_block_size)
|
|
continue;
|
|
max_ws_size = sz;
|
|
}
|
|
return max_ws_size;
|
|
}
|
|
|
|
template <typename perf_t>
|
|
std::vector<perf_t> getValidAlgorithms(
|
|
perf_t* perfResults,
|
|
const ConvolutionArgs& args,
|
|
int n_algo) {
|
|
std::vector<perf_t> result;
|
|
result.reserve(n_algo);
|
|
for (const auto i : c10::irange(n_algo)) {
|
|
perf_t perf = perfResults[i];
|
|
|
|
// TODO: Shouldn't all returned results be successful?
|
|
// Double check documentation for cudnnFindConvolutionForwardAlgorithmEx
|
|
if (perf.status == CUDNN_STATUS_SUCCESS) {
|
|
if (!args.params.deterministic ||
|
|
perf.determinism == CUDNN_DETERMINISTIC) {
|
|
result.push_back(perf);
|
|
}
|
|
}
|
|
}
|
|
TORCH_CHECK(
|
|
result.size() > 0, "no valid convolution algorithms available in CuDNN");
|
|
return result;
|
|
}
|
|
|
|
template <>
|
|
struct algorithm_search<cudnnConvolutionFwdAlgoPerf_t> {
|
|
using perf_t = cudnnConvolutionFwdAlgoPerf_t;
|
|
using algo_t = cudnnConvolutionFwdAlgo_t;
|
|
|
|
static constexpr auto DEFAULT_ALGO =
|
|
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
|
|
static BenchmarkCache<perf_t>& cache() {
|
|
return fwd_algos;
|
|
}
|
|
|
|
static std::vector<perf_t> findAlgorithms(
|
|
const ConvolutionArgs& args,
|
|
bool benchmark) {
|
|
static const algo_t algos[] = {
|
|
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
|
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
|
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
|
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
|
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
|
|
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
|
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
|
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
|
};
|
|
static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
|
|
static_assert(
|
|
sizeof(algos) / sizeof(algos[0]) == num_algos,
|
|
"Missing cuDNN convolution forward algorithms");
|
|
int perf_count;
|
|
c10::SmallVector<perf_t, CUDNN_CONVOLUTION_FWD_ALGO_COUNT> perf_results;
|
|
if (!benchmark) {
|
|
AT_CUDNN_CHECK_WITH_SHAPES(
|
|
cudnnGetConvolutionForwardAlgorithm_v7(
|
|
args.handle,
|
|
args.idesc.desc(),
|
|
args.wdesc.desc(),
|
|
args.cdesc.desc(),
|
|
args.odesc.desc(),
|
|
num_algos,
|
|
&perf_count,
|
|
perf_results.data()),
|
|
args);
|
|
} else {
|
|
size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
|
|
Workspace ws(max_ws_size);
|
|
at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind");
|
|
AT_CUDNN_CHECK_WITH_SHAPES(
|
|
cudnnFindConvolutionForwardAlgorithmEx(
|
|
args.handle,
|
|
args.idesc.desc(),
|
|
args.input.const_data_ptr(),
|
|
args.wdesc.desc(),
|
|
args.weight.const_data_ptr(),
|
|
args.cdesc.desc(),
|
|
args.odesc.desc(),
|
|
args.output.data_ptr(),
|
|
num_algos,
|
|
&perf_count,
|
|
perf_results.data(),
|
|
ws.data,
|
|
ws.size),
|
|
args);
|
|
|
|
// Free the cached blocks in our caching allocator. They are
|
|
// needed here because the above benchmarking uses a huge amount of
|
|
// memory, e.g. a few GBs.
|
|
c10::cuda::CUDACachingAllocator::emptyCache();
|
|
}
|
|
return getValidAlgorithms<perf_t>(perf_results.data(), args, perf_count);
|
|
}
|
|
|
|
static void getWorkspaceSize(
|
|
const ConvolutionArgs& args,
|
|
algo_t algo,
|
|
size_t* workspaceSize) {
|
|
AT_CUDNN_CHECK_WITH_SHAPES(
|
|
cudnnGetConvolutionForwardWorkspaceSize(
|
|
args.handle,
|
|
args.idesc.desc(),
|
|
args.wdesc.desc(),
|
|
args.cdesc.desc(),
|
|
args.odesc.desc(),
|
|
algo,
|
|
workspaceSize),
|
|
args);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct algorithm_search<cudnnConvolutionBwdDataAlgoPerf_t> {
|
|
using perf_t = cudnnConvolutionBwdDataAlgoPerf_t;
|
|
using algo_t = cudnnConvolutionBwdDataAlgo_t;
|
|
|
|
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
|
static BenchmarkCache<perf_t>& cache() {
|
|
return bwd_data_algos;
|
|
}
|
|
|
|
static std::vector<perf_t> findAlgorithms(
|
|
const ConvolutionArgs& args,
|
|
bool benchmark) {
|
|
static const algo_t algos[] = {
|
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
|
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
|
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
|
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
|
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
|
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED};
|
|
static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
|
|
static_assert(
|
|
sizeof(algos) / sizeof(algos[0]) == num_algos,
|
|
"Missing cuDNN convolution backward data algorithms.");
|
|
int perf_count;
|
|
c10::SmallVector<perf_t, CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT>
|
|
perf_results;
|
|
if (!benchmark) {
|
|
AT_CUDNN_CHECK_WITH_SHAPES(
|
|
cudnnGetConvolutionBackwardDataAlgorithm_v7(
|
|
args.handle,
|
|
args.wdesc.desc(),
|
|
args.odesc.desc(),
|
|
args.cdesc.desc(),
|
|
args.idesc.desc(),
|
|
num_algos,
|
|
&perf_count,
|
|
perf_results.data()),
|
|
args);
|
|
} else {
|
|
size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
|
|
Workspace ws(max_ws_size);
|
|
at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind");
|
|
AT_CUDNN_CHECK_WITH_SHAPES(
|
|
cudnnFindConvolutionBackwardDataAlgorithmEx(
|
|
args.handle,
|
|
args.wdesc.desc(),
|
|
args.weight.const_data_ptr(),
|
|
args.odesc.desc(),
|
|
args.output.const_data_ptr(),
|
|
args.cdesc.desc(),
|
|
args.idesc.desc(),
|
|
args.input.data_ptr(),
|
|
num_algos,
|
|
&perf_count,
|
|
perf_results.data(),
|
|
ws.data,
|
|
ws.size),
|
|
args);
|
|
|
|
// Free the cached blocks in our caching allocator. They are
|
|
// needed here because the above benchmarking uses a huge amount of
|
|
// memory, e.g. a few GBs.
|
|
c10::cuda::CUDACachingAllocator::emptyCache();
|
|
}
|
|
return getValidAlgorithms<perf_t>(perf_results.data(), args, perf_count);
|
|
}
|
|
|
|
static void getWorkspaceSize(
|
|
const ConvolutionArgs& args,
|
|
cudnnConvolutionBwdDataAlgo_t algo,
|
|
size_t* workspaceSize) {
|
|
AT_CUDNN_CHECK_WITH_SHAPES(
|
|
cudnnGetConvolutionBackwardDataWorkspaceSize(
|
|
args.handle,
|
|
args.wdesc.desc(),
|
|
args.odesc.desc(),
|
|
args.cdesc.desc(),
|
|
args.idesc.desc(),
|
|
algo,
|
|
workspaceSize),
|
|
args);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct algorithm_search<cudnnConvolutionBwdFilterAlgoPerf_t> {
|
|
using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
|
|
using algo_t = cudnnConvolutionBwdFilterAlgo_t;
|
|
|
|
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
|
|
|
|
static BenchmarkCache<perf_t>& cache() {
|
|
return bwd_filter_algos;
|
|
}
|
|
|
|
static std::vector<perf_t> findAlgorithms(
|
|
const ConvolutionArgs& args,
|
|
bool benchmark) {
|
|
static const algo_t algos[] = {
|
|
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
|
|
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
|
|
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
|
|
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
|
|
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED,
|
|
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
|
|
};
|
|
// NOTE: - 1 because ALGO_WINOGRAD is not implemented
|
|
static constexpr int num_algos =
|
|
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1;
|
|
static_assert(
|
|
sizeof(algos) / sizeof(algos[0]) == num_algos,
|
|
"Missing cuDNN convolution backward filter algorithms.");
|
|
c10::SmallVector<perf_t, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT>
|
|
perf_results;
|
|
int perf_count;
|
|
if (!benchmark) {
|
|
AT_CUDNN_CHECK_WITH_SHAPES(
|
|
cudnnGetConvolutionBackwardFilterAlgorithm_v7(
|
|
args.handle,
|
|
args.idesc.desc(),
|
|
args.odesc.desc(),
|
|
args.cdesc.desc(),
|
|
args.wdesc.desc(),
|
|
num_algos,
|
|
&perf_count,
|
|
perf_results.data()),
|
|
args);
|
|
} else {
|
|
size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
|
|
Workspace ws(max_ws_size);
|
|
at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind");
|
|
AT_CUDNN_CHECK_WITH_SHAPES(
|
|
cudnnFindConvolutionBackwardFilterAlgorithmEx(
|
|
args.handle,
|
|
args.idesc.desc(),
|
|
args.input.const_data_ptr(),
|
|
args.odesc.desc(),
|
|
args.output.const_data_ptr(),
|
|
args.cdesc.desc(),
|
|
args.wdesc.desc(),
|
|
args.weight.data_ptr(),
|
|
num_algos,
|
|
&perf_count,
|
|
perf_results.data(),
|
|
ws.data,
|
|
ws.size),
|
|
args);
|
|
|
|
// Free the cached blocks in our caching allocator. They are
|
|
// needed here because the above benchmarking uses a huge amount of
|
|
// memory, e.g. a few GBs.
|
|
c10::cuda::CUDACachingAllocator::emptyCache();
|
|
}
|
|
return getValidAlgorithms<perf_t>(perf_results.data(), args, perf_count);
|
|
}
|
|
|
|
static void getWorkspaceSize(
|
|
const ConvolutionArgs& args,
|
|
algo_t algo,
|
|
size_t* workspaceSize) {
|
|
AT_CUDNN_CHECK_WITH_SHAPES(
|
|
cudnnGetConvolutionBackwardFilterWorkspaceSize(
|
|
args.handle,
|
|
args.idesc.desc(),
|
|
args.odesc.desc(),
|
|
args.cdesc.desc(),
|
|
args.wdesc.desc(),
|
|
algo,
|
|
workspaceSize),
|
|
args);
|
|
}
|
|
};
|
|
|
|
template <typename perf_t>
|
|
class AlgoIterator {
|
|
using search = algorithm_search<perf_t>;
|
|
const ConvolutionArgs& args;
|
|
bool benchmark;
|
|
|
|
public:
|
|
AlgoIterator(const ConvolutionArgs& args, bool benchmark)
|
|
: args(args), benchmark(benchmark) {}
|
|
|
|
static std::vector<perf_t> onlyDefaultAlgorithm(const ConvolutionArgs& args) {
|
|
std::vector<perf_t> perfResults(1);
|
|
perfResults[0].algo = search::DEFAULT_ALGO;
|
|
if (args.params.dataType == CUDNN_DATA_HALF) {
|
|
perfResults[0].mathType = CUDNN_TENSOR_OP_MATH;
|
|
} else {
|
|
perfResults[0].mathType = CUDNN_DEFAULT_MATH;
|
|
if (args.params.dataType == CUDNN_DATA_FLOAT && !args.params.allow_tf32) {
|
|
perfResults[0].mathType = CUDNN_FMA_MATH;
|
|
}
|
|
}
|
|
search::getWorkspaceSize(
|
|
args, perfResults[0].algo, &(perfResults[0].memory));
|
|
return perfResults;
|
|
}
|
|
|
|
void try_all(std::function<void(const perf_t& perf)> f) {
|
|
bool only_use_default = args.params.deterministic && !benchmark;
|
|
|
|
auto& cache = search::cache();
|
|
perf_t algoPerf;
|
|
if (!only_use_default && cache.find(args.params, &algoPerf)) {
|
|
try {
|
|
f(algoPerf);
|
|
return;
|
|
} catch (c10::OutOfMemoryError&) {
|
|
std::ignore = cudaGetLastError(); // clear CUDA error
|
|
}
|
|
}
|
|
|
|
auto perfResults = only_use_default
|
|
? onlyDefaultAlgorithm(args)
|
|
: search::findAlgorithms(args, benchmark);
|
|
for (auto& algoPerf : perfResults) {
|
|
try {
|
|
f(algoPerf);
|
|
cache.insert(args.params, algoPerf);
|
|
return;
|
|
} catch (c10::OutOfMemoryError&) {
|
|
std::ignore = cudaGetLastError(); // clear CUDA error
|
|
} catch (c10::CuDNNError&) {
|
|
std::ignore = cudaGetLastError(); // clear CUDA error
|
|
}
|
|
}
|
|
TORCH_CHECK(
|
|
false, "Unable to find a valid cuDNN algorithm to run convolution");
|
|
}
|
|
};
|
|
|
|
inline Tensor allocate_workspace(size_t size, const Tensor& other) {
|
|
// Sometimes cuDNN returns a workspace size > 2^63, this could makes the
|
|
// allocation of workspace fail with some 64bit indexing error instead of an
|
|
// OOM error. In such case, we manually fail with OOM.
|
|
TORCH_CHECK_WITH(
|
|
OutOfMemoryError, size < 1_TiB, "Not enough memory for workspace!");
|
|
return at::empty({static_cast<int64_t>(size)}, other.options().dtype(kByte));
|
|
}
|
|
|
|
// NOTE [ raw_cudnn_convolution_forward_out ]
|
|
//
|
|
// - raw_cudnn_convolution_forward_out (Tensor)
|
|
// Functiont that handles tensors that are too large to use 32bit indexing.
|
|
// It just split the tensor and dispatches to
|
|
// `raw_cudnn_convolution_forward_out_32bit`.
|
|
//
|
|
// - raw_cudnn_convolution_forward_out_32bit (Tensor)
|
|
// Low level function which invokes CuDNN, and takes an output
|
|
// tensor which is directly written to (thus _out).
|
|
//
|
|
|
|
// ---------------------------------------------------------------------
|
|
//
|
|
// Splitting to 32bit
|
|
//
|
|
// ---------------------------------------------------------------------
|
|
|
|
template <typename func_t>
|
|
static inline void split_batch_dim_to_32bit_out(
|
|
const at::Tensor& output,
|
|
const at::Tensor& input,
|
|
const at::Tensor& weight,
|
|
IntArrayRef padding,
|
|
IntArrayRef stride,
|
|
IntArrayRef dilation,
|
|
int64_t groups,
|
|
bool benchmark,
|
|
bool deterministic,
|
|
bool allow_tf32,
|
|
int64_t max_worksize,
|
|
func_t func_32bit) {
|
|
constexpr int64_t int_max = std::numeric_limits<int>::max();
|
|
const int64_t ni = input.numel();
|
|
const int64_t no = output.numel();
|
|
// Assume the shape of the tensor is (N, C, D1, D2, ...)
|
|
// if N * C * D1 * D2 * ... <= int_max, then no need to split at all
|
|
if (ni <= int_max && no <= int_max) {
|
|
func_32bit(
|
|
output,
|
|
input,
|
|
weight,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
benchmark,
|
|
deterministic,
|
|
allow_tf32);
|
|
return;
|
|
}
|
|
// else, if C * D1 * D2 * ... <= int_max, then we just need to split across
|
|
// the N dimension
|
|
//
|
|
// Here we use a simple heuristics to determine the size of each split
|
|
// We don't max out the 2^31 address space because this number is super
|
|
// large and very likely to get an OOM.
|
|
int64_t n = output.size(0);
|
|
int64_t max_inner_size = std::max<int64_t>(ni, no) / n;
|
|
int64_t split_size = std::max<int64_t>(max_worksize / max_inner_size, 1L);
|
|
int64_t num_splits = (n + split_size - 1) / split_size;
|
|
if (split_size * max_inner_size < int_max) {
|
|
for (const auto i : c10::irange(num_splits)) {
|
|
int64_t start = split_size * i;
|
|
int64_t split_size_ = std::min<int64_t>(split_size, n - start);
|
|
Tensor input_ = input.narrow(0, start, split_size_);
|
|
Tensor output_ = output.narrow(0, start, split_size_);
|
|
func_32bit(
|
|
output_,
|
|
input_,
|
|
weight,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
benchmark,
|
|
deterministic,
|
|
allow_tf32);
|
|
}
|
|
return;
|
|
}
|
|
// If control flow reaches here, this means even splitting N is not enough,
|
|
// then things starts to become complicated: For example, for conv2d, there
|
|
// following questions needs to be considered.
|
|
// - Is the memory layout NCHW or NHWC ?
|
|
// - If the conv is NCHW -> NC'H'W', then should we
|
|
// - split only NC?
|
|
// - split only N'C'?
|
|
// - split both?
|
|
// - If the conv is NHWC, then we need to split across H, we need to be very
|
|
// careful about the boundary condition
|
|
// to make sure that the boundary is handled correctly.
|
|
// - If we decide to make these splits, is the memory contiguous? Do we need
|
|
// to copy the memory? Considering the complexity of this issue, it is better
|
|
// not to use cuDNN for this case
|
|
TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN.");
|
|
}
|
|
|
|
#define ASSERT_CORRECT_PRECISION(math_type) \
|
|
if (args.params.dataType == CUDNN_DATA_FLOAT) { \
|
|
TORCH_INTERNAL_ASSERT( \
|
|
args.params.allow_tf32 || math_type == CUDNN_FMA_MATH); \
|
|
}
|
|
|
|
// ---------------------------------------------------------------------
|
|
//
|
|
// Convolution forward / Transposed convolution backward
|
|
//
|
|
// ---------------------------------------------------------------------
|
|
|
|
void raw_cudnn_convolution_forward_out_32bit(
|
|
const Tensor& output,
|
|
const Tensor& input,
|
|
const Tensor& weight,
|
|
IntArrayRef padding,
|
|
IntArrayRef stride,
|
|
IntArrayRef dilation,
|
|
int64_t groups,
|
|
bool benchmark,
|
|
bool deterministic,
|
|
bool allow_tf32) {
|
|
auto dataType = getCudnnDataType(input);
|
|
|
|
ConvolutionArgs args{input, output, weight};
|
|
args.handle = getCudnnHandle();
|
|
at::MemoryFormat memory_format =
|
|
cudnn_conv_suggest_memory_format(input, weight);
|
|
setConvolutionParams(
|
|
&args.params,
|
|
input,
|
|
weight,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
deterministic,
|
|
allow_tf32,
|
|
memory_format);
|
|
args.idesc.set(input, memory_format);
|
|
args.wdesc.set(weight, memory_format, 0);
|
|
args.odesc.set(output, memory_format);
|
|
args.cdesc.set(
|
|
dataType,
|
|
input.dim() - 2,
|
|
args.params.padding,
|
|
args.params.stride,
|
|
args.params.dilation,
|
|
args.params.groups,
|
|
args.params.allow_tf32);
|
|
|
|
// TODO: when we do legacy group convolution support, we'll repeatedly
|
|
// reinitialize the workspace for each convolution we do. This is
|
|
// wasteful; we'd rather reuse the workspace. OTOH, legacy group
|
|
// convolution support is already pretty slow, so this might not
|
|
// matter. (This applies to raw_cudnn_convolution_backward_input as well.)
|
|
AlgoIterator<cudnnConvolutionFwdAlgoPerf_t>(args, benchmark)
|
|
.try_all([&](const cudnnConvolutionFwdAlgoPerf_t& fwdAlgPerf) {
|
|
Tensor workspace = allocate_workspace(fwdAlgPerf.memory, input);
|
|
|
|
// update convDesc mathType since cudnn 7.4+ now requires both algo +
|
|
// mathType to figure out whether to use Tensor core kernels or not See
|
|
// Note [behavior of cudnnFind and cudnnGet]
|
|
ASSERT_CORRECT_PRECISION(fwdAlgPerf.mathType);
|
|
AT_CUDNN_CHECK_WITH_SHAPES(
|
|
cudnnSetConvolutionMathType(
|
|
args.cdesc.mut_desc(), fwdAlgPerf.mathType),
|
|
args);
|
|
|
|
Constant one(dataType, 1);
|
|
Constant zero(dataType, 0);
|
|
|
|
AT_CUDNN_CHECK_WITH_SHAPES(
|
|
cudnnConvolutionForward(
|
|
args.handle,
|
|
&one,
|
|
args.idesc.desc(),
|
|
input.const_data_ptr(),
|
|
args.wdesc.desc(),
|
|
weight.const_data_ptr(),
|
|
args.cdesc.desc(),
|
|
fwdAlgPerf.algo,
|
|
workspace.data_ptr(),
|
|
fwdAlgPerf.memory,
|
|
&zero,
|
|
args.odesc.desc(),
|
|
output.data_ptr()),
|
|
args,
|
|
"Forward algorithm: ",
|
|
static_cast<int>(fwdAlgPerf.algo),
|
|
"\n");
|
|
});
|
|
}
|
|
|
|
void raw_cudnn_convolution_forward_out_v7(
|
|
const Tensor& output,
|
|
const Tensor& input,
|
|
const Tensor& weight,
|
|
IntArrayRef padding,
|
|
IntArrayRef stride,
|
|
IntArrayRef dilation,
|
|
int64_t groups,
|
|
bool benchmark,
|
|
bool deterministic,
|
|
bool allow_tf32) {
|
|
split_batch_dim_to_32bit_out(
|
|
output,
|
|
input,
|
|
weight,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
benchmark,
|
|
deterministic,
|
|
allow_tf32,
|
|
1024 * 1024 * 256,
|
|
raw_cudnn_convolution_forward_out_32bit);
|
|
}
|
|
|
|
// ---------------------------------------------------------------------
|
|
//
|
|
// Convolution backward / Transposed convolution forward
|
|
//
|
|
// ---------------------------------------------------------------------
|
|
|
|
void raw_cudnn_convolution_backward_input_out_32bit(
|
|
const at::Tensor& grad_input,
|
|
const at::Tensor& grad_output,
|
|
const at::Tensor& weight,
|
|
IntArrayRef padding,
|
|
IntArrayRef stride,
|
|
IntArrayRef dilation,
|
|
int64_t groups,
|
|
bool benchmark,
|
|
bool deterministic,
|
|
bool allow_tf32) {
|
|
auto dataType = getCudnnDataType(grad_output);
|
|
|
|
ConvolutionArgs args{grad_input, grad_output, weight};
|
|
args.handle = getCudnnHandle();
|
|
at::MemoryFormat memory_format =
|
|
cudnn_conv_suggest_memory_format(grad_input, weight);
|
|
setConvolutionParams(
|
|
&args.params,
|
|
grad_input,
|
|
weight,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
deterministic,
|
|
allow_tf32,
|
|
memory_format);
|
|
args.idesc.set(grad_input, memory_format);
|
|
args.wdesc.set(weight, memory_format, 0);
|
|
args.odesc.set(grad_output, memory_format);
|
|
args.cdesc.set(
|
|
dataType,
|
|
grad_output.dim() - 2,
|
|
args.params.padding,
|
|
args.params.stride,
|
|
args.params.dilation,
|
|
args.params.groups,
|
|
args.params.allow_tf32);
|
|
|
|
AlgoIterator<cudnnConvolutionBwdDataAlgoPerf_t>(args, benchmark)
|
|
.try_all([&](const cudnnConvolutionBwdDataAlgoPerf_t& bwdDataAlgPerf) {
|
|
Tensor workspace =
|
|
allocate_workspace(bwdDataAlgPerf.memory, grad_output);
|
|
|
|
// update convDesc mathType since cudnn 7.4+ now requires both algo +
|
|
// mathType to figure out whether to use Tensor core kernels or not See
|
|
// Note [behavior of cudnnFind and cudnnGet]
|
|
ASSERT_CORRECT_PRECISION(bwdDataAlgPerf.mathType);
|
|
AT_CUDNN_CHECK_WITH_SHAPES(
|
|
cudnnSetConvolutionMathType(
|
|
args.cdesc.mut_desc(), bwdDataAlgPerf.mathType),
|
|
args);
|
|
|
|
Constant one(dataType, 1);
|
|
Constant zero(dataType, 0);
|
|
|
|
AT_CUDNN_CHECK_WITH_SHAPES(
|
|
cudnnConvolutionBackwardData(
|
|
args.handle,
|
|
&one,
|
|
args.wdesc.desc(),
|
|
weight.const_data_ptr(),
|
|
args.odesc.desc(),
|
|
grad_output.const_data_ptr(),
|
|
args.cdesc.desc(),
|
|
bwdDataAlgPerf.algo,
|
|
workspace.data_ptr(),
|
|
bwdDataAlgPerf.memory,
|
|
&zero,
|
|
args.idesc.desc(),
|
|
grad_input.mutable_data_ptr()),
|
|
args,
|
|
"Additional pointer addresses: \n",
|
|
" grad_output: ",
|
|
grad_output.const_data_ptr(),
|
|
"\n",
|
|
" grad_input: ",
|
|
grad_input.mutable_data_ptr(),
|
|
"\n",
|
|
"Backward data algorithm: ",
|
|
static_cast<int>(bwdDataAlgPerf.algo),
|
|
"\n");
|
|
});
|
|
}
|
|
|
|
void raw_cudnn_convolution_backward_input_out_v7(
|
|
const at::Tensor& grad_input,
|
|
const at::Tensor& grad_output,
|
|
const at::Tensor& weight,
|
|
IntArrayRef padding,
|
|
IntArrayRef stride,
|
|
IntArrayRef dilation,
|
|
int64_t groups,
|
|
bool benchmark,
|
|
bool deterministic,
|
|
bool allow_tf32) {
|
|
split_batch_dim_to_32bit_out(
|
|
grad_input,
|
|
grad_output,
|
|
weight,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
benchmark,
|
|
deterministic,
|
|
allow_tf32,
|
|
1024 * 1024 * 128,
|
|
raw_cudnn_convolution_backward_input_out_32bit);
|
|
}
|
|
|
|
// ---------------------------------------------------------------------
|
|
//
|
|
// Convolution backward (weight)
|
|
//
|
|
// ---------------------------------------------------------------------
|
|
|
|
void raw_cudnn_convolution_backward_weight_out_32bit(
|
|
const Tensor& grad_weight,
|
|
const Tensor& grad_output,
|
|
const Tensor& input,
|
|
IntArrayRef padding,
|
|
IntArrayRef stride,
|
|
IntArrayRef dilation,
|
|
int64_t groups,
|
|
bool benchmark,
|
|
bool deterministic,
|
|
bool allow_tf32) {
|
|
auto dataType = getCudnnDataType(input);
|
|
|
|
ConvolutionArgs args{input, grad_output, grad_weight};
|
|
args.handle = getCudnnHandle();
|
|
at::MemoryFormat memory_format =
|
|
cudnn_conv_suggest_memory_format(input, grad_weight);
|
|
setConvolutionParams(
|
|
&args.params,
|
|
input,
|
|
grad_weight,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
deterministic,
|
|
allow_tf32,
|
|
memory_format);
|
|
args.idesc.set(input, memory_format);
|
|
args.wdesc.set(grad_weight, memory_format, 0);
|
|
args.odesc.set(grad_output, memory_format);
|
|
args.cdesc.set(
|
|
dataType,
|
|
input.dim() - 2,
|
|
args.params.padding,
|
|
args.params.stride,
|
|
args.params.dilation,
|
|
args.params.groups,
|
|
args.params.allow_tf32);
|
|
|
|
AlgoIterator<cudnnConvolutionBwdFilterAlgoPerf_t>(args, benchmark)
|
|
.try_all(
|
|
[&](const cudnnConvolutionBwdFilterAlgoPerf_t& bwdFilterAlgPerf) {
|
|
Tensor workspace =
|
|
allocate_workspace(bwdFilterAlgPerf.memory, input);
|
|
|
|
// update convDesc mathType since cudnn 7.4+ now requires both algo
|
|
// + mathType to figure out whether to use Tensor core kernels or
|
|
// not See Note [behavior of cudnnFind and cudnnGet]
|
|
ASSERT_CORRECT_PRECISION(bwdFilterAlgPerf.mathType);
|
|
AT_CUDNN_CHECK_WITH_SHAPES(
|
|
cudnnSetConvolutionMathType(
|
|
args.cdesc.mut_desc(), bwdFilterAlgPerf.mathType),
|
|
args);
|
|
|
|
Constant one(dataType, 1);
|
|
Constant zero(dataType, 0);
|
|
|
|
AT_CUDNN_CHECK_WITH_SHAPES(
|
|
cudnnConvolutionBackwardFilter(
|
|
args.handle,
|
|
&one,
|
|
args.idesc.desc(),
|
|
input.const_data_ptr(),
|
|
args.odesc.desc(),
|
|
grad_output.const_data_ptr(),
|
|
args.cdesc.desc(),
|
|
bwdFilterAlgPerf.algo,
|
|
workspace.data_ptr(),
|
|
bwdFilterAlgPerf.memory,
|
|
&zero,
|
|
args.wdesc.desc(),
|
|
grad_weight.data_ptr()),
|
|
args,
|
|
"Additional pointer addresses: \n",
|
|
" grad_output: ",
|
|
grad_output.const_data_ptr(),
|
|
"\n",
|
|
" grad_weight: ",
|
|
grad_weight.data_ptr(),
|
|
"\n",
|
|
"Backward filter algorithm: ",
|
|
static_cast<int>(bwdFilterAlgPerf.algo),
|
|
"\n");
|
|
});
|
|
}
|
|
|
|
void raw_cudnn_convolution_backward_weight_out_v7(
|
|
const Tensor& grad_weight,
|
|
const Tensor& grad_output,
|
|
const Tensor& input,
|
|
IntArrayRef padding,
|
|
IntArrayRef stride,
|
|
IntArrayRef dilation,
|
|
int64_t groups,
|
|
bool benchmark,
|
|
bool deterministic,
|
|
bool allow_tf32) {
|
|
constexpr int64_t int_max = std::numeric_limits<int>::max();
|
|
const int64_t ni = input.numel();
|
|
const int64_t no = grad_output.numel();
|
|
// Assume the shape of the tensor is (N, C, D1, D2, ...)
|
|
// if N * C * D1 * D2 * ... <= int_max, then no need to split at all
|
|
if (ni <= int_max && no <= int_max) {
|
|
raw_cudnn_convolution_backward_weight_out_32bit(
|
|
grad_weight,
|
|
grad_output,
|
|
input,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
benchmark,
|
|
deterministic,
|
|
allow_tf32);
|
|
return;
|
|
}
|
|
// else, if C * D1 * D2 * ... <= int_max, then we just need to split across
|
|
// the N dimension
|
|
//
|
|
// Here we use a simple heuristics to determine the size of each split
|
|
// We don't max out the 2^31 address space because this number is super
|
|
// large and very likely to get an OOM.
|
|
int64_t n = grad_output.size(0);
|
|
int64_t max_inner_size = std::max<int64_t>(ni, no) / n;
|
|
int64_t split_size =
|
|
std::max<int64_t>(1024 * 1024 * 512 / max_inner_size, 1L);
|
|
int64_t num_splits = (n + split_size - 1) / split_size;
|
|
if (split_size * max_inner_size < int_max) {
|
|
const auto kAccType = (grad_weight.scalar_type() == kHalf ||
|
|
grad_weight.scalar_type() == kBFloat16)
|
|
? kFloat
|
|
: grad_weight.scalar_type();
|
|
Tensor grad_weight_accumulator =
|
|
at::zeros(grad_weight.sizes(), grad_weight.options().dtype(kAccType));
|
|
for (const auto i : c10::irange(num_splits)) {
|
|
int64_t start = split_size * i;
|
|
int64_t split_size_ = std::min<int64_t>(split_size, n - start);
|
|
Tensor input_ = input.narrow(0, start, split_size_);
|
|
Tensor grad_output_ = grad_output.narrow(0, start, split_size_);
|
|
Tensor grad_weight_ = at::empty_like(grad_weight);
|
|
raw_cudnn_convolution_backward_weight_out_32bit(
|
|
grad_weight_,
|
|
grad_output_,
|
|
input_,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
benchmark,
|
|
deterministic,
|
|
allow_tf32);
|
|
grad_weight_accumulator.add_(grad_weight_);
|
|
}
|
|
grad_weight.copy_(grad_weight_accumulator);
|
|
return;
|
|
}
|
|
// If control flow reaches here, this means even splitting N is not enough,
|
|
// then things starts to become complicated: For example, for conv2d, there
|
|
// following questions needs to be considered.
|
|
// - Is the memory layout NCHW or NHWC ?
|
|
// - If the conv is NCHW -> NC'H'W', then should we
|
|
// - split only NC?
|
|
// - split only N'C'?
|
|
// - split both?
|
|
// - If the conv is NHWC, then we need to split across H, we need to be very
|
|
// careful about the boundary condition
|
|
// to make sure that the boundary is handled correctly.
|
|
// - If we decide to make these splits, is the memory contiguous? Do we need
|
|
// to copy the memory? Considering the complexity of this issue, it is better
|
|
// not to use cuDNN for this case
|
|
TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN.");
|
|
}
|
|
|
|
void raw_cudnn_convolution_add_relu_out_v7(
|
|
const Tensor& output,
|
|
const Tensor& input,
|
|
const Tensor& weight,
|
|
const Tensor& z,
|
|
float alpha,
|
|
const Tensor& bias,
|
|
IntArrayRef stride,
|
|
IntArrayRef padding,
|
|
IntArrayRef dilation,
|
|
int64_t groups,
|
|
bool benchmark,
|
|
bool deterministic,
|
|
bool allow_tf32) {
|
|
auto dataType = getCudnnDataType(input);
|
|
ConvolutionArgs args{input, output, weight};
|
|
args.handle = getCudnnHandle();
|
|
at::MemoryFormat memory_format =
|
|
cudnn_conv_suggest_memory_format(input, weight);
|
|
setConvolutionParams(
|
|
&args.params,
|
|
input,
|
|
weight,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
deterministic,
|
|
allow_tf32,
|
|
memory_format);
|
|
args.idesc.set(input, memory_format);
|
|
args.wdesc.set(weight, memory_format, 0);
|
|
args.odesc.set(output, memory_format);
|
|
args.cdesc.set(
|
|
dataType,
|
|
input.dim() - 2,
|
|
args.params.padding,
|
|
args.params.stride,
|
|
args.params.dilation,
|
|
args.params.groups,
|
|
args.params.allow_tf32);
|
|
|
|
TensorDescriptor zdesc;
|
|
zdesc.set(z, memory_format);
|
|
|
|
TensorDescriptor bdesc;
|
|
bdesc.set(bias.expand({1, bias.size(0)}), memory_format, output.dim());
|
|
|
|
ActivationDescriptor adesc;
|
|
adesc.set(CUDNN_ACTIVATION_RELU);
|
|
|
|
AlgoIterator<cudnnConvolutionFwdAlgoPerf_t>(args, benchmark)
|
|
.try_all([&](const cudnnConvolutionFwdAlgoPerf_t& fwdAlgPerf) {
|
|
Tensor workspace = allocate_workspace(fwdAlgPerf.memory, input);
|
|
|
|
// update convDesc mathType since cudnn 7.4+ now requires both algo +
|
|
// mathType to figure out whether to use Tensor core kernels or not See
|
|
// Note [behavior of cudnnFind and cudnnGet]
|
|
ASSERT_CORRECT_PRECISION(fwdAlgPerf.mathType);
|
|
AT_CUDNN_CHECK_WITH_SHAPES(
|
|
cudnnSetConvolutionMathType(
|
|
args.cdesc.mut_desc(), fwdAlgPerf.mathType),
|
|
args);
|
|
|
|
Constant one(dataType, 1);
|
|
Constant alpha_(dataType, alpha);
|
|
|
|
AT_CUDNN_CHECK_WITH_SHAPES(
|
|
cudnnConvolutionBiasActivationForward(
|
|
args.handle,
|
|
&one,
|
|
args.idesc.desc(),
|
|
input.const_data_ptr(),
|
|
args.wdesc.desc(),
|
|
weight.const_data_ptr(),
|
|
args.cdesc.desc(),
|
|
fwdAlgPerf.algo,
|
|
workspace.data_ptr(),
|
|
fwdAlgPerf.memory,
|
|
&alpha_,
|
|
zdesc.desc(),
|
|
z.const_data_ptr(),
|
|
bdesc.desc(),
|
|
bias.const_data_ptr(),
|
|
adesc.desc(),
|
|
args.odesc.desc(),
|
|
output.data_ptr()),
|
|
args,
|
|
"zdesc: ",
|
|
zdesc,
|
|
"bdesc: ",
|
|
bdesc,
|
|
"cudnnConvolutionBiasActivationForward: ",
|
|
static_cast<int>(fwdAlgPerf.algo),
|
|
"\n");
|
|
});
|
|
}
|
|
|
|
void raw_cudnn_convolution_add_relu_fallback_out(
|
|
const Tensor& output,
|
|
const Tensor& input,
|
|
const Tensor& weight,
|
|
const Tensor& z,
|
|
float alpha,
|
|
const Tensor& bias,
|
|
IntArrayRef stride,
|
|
IntArrayRef padding,
|
|
IntArrayRef dilation,
|
|
int64_t groups,
|
|
bool benchmark,
|
|
bool deterministic,
|
|
bool allow_tf32) {
|
|
// cuDNN Conv-Bias-Activation:
|
|
// y = act ( alpha1 * conv(x) + alpha2 * z + bias )
|
|
// In pytorch function `raw_cudnn_convolution_add_relu_out`: alpha1 is 1,
|
|
// alpha 2 is `float alpha`
|
|
|
|
raw_cudnn_convolution_forward_out(
|
|
output,
|
|
input,
|
|
weight,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
benchmark,
|
|
deterministic,
|
|
allow_tf32);
|
|
at::Tensor alpha_mul_z_add_bias =
|
|
at::native::reshape_bias(input.dim(), bias).add(z, alpha);
|
|
output.add_(alpha_mul_z_add_bias);
|
|
output.relu_();
|
|
}
|
|
|
|
} // namespace native
|
|
} // namespace at
|
|
|
|
#endif
|