mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This PR is to update PyTorch with the following cub changes: - Starting cub 1.13.1, cub requires users to define `CUB_NS_QUALIFIER` if `CUB_NS_PREFIX` is also defined. Besides that, a new mechanism `CUB_WRAPPED_NAMESPACE` is added. And I do the following change to PyTorch: - Starting CUDA 11.5, define `CUB_WRAPPED_NAMESPACE` globally as an nvcc flag. - Fix caffe2 failures caused by the above change. - Add a `aten/src/ATen/cuda/cub_definitions.cuh` that defines helper macros about feature availability. Pull Request resolved: https://github.com/pytorch/pytorch/pull/66219 Reviewed By: bdhirsh Differential Revision: D31626931 Pulled By: ngimel fbshipit-source-id: 97ebf5ef671ade8bf46d0860edc317f22660f26d
112 lines
2.7 KiB
Plaintext
112 lines
2.7 KiB
Plaintext
#include "caffe2/operators/arg_ops.h"
|
|
|
|
#include <limits>
|
|
|
|
#include "caffe2/utils/cub_namespace.cuh"
|
|
#include <cub/block/block_reduce.cuh>
|
|
|
|
#include "caffe2/core/common_gpu.h"
|
|
#include "caffe2/core/context_gpu.h"
|
|
#include "caffe2/utils/fixed_divisor.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
namespace {
|
|
|
|
template <typename K, typename V>
|
|
using KeyValuePair = cub::KeyValuePair<K, V>;
|
|
|
|
template <typename K, typename V>
|
|
using BlockReduce =
|
|
cub::BlockReduce<KeyValuePair<K, V>, CAFFE_CUDA_NUM_THREADS>;
|
|
|
|
template <typename T, class Reducer>
|
|
__global__ void ComputeArgCUDAKernel(
|
|
const int outer_size,
|
|
const int inner_size,
|
|
const FixedDivisor<int> stride,
|
|
const Reducer reducer,
|
|
const T init,
|
|
const T* X,
|
|
int64_t* Y) {
|
|
__shared__ typename BlockReduce<int, T>::TempStorage temp_storage;
|
|
const int d = stride.d();
|
|
for (int idx = blockIdx.x; idx < outer_size; idx += gridDim.x) {
|
|
int i;
|
|
int j;
|
|
stride.DivMod(idx, &i, &j);
|
|
KeyValuePair<int, T> kv = {-1, init};
|
|
for (int k = threadIdx.x; k < inner_size; k += blockDim.x) {
|
|
kv = reducer({k, X[i * inner_size * d + k * d + j]}, kv);
|
|
}
|
|
kv = BlockReduce<int, T>(temp_storage).Reduce(kv, reducer);
|
|
if (threadIdx.x == 0) {
|
|
Y[idx] = static_cast<int64_t>(kv.key);
|
|
}
|
|
__syncthreads();
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
template <>
|
|
template <typename T>
|
|
bool ArgMaxReducer<CUDAContext>::operator()(
|
|
const int prev_size,
|
|
const int next_size,
|
|
const int n,
|
|
const T* X,
|
|
int64_t* Y,
|
|
CUDAContext* context) const {
|
|
const int outer_size = prev_size * next_size;
|
|
const FixedDivisor<int> stride(next_size);
|
|
ComputeArgCUDAKernel<<<
|
|
std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS),
|
|
CAFFE_CUDA_NUM_THREADS,
|
|
0,
|
|
context->cuda_stream()>>>(
|
|
outer_size,
|
|
n,
|
|
stride,
|
|
cub::ArgMax(),
|
|
std::numeric_limits<T>::lowest(),
|
|
X,
|
|
Y);
|
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
|
|
return true;
|
|
}
|
|
|
|
template <>
|
|
template <typename T>
|
|
bool ArgMinReducer<CUDAContext>::operator()(
|
|
const int prev_size,
|
|
const int next_size,
|
|
const int n,
|
|
const T* X,
|
|
int64_t* Y,
|
|
CUDAContext* context) const {
|
|
const int outer_size = prev_size * next_size;
|
|
const FixedDivisor<int> stride(next_size);
|
|
ComputeArgCUDAKernel<<<
|
|
std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS),
|
|
CAFFE_CUDA_NUM_THREADS,
|
|
0,
|
|
context->cuda_stream()>>>(
|
|
outer_size,
|
|
n,
|
|
stride,
|
|
cub::ArgMin(),
|
|
std::numeric_limits<T>::max(),
|
|
X,
|
|
Y);
|
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
|
|
return true;
|
|
}
|
|
|
|
REGISTER_CUDA_OPERATOR(ArgMax, ArgOp<CUDAContext, ArgMaxReducer<CUDAContext>>);
|
|
REGISTER_CUDA_OPERATOR(ArgMin, ArgOp<CUDAContext, ArgMinReducer<CUDAContext>>);
|
|
|
|
} // namespace caffe2
|