mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Revert "[cuda] Add new faster gammabeta backward kernel (#148605)"
This reverts commit 114d404b07.
Reverted https://github.com/pytorch/pytorch/pull/148605 on behalf of https://github.com/drisspg due to See https://github.com/pytorch/pytorch/issues/150266#issuecomment-2773907902 for more details ([comment](https://github.com/pytorch/pytorch/pull/148605#issuecomment-2773928838))
This commit is contained in:
parent
de15ef0ee8
commit
61a1f09b5b
|
|
@ -508,6 +508,7 @@ __global__ void layer_norm_grad_input_kernel_vectorized(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename T, typename T_ACC>
|
template <typename T, typename T_ACC>
|
||||||
__global__ void GammaBetaBackwardSimpleCUDAKernel(
|
__global__ void GammaBetaBackwardSimpleCUDAKernel(
|
||||||
int64_t M,
|
int64_t M,
|
||||||
|
|
@ -539,364 +540,191 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename T_ACC,
|
// This implementation gets called if M and N divide with 32. This case should
|
||||||
unsigned int block_dim_x,
|
// be the most common. We can then make better use of warp level intrinsics
|
||||||
unsigned int block_dim_y,
|
// to improve performance.
|
||||||
unsigned int rows_per_block_y,
|
|
||||||
bool check_x,
|
template <typename T, typename T_ACC>
|
||||||
bool check_y>
|
__global__ void GammaBetaBackwardCUDAKernel_32x32(
|
||||||
__device__
|
|
||||||
__forceinline__
|
|
||||||
void
|
|
||||||
blockReduceGammaBetaBackwardsHelper(
|
|
||||||
int64_t M_start,
|
|
||||||
int64_t M,
|
int64_t M,
|
||||||
int64_t N,
|
int64_t N,
|
||||||
const T* __restrict__ dY,
|
const T* dY,
|
||||||
const T* __restrict__ X,
|
const T* X,
|
||||||
const T_ACC* __restrict__ mean,
|
const T_ACC* mean,
|
||||||
const T_ACC* __restrict__ rstd,
|
const T_ACC* rstd,
|
||||||
T* __restrict__ dg,
|
T* dg,
|
||||||
T* __restrict__ db,
|
T* db) {
|
||||||
T_ACC &dg_sum,
|
alignas(sizeof(double)) extern __shared__ char s_data1[];
|
||||||
T_ACC &db_sum
|
T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
|
||||||
) {
|
T_ACC* s_dg;
|
||||||
constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y;
|
T_ACC* s_db;
|
||||||
int64_t thread_x = blockIdx.x * block_dim_x + threadIdx.x;
|
|
||||||
|
|
||||||
int lane_id = (threadIdx.y * blockDim.x + threadIdx.x) & (kWarpSize - 1);
|
|
||||||
int64_t mean_index = M_start + threadIdx.y * rows_per_thread_y;
|
|
||||||
T_ACC warp_mean = 0, warp_rstd = 0;
|
|
||||||
if (lane_id < rows_per_thread_y && mean_index + lane_id < M) {
|
|
||||||
warp_mean = mean[mean_index + lane_id];
|
|
||||||
warp_rstd = rstd[mean_index + lane_id];
|
|
||||||
}
|
|
||||||
// We do a WARP_SYNC() here because we use WARP_SHFL below to access
|
|
||||||
// warp_mean and warp_rstd.
|
|
||||||
WARP_SYNC();
|
|
||||||
|
|
||||||
T_ACC dY_regs[rows_per_thread_y] = {0};
|
|
||||||
T_ACC X_regs[rows_per_thread_y] = {0};
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < rows_per_thread_y; ++i) {
|
|
||||||
int64_t current_y = M_start + threadIdx.y * rows_per_thread_y + i;
|
|
||||||
bool active = true;
|
|
||||||
if (check_x && thread_x >= N) {
|
|
||||||
active = false;
|
|
||||||
}
|
|
||||||
if (check_y && current_y >= M) {
|
|
||||||
active = false;
|
|
||||||
}
|
|
||||||
if (active) {
|
|
||||||
dY_regs[i] = dY[current_y * N + thread_x];
|
|
||||||
X_regs[i] = X[current_y * N + thread_x];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < rows_per_thread_y; ++i) {
|
|
||||||
T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize);
|
|
||||||
T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize);
|
|
||||||
dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg;
|
|
||||||
db_sum += dY_regs[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename T_ACC,
|
|
||||||
unsigned int block_dim_x,
|
|
||||||
unsigned int block_dim_y,
|
|
||||||
unsigned int rows_per_block_y,
|
|
||||||
bool check_x,
|
|
||||||
bool check_y>
|
|
||||||
__device__
|
|
||||||
__forceinline__
|
|
||||||
void
|
|
||||||
blockReduceGammaBetaBackwardsWithChecks(
|
|
||||||
int64_t M,
|
|
||||||
int64_t N,
|
|
||||||
const T* __restrict__ dY,
|
|
||||||
const T* __restrict__ X,
|
|
||||||
const T_ACC* __restrict__ mean,
|
|
||||||
const T_ACC* __restrict__ rstd,
|
|
||||||
T* __restrict__ dg,
|
|
||||||
T* __restrict__ db,
|
|
||||||
T_ACC &dg_sum,
|
|
||||||
T_ACC &db_sum
|
|
||||||
) {
|
|
||||||
for (int64_t M_start = blockIdx.y * rows_per_block_y;
|
|
||||||
M_start < M;
|
|
||||||
M_start += rows_per_block_y * gridDim.y) {
|
|
||||||
int64_t M_end = M_start + rows_per_block_y - 1;
|
|
||||||
if (!check_y || M_end < M) {
|
|
||||||
blockReduceGammaBetaBackwardsHelper<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, check_x, false>
|
|
||||||
(M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
|
|
||||||
} else {
|
|
||||||
blockReduceGammaBetaBackwardsHelper<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, check_x, true>
|
|
||||||
(M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// block_dim_x is the number of threads in the x dimension per block.
|
|
||||||
// block_dim_y is the number of threads in the y dimension per block.
|
|
||||||
// rows_per_block_y is the size of the tile (number of data elements)
|
|
||||||
// in the y dimension per block.
|
|
||||||
// partial_reduction indicates whether we need to reduce across threads
|
|
||||||
// or not. If set to true, we will not reduce across threads. This can
|
|
||||||
// be faster in the M >> N case but requires another kernel to do a full
|
|
||||||
// final reduction.
|
|
||||||
// aligned_grid means the data size is a multiple of tile size. In that
|
|
||||||
// case we don't need to check for boundary conditions which can provide
|
|
||||||
// a further speedup by not needing instructions to check for edge cases
|
|
||||||
// and not needing predicate registers.
|
|
||||||
template <typename T, typename T_ACC,
|
|
||||||
unsigned int block_dim_x, unsigned int block_dim_y,
|
|
||||||
unsigned int rows_per_block_y,
|
|
||||||
bool partial_reduction,
|
|
||||||
bool aligned_grid
|
|
||||||
>
|
|
||||||
__global__
|
|
||||||
void
|
|
||||||
GammaBetaBackwardCUDAKernelTemplate(
|
|
||||||
int64_t M,
|
|
||||||
int64_t N,
|
|
||||||
const T* __restrict__ dY,
|
|
||||||
const T* __restrict__ X,
|
|
||||||
const T_ACC* __restrict__ mean,
|
|
||||||
const T_ACC* __restrict__ rstd,
|
|
||||||
T* __restrict__ dg,
|
|
||||||
T* __restrict__ db) {
|
|
||||||
// This assert is a compile-time check only.
|
|
||||||
constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y;
|
|
||||||
static_assert(rows_per_thread_y <= kWarpSize);
|
|
||||||
|
|
||||||
T_ACC dg_sum = 0;
|
T_ACC dg_sum = 0;
|
||||||
T_ACC db_sum = 0;
|
T_ACC db_sum = 0;
|
||||||
|
|
||||||
if (aligned_grid) {
|
const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
|
||||||
// When N and M align perfectly with block_dim_x and block_dim_y, we
|
|
||||||
// can skip boundary condition checks that waste instruction issue slots.
|
|
||||||
blockReduceGammaBetaBackwardsWithChecks
|
|
||||||
<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, false, false>
|
|
||||||
(M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
|
|
||||||
} else {
|
|
||||||
// In the general case we need to check boundary conditions in the M
|
|
||||||
// dimension. However, we can still avoid boundary checks in the N dimension
|
|
||||||
// for the inner blocks. So try to avoid those checks when possible.
|
|
||||||
if (blockIdx.x * block_dim_x + block_dim_x - 1 < N) {
|
|
||||||
blockReduceGammaBetaBackwardsWithChecks
|
|
||||||
<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, false, true>
|
|
||||||
(M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
|
|
||||||
} else {
|
|
||||||
blockReduceGammaBetaBackwardsWithChecks
|
|
||||||
<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, true, true>
|
|
||||||
(M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t thread_x = ((int64_t)blockIdx.x) * block_dim_x + threadIdx.x;
|
if (j < N) {
|
||||||
|
constexpr int unroll_factor = 8;
|
||||||
|
int laneId = threadIdx.x & (C10_WARP_SIZE - 1);
|
||||||
|
|
||||||
// When partial_reduction is requested, we don't reduce within a block.
|
T_ACC mean_reg, mean_reg_tmp;
|
||||||
// We also don't reduce if we are only a single block in the y dimension.
|
T_ACC rstd_reg, rstd_reg_tmp;
|
||||||
if (partial_reduction || (blockDim.y == 1 && gridDim.y == 1)) {
|
T dY_reg;
|
||||||
if (aligned_grid || thread_x < N) {
|
T X_reg;
|
||||||
int64_t thread_y = ((int64_t)blockIdx.y) * blockDim.y + threadIdx.y;
|
|
||||||
if (dg) {
|
// Main loop
|
||||||
dg[thread_y * N + thread_x] = dg_sum;
|
int bcounter;
|
||||||
|
for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor);
|
||||||
|
bcounter++) {
|
||||||
|
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
|
||||||
|
|
||||||
|
if (laneId < unroll_factor) {
|
||||||
|
mean_reg_tmp = mean[offset + laneId];
|
||||||
|
rstd_reg_tmp = rstd[offset + laneId];
|
||||||
}
|
}
|
||||||
if (db) {
|
WARP_SYNC();
|
||||||
db[thread_y * N + thread_x] = db_sum;
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int ii = 0; ii < unroll_factor; ++ii) {
|
||||||
|
dY_reg = dY[(offset + ii) * N + j];
|
||||||
|
X_reg = X[(offset + ii) * N + j];
|
||||||
|
mean_reg = WARP_SHFL(mean_reg_tmp, ii, kWarpSize);
|
||||||
|
rstd_reg = WARP_SHFL(rstd_reg_tmp, ii, kWarpSize);
|
||||||
|
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
|
||||||
|
db_sum += dY_reg;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// The caller requested a full reduction so we must reduce across
|
// Remainder loop
|
||||||
// warps using shared memory and warp shuffles.
|
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
|
||||||
static_assert(rows_per_thread_y <= C10_WARP_SIZE);
|
for (int ii = 0; ii < unroll_factor; ii++) {
|
||||||
alignas(sizeof(double)) extern __shared__ char s_data1[];
|
if ((offset + ii) < M) {
|
||||||
T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
|
mean_reg = mean[offset + ii];
|
||||||
T_ACC* s_dg;
|
rstd_reg = rstd[offset + ii];
|
||||||
T_ACC* s_db;
|
dY_reg = dY[(offset + ii) * N + j];
|
||||||
int padded_bx = (block_dim_x + 1);
|
X_reg = X[(offset + ii) * N + j];
|
||||||
// Transpose dg and db.
|
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
|
||||||
|
db_sum += dY_reg;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This kernel uses a block of (C10_WARP_SIZE x C10_WARP_SIZE) and
|
||||||
|
// gets called when M; N divide by 32. We can use warp shuffles
|
||||||
|
// for the final reduction step. This removes 4 shmem loads and
|
||||||
|
// stores with their corresponding __syncthreads()
|
||||||
|
|
||||||
|
// This greatly reduces bank conflicts at the expense of a little
|
||||||
|
// extra shared memory. It does not impact occupancy
|
||||||
|
int padded_bx = (1 + blockDim.x);
|
||||||
|
|
||||||
s_dg = s_data_typed;
|
s_dg = s_data_typed;
|
||||||
s_db = s_data_typed + (padded_bx * block_dim_y);
|
s_db = s_data_typed + (padded_bx * blockDim.y);
|
||||||
s_dg[threadIdx.y * padded_bx + threadIdx.x] = dg_sum;
|
s_dg[threadIdx.y * padded_bx + threadIdx.x] = dg_sum;
|
||||||
s_db[threadIdx.y * padded_bx + threadIdx.x] = db_sum;
|
s_db[threadIdx.y * padded_bx + threadIdx.x] = db_sum;
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// Load transposed so that a warp holds an entire column
|
// Load transposed so that a warp holds an entire column
|
||||||
// Because block_dim_x != block_dim_y in the general case, we need
|
T_ACC reg_dg = s_dg[threadIdx.x * padded_bx + threadIdx.y];
|
||||||
// some code to handle the general case.
|
T_ACC reg_db = s_db[threadIdx.x * padded_bx + threadIdx.y];
|
||||||
static_assert(block_dim_x * block_dim_y % C10_WARP_SIZE == 0);
|
for (unsigned delta = C10_WARP_SIZE >> 1; delta >= 1; delta >>= 1) {
|
||||||
constexpr int warps_available_to_reduce = block_dim_x * block_dim_y / C10_WARP_SIZE;
|
reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize);
|
||||||
int thread_id = threadIdx.y * block_dim_x + threadIdx.x;
|
reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize);
|
||||||
int warp_id = thread_id / C10_WARP_SIZE;
|
}
|
||||||
int lane_id = thread_id & (C10_WARP_SIZE - 1);
|
|
||||||
#pragma unroll
|
if (threadIdx.x == 0) {
|
||||||
for (int i = warp_id; i < block_dim_x; i += warps_available_to_reduce) {
|
const int64_t j = blockIdx.x * blockDim.x + threadIdx.y;
|
||||||
T_ACC reg_db, reg_dg;
|
if (dg) {
|
||||||
if (lane_id < block_dim_y) {
|
dg[j] = reg_dg;
|
||||||
reg_dg = s_dg[lane_id * padded_bx + i];
|
|
||||||
reg_db = s_db[lane_id * padded_bx + i];
|
|
||||||
}
|
}
|
||||||
|
if (db) {
|
||||||
|
db[j] = reg_db;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename T_ACC>
|
||||||
|
__global__ void GammaBetaBackwardCUDAKernel(
|
||||||
|
int64_t M,
|
||||||
|
int64_t N,
|
||||||
|
const T* dY,
|
||||||
|
const T* X,
|
||||||
|
const T_ACC* mean,
|
||||||
|
const T_ACC* rstd,
|
||||||
|
T* dg,
|
||||||
|
T* db) {
|
||||||
|
alignas(sizeof(double)) extern __shared__ char s_data1[];
|
||||||
|
T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
|
||||||
|
T_ACC* s_dg;
|
||||||
|
T_ACC* s_db;
|
||||||
|
|
||||||
|
const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
T_ACC dg_sum = 0;
|
||||||
|
T_ACC db_sum = 0;
|
||||||
|
|
||||||
|
if (j < N) {
|
||||||
|
constexpr int unroll_factor = 8;
|
||||||
|
|
||||||
|
T_ACC mean_reg;
|
||||||
|
T_ACC rstd_reg;
|
||||||
|
T dY_reg;
|
||||||
|
T X_reg;
|
||||||
|
|
||||||
|
// Main Loop
|
||||||
|
int bcounter;
|
||||||
|
for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); bcounter++){
|
||||||
|
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (unsigned delta = block_dim_y >> 1; delta >= 1; delta >>= 1) {
|
for (int ii = 0; ii < unroll_factor; ++ii) {
|
||||||
reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize);
|
dY_reg = dY[(offset + ii) * N + j];
|
||||||
reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize);
|
X_reg = X[(offset + ii) * N + j];
|
||||||
}
|
mean_reg = mean[offset + ii];
|
||||||
// Reduce is done. Now write it out to global memory.
|
rstd_reg = rstd[offset + ii];
|
||||||
int64_t out_index = ((int64_t)blockIdx.x) * block_dim_x + i;
|
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
|
||||||
if (threadIdx.x == 0 && (aligned_grid || out_index < N)) {
|
db_sum += dY_reg;
|
||||||
if (dg) {
|
|
||||||
dg[out_index] = reg_dg;
|
|
||||||
}
|
|
||||||
if (db) {
|
|
||||||
db[out_index] = reg_db;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T, typename T_ACC,
|
// Remainder loop
|
||||||
int block_dim_x, int block_dim_y,
|
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
|
||||||
int rows_per_block_y,
|
for (int ii = 0; ii < unroll_factor; ii++ ){
|
||||||
bool partial_reduction>
|
if ((offset + ii) < M) {
|
||||||
void LaunchAndCheckGammaBetaBackwardKernel(
|
dY_reg = dY[(offset + ii) * N + j ];
|
||||||
bool aligned_grid,
|
X_reg = X[(offset + ii) * N + j];
|
||||||
dim3 blocks,
|
mean_reg = mean[offset + ii];
|
||||||
dim3 threads,
|
rstd_reg = rstd[offset + ii];
|
||||||
size_t shmem_sz,
|
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
|
||||||
cudaStream_t cuda_stream,
|
db_sum += dY_reg;
|
||||||
const T* dY_data,
|
}
|
||||||
const T* X_data,
|
|
||||||
const T_ACC* mean_data,
|
|
||||||
const T_ACC* rstd_data,
|
|
||||||
int64_t M,
|
|
||||||
int64_t N,
|
|
||||||
T* dgamma_data,
|
|
||||||
T* dbeta_data) {
|
|
||||||
if (aligned_grid) {
|
|
||||||
GammaBetaBackwardCUDAKernelTemplate<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, partial_reduction, true>
|
|
||||||
<<<blocks, threads, shmem_sz, cuda_stream>>>(
|
|
||||||
M,
|
|
||||||
N,
|
|
||||||
dY_data,
|
|
||||||
X_data,
|
|
||||||
mean_data,
|
|
||||||
rstd_data,
|
|
||||||
dgamma_data,
|
|
||||||
dbeta_data);
|
|
||||||
} else {
|
|
||||||
GammaBetaBackwardCUDAKernelTemplate<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, partial_reduction, false>
|
|
||||||
<<<blocks, threads, shmem_sz, cuda_stream>>>(
|
|
||||||
M,
|
|
||||||
N,
|
|
||||||
dY_data,
|
|
||||||
X_data,
|
|
||||||
mean_data,
|
|
||||||
rstd_data,
|
|
||||||
dgamma_data,
|
|
||||||
dbeta_data);
|
|
||||||
}
|
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T, typename T_ACC,
|
|
||||||
int block_dim_x, int block_dim_y,
|
|
||||||
int rows_per_block_y>
|
|
||||||
void ConfigureAndLaunchGammaBetaBackwardKernel(
|
|
||||||
const T* dY_data,
|
|
||||||
const T* X_data,
|
|
||||||
const T_ACC* mean_data,
|
|
||||||
const T_ACC* rstd_data,
|
|
||||||
int64_t M,
|
|
||||||
int64_t N,
|
|
||||||
Tensor* dgamma,
|
|
||||||
Tensor* dbeta,
|
|
||||||
cudaStream_t cuda_stream) {
|
|
||||||
T* dgamma_data =
|
|
||||||
dgamma->defined() ? dgamma->template data_ptr<T>() : nullptr;
|
|
||||||
T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T>() : nullptr;
|
|
||||||
bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0);
|
|
||||||
dim3 threads{block_dim_x, block_dim_y};
|
|
||||||
dim3 blocks;
|
|
||||||
blocks.x = (N + block_dim_x - 1) / block_dim_x;
|
|
||||||
blocks.y = 1;
|
|
||||||
size_t shmem_sz = (block_dim_x + 1) * block_dim_y * sizeof(T_ACC) * 2;
|
|
||||||
if (blocks.y == 1 && threads.y == 1) {
|
|
||||||
// Optimization: since there is just one thread doing all the summation, we don't need a reduction
|
|
||||||
// across threads. So we set partial_reduction to true.
|
|
||||||
LaunchAndCheckGammaBetaBackwardKernel<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, true>(
|
|
||||||
aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data);
|
|
||||||
} else {
|
|
||||||
LaunchAndCheckGammaBetaBackwardKernel<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, false>(
|
|
||||||
aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T, typename T_ACC>
|
|
||||||
void LaunchGammaBetaBackwardCUDAKernel(
|
|
||||||
const T* dY_data,
|
|
||||||
const T* X_data,
|
|
||||||
const T_ACC* mean_data,
|
|
||||||
const T_ACC* rstd_data,
|
|
||||||
int64_t M,
|
|
||||||
int64_t N,
|
|
||||||
Tensor* dgamma,
|
|
||||||
Tensor* dbeta,
|
|
||||||
cudaStream_t cuda_stream) {
|
|
||||||
constexpr int block_dim_x = 32;
|
|
||||||
const int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
|
||||||
if (M > 64 * 1024 && N / block_dim_x < sm_count / 2) {
|
|
||||||
// We have a situation where M >> N and N is small.
|
|
||||||
// In this case we can speed up the computation by parallelizing in the M dimension.
|
|
||||||
// We launch multiple blocks in the y-dimension, and compute partial sums for the
|
|
||||||
// gradient in the first pass. Then we do a .sum(0) to do a final reduction.
|
|
||||||
// Although we launch 2 kernels, we can get up to a 10x speedup for large M.
|
|
||||||
constexpr int block_dim_y = 1;
|
|
||||||
constexpr int rows_per_block_y = 32;
|
|
||||||
bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0);
|
|
||||||
dim3 threads{block_dim_x, block_dim_y};
|
|
||||||
dim3 blocks;
|
|
||||||
blocks.x = (N + block_dim_x - 1) / block_dim_x;
|
|
||||||
// int rows_per_block = my_gamma_beta_unroll_factor *
|
|
||||||
blocks.y = (M + rows_per_block_y - 1) / rows_per_block_y;
|
|
||||||
constexpr int max_grid_size = 64 * 1024 / 2;
|
|
||||||
blocks.y = std::min<unsigned int>(max_grid_size / blocks.x, blocks.y);
|
|
||||||
Tensor dgamma_blocks;
|
|
||||||
Tensor dbeta_blocks;
|
|
||||||
T * dgamma_blocks_ptr = nullptr;
|
|
||||||
T * dbeta_blocks_ptr = nullptr;
|
|
||||||
if (dgamma->defined()) {
|
|
||||||
auto options = dgamma->options();
|
|
||||||
dgamma_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options);
|
|
||||||
dgamma_blocks_ptr = dgamma_blocks.data_ptr<T>();
|
|
||||||
}
|
}
|
||||||
if (dbeta->defined()) {
|
|
||||||
auto options = dbeta->options();
|
|
||||||
dbeta_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options);
|
|
||||||
dbeta_blocks_ptr = dbeta_blocks.data_ptr<T>();
|
|
||||||
}
|
|
||||||
LaunchAndCheckGammaBetaBackwardKernel<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, true>(
|
|
||||||
aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr);
|
|
||||||
|
|
||||||
*dgamma = dgamma_blocks.sum(0);
|
// Do the final reduction in shared memory
|
||||||
*dbeta = dbeta_blocks.sum(0);
|
s_dg = s_data_typed;
|
||||||
} else {
|
s_db = s_data_typed + blockDim.x * blockDim.y;
|
||||||
// We are in the normal case where M is not that large.
|
s_dg[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum;
|
||||||
// We can change the tile shape (which is the last template parameter) in accordance with M.
|
s_db[threadIdx.y * blockDim.x + threadIdx.x] = db_sum;
|
||||||
// For small M it is faster to have a smaller tile, otherwise we could have idle threads.
|
__syncthreads();
|
||||||
// For larger M we use a bigger tile size.
|
|
||||||
if (M < 64) {
|
for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
|
||||||
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 1, 8>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
|
if (threadIdx.y < offset) {
|
||||||
} else if (M < 128) {
|
s_dg[threadIdx.y * blockDim.x + threadIdx.x] +=
|
||||||
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 8, 64>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
|
s_dg[(threadIdx.y + offset) * blockDim.x + threadIdx.x];
|
||||||
} else if (M < 256) {
|
s_db[threadIdx.y * blockDim.x + threadIdx.x] +=
|
||||||
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 16, 128>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
|
s_db[(threadIdx.y + offset) * blockDim.x + threadIdx.x];
|
||||||
} else {
|
}
|
||||||
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 32, 256>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIdx.y == 0) {
|
||||||
|
if (dg) {
|
||||||
|
dg[j] = s_dg[threadIdx.x];
|
||||||
|
}
|
||||||
|
if (db) {
|
||||||
|
db[j] = s_db[threadIdx.x];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1422,7 +1250,6 @@ void LayerNormBackwardKernelImplInternal(
|
||||||
dgamma->defined() ? dgamma->template data_ptr<T>() : nullptr;
|
dgamma->defined() ? dgamma->template data_ptr<T>() : nullptr;
|
||||||
T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T>() : nullptr;
|
T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T>() : nullptr;
|
||||||
|
|
||||||
#if defined(USE_ROCM)
|
|
||||||
if (M < 128) {
|
if (M < 128) {
|
||||||
// For small batch size, do colwise reduce directly.
|
// For small batch size, do colwise reduce directly.
|
||||||
const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads;
|
const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads;
|
||||||
|
|
@ -1438,6 +1265,7 @@ void LayerNormBackwardKernelImplInternal(
|
||||||
dbeta_data);
|
dbeta_data);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
} else {
|
} else {
|
||||||
|
#if defined(USE_ROCM)
|
||||||
// For small batch size, do colwise reduce directly.
|
// For small batch size, do colwise reduce directly.
|
||||||
const int part_size = warp_size;
|
const int part_size = warp_size;
|
||||||
const dim3 threads2(warp_size, 4, 1);
|
const dim3 threads2(warp_size, 4, 1);
|
||||||
|
|
@ -1472,11 +1300,47 @@ void LayerNormBackwardKernelImplInternal(
|
||||||
dgamma_data,
|
dgamma_data,
|
||||||
dbeta_data);
|
dbeta_data);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
}
|
|
||||||
#else
|
#else
|
||||||
LaunchGammaBetaBackwardCUDAKernel(
|
if ((M % kWarpSize == 0) && (N % kWarpSize == 0)) {
|
||||||
dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
|
// This implementation relies on warp primitives and requires that M and N divide
|
||||||
|
// exactly to warp size.
|
||||||
|
dim3 threads{kWarpSize, kWarpSize};
|
||||||
|
int blocks = (N + threads.x - 1) / threads.x;
|
||||||
|
|
||||||
|
// If M and N divide by warp_size, we can use warp shuffles for the final reduction.
|
||||||
|
// That requires transposing values in shared memory, so we apply a padding to
|
||||||
|
// reduce bank conflicts.
|
||||||
|
|
||||||
|
size_t shmem_sz = 2 * sizeof(T_ACC) * (threads.x + 1) * threads.y;
|
||||||
|
GammaBetaBackwardCUDAKernel_32x32<T, T_ACC>
|
||||||
|
<<<blocks, threads, shmem_sz, cuda_stream>>>(
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
dY_data,
|
||||||
|
X_data,
|
||||||
|
mean_data,
|
||||||
|
rstd_data,
|
||||||
|
dgamma_data,
|
||||||
|
dbeta_data);
|
||||||
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
} else {
|
||||||
|
dim3 threads{16, 32};
|
||||||
|
int blocks = (N + threads.x - 1) / threads.x;
|
||||||
|
size_t shmem_sz = 2 * sizeof(T_ACC) * threads.x * threads.y;
|
||||||
|
GammaBetaBackwardCUDAKernel<T, T_ACC>
|
||||||
|
<<<blocks, threads, shmem_sz, cuda_stream>>>(
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
dY_data,
|
||||||
|
X_data,
|
||||||
|
mean_data,
|
||||||
|
rstd_data,
|
||||||
|
dgamma_data,
|
||||||
|
dbeta_data);
|
||||||
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7195,26 +7195,6 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||||
ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False)
|
ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False)
|
||||||
self.assertEqual(ln.forward(x), torch.zeros_like(x))
|
self.assertEqual(ln.forward(x), torch.zeros_like(x))
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
|
||||||
def test_layer_norm_backwards_eps(self):
|
|
||||||
dtype = torch.float
|
|
||||||
m_x_n_list = [(3, 3), (5, 5), (11, 11), (55, 55),
|
|
||||||
(32, 32), (1024, 32), (1024, 1024),
|
|
||||||
(33, 33), (1025, 33), (1025, 1025)]
|
|
||||||
for m, n in m_x_n_list:
|
|
||||||
x = torch.randn((m, n), dtype=dtype, requires_grad=True)
|
|
||||||
grad_output = torch.rand_like(x)
|
|
||||||
x_cuda = x.clone().detach().to("cuda").requires_grad_()
|
|
||||||
grad_output_cuda = grad_output.clone().detach().to("cuda")
|
|
||||||
ln = nn.LayerNorm(n, dtype=dtype)
|
|
||||||
ln_cuda = nn.LayerNorm(n, device="cuda", dtype=dtype)
|
|
||||||
ln_out = ln(x)
|
|
||||||
ln_out_cuda = ln_cuda(x_cuda)
|
|
||||||
ln_out.backward(grad_output)
|
|
||||||
ln_out_cuda.backward(grad_output_cuda)
|
|
||||||
self.assertEqual(ln.weight.grad, ln_cuda.weight.grad, f"weight grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4)
|
|
||||||
self.assertEqual(ln.bias.grad, ln_cuda.bias.grad, f"bias grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4)
|
|
||||||
|
|
||||||
@largeTensorTest("40GB", device="cuda")
|
@largeTensorTest("40GB", device="cuda")
|
||||||
def test_layer_norm_large_tensor(self):
|
def test_layer_norm_large_tensor(self):
|
||||||
# test for https://github.com/pytorch/pytorch/issues/136291
|
# test for https://github.com/pytorch/pytorch/issues/136291
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user