mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Resolves ptxas warnings when compiling for CUDA_ARCH 750 and a memoryType deprecation warning (#15461)
Summary: When compiling for `TORCH_CUDA_ARCH_LIST=7.5` we were getting ptxas warnings (https://github.com/pytorch/pytorch/issues/14310). This was because we had some hardcoded values when using launch_bounds in kernels. The maximum number of threads per multiprocessor is 1024 for Turing architecture (7.5) but 2048 for previous architectures. The hardcoded launch_bounds in the kernel were requesting for 2048 threads when compiling for Turing and hence were generating the warning. This PR adds a macro that checks for the bounds on the launch bounds value supplied. The max number of threads per block across all architectures is 1024. If a user supplies more than 1024, I just clamp it down to 512. Depending on this value, I set the minimum number of blocks per sm. This PR should resolve https://github.com/pytorch/pytorch/issues/14310. The gradient computation being wrong reported in that PR is probably due to the faulty card. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15461 Differential Revision: D13633952 Pulled By: soumith fbshipit-source-id: 795aa151109f343ab5433bf3cb070cb6ec896fff
This commit is contained in:
parent
07ea3e035e
commit
86af14b0c7
|
|
@ -4,6 +4,7 @@
|
|||
#include <ATen/TensorUtils.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#include <math.h>
|
||||
|
||||
|
|
@ -198,8 +199,8 @@ inline void rearrangeDims(detail::TensorInfo<T1, IndexType>* aInfo,
|
|||
|
||||
// Threads per block for our apply kernel
|
||||
// FIXME: use occupancy calculator instead
|
||||
#define AT_APPLY_THREADS_PER_BLOCK 32 * 16
|
||||
#define AT_APPLY_BLOCKS_PER_SM 4
|
||||
constexpr uint32_t AT_APPLY_THREADS_PER_BLOCK = 512;
|
||||
constexpr uint32_t AT_APPLY_BLOCKS_PER_SM = 4;
|
||||
|
||||
// The `remaining_steps` argument is used to support Op that operates on
|
||||
// multiple elements at the same time. Generally, the strategy of ApplyOpN is to
|
||||
|
|
@ -272,7 +273,7 @@ template <typename Op,
|
|||
int ADims,
|
||||
int step>
|
||||
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
|
||||
__launch_bounds__(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
|
||||
C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
|
||||
#endif
|
||||
__global__ void kernelPointwiseApply1(detail::TensorInfo<scalar, IndexType> a,
|
||||
IndexType totalElements, const Op op) {
|
||||
|
|
@ -356,7 +357,7 @@ template <typename Op,
|
|||
int ADims, int BDims,
|
||||
int step>
|
||||
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
|
||||
__launch_bounds__(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
|
||||
C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
|
||||
#endif
|
||||
__global__ void
|
||||
kernelPointwiseApply2(detail::TensorInfo<scalar1, IndexType> a,
|
||||
|
|
@ -465,7 +466,7 @@ template <typename Op,
|
|||
int ADims, int BDims, int CDims,
|
||||
int step>
|
||||
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
|
||||
__launch_bounds__(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
|
||||
C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
|
||||
#endif
|
||||
__global__ void
|
||||
kernelPointwiseApply3(detail::TensorInfo<scalar1, IndexType> a,
|
||||
|
|
@ -588,7 +589,7 @@ template <typename Op,
|
|||
int ADims, int BDims, int CDims, int DDims,
|
||||
int step>
|
||||
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
|
||||
__launch_bounds__(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
|
||||
C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
|
||||
#endif
|
||||
__global__ void
|
||||
kernelPointwiseApply4(detail::TensorInfo<scalar1, IndexType> a,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <ATen/cuda/detail/IndexUtils.cuh>
|
||||
#include <ATen/cuda/detail/TensorInfo.cuh>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#include <THC/THCGeneral.h>
|
||||
|
|
@ -34,7 +35,7 @@ template <
|
|||
typename IndexType,
|
||||
int ADims>
|
||||
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
|
||||
__launch_bounds__(256,8)
|
||||
C10_LAUNCH_BOUNDS(256, 8)
|
||||
#endif
|
||||
__global__ void
|
||||
fused_dropout_kernel(cuda::detail::TensorInfo<scalar_t, IndexType> a,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include <ATen/cuda/detail/TensorInfo.cuh>
|
||||
#include <ATen/cuda/detail/IndexUtils.cuh>
|
||||
#include <ATen/cuda/detail/KernelUtils.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
|
|
@ -119,7 +120,7 @@ namespace {
|
|||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__launch_bounds__(1024)
|
||||
C10_LAUNCH_BOUNDS(1024)
|
||||
__global__ void grid_sampler_2d_kernel(
|
||||
const int nthreads,
|
||||
TensorInfo<scalar_t, int> input,
|
||||
|
|
@ -227,7 +228,7 @@ namespace {
|
|||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__launch_bounds__(1024)
|
||||
C10_LAUNCH_BOUNDS(1024)
|
||||
__global__ void grid_sampler_3d_kernel(
|
||||
const int nthreads,
|
||||
TensorInfo<scalar_t, int> input,
|
||||
|
|
@ -391,7 +392,7 @@ namespace {
|
|||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__launch_bounds__(1024)
|
||||
C10_LAUNCH_BOUNDS(1024)
|
||||
__global__ void grid_sampler_2d_backward_kernel(
|
||||
const int nthreads,
|
||||
TensorInfo<scalar_t, int> grad_output,
|
||||
|
|
@ -546,7 +547,7 @@ namespace {
|
|||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__launch_bounds__(1024)
|
||||
C10_LAUNCH_BOUNDS(1024)
|
||||
__global__ void grid_sampler_3d_backward_kernel(
|
||||
const int nthreads,
|
||||
TensorInfo<scalar_t, int> grad_output,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
||||
#include <ATen/detail/FunctionTraits.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
// Marks a lambda as executable on both the host and device. The __host__
|
||||
// attribute is important so that we can access static type information from
|
||||
|
|
@ -26,7 +26,7 @@
|
|||
namespace at { namespace native {
|
||||
|
||||
template<int nt, int vt, typename func_t>
|
||||
__launch_bounds__(nt, 4)
|
||||
C10_LAUNCH_BOUNDS(nt, 4)
|
||||
__global__ void elementwise_kernel(int N, func_t f) {
|
||||
int tid = threadIdx.x;
|
||||
int nv = nt * vt;
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
|
|
@ -46,7 +47,7 @@ __device__ static inline int64_t get_target_prime(const target_t* __restrict__ t
|
|||
template<typename scalar_t, typename target_t>
|
||||
__global__ void
|
||||
#if defined (__HIP_PLATFORM_HCC__)
|
||||
__launch_bounds__((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
|
||||
C10_LAUNCH_BOUNDS((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
|
||||
#endif
|
||||
ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
|
||||
const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
|
||||
|
|
@ -259,7 +260,7 @@ std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const
|
|||
// alpha kernel above. (As mentioned above, it might make sense do the calculation in the alpha kernel.)
|
||||
template<typename scalar_t, typename target_t>
|
||||
__global__ void
|
||||
__launch_bounds__((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
|
||||
C10_LAUNCH_BOUNDS((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
|
||||
ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
|
||||
const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
|
||||
const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length,
|
||||
|
|
@ -365,7 +366,7 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
|
|||
template<typename scalar_t, typename target_t>
|
||||
__global__ void
|
||||
#if defined (__HIP_PLATFORM_HCC__)
|
||||
__launch_bounds__((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
|
||||
C10_LAUNCH_BOUNDS((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
|
||||
#endif
|
||||
ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_data,
|
||||
const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride,
|
||||
|
|
@ -414,7 +415,7 @@ ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_da
|
|||
template<typename scalar_t, typename target_t>
|
||||
__global__ void
|
||||
#if defined (__HIP_PLATFORM_HCC__)
|
||||
__launch_bounds__((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
|
||||
C10_LAUNCH_BOUNDS((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
|
||||
#endif
|
||||
ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,
|
||||
const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
|
|
@ -81,7 +82,7 @@ namespace kernel {
|
|||
|
||||
template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
|
||||
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
|
||||
__launch_bounds__(32 * 16, 4)
|
||||
C10_LAUNCH_BOUNDS(512, 4)
|
||||
#endif
|
||||
__global__ void lstm_cell_forward(
|
||||
TensorInfo<scalar_t, index_type> input,
|
||||
|
|
@ -168,7 +169,7 @@ __global__ void lstm_cell_forward(
|
|||
|
||||
template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
|
||||
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
|
||||
__launch_bounds__(32 * 16, 4)
|
||||
C10_LAUNCH_BOUNDS(512, 4)
|
||||
#endif
|
||||
__global__ void lstm_cell_backward(
|
||||
TensorInfo<scalar_t, index_type> storage,
|
||||
|
|
@ -233,7 +234,7 @@ __global__ void lstm_cell_backward(
|
|||
|
||||
template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
|
||||
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
|
||||
__launch_bounds__(32 * 16, 4)
|
||||
C10_LAUNCH_BOUNDS(512, 4)
|
||||
#endif
|
||||
__global__ void gru_cell_forward(
|
||||
TensorInfo<scalar_t, index_type> Input,
|
||||
|
|
@ -303,7 +304,7 @@ __global__ void gru_cell_forward(
|
|||
|
||||
template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
|
||||
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
|
||||
__launch_bounds__(32 * 16, 4)
|
||||
C10_LAUNCH_BOUNDS(512, 4)
|
||||
#endif
|
||||
__global__ void gru_cell_backward(
|
||||
TensorInfo<scalar_t, index_type> gradInInput,
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include <THC/THCGeneral.hpp>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <functional>
|
||||
#include <iosfwd>
|
||||
#include <tuple>
|
||||
|
|
@ -146,7 +147,7 @@ struct ReduceConfig {
|
|||
std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);
|
||||
|
||||
template<int nt, typename R>
|
||||
__launch_bounds__(nt, 4)
|
||||
C10_LAUNCH_BOUNDS(nt, 4)
|
||||
__global__ void reduce_kernel(R reduction) {
|
||||
reduction.run();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
|
@ -11,12 +12,12 @@
|
|||
namespace at {
|
||||
namespace native {
|
||||
|
||||
#define AT_APPLY_THREADS_PER_BLOCK 32 * 16
|
||||
#define AT_APPLY_BLOCKS_PER_SM 4
|
||||
constexpr uint32_t AT_APPLY_THREADS_PER_BLOCK = 512;
|
||||
constexpr uint32_t AT_APPLY_BLOCKS_PER_SM = 4;
|
||||
|
||||
template <typename scalar_t, typename IndexType>
|
||||
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
|
||||
__launch_bounds__(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
|
||||
C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
|
||||
#endif
|
||||
__global__ void
|
||||
kernel_pointwise_flip_apply2(const cuda::detail::TensorInfo<scalar_t, IndexType> in_tensor_info,
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
#include <THC/THCTensorTypeUtils.cuh>
|
||||
#include <THC/THCReduceApplyUtils.cuh>
|
||||
#include <THC/THCNumerics.cuh>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
// Threads per thread block
|
||||
#define THC_NONCONTIG_REDUCE_BLOCK_SIZE 32 * 16
|
||||
|
|
@ -140,7 +141,7 @@ template
|
|||
typename FinalizeOp,
|
||||
int ADims, int BDims>
|
||||
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
|
||||
__launch_bounds__(32 * 16, 4)
|
||||
C10_LAUNCH_BOUNDS(512, 4)
|
||||
#endif
|
||||
__global__ void kernelReduceNoncontigDim_shared
|
||||
(TensorInfo<T, IndexType> out,
|
||||
|
|
@ -255,7 +256,7 @@ template <typename T,
|
|||
typename FinalizeOp,
|
||||
int ADims, int BDims>
|
||||
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
|
||||
__launch_bounds__(32 * 16, 4)
|
||||
C10_LAUNCH_BOUNDS(512, 4)
|
||||
#endif
|
||||
__global__ void
|
||||
kernelReduceNoncontigDim(TensorInfo<T, IndexType> out,
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
//
|
||||
|
||||
#include <THC/THCReduceApplyUtils.cuh>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
// Size per each reduction block
|
||||
#define THC_REDUCE_ALL_BLOCK_SIZE 1024L
|
||||
|
|
@ -26,7 +27,7 @@ template <typename T,
|
|||
int ADims>
|
||||
__global__ void
|
||||
#if defined(__HIP_PLATFORM_HCC__)
|
||||
__launch_bounds__(THC_REDUCE_ALL_BLOCK_SIZE)
|
||||
C10_LAUNCH_BOUNDS(THC_REDUCE_ALL_BLOCK_SIZE)
|
||||
#endif
|
||||
kernelReduceAll(TensorInfo<T, IndexType> in,
|
||||
IndexType totalElements,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include <THC/THCReduceApplyUtils.cuh>
|
||||
#include <THC/THCTensorTypeUtils.cuh>
|
||||
#include <THC/THCNumerics.cuh>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
// Collection of kernel sort routines
|
||||
template <typename T>
|
||||
|
|
@ -134,7 +135,7 @@ __device__ inline void bitonicSortKeys(K keys[Power2SortSize],
|
|||
template <typename K, typename V,
|
||||
int KeyDims, int ValueDims,
|
||||
typename Comparator, typename IndexType, int Power2SortSize>
|
||||
__launch_bounds__(1024)
|
||||
C10_LAUNCH_BOUNDS(1024)
|
||||
__global__ void
|
||||
bitonicSortKVInPlace(TensorInfo<K, IndexType> keys,
|
||||
IndexType keySlices,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include <THC/THCReduceApplyUtils.cuh>
|
||||
#include <TH/THHalf.h>
|
||||
#include <THCUNN/THCHalfAutoNumerics.cuh>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#include <thrust/functional.h>
|
||||
|
||||
|
|
@ -11,7 +12,7 @@
|
|||
|
||||
template <typename Dtype, typename Acctype>
|
||||
#if defined(__HIP_PLATFORM_HCC__)
|
||||
__launch_bounds__(MULTILABELMARGIN_THREADS)
|
||||
C10_LAUNCH_BOUNDS(MULTILABELMARGIN_THREADS)
|
||||
#endif
|
||||
__global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(Dtype *output,
|
||||
Dtype *input,
|
||||
|
|
@ -81,7 +82,7 @@ __global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(Dtype *output
|
|||
|
||||
template <typename Dtype, typename Acctype>
|
||||
#if defined(__HIP_PLATFORM_HCC__)
|
||||
__launch_bounds__(MULTILABELMARGIN_THREADS)
|
||||
C10_LAUNCH_BOUNDS(MULTILABELMARGIN_THREADS)
|
||||
#endif
|
||||
__global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gradInput,
|
||||
Dtype *gradOutput,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include <THC/THCDeviceTensorUtils.cuh>
|
||||
#include <THC/THCDeviceUtils.cuh>
|
||||
#include <THC/THCApply.cuh>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#include <thrust/functional.h>
|
||||
|
||||
|
|
@ -68,7 +69,7 @@ __global__ void SpatialClassNLLCriterion_updateGradInput_no_reduce_kernel(
|
|||
|
||||
template <typename T, typename AccumT>
|
||||
#if defined(__HIP_PLATFORM_HCC__)
|
||||
__launch_bounds__(1024)
|
||||
C10_LAUNCH_BOUNDS(1024)
|
||||
#endif
|
||||
__global__ void cunn_SpatialClassNLLCriterion_updateOutput_kernel(
|
||||
T *output,
|
||||
|
|
|
|||
|
|
@ -4,11 +4,12 @@
|
|||
#include <THC/THCTensor.hpp>
|
||||
#include <THC/THCStorage.hpp>
|
||||
#include <THCUNN/common.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
template <typename Dtype, typename Acctype>
|
||||
__global__ void
|
||||
#if __CUDA_ARCH__ >= 320 || defined __HIP_PLATFORM_HCC__
|
||||
__launch_bounds__(CUDA_NUM_THREADS)
|
||||
C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS)
|
||||
#endif
|
||||
LRNFillScale(const int nthreads, const Dtype* const in,
|
||||
const int num, const int channels, const int height,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include <THCUNN/THCHalfAutoNumerics.cuh>
|
||||
#include <THC/THCNumerics.cuh>
|
||||
#include <THCUNN/common.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
// kernels borrowed from Caffe
|
||||
template <typename Dtype, typename AccType>
|
||||
|
|
@ -47,7 +48,7 @@ __global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
|
|||
const int BACKWARD_THREADS = 256;
|
||||
|
||||
template <typename Dtype, typename AccType>
|
||||
__launch_bounds__(BACKWARD_THREADS,2048/BACKWARD_THREADS)
|
||||
C10_LAUNCH_BOUNDS(BACKWARD_THREADS, 8)
|
||||
__global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
|
||||
const int64_t* top_mask, const int num, const int channels,
|
||||
const int height, const int width, const int pooled_height,
|
||||
|
|
|
|||
|
|
@ -3,12 +3,13 @@
|
|||
#include <THCUNN/common.h>
|
||||
#include <TH/THHalf.h>
|
||||
#include <THCUNN/THCHalfAutoNumerics.cuh>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
// Kernel for fast unfold+copy
|
||||
// Borrowed from Theano
|
||||
// Authors: Arjun Jain, Frédéric Bastien, Jan Schlüter, Nicolas Ballas
|
||||
template <typename Dtype>
|
||||
__global__ void __launch_bounds__(CUDA_NUM_THREADS) // ensure that at least 1 block can be resident
|
||||
__global__ void C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS) // ensure that at least 1 block can be resident
|
||||
im3d2col_kernel(const int64_t n, const Dtype* data_im,
|
||||
const int64_t height, const int64_t width, const int64_t depth,
|
||||
const int64_t kernel_h, const int64_t kernel_w, const int64_t kernel_d,
|
||||
|
|
@ -87,7 +88,7 @@ void im3d2col(cudaStream_t stream, const Dtype* data_im, const int64_t channels,
|
|||
}
|
||||
|
||||
template <typename Dtype, typename Acctype>
|
||||
__global__ void __launch_bounds__(CUDA_NUM_THREADS) // ensure that at least 1 block can be resident
|
||||
__global__ void C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS) // ensure that at least 1 block can be resident
|
||||
col2im3d_kernel(const int64_t n, const Dtype* data_col,
|
||||
const int64_t height, const int64_t width, const int64_t depth,
|
||||
const int64_t channels,
|
||||
|
|
|
|||
|
|
@ -10,9 +10,10 @@
|
|||
#include <TH/THHalf.h>
|
||||
#include <THCUNN/THCHalfAutoNumerics.cuh>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
template<typename Dtype, typename Acctype>
|
||||
__launch_bounds__(1024)
|
||||
C10_LAUNCH_BOUNDS(1024)
|
||||
__global__ void caffe_gpu_interp2_kernel(const int n,
|
||||
const Acctype rdepth, const Acctype rheight, const Acctype rwidth, const bool align_corners,
|
||||
const THCDeviceTensor<Dtype, 5> data1, THCDeviceTensor<Dtype, 5> data2) {
|
||||
|
|
@ -80,7 +81,7 @@ __global__ void caffe_gpu_interp2_kernel(const int n,
|
|||
|
||||
// Backward (adjoint) operation 1 <- 2 (accumulates)
|
||||
template <typename Dtype, typename Acctype>
|
||||
__launch_bounds__(1024)
|
||||
C10_LAUNCH_BOUNDS(1024)
|
||||
__global__ void caffe_gpu_interp2_kernel_backward(const int n,
|
||||
const Acctype rdepth, const Acctype rheight, const Acctype rwidth, const bool align_corners,
|
||||
THCDeviceTensor<Dtype, 5> data1, const THCDeviceTensor<Dtype, 5> data2){
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@
|
|||
|
||||
#include <THCUNN/common.h>
|
||||
#include <THC/THCNumerics.cuh>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
// Kernel for fast unfold+copy
|
||||
// (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu)
|
||||
template <typename Dtype>
|
||||
__launch_bounds__(CUDA_NUM_THREADS)
|
||||
C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS)
|
||||
__global__ void im2col_kernel(const int64_t n, const Dtype* data_im,
|
||||
const int64_t height, const int64_t width,
|
||||
const int64_t ksize_h, const int64_t ksize_w,
|
||||
|
|
@ -59,7 +60,7 @@ void im2col(cudaStream_t stream, const Dtype* data_im, const int64_t channels,
|
|||
}
|
||||
|
||||
template <typename Dtype, typename Acctype>
|
||||
__launch_bounds__(CUDA_NUM_THREADS)
|
||||
C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS)
|
||||
__global__ void col2im_kernel(const int64_t n, const Dtype* data_col,
|
||||
const int64_t height, const int64_t width, const int64_t channels,
|
||||
const int64_t kernel_h, const int64_t kernel_w,
|
||||
|
|
|
|||
|
|
@ -109,6 +109,39 @@ namespace at { namespace cuda { using namespace c10::hip; }}
|
|||
#define C10_HOST_DEVICE __host__ __device__
|
||||
#define C10_DEVICE __device__
|
||||
#define C10_HOST __host__
|
||||
// constants from (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications)
|
||||
// The maximum number of threads per multiprocessor is 1024 for Turing architecture (7.5)
|
||||
// but 2048 for previous architectures. You'll get warnings if you exceed these constants.
|
||||
// Hence, the following macros adjust the input values from the user to resolve potential warnings.
|
||||
#if __CUDA_ARCH__ >= 750
|
||||
constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1024;
|
||||
#else
|
||||
constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 2048;
|
||||
#endif
|
||||
// CUDA_MAX_THREADS_PER_BLOCK is same for all architectures currently
|
||||
constexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 1024;
|
||||
// CUDA_THREADS_PER_BLOCK_FALLBACK is the "canonical fallback" choice of block size.
|
||||
// 256 is a good number for this fallback and should give good occupancy and
|
||||
// versatility across all architectures.
|
||||
constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
|
||||
// NOTE: if you are thinking of constexpr-ify the inputs to launch bounds, it
|
||||
// turns out that although __launch_bounds__ can take constexpr, it
|
||||
// can't take a constexpr that has anything to do with templates.
|
||||
// Currently we use launch_bounds that depend on template arguments in
|
||||
// Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, C10_MAX_THREADS_PER_BLOCK and
|
||||
// C10_MIN_BLOCKS_PER_SM are kept as macros.
|
||||
// Suppose you were planning to write __launch_bounds__(a, b), based on your performance tuning on a modern GPU.
|
||||
// Instead, you should write __launch_bounds__(C10_MAX_THREADS_PER_BLOCK(a), C10_MIN_BLOCKS_PER_SM(a, b)),
|
||||
// which will also properly respect limits on old architectures.
|
||||
#define C10_MAX_THREADS_PER_BLOCK(val) (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) : CUDA_THREADS_PER_BLOCK_FALLBACK)
|
||||
#define C10_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) ((((threads_per_block)*(blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) ? (blocks_per_sm) : ((CUDA_MAX_THREADS_PER_SM + (threads_per_block) - 1) / (threads_per_block))))
|
||||
// C10_LAUNCH_BOUNDS is analogous to __launch_bounds__
|
||||
// https://stackoverflow.com/a/8814003 snippet to have macro with an optional argument
|
||||
#define C10_LAUNCH_BOUNDS_0 __launch_bounds__(256, 4) // default launch bounds that should give good occupancy and versatility across all architectures.
|
||||
#define C10_LAUNCH_BOUNDS_1(max_threads_per_block) __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))))
|
||||
#define C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))), (C10_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm))))
|
||||
#define C10_LAUNCH_BOUNDS_X(x,max_threads_per_block,min_blocks_per_sm,FUNC, ...) FUNC
|
||||
#define C10_LAUNCH_BOUNDS(...) C10_LAUNCH_BOUNDS_X(,##__VA_ARGS__, C10_LAUNCH_BOUNDS_2(__VA_ARGS__), C10_LAUNCH_BOUNDS_1(__VA_ARGS__), C10_LAUNCH_BOUNDS_0(__VA_ARGS__))
|
||||
#else
|
||||
#define C10_HOST_DEVICE
|
||||
#define C10_HOST
|
||||
|
|
|
|||
|
|
@ -36,7 +36,11 @@ static PyObject * THPStorage_(isPinned)(THPStorage *self)
|
|||
cudaGetLastError();
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
#if CUDA_VERSION >= 10000
|
||||
return PyBool_FromLong(attr.type == cudaMemoryTypeHost);
|
||||
#else
|
||||
return PyBool_FromLong(attr.memoryType == cudaMemoryTypeHost);
|
||||
#endif
|
||||
#else
|
||||
Py_RETURN_FALSE;
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user