Optimize group_norm_op (#17945)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17945

Optimize group_norm_op

Reviewed By: houseroad

Differential Revision: D14419908

fbshipit-source-id: 4024b5c5dbeff97f4f026d61fc44af1f0e98ed68
This commit is contained in:
Xiaomeng Yang 2019-03-21 12:56:20 -07:00 committed by Facebook Github Bot
parent 9214852da2
commit 43a5c636e2
3 changed files with 937 additions and 594 deletions

View File

@ -8,85 +8,376 @@
#include "caffe2/operators/group_norm_op.h"
#include "caffe2/utils/eigen_utils.h"
namespace caffe2 {
namespace {
template <typename T, StorageOrder kOrder>
void ComputeInternalGradients(
const std::array<int, 4>& dims,
const T* dY,
const T* X,
const T* gamma,
T* ds,
T* db) {
constexpr int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2;
constexpr int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3;
const int size = dims[0] * dims[1] * dims[2] * dims[3];
std::array<int, 4> index = {0, 0, 0, 0};
for (int i = 0; i < size; ++i) {
const int i_mu = index[0] * dims[kGDim] + index[kGDim];
const int i_gamma = index[kGDim] * dims[kDDim] + index[kDDim];
ds[i_mu] += gamma[i_gamma] * dY[i] * X[i];
db[i_mu] += gamma[i_gamma] * dY[i];
math::utils::IncreaseIndexInDims(4, dims.data(), index.data());
}
}
// Math:
// Y = gamma * (X - mu) * rsig + beta
// let s = gamma * rsig
// let b = beta - mu * rsig
// Y = s * X + b
// let n = D * HxW
// let n = K * HxW
// dL/dX = dL/dY * dY/dX = dL/dY * (d(s * X)/dX + db/dX)
// d(s * X)/dX = s + X * ds/dX = s + gamma * X * drsig/dX
// db/dX = -u * drsig/dX - rsig * dmu/dX
// drsig/dX = -rsig^3 * (X - mu) / n
// dmu/dX = 1 / n
namespace {
template <typename T, StorageOrder kOrder>
void GroupNormBackward(
const std::array<int, 4>& dims,
void ComputeInternalGradients(
int N,
int C,
int HxW,
const T* dY,
const T* X,
T* ds,
T* db);
template <>
void ComputeInternalGradients<float, StorageOrder::NCHW>(
const int N,
const int C,
const int HxW,
const float* dY,
const float* X,
float* ds,
float* db) {
ConstEigenArrayMap<float> dY_arr(dY, HxW, N * C);
ConstEigenArrayMap<float> X_arr(X, HxW, N * C);
for (int i = 0; i < N * C; ++i) {
ds[i] = (dY_arr.col(i) * X_arr.col(i)).sum();
db[i] = dY_arr.col(i).sum();
}
}
template <>
void ComputeInternalGradients<float, StorageOrder::NHWC>(
const int N,
const int C,
const int HxW,
const float* dY,
const float* X,
float* ds,
float* db) {
EigenArrayMap<float> ds_arr(ds, C, N);
EigenArrayMap<float> db_arr(db, C, N);
for (int i = 0; i < N; ++i) {
ConstEigenArrayMap<float> dY_arr(dY + i * C * HxW, C, HxW);
ConstEigenArrayMap<float> X_arr(X + i * C * HxW, C, HxW);
ds_arr.col(i) = dY_arr.col(0) * X_arr.col(0);
db_arr.col(i) = dY_arr.col(0);
for (int j = 1; j < HxW; ++j) {
ds_arr.col(i) += dY_arr.col(j) * X_arr.col(j);
db_arr.col(i) += dY_arr.col(j);
}
}
}
template <typename T>
void ComputeGradientFusedParams(
const int N,
const int G,
const int K,
const int HxW,
const T* ds,
const T* db,
const T* mu,
const T* rsig,
const T* gamma,
T* dY_scale,
T* X_scale,
T* bias) {
ConstEigenArrayMap<T> rsig_arr(rsig, G, N);
ConstEigenArrayMap<T> gamma_arr(gamma, K, G);
for (int i = 0; i < N; ++i) {
EigenArrayMap<T>(dY_scale + i * G * K, K, G) =
gamma_arr.rowwise() * (rsig_arr.col(i).transpose());
}
ConstEigenVectorArrayMap<T> mu_arr(mu, N * G);
ConstEigenVectorArrayMap<T> rsig_vec(rsig, N * G);
EigenVectorArrayMap<T> X_scale_arr(X_scale, N * G);
EigenVectorArrayMap<T> bias_arr(bias, N * G);
for (int i = 0; i < N; ++i) {
ConstEigenArrayMap<T> ds_arr(ds + i * G * K, K, G);
ConstEigenArrayMap<T> db_arr(db + i * G * K, K, G);
for (int j = 0; j < G; ++j) {
X_scale_arr(i * G + j) = (ds_arr.col(j) * gamma_arr.col(j)).sum();
bias_arr(i * G + j) = (db_arr.col(j) * gamma_arr.col(j)).sum();
}
}
const T alpha = T(1) / static_cast<T>(K * HxW);
X_scale_arr = (bias_arr * mu_arr - X_scale_arr) * rsig_vec.cube() * alpha;
bias_arr = -X_scale_arr * mu_arr - bias_arr * rsig_vec * alpha;
}
template <typename T, StorageOrder kOrder>
void GroupNormBackward(
int N,
int G,
int K,
int HxW,
const T* dY_scale,
const T* dY,
const T* X_scale,
const T* X,
const T* bias,
T* dX);
template <>
void GroupNormBackward<float, StorageOrder::NCHW>(
const int N,
const int G,
const int K,
const int HxW,
const float* dY_scale,
const float* dY,
const float* X_scale,
const float* X,
const float* bias,
float* dX) {
const int C = G * K;
ConstEigenArrayMap<float> dY_arr(dY, HxW, N * C);
ConstEigenArrayMap<float> X_arr(X, HxW, N * C);
EigenArrayMap<float> dX_arr(dX, HxW, N * C);
for (int i = 0; i < N * G; ++i) {
for (int j = 0; j < K; ++j) {
const int c = i * K + j;
dX_arr.col(c) =
dY_arr.col(c) * dY_scale[c] + X_arr.col(c) * X_scale[i] + bias[i];
}
}
}
template <>
void GroupNormBackward<float, StorageOrder::NHWC>(
const int N,
const int G,
const int K,
const int HxW,
const float* dY_scale,
const float* dY,
const float* X_scale,
const float* X,
const float* bias,
float* dX) {
const int C = G * K;
ConstEigenArrayMap<float> X_scale_arr(X_scale, G, N);
ConstEigenArrayMap<float> bias_arr(bias, G, N);
for (int n = 0; n < N; ++n) {
ConstEigenArrayMap<float> dY_scale_arr(dY_scale + n * C, K, G);
for (int i = 0; i < HxW; ++i) {
const int m = n * HxW + i;
ConstEigenArrayMap<float> dY_arr(dY + m * C, K, G);
ConstEigenArrayMap<float> X_arr(X + m * C, K, G);
EigenArrayMap<float> dX_arr(dX + m * C, K, G);
dX_arr = (dY_arr * dY_scale_arr +
X_arr.rowwise() * X_scale_arr.col(n).transpose())
.rowwise() +
bias_arr.col(n).transpose();
}
}
}
template <typename T>
void GammaBetaBackward(
const int N,
const int G,
const int K,
const T* ds,
const T* db,
T* dX,
const T* mu,
const T* rsig,
T* dgamma,
T* dbeta) {
constexpr int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2;
constexpr int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3;
const int size = dims[0] * dims[1] * dims[2] * dims[3];
const int HxW = kOrder == StorageOrder::NCHW ? dims[3] : dims[1];
const T denom = T(1) / static_cast<T>(dims[kDDim] * HxW);
std::array<int, 4> index = {0, 0, 0, 0};
for (int i = 0; i < size; ++i) {
const int i_mu = index[0] * dims[kGDim] + index[kGDim];
const int i_gamma = index[kGDim] * dims[kDDim] + index[kDDim];
const T u = (db[i_mu] * mu[i_mu] - ds[i_mu]) * (X[i] - mu[i_mu]) *
math::utils::Cube(rsig[i_mu]);
const T v = db[i_mu] * rsig[i_mu];
dX[i] = gamma[i_gamma] * dY[i] * rsig[i_mu] + (u - v) * denom;
dgamma[i_gamma] += dY[i] * (X[i] - mu[i_mu]) * rsig[i_mu];
dbeta[i_gamma] += dY[i];
math::utils::IncreaseIndexInDims(4, dims.data(), index.data());
const int C = G * K;
ConstEigenArrayMap<T> ds0_arr(ds, K, G);
ConstEigenArrayMap<T> db0_arr(db, K, G);
ConstEigenArrayMap<T> mu_arr(mu, G, N);
ConstEigenArrayMap<T> rsig_arr(rsig, G, N);
EigenArrayMap<T> dgamma_arr(dgamma, K, G);
EigenArrayMap<T> dbeta_arr(dbeta, K, G);
dgamma_arr =
(ds0_arr - db0_arr.rowwise() * mu_arr.col(0).transpose()).rowwise() *
rsig_arr.col(0).transpose();
dbeta_arr = db0_arr;
for (int i = 1; i < N; ++i) {
ConstEigenArrayMap<T> dsi_arr(ds + i * C, K, G);
ConstEigenArrayMap<T> dbi_arr(db + i * C, K, G);
dgamma_arr +=
(dsi_arr - dbi_arr.rowwise() * mu_arr.col(i).transpose()).rowwise() *
rsig_arr.col(i).transpose();
dbeta_arr += dbi_arr;
}
}
} // namespace
template <>
void GroupNormOp<float, CPUContext>::ComputeFusedParams(
const int N,
const int G,
const int K,
const float* mu,
const float* rsig,
const float* gamma,
const float* beta,
float* scale,
float* bias) {
const int C = G * K;
ConstEigenArrayMap<float> mu_arr(mu, G, N);
ConstEigenArrayMap<float> rsig_arr(rsig, G, N);
ConstEigenArrayMap<float> gamma_arr(gamma, K, G);
ConstEigenArrayMap<float> beta_arr(beta, K, G);
for (int i = 0; i < N; ++i) {
EigenArrayMap<float> scale_arr(scale + i * C, K, G);
EigenArrayMap<float> bias_arr(bias + i * C, K, G);
scale_arr = gamma_arr.rowwise() * rsig_arr.col(i).transpose();
bias_arr = beta_arr - scale_arr.rowwise() * mu_arr.col(i).transpose();
}
}
template <>
void GroupNormOp<float, CPUContext>::GroupNormForwardNCHW(
const int N,
const int C,
const int HxW,
const float* X,
const float* scale,
const float* bias,
float* Y) {
EigenArrayMap<float>(Y, HxW, N * C) =
(ConstEigenArrayMap<float>(X, HxW, N * C).rowwise() *
ConstEigenVectorArrayMap<float>(scale, N * C).transpose())
.rowwise() +
ConstEigenVectorArrayMap<float>(bias, N * C).transpose();
}
template <>
void GroupNormOp<float, CPUContext>::GroupNormForwardNHWC(
const int N,
const int C,
const int HxW,
const float* X,
const float* scale,
const float* bias,
float* Y) {
const int stride = HxW * C;
for (int i = 0; i < N; ++i) {
EigenArrayMap<float>(Y + i * stride, C, HxW) =
(ConstEigenArrayMap<float>(X + i * stride, C, HxW).colwise() *
ConstEigenVectorArrayMap<float>(scale + i * C, C))
.colwise() +
ConstEigenVectorArrayMap<float>(bias + i * C, C);
}
}
template <>
bool GroupNormOp<float, CPUContext>::RunOnDeviceWithOrderNHWC(
const int N,
const int G,
const int K,
const int HxW,
const float* X,
const float* gamma,
const float* beta,
float* Y,
float* mu,
float* rsig) {
const int C = G * K;
ReinitializeTensor(&scale_, {N, C}, at::dtype<float>().device(CPU));
ReinitializeTensor(&bias_, {N, C}, at::dtype<float>().device(CPU));
float* scale_data = scale_.mutable_data<float>();
float* bias_data = bias_.mutable_data<float>();
EigenVectorArrayMap<float> mu_arr(mu, N * G);
EigenVectorArrayMap<float> rsig_arr(rsig, N * G);
mu_arr.setZero();
rsig_arr.setZero();
for (int n = 0; n < N; ++n) {
for (int i = 0; i < HxW; ++i) {
const int m = n * HxW + i;
ConstEigenArrayMap<float> X_arr(X + m * C, K, G);
for (int j = 0; j < G; ++j) {
mu_arr(n * G + j) += X_arr.col(j).sum();
rsig_arr(n * G + j) += X_arr.col(j).square().sum();
}
}
}
const float scale = 1.0f / static_cast<float>(K * HxW);
mu_arr *= scale;
rsig_arr = (rsig_arr * scale - mu_arr.square() + epsilon_).rsqrt();
ComputeFusedParams(N, G, K, mu, rsig, gamma, beta, scale_data, bias_data);
GroupNormForwardNHWC(N, C, HxW, X, scale_data, bias_data, Y);
return true;
}
// Math:
// let: s = gamma * rsig
// let: b = beta - mu * gamma * rsig
// then: Y = s * X + b
template <typename T, class Context>
bool GroupNormGradientOp<T, Context>::RunOnDeviceImpl(
template <>
bool GroupNormGradientOp<float, CPUContext>::RunOnDeviceWithOrderNCHW(
const int N,
const int G,
const int D,
const int K,
const int HxW,
const float* dY_data,
const float* X_data,
const float* mu_data,
const float* rsig_data,
const float* gamma_data,
float* dX_data,
float* dgamma_data,
float* dbeta_data) {
const int C = G * K;
ReinitializeTensor(&ds_, {N, C}, at::dtype<float>().device(CPU));
ReinitializeTensor(&db_, {N, C}, at::dtype<float>().device(CPU));
ReinitializeTensor(&dY_scale_, {N, C}, at::dtype<float>().device(CPU));
ReinitializeTensor(&X_scale_, {N, G}, at::dtype<float>().device(CPU));
ReinitializeTensor(&bias_, {N, G}, at::dtype<float>().device(CPU));
float* ds_data = ds_.mutable_data<float>();
float* db_data = db_.mutable_data<float>();
float* dY_scale_data = dY_scale_.mutable_data<float>();
float* X_scale_data = X_scale_.mutable_data<float>();
float* bias_data = bias_.mutable_data<float>();
ComputeInternalGradients<float, StorageOrder::NCHW>(
N, C, HxW, dY_data, X_data, ds_data, db_data);
ComputeGradientFusedParams<float>(
N,
G,
K,
HxW,
ds_data,
db_data,
mu_data,
rsig_data,
gamma_data,
dY_scale_data,
X_scale_data,
bias_data);
GroupNormBackward<float, StorageOrder::NCHW>(
N,
G,
K,
HxW,
dY_scale_data,
dY_data,
X_scale_data,
X_data,
bias_data,
dX_data);
GammaBetaBackward<float>(
N, G, K, ds_data, db_data, mu_data, rsig_data, dgamma_data, dbeta_data);
return true;
}
template <typename T, class Context>
bool GroupNormGradientOp<T, Context>::RunOnDeviceWithOrderNHWC(
const int N,
const int G,
const int K,
const int HxW,
const T* dY_data,
const T* X_data,
@ -96,60 +387,45 @@ bool GroupNormGradientOp<T, Context>::RunOnDeviceImpl(
T* dX_data,
T* dgamma_data,
T* dbeta_data) {
const std::array<int, 4> dims = order_ == StorageOrder::NCHW
? std::array<int, 4>{N, G, D, HxW}
: std::array<int, 4>{N, HxW, G, D};
// Computes dL/ds and dL/db.
// dL/ds = Sum(dL/dY * gamma * X)
// dL/db = Sum(dL/dY * gamma)
const int C = G * D;
ReinitializeTensor(
&ds_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
ReinitializeTensor(
&db_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
T* ds_data = ds_.template mutable_data<T>();
T* db_data = db_.template mutable_data<T>();
math::Set<T, Context>(N * G, T(0), ds_data, &context_);
math::Set<T, Context>(N * G, T(0), db_data, &context_);
if (order_ == StorageOrder::NCHW) {
ComputeInternalGradients<T, StorageOrder::NCHW>(
dims, dY_data, X_data, gamma_data, ds_data, db_data);
} else {
ComputeInternalGradients<T, StorageOrder::NHWC>(
dims, dY_data, X_data, gamma_data, ds_data, db_data);
}
// Computes dL/dX, dL/dgamma and dL/dbeta.
math::Set<T, Context>(C, T(0), dgamma_data, &context_);
math::Set<T, Context>(C, T(0), dbeta_data, &context_);
if (order_ == StorageOrder::NCHW) {
GroupNormBackward<T, StorageOrder::NCHW>(
dims,
dY_data,
X_data,
mu_data,
rsig_data,
gamma_data,
ds_data,
db_data,
dX_data,
dgamma_data,
dbeta_data);
} else {
GroupNormBackward<T, StorageOrder::NHWC>(
dims,
dY_data,
X_data,
mu_data,
rsig_data,
gamma_data,
ds_data,
db_data,
dX_data,
dgamma_data,
dbeta_data);
}
const int C = G * K;
ReinitializeTensor(&ds_, {N, C}, at::dtype<float>().device(CPU));
ReinitializeTensor(&db_, {N, C}, at::dtype<float>().device(CPU));
ReinitializeTensor(&dY_scale_, {N, C}, at::dtype<float>().device(CPU));
ReinitializeTensor(&X_scale_, {N, G}, at::dtype<float>().device(CPU));
ReinitializeTensor(&bias_, {N, G}, at::dtype<float>().device(CPU));
float* ds_data = ds_.mutable_data<float>();
float* db_data = db_.mutable_data<float>();
float* dY_scale_data = dY_scale_.mutable_data<float>();
float* X_scale_data = X_scale_.mutable_data<float>();
float* bias_data = bias_.mutable_data<float>();
ComputeInternalGradients<float, StorageOrder::NHWC>(
N, C, HxW, dY_data, X_data, ds_data, db_data);
ComputeGradientFusedParams<float>(
N,
G,
K,
HxW,
ds_data,
db_data,
mu_data,
rsig_data,
gamma_data,
dY_scale_data,
X_scale_data,
bias_data);
GroupNormBackward<float, StorageOrder::NHWC>(
N,
G,
K,
HxW,
dY_scale_data,
dY_data,
X_scale_data,
X_data,
bias_data,
dX_data);
GammaBetaBackward<float>(
N, G, K, ds_data, db_data, mu_data, rsig_data, dgamma_data, dbeta_data);
return true;
}
@ -201,17 +477,21 @@ Group Normalization (GN) operation: https://arxiv.org/abs/1803.08494
// Input: dY, X, gamma, beta, mu, sig; Output: dX, dgamma, dbeta
OPERATOR_SCHEMA(GroupNormGradient).NumInputs(6).NumOutputs(3);
namespace {
class GetGroupNormGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
std::vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"GroupNormGradient",
"",
vector<string>{GO(0), I(0), I(1), I(2), O(1), O(2)},
vector<string>{GI(0), GI(1), GI(2)});
std::vector<std::string>{GO(0), I(0), I(1), I(2), O(1), O(2)},
std::vector<std::string>{GI(0), GI(1), GI(2)});
}
};
} // namespace
REGISTER_GRADIENT(GroupNorm, GetGroupNormGradient);
} // namespace caffe2

View File

@ -8,9 +8,6 @@
#include "caffe2/operators/group_norm_op.h"
#include <cub/block/block_reduce.cuh>
#include <cub/cub.cuh>
#include "caffe2/core/context_gpu.h"
#include "caffe2/utils/math.h"
#include "caffe2/utils/math/reduce.cuh"
@ -21,6 +18,7 @@ namespace {
template <typename T>
__global__ void ComputeFusedParamsCUDAKernel(
const int N,
const int G,
const int K,
const T* mu,
@ -28,162 +26,113 @@ __global__ void ComputeFusedParamsCUDAKernel(
const T* gamma,
const T* beta,
T* scale,
T* bias) {
const int n = blockIdx.x;
const int g = blockIdx.y;
const int i_mu = n * G + g;
for (int i = threadIdx.x; i < K; i += blockDim.x) {
const int index = i_mu * K + i;
const int i_gamma = g * K + i;
#if __CUDA_ARCH__ >= 350
const T scale_val = __ldg(gamma + i_gamma) * __ldg(rsig + i_mu);
scale[index] = scale_val;
bias[index] = __ldg(beta + i_gamma) - scale_val * __ldg(mu + i_mu);
#else
const T scale_val = gamma[i_gamma] * rsig[i_mu];
scale[index] = scale_val;
bias[index] = beta[i_gamma] - scale_val * mu[i_mu];
#endif
}
}
template <typename T>
__global__ void GroupNormForwardNCHWCUDAKernel(
const int M,
const int HxW,
const T* X,
const T* scale,
const T* bias,
T* Y);
T* bias);
template <>
__global__ void GroupNormForwardNCHWCUDAKernel<float>(
const int M,
const int HxW,
const float* X,
const float* scale,
const float* bias,
float* Y) {
const int nc = blockIdx.x / M;
const int hw = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
if (hw < HxW) {
const int index = nc * HxW + hw;
#if __CUDA_ARCH__ >= 350
Y[index] = fmaf(__ldg(X + index), __ldg(scale + nc), __ldg(bias + nc));
#else
Y[index] = fmaf(X[index], scale[nc], bias[nc]);
#endif
}
}
template <typename T>
__global__ void GroupNormForwardNHWCCUDAKernel(
const int C,
const int HxW,
const T* X,
const T* scale,
const T* bias,
T* Y);
template <>
__global__ void GroupNormForwardNHWCCUDAKernel<float>(
const int C,
const int HxW,
const float* X,
const float* scale,
const float* bias,
float* Y) {
const int n = blockIdx.x / HxW;
const int c = blockIdx.y * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
if (c < C) {
const int index = blockIdx.x * C + c;
const int nc = n * C + c;
#if __CUDA_ARCH__ >= 350
Y[index] = fmaf(__ldg(X + index), __ldg(scale + nc), __ldg(bias + nc));
#else
Y[index] = fmaf(X[index], scale[nc], bias[nc]);
#endif
}
}
template <typename T, int kBlockDimX, int kBlockDimY>
__global__ void ComputeInternalGradientsNCHWCUDAKernel(
__global__ void ComputeFusedParamsCUDAKernel<float>(
const int N,
const int G,
const int K,
const int HxW,
const T* dY,
const T* X,
const T* gamma,
T* ds,
T* db) {
__shared__
typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage ds_storage;
__shared__
typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage db_storage;
const int n = blockIdx.x;
const int g = blockIdx.y;
const int ng = n * G + g;
T ds_val = 0;
T db_val = 0;
for (int i = threadIdx.x; i < K; i += blockDim.x) {
const int c = g * K + i;
for (int j = threadIdx.y; j < HxW; j += blockDim.y) {
const int index = (ng * K + i) * HxW + j;
#if __CUDA_ARCH__ >= 350
ds_val += __ldg(gamma + c) * __ldg(dY + index) * __ldg(X + index);
db_val += __ldg(gamma + c) * __ldg(dY + index);
#else
ds_val += gamma[c] * dY[index] * X[index];
db_val += gamma[c] * dY[index];
#endif
}
}
ds_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(ds_storage).Sum(ds_val);
db_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(db_storage).Sum(db_val);
if (threadIdx.x == 0 && threadIdx.y == 0) {
ds[ng] = ds_val;
db[ng] = db_val;
}
}
template <typename T, int kBlockDimX, int kBlockDimY>
__global__ void ComputeInternalGradientsNHWCCUDAKernel(
const int G,
const int K,
const int HxW,
const T* dY,
const T* X,
const T* gamma,
T* ds,
T* db) {
__shared__
typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage ds_storage;
__shared__
typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage db_storage;
const float* mu,
const float* rsig,
const float* gamma,
const float* beta,
float* scale,
float* bias) {
const int C = G * K;
const int n = blockIdx.x;
const int g = blockIdx.y;
const int ng = n * G + g;
T ds_val = 0;
T db_val = 0;
for (int i = threadIdx.x; i < HxW; i += blockDim.x) {
for (int j = threadIdx.y; j < K; j += blockDim.y) {
const int c = g * K + j;
const int index = (n * HxW + i) * C + c;
#if __CUDA_ARCH__ >= 350
ds_val += __ldg(gamma + c) * __ldg(dY + index) * __ldg(X + index);
db_val += __ldg(gamma + c) * __ldg(dY + index);
const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
if (index < N * C) {
const int ng = index / K;
const int c = index % C;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
const float scale_val = __ldg(gamma + c) * __ldg(rsig + ng);
scale[index] = scale_val;
bias[index] = fmaf(-scale_val, __ldg(mu + ng), __ldg(beta + c));
#else
ds_val += gamma[c] * dY[index] * X[index];
db_val += gamma[c] * dY[index];
const float scale_val = gamma[c] * rsig[ng];
scale[index] = scale_val;
bias[index] = fmaf(-scale_val, mu[ng], beta[c]);
#endif
}
}
ds_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(ds_storage).Sum(ds_val);
db_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(db_storage).Sum(db_val);
if (threadIdx.x == 0 && threadIdx.y == 0) {
ds[ng] = ds_val;
db[ng] = db_val;
}
template <typename T, StorageOrder kOrder>
__global__ void GroupNormForwardCUDAKernel(
const int N,
const int C,
const int HxW,
const T* X,
const T* scale,
const T* bias,
T* Y);
template <>
__global__ void GroupNormForwardCUDAKernel<float, StorageOrder::NCHW>(
const int N,
const int C,
const int HxW,
const float* X,
const float* scale,
const float* bias,
float* Y) {
const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
if (index < N * C * HxW) {
const int nc = index / HxW;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
Y[index] = fmaf(__ldg(X + index), __ldg(scale + nc), __ldg(bias + nc));
#else
Y[index] = fmaf(X[index], scale[nc], bias[nc]);
#endif
}
}
template <>
__global__ void GroupNormForwardCUDAKernel<float, StorageOrder::NHWC>(
const int N,
const int C,
const int HxW,
const float* X,
const float* scale,
const float* bias,
float* Y) {
const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
if (index < N * C * HxW) {
const int nc = index / (HxW * C) * C + index % C;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
Y[index] = fmaf(__ldg(X + index), __ldg(scale + nc), __ldg(bias + nc));
#else
Y[index] = fmaf(X[index], scale[nc], bias[nc]);
#endif
}
}
template <typename T>
__global__ void ComputeInternalGradientsNCHWCUDAKernel(
const int HxW,
const T* dY,
const T* X,
T* ds,
T* db) {
__shared__ typename BlockReduce<T>::TempStorage ds_storage;
__shared__ typename BlockReduce<T>::TempStorage db_storage;
const int nc = blockIdx.x;
T ds_sum = 0;
T db_sum = 0;
for (int i = threadIdx.x; i < HxW; i += blockDim.x) {
const int index = nc * HxW + i;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
ds_sum += __ldg(dY + index) * __ldg(X + index);
db_sum += __ldg(dY + index);
#else
ds_sum += dY[index] * X[index];
db_sum += dY[index];
#endif
}
ds_sum = BlockReduce<T>(ds_storage).Sum(ds_sum);
db_sum = BlockReduce<T>(db_storage).Sum(db_sum);
if (threadIdx.x == 0) {
ds[nc] = ds_sum;
db[nc] = db_sum;
}
}
@ -192,174 +141,212 @@ __global__ void ComputeInternalGradientsNHWCCUDAKernel(
// let s = gamma * rsig
// let b = beta - mu * rsig
// Y = s * X + b
// let n = D * HxW
// let n = K * HxW
// dL/dX = dL/dY * dY/dX = dL/dY * (d(s * X)/dX + db/dX)
// d(s * X)/dX = s + X * ds/dX = s + gamma * X * drsig/dX
// db/dX = -u * drsig/dX - rsig * dmu/dX
// drsig/dX = -rsig^3 * (X - mu) / n
// dmu/dX = 1 / n
template <typename T>
__global__ void GroupNormBackwardNCHWCUDAKernel(
__global__ void ComputeYGradientScaleCUDAKernel(
const int N,
const int G,
const int K,
const int M,
const int HxW,
const T* dY,
const T* X,
const T* mu,
const T* rsig,
const T* gamma,
const T* ds,
const T* db,
T* dX) {
T* dY_scale) {
const int C = G * K;
const T denom = T(1) / static_cast<T>(K * HxW);
const int nc = blockIdx.x / M;
const int n = nc / C;
const int c = nc % C;
const int g = c / K;
const int ng = n * G + g;
const int hw = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
const int index = nc * HxW + hw;
if (hw < HxW) {
#if __CUDA_ARCH__ >= 350
const T u = (__ldg(db + ng) * __ldg(mu + ng) - __ldg(ds + ng)) *
(__ldg(X + index) - __ldg(mu + ng)) *
math::utils::Cube<T>(__ldg(rsig + ng));
const T v = __ldg(db + ng) * __ldg(rsig + ng);
dX[index] = __ldg(gamma + c) * __ldg(dY + index) * __ldg(rsig + ng) +
(u - v) * denom;
const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
if (index < N * C) {
const int ng = index / K;
const int c = index % C;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
dY_scale[index] = __ldg(gamma + c) * __ldg(rsig + ng);
#else
const T u = (db[ng] * mu[ng] - ds[ng]) * (X[index] - mu[ng]) *
math::utils::Cube<T>(rsig[ng]);
const T v = db[ng] * rsig[ng];
dX[index] = gamma[c] * dY[index] * rsig[ng] + (u - v) * denom;
dY_scale[index] = gamma[c] * rsig[ng];
#endif
}
}
template <typename T>
__global__ void GroupNormBackwardNHWCCUDAKernel(
__global__ void ComputeXScaleAndBiasCUDAKernel(
const int G,
const int K,
const int HxW,
const T* dY,
const T* X,
const T alpha,
const T* ds,
const T* db,
const T* mu,
const T* rsig,
const T* gamma,
T* X_scale,
T* bias);
template <>
__global__ void ComputeXScaleAndBiasCUDAKernel<float>(
const int G,
const int K,
const float alpha,
const float* ds,
const float* db,
const float* mu,
const float* rsig,
const float* gamma,
float* X_scale,
float* bias) {
__shared__ typename BlockReduce<float>::TempStorage ds_storage;
__shared__ typename BlockReduce<float>::TempStorage db_storage;
const int n = blockIdx.x;
const int g = blockIdx.y;
const int ng = n * G + g;
float ds_sum = 0;
float db_sum = 0;
for (int i = threadIdx.x; i < K; i += blockDim.x) {
const int index = ng * K + i;
const int c = g * K + i;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
ds_sum += __ldg(ds + index) * __ldg(gamma + c);
db_sum += __ldg(db + index) * __ldg(gamma + c);
#else
ds_sum += ds[index] * gamma[c];
db_sum += db[index] * gamma[c];
#endif
}
ds_sum = BlockReduce<float>(ds_storage).Sum(ds_sum);
db_sum = BlockReduce<float>(db_storage).Sum(db_sum);
if (threadIdx.x == 0) {
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
const float x = fmaf(db_sum, __ldg(mu + ng), -ds_sum) *
math::utils::Cube<float>(__ldg(rsig + ng)) * alpha;
X_scale[ng] = x;
bias[ng] = -fmaf(x, __ldg(mu + ng), db_sum * __ldg(rsig + ng) * alpha);
#else
const float x = fmaf(db_sum, mu[ng], -ds_sum) *
math::utils::Cube<float>(rsig[ng]) * alpha;
X_scale[ng] = x;
bias[ng] = -fmaf(x, mu[ng], db_sum * rsig[ng] * alpha);
#endif
}
}
template <typename T, StorageOrder kOrder>
__global__ void GroupNormBackwardCUDAKernel(
const int N,
const int G,
const int K,
const int HxW,
const T* dY_scale,
const T* dY,
const T* X_scale,
const T* X,
const T* bias,
T* dX);
template <>
__global__ void GroupNormBackwardCUDAKernel<float, StorageOrder::NCHW>(
const int N,
const int G,
const int K,
const int HxW,
const float* dY_scale,
const float* dY,
const float* X_scale,
const float* X,
const float* bias,
float* dX) {
const int C = G * K;
const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
if (index < N * C * HxW) {
const int nc = index / HxW;
const int ng = nc / K;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
dX[index] = fmaf(
__ldg(dY_scale + nc),
__ldg(dY + index),
fmaf(__ldg(X_scale + ng), __ldg(X + index), __ldg(bias + ng)));
#else
dX[index] =
fmaf(dY_scale[nc], dY[index], fmaf(X_scale[ng], X[index], bias[ng]));
#endif
}
}
template <>
__global__ void GroupNormBackwardCUDAKernel<float, StorageOrder::NHWC>(
const int N,
const int G,
const int K,
const int HxW,
const float* dY_scale,
const float* dY,
const float* X_scale,
const float* X,
const float* bias,
float* dX) {
const int C = G * K;
const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
if (index < N * C * HxW) {
const int nc = index / (HxW * C) * C + index % C;
const int ng = nc / K;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
dX[index] = fmaf(
__ldg(dY_scale + nc),
__ldg(dY + index),
fmaf(__ldg(X_scale + ng), __ldg(X + index), __ldg(bias + ng)));
#else
dX[index] =
fmaf(dY_scale[nc], dY[index], fmaf(X_scale[ng], X[index], bias[ng]));
#endif
}
}
template <typename T>
__global__ void GammaBetaBackwardCUDAKernel(
const int N,
const int G,
const int K,
const T* ds,
const T* db,
T* dX) {
const int C = G * K;
const T denom = T(1) / static_cast<T>(K * HxW);
const int x = blockIdx.x;
const int g = blockIdx.y;
const int n = x / HxW;
const int ng = n * G + g;
const int i = blockIdx.z * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
if (i < K) {
const int c = g * K + i;
const int index = x * C + c;
#if __CUDA_ARCH__ >= 350
const T u = (__ldg(db + ng) * __ldg(mu + ng) - __ldg(ds + ng)) *
(__ldg(X + index) - __ldg(mu + ng)) *
math::utils::Cube<T>(__ldg(rsig + ng));
const T v = __ldg(db + ng) * __ldg(rsig + ng);
dX[index] = __ldg(gamma + c) * __ldg(dY + index) * __ldg(rsig + ng) +
(u - v) * denom;
#else
const T u = (db[ng] * mu[ng] - ds[ng]) * (X[index] - mu[ng]) *
math::utils::Cube<T>(rsig[ng]);
const T v = db[ng] * rsig[ng];
dX[index] = gamma[c] * dY[index] * rsig[ng] + (u - v) * denom;
#endif
}
}
template <typename T, int kBlockDimX, int kBlockDimY>
__global__ void GammaBetaBackwardNCHWCUDAKernel(
const int N,
const int G,
const int K,
const int HxW,
const T* dY,
const T* X,
const T* mu,
const T* rsig,
T* dgamma,
T* dbeta) {
__shared__
typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage dg_storage;
__shared__
typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage db_storage;
const int C = G * K;
const int c = blockIdx.x;
const int g = c / K;
T dg_val = 0;
T db_val = 0;
for (int i = threadIdx.x; i < N; i += blockDim.x) {
for (int j = threadIdx.y; j < HxW; j += blockDim.y) {
const int index = (i * C + c) * HxW + j;
const int ng = i * G + g;
#if __CUDA_ARCH__ >= 350
dg_val += __ldg(dY + index) * (__ldg(X + index) - __ldg(mu + ng)) *
__ldg(rsig + ng);
db_val += __ldg(dY + index);
#else
dg_val += dY[index] * (X[index] - mu[ng]) * rsig[ng];
db_val += dY[index];
#endif
}
}
dg_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(dg_storage).Sum(dg_val);
db_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(db_storage).Sum(db_val);
if (threadIdx.x == 0 && threadIdx.y == 0) {
dgamma[c] = dg_val;
dbeta[c] = db_val;
}
}
T* dbeta);
template <typename T, int kBlockDimX, int kBlockDimY>
__global__ void GammaBetaBackwardNHWCCUDAKernel(
template <>
__global__ void GammaBetaBackwardCUDAKernel<float>(
const int N,
const int G,
const int K,
const int HxW,
const T* dY,
const T* X,
const T* mu,
const T* rsig,
T* dgamma,
T* dbeta) {
__shared__
typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage dg_storage;
__shared__
typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage db_storage;
const float* ds,
const float* db,
const float* mu,
const float* rsig,
float* dgamma,
float* dbeta) {
__shared__ typename BlockReduce<float>::TempStorage dg_storage;
__shared__ typename BlockReduce<float>::TempStorage db_storage;
const int C = G * K;
const int c = blockIdx.x;
const int g = c / K;
T dg_val = 0;
T db_val = 0;
const int g = blockIdx.x;
const int k = blockIdx.y;
const int c = g * K + k;
float dg_sum = 0;
float db_sum = 0;
for (int i = threadIdx.x; i < N; i += blockDim.x) {
for (int j = threadIdx.y; j < HxW; j += blockDim.y) {
const int index = (i * HxW + j) * C + c;
const int ng = i * G + g;
#if __CUDA_ARCH__ >= 350
dg_val += __ldg(dY + index) * (__ldg(X + index) - __ldg(mu + ng)) *
__ldg(rsig + ng);
db_val += __ldg(dY + index);
const int nc = i * C + c;
const int ng = i * G + g;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
dg_sum += fmaf(-__ldg(db + nc), __ldg(mu + ng), __ldg(ds + nc)) *
__ldg(rsig + ng);
db_sum += __ldg(db + nc);
#else
dg_val += dY[index] * (X[index] - mu[ng]) * rsig[ng];
db_val += dY[index];
dg_sum += fmaf(-db[nc], mu[ng], ds[nc]) * rsig[ng];
db_sum += db[nc];
#endif
}
}
dg_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(dg_storage).Sum(dg_val);
db_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(db_storage).Sum(db_val);
if (threadIdx.x == 0 && threadIdx.y == 0) {
dgamma[c] = dg_val;
dbeta[c] = db_val;
dg_sum = BlockReduce<float>(dg_storage).Sum(dg_sum);
db_sum = BlockReduce<float>(db_storage).Sum(db_sum);
if (threadIdx.x == 0) {
dgamma[c] = dg_sum;
dbeta[c] = db_sum;
}
}
@ -376,9 +363,10 @@ void GroupNormOp<float, CUDAContext>::ComputeFusedParams(
const float* beta,
float* scale,
float* bias) {
const int M = math::DivUp(N * G * K, CAFFE_CUDA_NUM_THREADS);
ComputeFusedParamsCUDAKernel<float>
<<<dim3(N, G), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
G, K, mu, rsig, gamma, beta, scale, bias);
<<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
N, G, K, mu, rsig, gamma, beta, scale, bias);
}
template <>
@ -390,10 +378,10 @@ void GroupNormOp<float, CUDAContext>::GroupNormForwardNCHW(
const float* scale,
const float* bias,
float* Y) {
const int M = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
GroupNormForwardNCHWCUDAKernel<float>
<<<N * C * M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
M, HxW, X, scale, bias, Y);
const int M = math::DivUp(N * C * HxW, CAFFE_CUDA_NUM_THREADS);
GroupNormForwardCUDAKernel<float, StorageOrder::NCHW>
<<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
N, C, HxW, X, scale, bias, Y);
}
template <>
@ -405,10 +393,10 @@ void GroupNormOp<float, CUDAContext>::GroupNormForwardNHWC(
const float* scale,
const float* bias,
float* Y) {
const int M = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
GroupNormForwardNHWCCUDAKernel<float>
<<<dim3(N * HxW, M), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
C, HxW, X, scale, bias, Y);
const int M = math::DivUp(N * C * HxW, CAFFE_CUDA_NUM_THREADS);
GroupNormForwardCUDAKernel<float, StorageOrder::NHWC>
<<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
N, C, HxW, X, scale, bias, Y);
}
// Math:
@ -416,7 +404,7 @@ void GroupNormOp<float, CUDAContext>::GroupNormForwardNHWC(
// let: b = beta - mu * gamma * rsig
// then: Y = s * X + b
template <>
bool GroupNormGradientOp<float, CUDAContext>::RunOnDeviceImpl(
bool GroupNormGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW(
const int N,
const int G,
const int K,
@ -430,119 +418,158 @@ bool GroupNormGradientOp<float, CUDAContext>::RunOnDeviceImpl(
float* dgamma_data,
float* dbeta_data) {
const int C = G * K;
ReinitializeTensor(&ds_, {N, G}, at::dtype<float>().device(CUDA));
ReinitializeTensor(&db_, {N, G}, at::dtype<float>().device(CUDA));
ReinitializeTensor(&ds_, {N, C}, at::dtype<float>().device(CUDA));
ReinitializeTensor(&db_, {N, C}, at::dtype<float>().device(CUDA));
ReinitializeTensor(&dY_scale_, {N, C}, at::dtype<float>().device(CUDA));
ReinitializeTensor(&X_scale_, {N, G}, at::dtype<float>().device(CUDA));
ReinitializeTensor(&bias_, {N, G}, at::dtype<float>().device(CUDA));
float* ds_data = ds_.mutable_data<float>();
float* db_data = db_.mutable_data<float>();
if (order_ == StorageOrder::NCHW) {
// Computes dL/ds and dL/db.
// dL/ds = Sum(dL/dY * gamma * X)
// dL/db = Sum(dL/dY * gamma)
DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1(
HxW,
ComputeInternalGradientsNCHWCUDAKernel,
float,
dim3(N, G),
context_.cuda_stream(),
G,
K,
HxW,
dY_data,
X_data,
gamma_data,
ds_data,
db_data);
float* dY_scale_data = dY_scale_.mutable_data<float>();
float* X_scale_data = X_scale_.mutable_data<float>();
float* bias_data = bias_.mutable_data<float>();
// Computes dL/dX.
const int M = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
GroupNormBackwardNCHWCUDAKernel<float>
<<<N * C * M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
G,
K,
M,
HxW,
dY_data,
X_data,
mu_data,
rsig_data,
gamma_data,
ds_data,
db_data,
dX_data);
ComputeInternalGradientsNCHWCUDAKernel<float>
<<<N * C, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
HxW, dY_data, X_data, ds_data, db_data);
// Computes dL/dgamma and dL/dbeta.
DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1(
HxW,
GammaBetaBackwardNCHWCUDAKernel,
float,
C,
context_.cuda_stream(),
N,
G,
K,
HxW,
dY_data,
X_data,
mu_data,
rsig_data,
dgamma_data,
dbeta_data);
} else {
// Computes dL/ds and dL/db.
// dL/ds = Sum(dL/dY * gamma * X)
// dL/db = Sum(dL/dY * gamma)
DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1(
K,
ComputeInternalGradientsNHWCCUDAKernel,
float,
dim3(N, G),
context_.cuda_stream(),
G,
K,
HxW,
dY_data,
X_data,
gamma_data,
ds_data,
db_data);
// Computes dL/dX.
int M = math::DivUp(N * C, CAFFE_CUDA_NUM_THREADS);
ComputeYGradientScaleCUDAKernel<float>
<<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
N, G, K, rsig_data, gamma_data, dY_scale_data);
ComputeXScaleAndBiasCUDAKernel<float>
<<<dim3(N, G), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
G,
K,
1.0f / static_cast<float>(K * HxW),
ds_data,
db_data,
mu_data,
rsig_data,
gamma_data,
X_scale_data,
bias_data);
M = math::DivUp(N * C * HxW, CAFFE_CUDA_NUM_THREADS);
GroupNormBackwardCUDAKernel<float, StorageOrder::NCHW>
<<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
N,
G,
K,
HxW,
dY_scale_data,
dY_data,
X_scale_data,
X_data,
bias_data,
dX_data);
// Computes dL/dX.
const int M = math::DivUp(K, CAFFE_CUDA_NUM_THREADS);
GroupNormBackwardNHWCCUDAKernel<float>
<<<dim3(N * HxW, G, M),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
G,
K,
HxW,
dY_data,
X_data,
mu_data,
rsig_data,
gamma_data,
ds_data,
db_data,
dX_data);
// Computes dL/dgamma and dL/dbeta.
GammaBetaBackwardCUDAKernel<
float><<<dim3(G, K), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
N, G, K, ds_data, db_data, mu_data, rsig_data, dgamma_data, dbeta_data);
return true;
}
// Computes dL/dgamma and dL/dbeta.
DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1(
HxW,
GammaBetaBackwardNHWCCUDAKernel,
float,
C,
context_.cuda_stream(),
N,
G,
K,
HxW,
dY_data,
X_data,
mu_data,
rsig_data,
dgamma_data,
dbeta_data);
}
template <>
bool GroupNormGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC(
const int N,
const int G,
const int K,
const int HxW,
const float* dY_data,
const float* X_data,
const float* mu_data,
const float* rsig_data,
const float* gamma_data,
float* dX_data,
float* dgamma_data,
float* dbeta_data) {
const int C = G * K;
ReinitializeTensor(&ds_, {N, C}, at::dtype<float>().device(CUDA));
ReinitializeTensor(&db_, {N, C}, at::dtype<float>().device(CUDA));
ReinitializeTensor(&dY_scale_, {N, C}, at::dtype<float>().device(CUDA));
ReinitializeTensor(&X_scale_, {N, G}, at::dtype<float>().device(CUDA));
ReinitializeTensor(&bias_, {N, G}, at::dtype<float>().device(CUDA));
ReinitializeTensor(&ones_, {HxW}, at::dtype<float>().device(CUDA));
float* ds_data = ds_.mutable_data<float>();
float* db_data = db_.mutable_data<float>();
float* dY_scale_data = dY_scale_.mutable_data<float>();
float* X_scale_data = X_scale_.mutable_data<float>();
float* bias_data = bias_.mutable_data<float>();
float* ones_data = ones_.mutable_data<float>();
math::Set<float, CUDAContext>(HxW, 1.0f, ones_data, &context_);
math::Mul<float, CUDAContext>(
N * C * HxW, dY_data, X_data, dX_data, &context_);
math::GemmStridedBatched<float, CUDAContext>(
CblasTrans,
CblasNoTrans,
N,
C,
1,
HxW,
1.0f,
dX_data,
C * HxW,
ones_data,
0,
0.0f,
ds_data,
C,
&context_);
math::GemmStridedBatched<float, CUDAContext>(
CblasTrans,
CblasNoTrans,
N,
C,
1,
HxW,
1.0f,
dY_data,
C * HxW,
ones_data,
0,
0.0f,
db_data,
C,
&context_);
// Computes dL/dX.
int M = math::DivUp(N * C, CAFFE_CUDA_NUM_THREADS);
ComputeYGradientScaleCUDAKernel<float>
<<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
N, G, K, rsig_data, gamma_data, dY_scale_data);
ComputeXScaleAndBiasCUDAKernel<float>
<<<dim3(N, G), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
G,
K,
1.0f / static_cast<float>(K * HxW),
ds_data,
db_data,
mu_data,
rsig_data,
gamma_data,
X_scale_data,
bias_data);
M = math::DivUp(N * C * HxW, CAFFE_CUDA_NUM_THREADS);
GroupNormBackwardCUDAKernel<float, StorageOrder::NHWC>
<<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
N,
G,
K,
HxW,
dY_scale_data,
dY_data,
X_scale_data,
X_data,
bias_data,
dX_data);
// Computes dL/dgamma and dL/dbeta.
GammaBetaBackwardCUDAKernel<
float><<<dim3(G, K), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
N, G, K, ds_data, db_data, mu_data, rsig_data, dgamma_data, dbeta_data);
return true;
}

View File

@ -8,7 +8,6 @@
#include "caffe2/core/common.h"
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
@ -47,8 +46,7 @@ class GroupNormOp final : public Operator<Context> {
CAFFE_ENFORCE_EQ(gamma.numel(), C);
CAFFE_ENFORCE_EQ(beta.numel(), C);
const int G = group_;
const int D = C / G;
const int K = C / G;
auto* Y = Output(OUTPUT, X.sizes(), at::dtype<T>());
T* mu_data = nullptr;
T* rsig_data = nullptr;
@ -65,24 +63,38 @@ class GroupNormOp final : public Operator<Context> {
mu_data = mu_.template mutable_data<T>();
rsig_data = rsig_.template mutable_data<T>();
}
return RunOnDeviceImpl(
N,
G,
D,
HxW,
X.template data<T>(),
gamma.template data<T>(),
beta.template data<T>(),
Y->template mutable_data<T>(),
mu_data,
rsig_data);
if (order_ == StorageOrder::NCHW) {
return RunOnDeviceWithOrderNCHW(
N,
G,
K,
HxW,
X.template data<T>(),
gamma.template data<T>(),
beta.template data<T>(),
Y->template mutable_data<T>(),
mu_data,
rsig_data);
} else {
return RunOnDeviceWithOrderNHWC(
N,
G,
K,
HxW,
X.template data<T>(),
gamma.template data<T>(),
beta.template data<T>(),
Y->template mutable_data<T>(),
mu_data,
rsig_data);
}
}
protected:
bool RunOnDeviceImpl(
private:
bool RunOnDeviceWithOrderNCHW(
const int N,
const int G,
const int D,
const int K,
const int HxW,
const T* X,
const T* gamma,
@ -90,57 +102,63 @@ class GroupNormOp final : public Operator<Context> {
T* Y,
T* mu,
T* rsig) {
const int C = G * D;
const int C = G * K;
ReinitializeTensor(
&scale_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
ReinitializeTensor(
&bias_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
T* scale_data = scale_.template mutable_data<T>();
T* bias_data = bias_.template mutable_data<T>();
if (order_ == StorageOrder::NCHW) {
const std::array<int, 2> X_dims = {N * G, D * HxW};
const std::array<int, 2> Y_dims = {N * G, 1};
math::Moments<T, Context>(
2, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_);
math::InvStd<T, Context>(
N * G, static_cast<T>(epsilon_), rsig, rsig, &context_);
ComputeFusedParams(N, G, D, mu, rsig, gamma, beta, scale_data, bias_data);
GroupNormForwardNCHW(N, C, HxW, X, scale_data, bias_data, Y);
} else {
const std::array<int, 4> X_dims = {N, HxW, G, D};
const std::array<int, 4> Y_dims = {N, 1, G, 1};
math::Moments<T, Context>(
4, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_);
math::InvStd<T, Context>(
N * G, static_cast<T>(epsilon_), rsig, rsig, &context_);
ComputeFusedParams(N, G, D, mu, rsig, gamma, beta, scale_data, bias_data);
GroupNormForwardNHWC(N, C, HxW, X, scale_data, bias_data, Y);
}
const std::array<int, 2> X_dims = {N * G, K * HxW};
const std::array<int, 2> Y_dims = {N * G, 1};
math::Moments<T, Context>(
2, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_);
math::InvStd<T, Context>(
N * G, static_cast<T>(epsilon_), rsig, rsig, &context_);
ComputeFusedParams(N, G, K, mu, rsig, gamma, beta, scale_data, bias_data);
GroupNormForwardNCHW(N, C, HxW, X, scale_data, bias_data, Y);
return true;
}
bool RunOnDeviceWithOrderNHWC(
const int N,
const int G,
const int K,
const int HxW,
const T* X,
const T* gamma,
const T* beta,
T* Y,
T* mu,
T* rsig) {
const int C = G * K;
ReinitializeTensor(
&scale_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
ReinitializeTensor(
&bias_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
T* scale_data = scale_.template mutable_data<T>();
T* bias_data = bias_.template mutable_data<T>();
const std::array<int, 4> X_dims = {N, HxW, G, K};
const std::array<int, 4> Y_dims = {N, 1, G, 1};
math::Moments<T, Context>(
4, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_);
math::InvStd<T, Context>(
N * G, static_cast<T>(epsilon_), rsig, rsig, &context_);
ComputeFusedParams(N, G, K, mu, rsig, gamma, beta, scale_data, bias_data);
GroupNormForwardNHWC(N, C, HxW, X, scale_data, bias_data, Y);
return true;
}
void ComputeFusedParams(
const int N,
const int G,
const int D,
int N,
int G,
int K,
const T* mu,
const T* rsig,
const T* gamma,
const T* beta,
T* scale,
T* bias) {
const int C = G * D;
ConstEigenArrayMap<float> gamma_arr(gamma, D, G);
ConstEigenArrayMap<float> beta_arr(beta, D, G);
for (int i = 0; i < N; ++i) {
EigenArrayMap<T> scale_arr(scale + i * C, D, G);
scale_arr = gamma_arr.rowwise() *
ConstEigenVectorArrayMap<T>(rsig + i * G, G).transpose();
EigenArrayMap<T>(bias + i * C, D, G) = beta_arr -
scale_arr.rowwise() *
ConstEigenVectorArrayMap<T>(mu + i * G, G).transpose();
}
}
T* bias);
void GroupNormForwardNCHW(
const int N,
@ -149,13 +167,7 @@ class GroupNormOp final : public Operator<Context> {
const T* X,
const T* scale,
const T* bias,
T* Y) {
EigenArrayMap<float>(Y, HxW, N * C) =
(ConstEigenArrayMap<float>(X, HxW, N * C).rowwise() *
ConstEigenVectorArrayMap<float>(scale, N * C).transpose())
.rowwise() +
ConstEigenVectorArrayMap<float>(bias, N * C).transpose();
}
T* Y);
void GroupNormForwardNHWC(
const int N,
@ -164,16 +176,7 @@ class GroupNormOp final : public Operator<Context> {
const T* X,
const T* scale,
const T* bias,
T* Y) {
const int stride = HxW * C;
for (int i = 0; i < N; ++i) {
EigenArrayMap<float>(Y + i * stride, C, HxW) =
(ConstEigenArrayMap<float>(X + i * stride, C, HxW).colwise() *
ConstEigenVectorArrayMap<float>(scale + i * C, C))
.colwise() +
ConstEigenVectorArrayMap<float>(bias + i * C, C);
}
}
T* Y);
const int group_;
const float epsilon_;
@ -223,32 +226,61 @@ class GroupNormGradientOp final : public Operator<Context> {
CAFFE_ENFORCE_EQ(gamma.numel(), C);
CAFFE_ENFORCE_EQ(beta.numel(), C);
const int G = group_;
const int D = C / G;
const int K = C / G;
auto* dX = Output(INPUT_GRAD, X.sizes(), at::dtype<T>());
auto* dgamma = Output(GAMMA_GRAD, gamma.sizes(), at::dtype<T>());
auto* dbeta = Output(BETA_GRAD, beta.sizes(), at::dtype<T>());
return RunOnDeviceImpl(
N,
G,
D,
HxW,
dY.template data<T>(),
X.template data<T>(),
mu.template data<T>(),
rsig.template data<T>(),
gamma.template data<T>(),
dX->template mutable_data<T>(),
dgamma->template mutable_data<T>(),
dbeta->template mutable_data<T>());
if (order_ == StorageOrder::NCHW) {
return RunOnDeviceWithOrderNCHW(
N,
G,
K,
HxW,
dY.template data<T>(),
X.template data<T>(),
mu.template data<T>(),
rsig.template data<T>(),
gamma.template data<T>(),
dX->template mutable_data<T>(),
dgamma->template mutable_data<T>(),
dbeta->template mutable_data<T>());
} else {
return RunOnDeviceWithOrderNHWC(
N,
G,
K,
HxW,
dY.template data<T>(),
X.template data<T>(),
mu.template data<T>(),
rsig.template data<T>(),
gamma.template data<T>(),
dX->template mutable_data<T>(),
dgamma->template mutable_data<T>(),
dbeta->template mutable_data<T>());
}
}
protected:
bool RunOnDeviceImpl(
const int N,
const int G,
const int D,
const int HxW,
bool RunOnDeviceWithOrderNCHW(
int N,
int G,
int K,
int HxW,
const T* dY_data,
const T* X_data,
const T* mu_data,
const T* rsig_data,
const T* gamma_data,
T* dX_data,
T* dgamma_data,
T* dbeta_data);
bool RunOnDeviceWithOrderNHWC(
int N,
int G,
int K,
int HxW,
const T* dY_data,
const T* X_data,
const T* mu_data,
@ -263,6 +295,10 @@ class GroupNormGradientOp final : public Operator<Context> {
Tensor ds_;
Tensor db_;
Tensor dY_scale_;
Tensor X_scale_;
Tensor bias_;
Tensor ones_;
// Input: dY, X, gamma, beta, mu, inv_sig
// Output: dX, dgamma, dbeta