Optimize reduce ops for 2d and 3d (#9992)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9992

Optimize reduce ops for 2d and 3d

Reviewed By: houseroad

Differential Revision: D9042505

fbshipit-source-id: 62af2125aa6439106293e59bdf6a2b920792fd2d
This commit is contained in:
Xiaomeng Yang 2018-08-04 13:46:48 -07:00 committed by Facebook Github Bot
parent 29406a2c4c
commit 57d2d4bcff
34 changed files with 1643 additions and 701 deletions

View File

@ -181,8 +181,11 @@ namespace caffe2 {
T* comp_data = Comp_rate->template mutable_data<T>();
math::Sum<T, Context>(
Mask.size(), Mask.template data<T>(), comp_data, &context_);
math::Scale<T, Context>(
1, static_cast<T>(1.) / Mask.size(), comp_data, comp_data,
math::Scale<float, T, Context>(
1,
static_cast<T>(1.) / Mask.size(),
comp_data,
comp_data,
&context_);
}
return true;
@ -263,8 +266,11 @@ namespace caffe2 {
T* comp_data = comp_r_buf_.template mutable_data<T>();
math::Sum<T, Context>(
Mask.size(), Mask.template data<T>(), comp_data, &context_);
math::Scale<T, Context>(
1, static_cast<T>(1.) / Mask.size(), comp_data, comp_data,
math::Scale<float, T, Context>(
1,
static_cast<T>(1.) / Mask.size(),
comp_data,
comp_data,
&context_);
// update W size window
// Notice here we need to maintain state in OP.

View File

@ -29,8 +29,8 @@ bool BatchMomentsOp<float, CPUContext>::ComputeBatchMomentsNCHW(
X_ptr += stride;
}
const float scale = 1.0f / static_cast<float>(N * HxW);
math::Scale<float, CPUContext>(C, scale, mu, mu, &context_);
math::Scale<float, CPUContext>(C, scale, var, var, &context_);
math::Scale<float, float, CPUContext>(C, scale, mu, mu, &context_);
math::Scale<float, float, CPUContext>(C, scale, var, var, &context_);
return true;
}
@ -71,7 +71,7 @@ bool BatchMomentsGradientOp<float, CPUContext>::ComputeBatchMomentsGradientNCHW(
dX_ptr += stride;
}
const float scale = 1.0f / static_cast<float>(N * HxW);
math::Scale<float, CPUContext>(N * C * HxW, scale, dX, dX, &context_);
math::Scale<float, float, CPUContext>(N * C * HxW, scale, dX, dX, &context_);
return true;
}
@ -89,7 +89,7 @@ bool BatchMomentsGradientOp<float, CPUContext>::ComputeBatchMomentsGradientNHWC(
dX_arr = ConstEigenArrayMap<float>(X, C, N * HxW).colwise() *
ConstEigenVectorArrayMap<float>(dvar, C) * 2.0f;
dX_arr.colwise() += ConstEigenVectorArrayMap<float>(dmu, C);
math::Scale<float, CPUContext>(N * C * HxW, scale, dX, dX, &context_);
math::Scale<float, float, CPUContext>(N * C * HxW, scale, dX, dX, &context_);
return true;
}

View File

@ -177,7 +177,7 @@ bool CosineSimilarityGradientOp<float, CPUContext>::RunOnDevice() {
math::Dot<float, CPUContext>(
D, X_data + offset, Y_data + offset, &XY, &context_);
math::Scale<float, CPUContext>(
math::Scale<float, float, CPUContext>(
D, dCos_data[i] / XYN, Y_data + offset, dX_data + offset, &context_);
math::Axpy(
D,
@ -186,7 +186,7 @@ bool CosineSimilarityGradientOp<float, CPUContext>::RunOnDevice() {
dX_data + offset,
&context_);
math::Scale<float, CPUContext>(
math::Scale<float, float, CPUContext>(
D, dCos_data[i] / XYN, X_data + offset, dY_data + offset, &context_);
math::Axpy(
D,
@ -282,9 +282,9 @@ bool DotProductGradientOp<float, CPUContext>::RunOnDevice() {
auto* dY_data = dY->template mutable_data<float>();
for (int i = 0; i < N; ++i) { // TODO: multithreading
auto offset = i * D;
math::Scale<float, CPUContext>(
math::Scale<float, float, CPUContext>(
D, dDot_data[i], X_data + offset, dY_data + offset, &context_);
math::Scale<float, CPUContext>(
math::Scale<float, float, CPUContext>(
D, dDot_data[i], Y_data + offset, dX_data + offset, &context_);
}
return true;

View File

@ -116,7 +116,7 @@ bool SquaredL2DistanceGradientOp<float, CUDAContext>::RunOnDevice() {
dX->template mutable_data<float>());
// The gradient of the other side is basically the negative.
math::Scale<float, CUDAContext>(
math::Scale<float, float, CUDAContext>(
X.size(),
-1,
dX->data<float>(),

View File

@ -50,7 +50,7 @@ class SquaredL2DistanceGradientOp final : public Operator<Context> {
dX->template mutable_data<T>(),
&context_);
for (int i = 0; i < N; ++i) {
math::Scale<T, Context>(
math::Scale<T, T, Context>(
D,
dDistance.template data<T>() + i,
dX->template data<T>() + i * D,
@ -58,7 +58,7 @@ class SquaredL2DistanceGradientOp final : public Operator<Context> {
&context_);
}
// The gradient of the other side is basically the negative.
math::Scale<T, Context>(
math::Scale<T, T, Context>(
X.size(),
-1,
dX->template data<T>(),
@ -245,16 +245,16 @@ class DotProductWithPaddingGradientOp final : public Operator<Context> {
std::vector<T> tmp_data(DS);
math::Set<T, Context>(DS, 0.0, dS_data, &context_);
for (int j = 0; j < DL / DS; j++) {
math::Scale<T, Context>(
math::Scale<T, T, Context>(
DS, dDot_data[i], S_data, dL_data + j * DS, &context_);
math::Scale<T, Context>(
math::Scale<T, T, Context>(
DS, dDot_data[i], L_data + j * DS, tmp_data.data(), &context_);
math::Axpy<T, Context>(DS, 1.0, tmp_data.data(), dS_data, &context_);
}
} else {
math::Scale<T, Context>(
math::Scale<T, T, Context>(
D, dDot_data[i], X_data + offsetX, dY_data + offsetY, &context_);
math::Scale<T, Context>(
math::Scale<T, T, Context>(
D, dDot_data[i], Y_data + offsetY, dX_data + offsetX, &context_);
}

View File

@ -56,6 +56,7 @@ struct AddFunctor {
C_dims.data(),
A_axes.size(),
A_axes.data(),
TGrad(1),
dC,
dA,
context);
@ -64,6 +65,7 @@ struct AddFunctor {
C_dims.data(),
B_axes.size(),
B_axes.data(),
TGrad(1),
dC,
dB,
context);

View File

@ -56,6 +56,7 @@ struct SubFunctor {
C_dims.data(),
A_axes.size(),
A_axes.data(),
TGrad(1),
dC,
dA,
context);
@ -64,12 +65,10 @@ struct SubFunctor {
C_dims.data(),
B_axes.size(),
B_axes.data(),
TGrad(-1),
dC,
dB,
context);
const int size = std::accumulate(
B_dims.cbegin(), B_dims.cend(), 1, std::multiplies<int>());
math::Neg(size, dB, dB, context);
return true;
}
};

View File

@ -96,6 +96,7 @@ class ExpandGradientOp final : public Operator<Context> {
dY_dims.data(),
axes.size(),
axes.data(),
T(1),
dY.template data<T>(),
dX->template mutable_data<T>(),
&context_);

View File

@ -23,7 +23,7 @@ void averaged_loss_op_cpu_impl(
caffe2::math::Sum<T, Context>(
X.size(), X.template data<T>(), data, static_cast<Context*>(context), &state->scratch);
if (X.size() > 0) {
caffe2::math::Scale<T, Context>(
caffe2::math::Scale<T, T, Context>(
1,
static_cast<T>(1.) / X.size(),
sum->template data<T>(),

View File

@ -129,7 +129,7 @@ bool LayerNormOp<CUDAContext>::DoRunWithType<float>() {
context_.cuda_stream());
// Second stage: Normalize by feature vector dim
math::Scale<float, CUDAContext>(
math::Scale<float, float, CUDAContext>(
left,
1.0f / right,
mean->mutable_data<float>(),
@ -151,7 +151,7 @@ bool LayerNormOp<CUDAContext>::DoRunWithType<float>() {
context_.cuda_stream());
// Second stage: Normalize by feature vector dim
math::Scale<float, CUDAContext>(
math::Scale<float, float, CUDAContext>(
left,
1.0f / right,
stdev->mutable_data<float>(),

View File

@ -137,6 +137,7 @@ struct MinReducer {
dims.data(),
axes.size(),
axes.data(),
T(1),
X_data,
Y_data,
context);
@ -168,6 +169,7 @@ struct MaxReducer {
dims.data(),
axes.size(),
axes.data(),
T(1),
X_data,
Y_data,
context);
@ -199,6 +201,7 @@ struct SumReducer {
dims.data(),
axes.size(),
axes.data(),
T(1),
X_data,
Y_data,
context);
@ -240,6 +243,7 @@ struct MeanReducer {
dims.data(),
axes.size(),
axes.data(),
T(1),
X_data,
Y_data,
context);
@ -267,7 +271,7 @@ struct MeanReducer {
dY_dims.cbegin(), dY_dims.cend(), 1, std::multiplies<int>());
const int dX_size = std::accumulate(
dX_dims.cbegin(), dX_dims.cend(), 1, std::multiplies<int>());
math::Scale<T, Context>(
math::Scale<T, T, Context>(
dX_size,
static_cast<float>(dY_size) / static_cast<float>(dX_size),
dX_data,
@ -291,6 +295,7 @@ struct L1Reducer {
dims.data(),
axes.size(),
axes.data(),
T(1),
X_data,
Y_data,
context);
@ -322,6 +327,7 @@ struct L2Reducer {
dims.data(),
axes.size(),
axes.data(),
T(1),
X_data,
Y_data,
context);

View File

@ -31,7 +31,7 @@ class SumElementsOp : public Operator<Context> {
math::Sum<T, Context>(
X.size(), X.template data<T>(), data, &context_, &scratch_);
if (average_ && X.size() > 0) {
math::Scale<T, Context>(
math::Scale<float, T, Context>(
1,
static_cast<T>(1.) / X.size(),
sum->template data<T>(),
@ -113,7 +113,7 @@ class SumSqrElementsOp : public Operator<Context> {
&context_,
&scratch_);
if (average && X.size() > 0) {
math::Scale<T, Context>(
math::Scale<float, T, Context>(
1,
float(1.) / X.size(),
sum->template data<T>(),

View File

@ -20,7 +20,7 @@ class ScaleOp final : public Operator<Context> {
auto& X = Input(0);
auto* Y = Output(0);
Y->ResizeLike(X);
math::Scale<T, Context>(
math::Scale<float, T, Context>(
X.size(),
scale_,
X.template data<T>(),

View File

@ -380,7 +380,7 @@ bool SoftmaxWithLossOp<float, CUDAContext>::RunOnDevice() {
losses_.size(), losses_.data<float>(), avg_loss_data, &context_, &scratch_);
// Average of input batch size
if (total_weight > 0) {
math::Scale<float, CUDAContext>(
math::Scale<float, float, CUDAContext>(
1, scale_ / total_weight, avg_loss_data, avg_loss_data, &context_);
}
@ -466,7 +466,7 @@ bool SpatialSoftmaxWithLossOp<float, CUDAContext>::RunOnDevice() {
// Final scaling
if (h_total_weight > 0) {
math::Scale<float, CUDAContext>(
math::Scale<float, float, CUDAContext>(
1, scale_ / h_total_weight, avg_loss_data, avg_loss_data, &context_);
}
@ -571,14 +571,14 @@ bool SoftmaxWithLossGradientOp<float, CUDAContext>::RunOnDevice() {
// Scale by d_avg_loss / N
if (total_weight > 0) {
math::Scale<float, CUDAContext>(
math::Scale<float, float, CUDAContext>(
dX->size(),
scale_ / total_weight,
dX->data<float>(),
dX->template mutable_data<float>(),
&context_);
}
math::Scale<float, CUDAContext>(
math::Scale<float, float, CUDAContext>(
dX->size(),
d_avg_loss.data<float>(),
dX->data<float>(),
@ -661,14 +661,14 @@ bool SpatialSoftmaxWithLossGradientOp<float, CUDAContext>::RunOnDevice() {
// Final scaling
if (h_total_weight > 0) {
math::Scale<float, CUDAContext>(
math::Scale<float, float, CUDAContext>(
dX->size(),
scale_ / h_total_weight,
dX->data<float>(),
dX->template mutable_data<float>(),
&context_);
}
math::Scale<float, CUDAContext>(
math::Scale<float, float, CUDAContext>(
dX->size(),
d_avg_loss.data<float>(),
dX->data<float>(),

View File

@ -348,7 +348,7 @@ bool SoftmaxWithLossGradientOp<float, CPUContext>::RunOnDevice() {
// Scale by d_avg_loss / N
if (total_weight > 0) {
math::Scale<float, CPUContext>(
math::Scale<float, float, CPUContext>(
dX->size(),
scale_ / total_weight * d_avg_loss.data<float>()[0],
dX->data<float>(),

View File

@ -217,14 +217,14 @@ bool SpatialSoftmaxWithLossGradientOp<float, CPUContext>::RunOnDevice() {
}
if (total_weight > 0) {
math::Scale<float, CPUContext>(
math::Scale<float, float, CPUContext>(
dX->size(),
scale_ / total_weight,
dX->data<float>(),
dX_data,
&context_);
}
math::Scale<float, CPUContext>(
math::Scale<float, float, CPUContext>(
dX->size(),
d_avg_loss.data<float>(),
dX->data<float>(),

View File

@ -43,7 +43,7 @@ class SquareRootDivideOp final : public Operator<Context> {
auto scale = scalePtr[i];
CAFFE_ENFORCE(scale >= 0, scale, " < 0");
auto multiplier = scale == 0 ? 1.0 : 1 / std::sqrt(scale);
math::Scale<TData, Context>(
math::Scale<float, TData, Context>(
exampleSize,
multiplier,
dataPtr + i * exampleSize,

View File

@ -1,7 +1,9 @@
#ifndef CAFFE2_OPERATORS_UTILITY_OPS_H_
#define CAFFE2_OPERATORS_UTILITY_OPS_H_
#include <math.h>
#include <cmath>
#include <map>
#include <utility>
#include "caffe2/core/common_omp.h"
#include "caffe2/core/context.h"
@ -9,11 +11,9 @@
#include "caffe2/core/operator.h"
#include "caffe2/core/types.h"
#include "caffe2/operators/gather_op.h"
#include "caffe2/utils/conversions.h"
#include "caffe2/utils/math.h"
#include <map>
#include <utility>
namespace caffe2 {
template <class Context>
@ -331,7 +331,7 @@ class WeightedSumOp : public Operator<Context> {
int size = X0.size();
auto* output = Output(0);
output->ResizeLike(X0);
math::Scale<DstType, Context>(
math::Scale<float, DstType, Context>(
size,
weight0.template data<float>(),
X0.template data<DstType>(),
@ -388,7 +388,7 @@ class WeightedSumGradientOp : public Operator<Context> {
auto* cur_dX = Output(i);
cur_dX->ResizeLike(dY);
math::Scale<DstType, Context>(
math::Scale<float, DstType, Context>(
size,
cur_w.template data<float>(),
dY_data,
@ -521,8 +521,8 @@ class ScatterWeightedSumOp : public Operator<Context> {
for (int i = 0; i < K; ++i) {
Index idx = idxs[i];
// double-checking the indices, but it's fine as it's DCHECK only
DCHECK(0 <= idx && idx < N) << "Index out of bounds: " << idx
<< ", range 0 to " << N;
DCHECK(0 <= idx && idx < N)
<< "Index out of bounds: " << idx << ", range 0 to " << N;
math::AxpyFixedSize<T, Context, FixedSize>(
block_size,
w,
@ -662,8 +662,8 @@ class ScatterAssignOp : public Operator<Context> {
for (int i = 0; i < K; ++i) {
Index idx = idxs[i];
// double-checking the indices, but it's fine as it's DCHECK only
DCHECK(0 <= idx && idx < N) << "Index out of bounds: " << idx
<< ", range 0 to " << N;
DCHECK(0 <= idx && idx < N)
<< "Index out of bounds: " << idx << ", range 0 to " << N;
context_.template CopySameDevice<T>(
block_size, slicesData + block_size * i, data + block_size * idx);
}

View File

@ -69,7 +69,7 @@ static void EmbeddingLookupGenericSlow(
}
if (normalize_by_lengths && lengths[m]) {
// hack: context is not really used
math::Scale<OutType, CPUContext>(
math::Scale<float, OutType, CPUContext>(
block_size, 1.f / lengths[m], out, out, nullptr);
}
out += block_size;

View File

@ -72,7 +72,7 @@ static void Fused8BitRowwiseEmbeddingLookupGenericSlow(
}
if (normalize_by_lengths && lengths[m]) {
// hack: context is not really used
math::Scale<OutType, CPUContext>(
math::Scale<float, OutType, CPUContext>(
block_size, 1.f / lengths[m], out, out, nullptr);
}
out += block_size;

View File

@ -43,7 +43,7 @@ class ClipTensorByScalingOp final : public Operator<Context> {
if (*val_data > threshold_) {
float ratio = threshold_ / *val_data;
math::Scale<float, Context>(
math::Scale<float, float, Context>(
clipped->size(),
ratio,
input_tensor_data,

View File

@ -1728,32 +1728,63 @@ void Select<float16, HIPContext>(
}
namespace {
template <typename T>
__global__ void ScaleKernel(const int n, const float alpha, const T* x, T* y) {
HIP_1D_KERNEL_LOOP(i, n) {
// y[i] = convert::To<float,T>(convert::To<T, float>(x[i]) * alpha);
y[i] = convert::Get<T>(convert::Get<float>(x[i]) * alpha);
}
}
template <typename T>
template <typename TAlpha, typename TData>
__global__ void
ScaleKernelDeviceAlpha(const int n, const float* alpha, const T* x, T* y) {
ScaleKernel(const int n, const TAlpha alpha, const TData* x, TData* y) {
HIP_1D_KERNEL_LOOP(i, n) {
y[i] = x[i] * (*alpha);
y[i] = x[i] * static_cast<TData>(alpha);
}
}
template <typename T>
__global__ void PowKernel(const int n, const T* x, const T exponent, T* y) {
template <typename TAlpha, typename TData>
__global__ void
ScaleKernel(const int n, const TAlpha* alpha, const TData* x, TData* y) {
HIP_1D_KERNEL_LOOP(i, n) {
y[i] = powf(x[i], exponent);
y[i] = x[i] * static_cast<TData>(*alpha);
}
}
template <>
__global__ void ScaleKernel<float16, float16>(
const int n,
const float16 alpha,
const float16* x,
float16* y) {
HIP_1D_KERNEL_LOOP(i, n) {
y[i] = convert::To<float, float16>(
convert::To<float16, float>(x[i]) * convert::To<float16, float>(alpha));
}
}
template <>
__global__ void ScaleKernel<float16, float16>(
const int n,
const float16* alpha,
const float16* x,
float16* y) {
HIP_1D_KERNEL_LOOP(i, n) {
y[i] = convert::To<float, float16>(
convert::To<float16, float>(x[i]) *
convert::To<float16, float>(*alpha));
}
}
// fp16 specialization
template <>
__global__ void ScaleKernelDeviceAlpha(
__global__ void ScaleKernel<float, float16>(
const int n,
const float alpha,
const float16* x,
float16* y) {
HIP_1D_KERNEL_LOOP(i, n) {
y[i] =
convert::To<float, float16>(convert::To<float16, float>(x[i]) * alpha);
}
}
template <>
__global__ void ScaleKernel<float, float16>(
const int n,
const float* alpha,
const float16* x,
@ -1764,6 +1795,13 @@ __global__ void ScaleKernelDeviceAlpha(
}
}
template <typename T>
__global__ void PowKernel(const int n, const T* x, const T exponent, T* y) {
HIP_1D_KERNEL_LOOP(i, n) {
y[i] = powf(x[i], exponent);
}
}
} // namespace
template <>
@ -1785,81 +1823,50 @@ void Powx<float, HIPContext>(
y);
}
template <>
void Scale<float, HIPContext>(
const int n,
const float alpha,
const float* x,
float* y,
HIPContext* context) {
hipLaunchKernelGGL(
(ScaleKernel<float>),
dim3(CAFFE_GET_BLOCKS(n)),
dim3(CAFFE_HIP_NUM_THREADS),
0,
context->hip_stream(),
n,
alpha,
x,
y);
}
template <>
void Scale<float16, HIPContext>(
const int n,
const float alpha,
const float16* x,
float16* y,
HIPContext* context) {
hipLaunchKernelGGL(
(ScaleKernel<float16>),
dim3(CAFFE_GET_BLOCKS(n)),
dim3(CAFFE_HIP_NUM_THREADS),
0,
context->hip_stream(),
n,
alpha,
x,
y);
}
template <>
void Scale<float, HIPContext>(
const int n,
const float* alpha,
const float* x,
float* y,
HIPContext* context) {
hipLaunchKernelGGL(
(ScaleKernelDeviceAlpha<float>),
dim3(CAFFE_GET_BLOCKS(n)),
dim3(CAFFE_HIP_NUM_THREADS),
0,
context->hip_stream(),
n,
alpha,
x,
y);
}
template <>
void Scale<float16, HIPContext>(
const int n,
const float* alpha,
const float16* x,
float16* y,
HIPContext* context) {
hipLaunchKernelGGL(
(ScaleKernelDeviceAlpha<float16>),
dim3(CAFFE_GET_BLOCKS(n)),
dim3(CAFFE_HIP_NUM_THREADS),
0,
context->hip_stream(),
n,
alpha,
x,
y);
#define CAFFE2_SPECIALIZED_HIP_SCALE(TAlpha, TData) \
template <> \
void Scale<TAlpha, TData, HIPContext>( \
const int n, \
const TAlpha alpha, \
const TData* x, \
TData* y, \
HIPContext* context) { \
hipLaunchKernelGGL( \
(ScaleKernel<TAlpha, TData>), \
dim3(CAFFE_GET_BLOCKS(n)), \
dim3(CAFFE_HIP_NUM_THREADS), \
0, \
context->hip_stream(), \
n, \
alpha, \
x, \
y); \
} \
template <> \
void Scale<TAlpha, TData, HIPContext>( \
const int n, \
const TAlpha* alpha, \
const TData* x, \
TData* y, \
HIPContext* context) { \
hipLaunchKernelGGL( \
(ScaleKernel<TAlpha, TData>), \
dim3(CAFFE_GET_BLOCKS(n)), \
dim3(CAFFE_HIP_NUM_THREADS), \
0, \
context->hip_stream(), \
n, \
alpha, \
x, \
y); \
}
CAFFE2_SPECIALIZED_HIP_SCALE(float, float)
CAFFE2_SPECIALIZED_HIP_SCALE(float16, float16)
CAFFE2_SPECIALIZED_HIP_SCALE(float, float16)
CAFFE2_SPECIALIZED_HIP_SCALE(double, double)
CAFFE2_SPECIALIZED_HIP_SCALE(std::int32_t, std::int32_t)
CAFFE2_SPECIALIZED_HIP_SCALE(std::int64_t, std::int64_t)
#undef CAFFE2_SPECIALIZED_HIP_SCALE
template <>
void Axpy<float, HIPContext>(
@ -2622,6 +2629,7 @@ __global__ void RowwiseReduceKernel(
const int cols,
const Reducer reducer,
const T init,
const T alpha,
const T* X,
T* Y) {
__shared__ typename BlockReduce<T>::TempStorage temp_storage;
@ -2632,7 +2640,7 @@ __global__ void RowwiseReduceKernel(
}
val = BlockReduce<T>(temp_storage).Reduce(val, reducer);
if (threadIdx.x == 0) {
Y[i] = val;
Y[i] = val * alpha;
}
__syncthreads();
}
@ -2644,6 +2652,7 @@ __global__ void ColwiseReduceKernel(
const int cols,
const Reducer reducer,
const T init,
const T alpha,
const T* X,
T* Y) {
__shared__ typename BlockReduce<T>::TempStorage temp_storage;
@ -2654,7 +2663,7 @@ __global__ void ColwiseReduceKernel(
}
val = BlockReduce<T>(temp_storage).Reduce(val, reducer);
if (threadIdx.x == 0) {
Y[i] = val;
Y[i] = val * alpha;
}
__syncthreads();
}
@ -2676,6 +2685,7 @@ __global__ void ColwiseReduceKernel(
D, \
cub::Max(), \
std::numeric_limits<T>::lowest(), \
T(1), \
x, \
y); \
}
@ -2696,6 +2706,7 @@ CAFFE2_SPECIALIZED_HIP_ROWWISE_MAX(float)
D, \
cub::Max(), \
std::numeric_limits<T>::lowest(), \
T(1), \
x, \
y); \
}
@ -2740,6 +2751,7 @@ __global__ void ReduceTensorHIPKernel(
SimpleArray<FixedDivisor<int>, D> Y_dims,
const Reducer reducer,
const T init,
const T alpha,
const T* X,
T* Y) {
__shared__ typename BlockReduce<T>::TempStorage temp_storage;
@ -2762,7 +2774,7 @@ __global__ void ReduceTensorHIPKernel(
}
val = BlockReduce<T>(temp_storage).Reduce(val, reducer);
if (threadIdx.x == 0) {
Y[i] = val;
Y[i] = val * alpha;
}
__syncthreads();
}
@ -2775,7 +2787,8 @@ void ReduceTensorHIPImpl(
const int* dims,
const int* axes,
const Reducer& reducer,
const T& init,
const T init,
const T alpha,
const T* X,
T* Y,
HIPContext* context) {
@ -2797,6 +2810,7 @@ void ReduceTensorHIPImpl(
Y_dims,
reducer,
init,
alpha,
X,
Y);
}
@ -2808,14 +2822,12 @@ void ReduceTensorHIP(
const int num_axes,
const int* axes,
const Reducer& reducer,
const T& init,
const T init,
const T alpha,
const T* X,
T* Y,
HIPContext* context) {
CAFFE_ENFORCE_LE(num_axes, num_dims);
if (X == Y) {
return;
}
std::vector<int> transpose_axes(num_dims);
utils::ComputeTransposeAxesForReduceOp(
num_dims, num_axes, axes, transpose_axes.data());
@ -2828,7 +2840,17 @@ void ReduceTensorHIP(
for (int i = pivot; i < num_dims; ++i) {
inner_size *= dims[transpose_axes[i]];
}
if (outer_size > 0 && inner_size > 0) {
if (outer_size == 0) {
return;
}
if (inner_size == 0) {
Set<T, HIPContext>(outer_size, alpha * init, Y, context);
return;
}
if (inner_size == 1) {
Scale<T, T, HIPContext>(outer_size, alpha, X, Y, context);
return;
}
if (transpose_axes[pivot] == pivot) {
hipLaunchKernelGGL(
(RowwiseReduceKernel<T>),
@ -2840,6 +2862,7 @@ void ReduceTensorHIP(
inner_size,
reducer,
init,
alpha,
X,
Y);
return;
@ -2855,12 +2878,10 @@ void ReduceTensorHIP(
transpose_axes.data(),
reducer,
init,
alpha,
X,
Y,
context);
} else if (outer_size > 0) {
math::Set<T, HIPContext>(outer_size, init, Y, context);
}
}
template <typename T>
@ -2869,19 +2890,27 @@ void ReduceMeanHIPImpl(
const int* dims,
const int num_axes,
const int* axes,
const T alpha,
const T* X,
T* Y,
HIPContext* context) {
ReduceTensorHIP(
num_dims, dims, num_axes, axes, cub::Sum(), T(0), X, Y, context);
const int X_size =
std::accumulate(dims, dims + num_dims, 1, std::multiplies<int>());
int scale = 1;
for (int i = 0; i < num_axes; ++i) {
scale *= dims[axes[i]];
}
const int Y_size = X_size / scale;
Scale<T, HIPContext>(Y_size, 1.0f / static_cast<float>(scale), Y, Y, context);
ReduceTensorHIP(
num_dims,
dims,
num_axes,
axes,
cub::Sum(),
T(0),
alpha / static_cast<T>(scale),
X,
Y,
context);
}
} // namespace
@ -2893,6 +2922,7 @@ void ReduceMeanHIPImpl(
const int* dims, \
const int num_axes, \
const int* axes, \
const T alpha, \
const T* X, \
T* Y, \
HIPContext* context) { \
@ -2903,6 +2933,7 @@ void ReduceMeanHIPImpl(
axes, \
cub::Min(), \
std::numeric_limits<T>::max(), \
alpha, \
X, \
Y, \
context); \
@ -2920,6 +2951,7 @@ CAFFE2_SPECIALIZED_HIP_REDUCE_MIN(double)
const int* dims, \
const int num_axes, \
const int* axes, \
const T alpha, \
const T* X, \
T* Y, \
HIPContext* context) { \
@ -2930,6 +2962,7 @@ CAFFE2_SPECIALIZED_HIP_REDUCE_MIN(double)
axes, \
cub::Max(), \
std::numeric_limits<T>::lowest(), \
alpha, \
X, \
Y, \
context); \
@ -2947,11 +2980,21 @@ CAFFE2_SPECIALIZED_HIP_REDUCE_MAX(double)
const int* dims, \
const int num_axes, \
const int* axes, \
const T alpha, \
const T* X, \
T* Y, \
HIPContext* context) { \
ReduceTensorHIP( \
num_dims, dims, num_axes, axes, cub::Sum(), T(0), X, Y, context); \
num_dims, \
dims, \
num_axes, \
axes, \
cub::Sum(), \
T(0), \
alpha, \
X, \
Y, \
context); \
}
CAFFE2_SPECIALIZED_HIP_REDUCE_SUM(std::int32_t)
CAFFE2_SPECIALIZED_HIP_REDUCE_SUM(std::int64_t)
@ -2966,10 +3009,12 @@ CAFFE2_SPECIALIZED_HIP_REDUCE_SUM(double)
const int* dims, \
const int num_axes, \
const int* axes, \
const T alpha, \
const T* X, \
T* Y, \
HIPContext* context) { \
ReduceMeanHIPImpl<T>(num_dims, dims, num_axes, axes, X, Y, context); \
ReduceMeanHIPImpl<T>( \
num_dims, dims, num_axes, axes, alpha, X, Y, context); \
}
CAFFE2_SPECIALIZED_HIP_REDUCE_MEAN(float)
#undef CAFFE2_SPECIALIZED_HIP_REDUCE_MEAN

View File

@ -186,6 +186,7 @@ void ReduceMin(
const int* dims,
const int num_axes,
const int* axes,
const T alpha,
const T* X,
T* Y,
Context* context);
@ -196,6 +197,7 @@ void ReduceMax(
const int* dims,
const int num_axes,
const int* axes,
const T alpha,
const T* X,
T* Y,
Context* context);
@ -206,6 +208,7 @@ void ReduceSum(
const int* dims,
const int num_axes,
const int* axes,
const T alpha,
const T* X,
T* Y,
Context* context);
@ -216,6 +219,7 @@ void ReduceMean(
const int* dims,
const int num_axes,
const int* axes,
const T alpha,
const T* X,
T* Y,
Context* context);
@ -226,6 +230,7 @@ void ReduceL1(
const int* dims,
const int num_axes,
const int* axes,
const T alpha,
const T* X,
T* Y,
Context* context);
@ -236,6 +241,7 @@ void ReduceL2(
const int* dims,
const int num_axes,
const int* axes,
const T alpha,
const T* X,
T* Y,
Context* context);
@ -464,14 +470,24 @@ void Select(
T* y,
Context* context);
template <typename T, class Context>
void Scale(const int N, const float alpha, const T* x, T* y, Context* context);
template <typename TAlpha, typename TData, class Context>
void Scale(
const int N,
const TAlpha alpha,
const TData* x,
TData* y,
Context* context);
// Different from the Scale function above, if alpha is passed in
// as a pointer, we will assume that it lives on the Context device,
// for example on GPU.
template <typename T, class Context>
void Scale(const int N, const float* alpha, const T* x, T* y, Context* context);
template <typename TAlpha, typename TData, class Context>
void Scale(
const int N,
const TAlpha* alpha,
const TData* x,
TData* y,
Context* context);
template <typename T, class Context>
void Axpy(const int N, const float alpha, const T* x, T* y, Context* context);

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -424,6 +424,7 @@ class ReduceTensorGPUTest : public testing::Test {
X_dims.data(),
axes.size(),
axes.data(),
1.0f,
X_->data<float>(),
Y_->mutable_data<float>(),
cuda_context_.get());
@ -445,11 +446,12 @@ TEST_F(ReduceTensorGPUTest, ReduceMinGPUTest) {
const int* dims,
const int num_axes,
const int* axes,
const float alpha,
const float* X,
float* Y,
CUDAContext* context) {
return math::ReduceMin<float, CUDAContext>(
num_dims, dims, num_axes, axes, X, Y, context);
num_dims, dims, num_axes, axes, alpha, X, Y, context);
};
// Test for 1D tensor.
RunRedcueTensorTest(reduce_min, {3}, {0}, {1.0f, 2.0f, 3.0f}, {1.0f});
@ -499,11 +501,12 @@ TEST_F(ReduceTensorGPUTest, ReduceMaxGPUTest) {
const int* dims,
const int num_axes,
const int* axes,
const float alpha,
const float* X,
float* Y,
CUDAContext* context) {
return math::ReduceMax<float, CUDAContext>(
num_dims, dims, num_axes, axes, X, Y, context);
num_dims, dims, num_axes, axes, alpha, X, Y, context);
};
// Test for 1D tensor.
RunRedcueTensorTest(reduce_max, {3}, {0}, {1.0f, 2.0f, 3.0f}, {3.0f});

View File

@ -452,6 +452,7 @@ class ReduceTensorTest : public testing::Test {
X_dims.data(),
axes.size(),
axes.data(),
1.0f,
X_.data<float>(),
Y_.mutable_data<float>(),
cpu_context_.get());
@ -472,11 +473,12 @@ TEST_F(ReduceTensorTest, ReduceMinTest) {
const int* dims,
const int num_axes,
const int* axes,
const float alpha,
const float* X,
float* Y,
CPUContext* context) {
return math::ReduceMin<float, CPUContext>(
num_dims, dims, num_axes, axes, X, Y, context);
num_dims, dims, num_axes, axes, alpha, X, Y, context);
};
// Test for 1D tensor.
RunRedcueTensorTest(reduce_min, {3}, {0}, {1.0f, 2.0f, 3.0f}, {1.0f});
@ -523,11 +525,12 @@ TEST_F(ReduceTensorTest, ReduceMaxTest) {
const int* dims,
const int num_axes,
const int* axes,
const float alpha,
const float* X,
float* Y,
CPUContext* context) {
return math::ReduceMax<float, CPUContext>(
num_dims, dims, num_axes, axes, X, Y, context);
num_dims, dims, num_axes, axes, alpha, X, Y, context);
};
// Test for 1D tensor.
RunRedcueTensorTest(reduce_max, {3}, {0}, {1.0f, 2.0f, 3.0f}, {3.0f});

View File

@ -41,6 +41,75 @@ bool IsIdentityPermutation(const int n, const int* perm) {
return true;
}
bool IsRowwiseReduce(
const int ndim,
const int* A_dims,
const int* B_dims,
int* rows,
int* cols) {
*cols = 1;
int pivot = ndim - 1;
for (; pivot >= 0 && B_dims[pivot] == 1; --pivot) {
*cols *= A_dims[pivot];
}
*rows = 1;
for (int i = pivot; i >= 0; --i) {
if (A_dims[i] != B_dims[i]) {
return false;
}
*rows *= A_dims[i];
}
return true;
}
bool IsColwiseReduce(
const int ndim,
const int* A_dims,
const int* B_dims,
int* rows,
int* cols) {
*rows = 1;
int pivot = 0;
for (; pivot < ndim && B_dims[pivot] == 1; ++pivot) {
*rows *= A_dims[pivot];
}
*cols = 1;
for (int i = pivot; i < ndim; ++i) {
if (A_dims[i] != B_dims[i]) {
return false;
}
*cols *= A_dims[i];
}
return true;
}
bool IsBothEndsReduce(
const int ndim,
const int* A_dims,
const int* B_dims,
int* pre,
int* mid,
int* nxt) {
*nxt = 1;
int r = ndim - 1;
for (; r >= 0 && B_dims[r] == 1; --r) {
*nxt *= A_dims[r];
}
*pre = 1;
int l = 0;
for (; l <= r && B_dims[l] == 1; ++l) {
*pre *= A_dims[l];
}
*mid = 1;
for (int i = l; i <= r; ++i) {
if (A_dims[i] != B_dims[i]) {
return false;
}
*mid *= A_dims[i];
}
return true;
}
void ComputeBroadcastBinaryOpDims(
const int A_ndim,
const int* A_dims,
@ -146,7 +215,7 @@ bool IsColwiseBroadcastBinaryOp(
return true;
}
bool IsMiddleBroadcastBinaryOp(
bool IsBothEndsBroadcastBinaryOp(
const int ndim,
const int* A_dims,
const int* B_dims,

View File

@ -49,6 +49,28 @@ int GetIndexFromDims(const int n, const int* dims, const int* index);
// Checks if the input permutation is an identity permutation;
bool IsIdentityPermutation(const int n, const int* perm);
bool IsRowwiseReduce(
const int ndim,
const int* X_dims,
const int* Y_dims,
int* rows,
int* cols);
bool IsColwiseReduce(
const int ndim,
const int* X_dims,
const int* Y_dims,
int* rows,
int* cols);
bool IsBothEndsReduce(
const int ndim,
const int* X_dims,
const int* Y_dims,
int* pre,
int* mid,
int* nxt);
// Computest the broadcast binary operation dims.
void ComputeBroadcastBinaryOpDims(
const int A_ndim,
@ -75,7 +97,7 @@ bool IsColwiseBroadcastBinaryOp(
int* cols,
bool* broadcast_1st);
bool IsMiddleBroadcastBinaryOp(
bool IsBothEndsBroadcastBinaryOp(
const int ndim,
const int* A_dims,
const int* B_dims,

View File

@ -137,7 +137,7 @@ bool SelectSmoothL1LossOp<float, CUDAContext>::RunOnDevice() {
buff_.size(), buff_.data<float>(), avg_loss_data, &context_);
// Average of input batch size
math::Scale<float, CUDAContext>(
math::Scale<float, float, CUDAContext>(
1, scale_, avg_loss_data, avg_loss_data, &context_);
return true;
}

View File

@ -109,7 +109,7 @@ bool SigmoidCrossEntropyLossOp<float, CUDAContext>::RunOnDevice() {
math::Div<float, CUDAContext>(
1, avg_loss_data, normalizer_data, avg_loss_data, &context_);
}
math::Scale<float, CUDAContext>(
math::Scale<float, float, CUDAContext>(
1, scale_, avg_loss_data, avg_loss_data, &context_);
return true;
@ -151,22 +151,22 @@ bool SigmoidCrossEntropyLossGradientOp<float, CUDAContext>::RunOnDevice() {
normalizer_data,
normalizer_data,
&context_);
math::Scale<float, CUDAContext>(
math::Scale<float, float, CUDAContext>(
1, scale_, normalizer_data, normalizer_data, &context_);
math::Scale<float, CUDAContext>(
math::Scale<float, float, CUDAContext>(
dX->size(),
normalizer_data,
dX->data<float>(),
dX->mutable_data<float>(),
&context_);
} else {
math::Scale<float, CUDAContext>(
math::Scale<float, float, CUDAContext>(
dX->size(),
scale_,
dX->data<float>(),
dX->mutable_data<float>(),
&context_);
math::Scale<float, CUDAContext>(
math::Scale<float, float, CUDAContext>(
dX->size(),
d_avg_loss.data<float>(),
dX->data<float>(),

View File

@ -137,7 +137,7 @@ bool SigmoidFocalLossOp<float, CUDAContext>::RunOnDevice() {
math::Sum<float, CUDAContext>(
losses_.size(), losses_.data<float>(), avg_loss_data, &context_);
math::Scale<float, CUDAContext>(
math::Scale<float, float, CUDAContext>(
1, scale_, avg_loss_data, avg_loss_data, &context_);
return true;
@ -165,8 +165,12 @@ bool SigmoidFocalLossGradientOp<float, CUDAContext>::RunOnDevice() {
N, D, H, W, X.data<float>(), T.data<int>(), dX->mutable_data<float>(),
wp.data<float>(), gamma_, alpha_, num_classes_,
d_avg_loss.data<float>());
math::Scale<float, CUDAContext>(
dX->size(), scale_, dX->data<float>(), dX->mutable_data<float>(), &context_);
math::Scale<float, float, CUDAContext>(
dX->size(),
scale_,
dX->data<float>(),
dX->mutable_data<float>(),
&context_);
return true;
}

View File

@ -116,7 +116,7 @@ bool SmoothL1LossOp<float, CUDAContext>::RunOnDevice() {
buff_.size(), buff_.data<float>(), avg_loss_data, &context_);
// Average of input batch size
// al := 1/N * al
math::Scale<float, CUDAContext>(
math::Scale<float, float, CUDAContext>(
1, scale_ / N, avg_loss_data, avg_loss_data, &context_);
return true;
}

View File

@ -189,7 +189,7 @@ bool SoftmaxFocalLossOp<float, CUDAContext>::RunOnDevice() {
float* avg_loss_data = avg_loss->mutable_data<float>();
math::Sum<float, CUDAContext>(
losses_.size(), losses_.data<float>(), avg_loss_data, &context_);
math::Scale<float, CUDAContext>(
math::Scale<float, float, CUDAContext>(
1, scale_, avg_loss_data, avg_loss_data, &context_);
return true;
@ -235,8 +235,11 @@ bool SoftmaxFocalLossGradientOp<float, CUDAContext>::RunOnDevice() {
0, context_.cuda_stream()>>>(
N, D, H, W, Pdata, Tdata, Bdata, d_avg_loss.data<float>(),
dX->mutable_data<float>(), num_classes_);
math::Scale<float, CUDAContext>(
dX->size(), scale_, dX->data<float>(), dX->mutable_data<float>(),
math::Scale<float, float, CUDAContext>(
dX->size(),
scale_,
dX->data<float>(),
dX->mutable_data<float>(),
&context_);
return true;
}