pytorch/caffe2/operators/reduction_ops.cu
Will Constable 4f34cd6d1e Replace all CHECK_ and DCHECK_ with TORCH_* macros (#82032)
Avoid exposing defines that conflict with google logging, since this blocks external usage of libtorch in certain cases.

All the 'interesting' changes should be in these two files, and the rest should just be mechanical changes via sed.
c10/util/logging_is_not_google_glog.h
c10/util/logging_is_google_glog.h

Fixes https://github.com/pytorch/pytorch/issues/81415

cc @miladm @malfet
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82032
Approved by: https://github.com/soumith, https://github.com/miladm
2022-07-26 01:20:44 +00:00

144 lines
3.8 KiB
Plaintext

#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/reduction_ops.h"
#include "caffe2/utils/conversions.h"
#include "caffe2/utils/cub_namespace.cuh"
namespace caffe2 {
REGISTER_CUDA_OPERATOR(SumElements, SumElementsOp<float, CUDAContext>);
REGISTER_CUDA_OPERATOR(SumElementsInt, SumElementsIntOp<int, CUDAContext>);
REGISTER_CUDA_OPERATOR(SumSqrElements, SumSqrElementsOp<CUDAContext>);
REGISTER_CUDA_OPERATOR(RowwiseMax, MaxReductionOp<float, CUDAContext, true>);
REGISTER_CUDA_OPERATOR(ColwiseMax, MaxReductionOp<float, CUDAContext, false>);
REGISTER_CUDA_OPERATOR(
RowwiseMaxGradient,
MaxReductionGradientOp<float, CUDAContext, true>)
REGISTER_CUDA_OPERATOR(
ColwiseMaxGradient,
MaxReductionGradientOp<float, CUDAContext, false>)
REGISTER_CUDA_OPERATOR(
SumElementsGradient,
SumElementsGradientOp<float, CUDAContext>);
template <typename T>
__global__ void
SumElementsGradientKernel(bool average, const int N, const T* dY, T* dX) {
const T value = average ? (*dY) / N : *dY;
CUDA_1D_KERNEL_LOOP(i, N) {
dX[i] = value;
}
}
__global__ void rowwise_max_gradient_kernel(
const int batch_size,
const int M,
const int N,
const float* X,
const float* Y,
const float* dY,
float* dX) {
const int input_size = M * N;
CUDA_1D_KERNEL_LOOP(i, batch_size * M * N) {
const int b_i = i / input_size;
const int b_n = i / input_size / N;
const int y_index = b_i * M + b_n;
if (X[i] == Y[y_index]) {
dX[i] = dY[y_index];
} else {
dX[i] = 0.0;
}
}
}
template <>
bool SumSqrElementsOp<CUDAContext>::RunOnDevice() {
return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
}
__global__ void colwise_max_gradient_kernel(
const int batch_size,
const int M,
const int N,
const float* X,
const float* Y,
const float* dY,
float* dX) {
const int input_size = M * N;
CUDA_1D_KERNEL_LOOP(i, batch_size * M * N) {
const int b_i = i / input_size;
const int b_n = i % input_size % N;
const int y_index = b_i * N + b_n;
if (X[i] == Y[y_index]) {
dX[i] = dY[y_index];
} else {
dX[i] = 0.0;
}
}
}
template <>
bool SumElementsGradientOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto& dY = Input(1);
TORCH_DCHECK_EQ(dY.numel(), 1);
auto* dX = Output(0, X.sizes(), at::dtype<float>());
SumElementsGradientKernel<float>
<<<CAFFE_GET_BLOCKS(X.numel()),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
average_,
X.numel(),
dY.data<float>(),
dX->template mutable_data<float>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
return true;
}
template <typename T, class Context, bool ROWWISE>
bool MaxReductionGradientOp<T, Context, ROWWISE>::RunOnDevice() {
auto& X = Input(0);
auto& Y = Input(1);
auto& dY = Input(2);
auto* dX = Output(0, X.sizes(), at::dtype<T>());
CAFFE_ENFORCE_EQ(X.dim(), 3);
const int batch_size = X.dim32(0);
const int M = X.dim32(1);
const int N = X.dim32(2);
const T* Xdata = X.template data<T>();
const T* Ydata = Y.template data<T>();
const T* dYdata = dY.template data<T>();
T* dXdata = dX->template mutable_data<T>();
const int input_size = M * N;
if (ROWWISE) {
rowwise_max_gradient_kernel<<<
CAFFE_GET_BLOCKS(batch_size * input_size),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
batch_size, M, N, Xdata, Ydata, dYdata, dXdata);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
colwise_max_gradient_kernel<<<
CAFFE_GET_BLOCKS(batch_size * input_size),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
batch_size, M, N, Xdata, Ydata, dYdata, dXdata);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
return true;
}
} // namespace caffe2