mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Revert "[cuda] Add new faster gammabeta backward kernel (#148605)"
This reverts commit 114d404b07.
Reverted https://github.com/pytorch/pytorch/pull/148605 on behalf of https://github.com/drisspg due to See https://github.com/pytorch/pytorch/issues/150266#issuecomment-2773907902 for more details ([comment](https://github.com/pytorch/pytorch/pull/148605#issuecomment-2773928838))
This commit is contained in:
parent
de15ef0ee8
commit
61a1f09b5b
|
|
@ -508,6 +508,7 @@ __global__ void layer_norm_grad_input_kernel_vectorized(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename T_ACC>
|
||||
__global__ void GammaBetaBackwardSimpleCUDAKernel(
|
||||
int64_t M,
|
||||
|
|
@ -539,364 +540,191 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T, typename T_ACC,
|
||||
unsigned int block_dim_x,
|
||||
unsigned int block_dim_y,
|
||||
unsigned int rows_per_block_y,
|
||||
bool check_x,
|
||||
bool check_y>
|
||||
__device__
|
||||
__forceinline__
|
||||
void
|
||||
blockReduceGammaBetaBackwardsHelper(
|
||||
int64_t M_start,
|
||||
// This implementation gets called if M and N divide with 32. This case should
|
||||
// be the most common. We can then make better use of warp level intrinsics
|
||||
// to improve performance.
|
||||
|
||||
template <typename T, typename T_ACC>
|
||||
__global__ void GammaBetaBackwardCUDAKernel_32x32(
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
const T* __restrict__ dY,
|
||||
const T* __restrict__ X,
|
||||
const T_ACC* __restrict__ mean,
|
||||
const T_ACC* __restrict__ rstd,
|
||||
T* __restrict__ dg,
|
||||
T* __restrict__ db,
|
||||
T_ACC &dg_sum,
|
||||
T_ACC &db_sum
|
||||
) {
|
||||
constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y;
|
||||
int64_t thread_x = blockIdx.x * block_dim_x + threadIdx.x;
|
||||
|
||||
int lane_id = (threadIdx.y * blockDim.x + threadIdx.x) & (kWarpSize - 1);
|
||||
int64_t mean_index = M_start + threadIdx.y * rows_per_thread_y;
|
||||
T_ACC warp_mean = 0, warp_rstd = 0;
|
||||
if (lane_id < rows_per_thread_y && mean_index + lane_id < M) {
|
||||
warp_mean = mean[mean_index + lane_id];
|
||||
warp_rstd = rstd[mean_index + lane_id];
|
||||
}
|
||||
// We do a WARP_SYNC() here because we use WARP_SHFL below to access
|
||||
// warp_mean and warp_rstd.
|
||||
WARP_SYNC();
|
||||
|
||||
T_ACC dY_regs[rows_per_thread_y] = {0};
|
||||
T_ACC X_regs[rows_per_thread_y] = {0};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < rows_per_thread_y; ++i) {
|
||||
int64_t current_y = M_start + threadIdx.y * rows_per_thread_y + i;
|
||||
bool active = true;
|
||||
if (check_x && thread_x >= N) {
|
||||
active = false;
|
||||
}
|
||||
if (check_y && current_y >= M) {
|
||||
active = false;
|
||||
}
|
||||
if (active) {
|
||||
dY_regs[i] = dY[current_y * N + thread_x];
|
||||
X_regs[i] = X[current_y * N + thread_x];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < rows_per_thread_y; ++i) {
|
||||
T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize);
|
||||
T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize);
|
||||
dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg;
|
||||
db_sum += dY_regs[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename T_ACC,
|
||||
unsigned int block_dim_x,
|
||||
unsigned int block_dim_y,
|
||||
unsigned int rows_per_block_y,
|
||||
bool check_x,
|
||||
bool check_y>
|
||||
__device__
|
||||
__forceinline__
|
||||
void
|
||||
blockReduceGammaBetaBackwardsWithChecks(
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
const T* __restrict__ dY,
|
||||
const T* __restrict__ X,
|
||||
const T_ACC* __restrict__ mean,
|
||||
const T_ACC* __restrict__ rstd,
|
||||
T* __restrict__ dg,
|
||||
T* __restrict__ db,
|
||||
T_ACC &dg_sum,
|
||||
T_ACC &db_sum
|
||||
) {
|
||||
for (int64_t M_start = blockIdx.y * rows_per_block_y;
|
||||
M_start < M;
|
||||
M_start += rows_per_block_y * gridDim.y) {
|
||||
int64_t M_end = M_start + rows_per_block_y - 1;
|
||||
if (!check_y || M_end < M) {
|
||||
blockReduceGammaBetaBackwardsHelper<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, check_x, false>
|
||||
(M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
|
||||
} else {
|
||||
blockReduceGammaBetaBackwardsHelper<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, check_x, true>
|
||||
(M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// block_dim_x is the number of threads in the x dimension per block.
|
||||
// block_dim_y is the number of threads in the y dimension per block.
|
||||
// rows_per_block_y is the size of the tile (number of data elements)
|
||||
// in the y dimension per block.
|
||||
// partial_reduction indicates whether we need to reduce across threads
|
||||
// or not. If set to true, we will not reduce across threads. This can
|
||||
// be faster in the M >> N case but requires another kernel to do a full
|
||||
// final reduction.
|
||||
// aligned_grid means the data size is a multiple of tile size. In that
|
||||
// case we don't need to check for boundary conditions which can provide
|
||||
// a further speedup by not needing instructions to check for edge cases
|
||||
// and not needing predicate registers.
|
||||
template <typename T, typename T_ACC,
|
||||
unsigned int block_dim_x, unsigned int block_dim_y,
|
||||
unsigned int rows_per_block_y,
|
||||
bool partial_reduction,
|
||||
bool aligned_grid
|
||||
>
|
||||
__global__
|
||||
void
|
||||
GammaBetaBackwardCUDAKernelTemplate(
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
const T* __restrict__ dY,
|
||||
const T* __restrict__ X,
|
||||
const T_ACC* __restrict__ mean,
|
||||
const T_ACC* __restrict__ rstd,
|
||||
T* __restrict__ dg,
|
||||
T* __restrict__ db) {
|
||||
// This assert is a compile-time check only.
|
||||
constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y;
|
||||
static_assert(rows_per_thread_y <= kWarpSize);
|
||||
const T* dY,
|
||||
const T* X,
|
||||
const T_ACC* mean,
|
||||
const T_ACC* rstd,
|
||||
T* dg,
|
||||
T* db) {
|
||||
alignas(sizeof(double)) extern __shared__ char s_data1[];
|
||||
T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
|
||||
T_ACC* s_dg;
|
||||
T_ACC* s_db;
|
||||
|
||||
T_ACC dg_sum = 0;
|
||||
T_ACC db_sum = 0;
|
||||
|
||||
if (aligned_grid) {
|
||||
// When N and M align perfectly with block_dim_x and block_dim_y, we
|
||||
// can skip boundary condition checks that waste instruction issue slots.
|
||||
blockReduceGammaBetaBackwardsWithChecks
|
||||
<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, false, false>
|
||||
(M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
|
||||
} else {
|
||||
// In the general case we need to check boundary conditions in the M
|
||||
// dimension. However, we can still avoid boundary checks in the N dimension
|
||||
// for the inner blocks. So try to avoid those checks when possible.
|
||||
if (blockIdx.x * block_dim_x + block_dim_x - 1 < N) {
|
||||
blockReduceGammaBetaBackwardsWithChecks
|
||||
<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, false, true>
|
||||
(M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
|
||||
} else {
|
||||
blockReduceGammaBetaBackwardsWithChecks
|
||||
<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, true, true>
|
||||
(M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
|
||||
}
|
||||
}
|
||||
const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
|
||||
int64_t thread_x = ((int64_t)blockIdx.x) * block_dim_x + threadIdx.x;
|
||||
if (j < N) {
|
||||
constexpr int unroll_factor = 8;
|
||||
int laneId = threadIdx.x & (C10_WARP_SIZE - 1);
|
||||
|
||||
// When partial_reduction is requested, we don't reduce within a block.
|
||||
// We also don't reduce if we are only a single block in the y dimension.
|
||||
if (partial_reduction || (blockDim.y == 1 && gridDim.y == 1)) {
|
||||
if (aligned_grid || thread_x < N) {
|
||||
int64_t thread_y = ((int64_t)blockIdx.y) * blockDim.y + threadIdx.y;
|
||||
if (dg) {
|
||||
dg[thread_y * N + thread_x] = dg_sum;
|
||||
T_ACC mean_reg, mean_reg_tmp;
|
||||
T_ACC rstd_reg, rstd_reg_tmp;
|
||||
T dY_reg;
|
||||
T X_reg;
|
||||
|
||||
// Main loop
|
||||
int bcounter;
|
||||
for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor);
|
||||
bcounter++) {
|
||||
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
|
||||
|
||||
if (laneId < unroll_factor) {
|
||||
mean_reg_tmp = mean[offset + laneId];
|
||||
rstd_reg_tmp = rstd[offset + laneId];
|
||||
}
|
||||
if (db) {
|
||||
db[thread_y * N + thread_x] = db_sum;
|
||||
WARP_SYNC();
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < unroll_factor; ++ii) {
|
||||
dY_reg = dY[(offset + ii) * N + j];
|
||||
X_reg = X[(offset + ii) * N + j];
|
||||
mean_reg = WARP_SHFL(mean_reg_tmp, ii, kWarpSize);
|
||||
rstd_reg = WARP_SHFL(rstd_reg_tmp, ii, kWarpSize);
|
||||
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
|
||||
db_sum += dY_reg;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// The caller requested a full reduction so we must reduce across
|
||||
// warps using shared memory and warp shuffles.
|
||||
static_assert(rows_per_thread_y <= C10_WARP_SIZE);
|
||||
alignas(sizeof(double)) extern __shared__ char s_data1[];
|
||||
T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
|
||||
T_ACC* s_dg;
|
||||
T_ACC* s_db;
|
||||
int padded_bx = (block_dim_x + 1);
|
||||
// Transpose dg and db.
|
||||
|
||||
// Remainder loop
|
||||
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
|
||||
for (int ii = 0; ii < unroll_factor; ii++) {
|
||||
if ((offset + ii) < M) {
|
||||
mean_reg = mean[offset + ii];
|
||||
rstd_reg = rstd[offset + ii];
|
||||
dY_reg = dY[(offset + ii) * N + j];
|
||||
X_reg = X[(offset + ii) * N + j];
|
||||
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
|
||||
db_sum += dY_reg;
|
||||
}
|
||||
}
|
||||
|
||||
// This kernel uses a block of (C10_WARP_SIZE x C10_WARP_SIZE) and
|
||||
// gets called when M; N divide by 32. We can use warp shuffles
|
||||
// for the final reduction step. This removes 4 shmem loads and
|
||||
// stores with their corresponding __syncthreads()
|
||||
|
||||
// This greatly reduces bank conflicts at the expense of a little
|
||||
// extra shared memory. It does not impact occupancy
|
||||
int padded_bx = (1 + blockDim.x);
|
||||
|
||||
s_dg = s_data_typed;
|
||||
s_db = s_data_typed + (padded_bx * block_dim_y);
|
||||
s_db = s_data_typed + (padded_bx * blockDim.y);
|
||||
s_dg[threadIdx.y * padded_bx + threadIdx.x] = dg_sum;
|
||||
s_db[threadIdx.y * padded_bx + threadIdx.x] = db_sum;
|
||||
__syncthreads();
|
||||
|
||||
// Load transposed so that a warp holds an entire column
|
||||
// Because block_dim_x != block_dim_y in the general case, we need
|
||||
// some code to handle the general case.
|
||||
static_assert(block_dim_x * block_dim_y % C10_WARP_SIZE == 0);
|
||||
constexpr int warps_available_to_reduce = block_dim_x * block_dim_y / C10_WARP_SIZE;
|
||||
int thread_id = threadIdx.y * block_dim_x + threadIdx.x;
|
||||
int warp_id = thread_id / C10_WARP_SIZE;
|
||||
int lane_id = thread_id & (C10_WARP_SIZE - 1);
|
||||
#pragma unroll
|
||||
for (int i = warp_id; i < block_dim_x; i += warps_available_to_reduce) {
|
||||
T_ACC reg_db, reg_dg;
|
||||
if (lane_id < block_dim_y) {
|
||||
reg_dg = s_dg[lane_id * padded_bx + i];
|
||||
reg_db = s_db[lane_id * padded_bx + i];
|
||||
T_ACC reg_dg = s_dg[threadIdx.x * padded_bx + threadIdx.y];
|
||||
T_ACC reg_db = s_db[threadIdx.x * padded_bx + threadIdx.y];
|
||||
for (unsigned delta = C10_WARP_SIZE >> 1; delta >= 1; delta >>= 1) {
|
||||
reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize);
|
||||
reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
const int64_t j = blockIdx.x * blockDim.x + threadIdx.y;
|
||||
if (dg) {
|
||||
dg[j] = reg_dg;
|
||||
}
|
||||
if (db) {
|
||||
db[j] = reg_db;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename T_ACC>
|
||||
__global__ void GammaBetaBackwardCUDAKernel(
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
const T* dY,
|
||||
const T* X,
|
||||
const T_ACC* mean,
|
||||
const T_ACC* rstd,
|
||||
T* dg,
|
||||
T* db) {
|
||||
alignas(sizeof(double)) extern __shared__ char s_data1[];
|
||||
T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
|
||||
T_ACC* s_dg;
|
||||
T_ACC* s_db;
|
||||
|
||||
const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
|
||||
T_ACC dg_sum = 0;
|
||||
T_ACC db_sum = 0;
|
||||
|
||||
if (j < N) {
|
||||
constexpr int unroll_factor = 8;
|
||||
|
||||
T_ACC mean_reg;
|
||||
T_ACC rstd_reg;
|
||||
T dY_reg;
|
||||
T X_reg;
|
||||
|
||||
// Main Loop
|
||||
int bcounter;
|
||||
for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); bcounter++){
|
||||
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned delta = block_dim_y >> 1; delta >= 1; delta >>= 1) {
|
||||
reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize);
|
||||
reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize);
|
||||
}
|
||||
// Reduce is done. Now write it out to global memory.
|
||||
int64_t out_index = ((int64_t)blockIdx.x) * block_dim_x + i;
|
||||
if (threadIdx.x == 0 && (aligned_grid || out_index < N)) {
|
||||
if (dg) {
|
||||
dg[out_index] = reg_dg;
|
||||
}
|
||||
if (db) {
|
||||
db[out_index] = reg_db;
|
||||
}
|
||||
for (int ii = 0; ii < unroll_factor; ++ii) {
|
||||
dY_reg = dY[(offset + ii) * N + j];
|
||||
X_reg = X[(offset + ii) * N + j];
|
||||
mean_reg = mean[offset + ii];
|
||||
rstd_reg = rstd[offset + ii];
|
||||
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
|
||||
db_sum += dY_reg;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename T_ACC,
|
||||
int block_dim_x, int block_dim_y,
|
||||
int rows_per_block_y,
|
||||
bool partial_reduction>
|
||||
void LaunchAndCheckGammaBetaBackwardKernel(
|
||||
bool aligned_grid,
|
||||
dim3 blocks,
|
||||
dim3 threads,
|
||||
size_t shmem_sz,
|
||||
cudaStream_t cuda_stream,
|
||||
const T* dY_data,
|
||||
const T* X_data,
|
||||
const T_ACC* mean_data,
|
||||
const T_ACC* rstd_data,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
T* dgamma_data,
|
||||
T* dbeta_data) {
|
||||
if (aligned_grid) {
|
||||
GammaBetaBackwardCUDAKernelTemplate<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, partial_reduction, true>
|
||||
<<<blocks, threads, shmem_sz, cuda_stream>>>(
|
||||
M,
|
||||
N,
|
||||
dY_data,
|
||||
X_data,
|
||||
mean_data,
|
||||
rstd_data,
|
||||
dgamma_data,
|
||||
dbeta_data);
|
||||
} else {
|
||||
GammaBetaBackwardCUDAKernelTemplate<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, partial_reduction, false>
|
||||
<<<blocks, threads, shmem_sz, cuda_stream>>>(
|
||||
M,
|
||||
N,
|
||||
dY_data,
|
||||
X_data,
|
||||
mean_data,
|
||||
rstd_data,
|
||||
dgamma_data,
|
||||
dbeta_data);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template<typename T, typename T_ACC,
|
||||
int block_dim_x, int block_dim_y,
|
||||
int rows_per_block_y>
|
||||
void ConfigureAndLaunchGammaBetaBackwardKernel(
|
||||
const T* dY_data,
|
||||
const T* X_data,
|
||||
const T_ACC* mean_data,
|
||||
const T_ACC* rstd_data,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
Tensor* dgamma,
|
||||
Tensor* dbeta,
|
||||
cudaStream_t cuda_stream) {
|
||||
T* dgamma_data =
|
||||
dgamma->defined() ? dgamma->template data_ptr<T>() : nullptr;
|
||||
T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T>() : nullptr;
|
||||
bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0);
|
||||
dim3 threads{block_dim_x, block_dim_y};
|
||||
dim3 blocks;
|
||||
blocks.x = (N + block_dim_x - 1) / block_dim_x;
|
||||
blocks.y = 1;
|
||||
size_t shmem_sz = (block_dim_x + 1) * block_dim_y * sizeof(T_ACC) * 2;
|
||||
if (blocks.y == 1 && threads.y == 1) {
|
||||
// Optimization: since there is just one thread doing all the summation, we don't need a reduction
|
||||
// across threads. So we set partial_reduction to true.
|
||||
LaunchAndCheckGammaBetaBackwardKernel<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, true>(
|
||||
aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data);
|
||||
} else {
|
||||
LaunchAndCheckGammaBetaBackwardKernel<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, false>(
|
||||
aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<typename T, typename T_ACC>
|
||||
void LaunchGammaBetaBackwardCUDAKernel(
|
||||
const T* dY_data,
|
||||
const T* X_data,
|
||||
const T_ACC* mean_data,
|
||||
const T_ACC* rstd_data,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
Tensor* dgamma,
|
||||
Tensor* dbeta,
|
||||
cudaStream_t cuda_stream) {
|
||||
constexpr int block_dim_x = 32;
|
||||
const int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||
if (M > 64 * 1024 && N / block_dim_x < sm_count / 2) {
|
||||
// We have a situation where M >> N and N is small.
|
||||
// In this case we can speed up the computation by parallelizing in the M dimension.
|
||||
// We launch multiple blocks in the y-dimension, and compute partial sums for the
|
||||
// gradient in the first pass. Then we do a .sum(0) to do a final reduction.
|
||||
// Although we launch 2 kernels, we can get up to a 10x speedup for large M.
|
||||
constexpr int block_dim_y = 1;
|
||||
constexpr int rows_per_block_y = 32;
|
||||
bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0);
|
||||
dim3 threads{block_dim_x, block_dim_y};
|
||||
dim3 blocks;
|
||||
blocks.x = (N + block_dim_x - 1) / block_dim_x;
|
||||
// int rows_per_block = my_gamma_beta_unroll_factor *
|
||||
blocks.y = (M + rows_per_block_y - 1) / rows_per_block_y;
|
||||
constexpr int max_grid_size = 64 * 1024 / 2;
|
||||
blocks.y = std::min<unsigned int>(max_grid_size / blocks.x, blocks.y);
|
||||
Tensor dgamma_blocks;
|
||||
Tensor dbeta_blocks;
|
||||
T * dgamma_blocks_ptr = nullptr;
|
||||
T * dbeta_blocks_ptr = nullptr;
|
||||
if (dgamma->defined()) {
|
||||
auto options = dgamma->options();
|
||||
dgamma_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options);
|
||||
dgamma_blocks_ptr = dgamma_blocks.data_ptr<T>();
|
||||
// Remainder loop
|
||||
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
|
||||
for (int ii = 0; ii < unroll_factor; ii++ ){
|
||||
if ((offset + ii) < M) {
|
||||
dY_reg = dY[(offset + ii) * N + j ];
|
||||
X_reg = X[(offset + ii) * N + j];
|
||||
mean_reg = mean[offset + ii];
|
||||
rstd_reg = rstd[offset + ii];
|
||||
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
|
||||
db_sum += dY_reg;
|
||||
}
|
||||
}
|
||||
if (dbeta->defined()) {
|
||||
auto options = dbeta->options();
|
||||
dbeta_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options);
|
||||
dbeta_blocks_ptr = dbeta_blocks.data_ptr<T>();
|
||||
}
|
||||
LaunchAndCheckGammaBetaBackwardKernel<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, true>(
|
||||
aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr);
|
||||
|
||||
*dgamma = dgamma_blocks.sum(0);
|
||||
*dbeta = dbeta_blocks.sum(0);
|
||||
} else {
|
||||
// We are in the normal case where M is not that large.
|
||||
// We can change the tile shape (which is the last template parameter) in accordance with M.
|
||||
// For small M it is faster to have a smaller tile, otherwise we could have idle threads.
|
||||
// For larger M we use a bigger tile size.
|
||||
if (M < 64) {
|
||||
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 1, 8>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
|
||||
} else if (M < 128) {
|
||||
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 8, 64>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
|
||||
} else if (M < 256) {
|
||||
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 16, 128>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
|
||||
} else {
|
||||
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 32, 256>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
|
||||
// Do the final reduction in shared memory
|
||||
s_dg = s_data_typed;
|
||||
s_db = s_data_typed + blockDim.x * blockDim.y;
|
||||
s_dg[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum;
|
||||
s_db[threadIdx.y * blockDim.x + threadIdx.x] = db_sum;
|
||||
__syncthreads();
|
||||
|
||||
for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
|
||||
if (threadIdx.y < offset) {
|
||||
s_dg[threadIdx.y * blockDim.x + threadIdx.x] +=
|
||||
s_dg[(threadIdx.y + offset) * blockDim.x + threadIdx.x];
|
||||
s_db[threadIdx.y * blockDim.x + threadIdx.x] +=
|
||||
s_db[(threadIdx.y + offset) * blockDim.x + threadIdx.x];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (threadIdx.y == 0) {
|
||||
if (dg) {
|
||||
dg[j] = s_dg[threadIdx.x];
|
||||
}
|
||||
if (db) {
|
||||
db[j] = s_db[threadIdx.x];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1422,7 +1250,6 @@ void LayerNormBackwardKernelImplInternal(
|
|||
dgamma->defined() ? dgamma->template data_ptr<T>() : nullptr;
|
||||
T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T>() : nullptr;
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
if (M < 128) {
|
||||
// For small batch size, do colwise reduce directly.
|
||||
const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads;
|
||||
|
|
@ -1438,6 +1265,7 @@ void LayerNormBackwardKernelImplInternal(
|
|||
dbeta_data);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
#if defined(USE_ROCM)
|
||||
// For small batch size, do colwise reduce directly.
|
||||
const int part_size = warp_size;
|
||||
const dim3 threads2(warp_size, 4, 1);
|
||||
|
|
@ -1472,11 +1300,47 @@ void LayerNormBackwardKernelImplInternal(
|
|||
dgamma_data,
|
||||
dbeta_data);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
#else
|
||||
LaunchGammaBetaBackwardCUDAKernel(
|
||||
dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
|
||||
if ((M % kWarpSize == 0) && (N % kWarpSize == 0)) {
|
||||
// This implementation relies on warp primitives and requires that M and N divide
|
||||
// exactly to warp size.
|
||||
dim3 threads{kWarpSize, kWarpSize};
|
||||
int blocks = (N + threads.x - 1) / threads.x;
|
||||
|
||||
// If M and N divide by warp_size, we can use warp shuffles for the final reduction.
|
||||
// That requires transposing values in shared memory, so we apply a padding to
|
||||
// reduce bank conflicts.
|
||||
|
||||
size_t shmem_sz = 2 * sizeof(T_ACC) * (threads.x + 1) * threads.y;
|
||||
GammaBetaBackwardCUDAKernel_32x32<T, T_ACC>
|
||||
<<<blocks, threads, shmem_sz, cuda_stream>>>(
|
||||
M,
|
||||
N,
|
||||
dY_data,
|
||||
X_data,
|
||||
mean_data,
|
||||
rstd_data,
|
||||
dgamma_data,
|
||||
dbeta_data);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
dim3 threads{16, 32};
|
||||
int blocks = (N + threads.x - 1) / threads.x;
|
||||
size_t shmem_sz = 2 * sizeof(T_ACC) * threads.x * threads.y;
|
||||
GammaBetaBackwardCUDAKernel<T, T_ACC>
|
||||
<<<blocks, threads, shmem_sz, cuda_stream>>>(
|
||||
M,
|
||||
N,
|
||||
dY_data,
|
||||
X_data,
|
||||
mean_data,
|
||||
rstd_data,
|
||||
dgamma_data,
|
||||
dbeta_data);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7195,26 +7195,6 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False)
|
||||
self.assertEqual(ln.forward(x), torch.zeros_like(x))
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
||||
def test_layer_norm_backwards_eps(self):
|
||||
dtype = torch.float
|
||||
m_x_n_list = [(3, 3), (5, 5), (11, 11), (55, 55),
|
||||
(32, 32), (1024, 32), (1024, 1024),
|
||||
(33, 33), (1025, 33), (1025, 1025)]
|
||||
for m, n in m_x_n_list:
|
||||
x = torch.randn((m, n), dtype=dtype, requires_grad=True)
|
||||
grad_output = torch.rand_like(x)
|
||||
x_cuda = x.clone().detach().to("cuda").requires_grad_()
|
||||
grad_output_cuda = grad_output.clone().detach().to("cuda")
|
||||
ln = nn.LayerNorm(n, dtype=dtype)
|
||||
ln_cuda = nn.LayerNorm(n, device="cuda", dtype=dtype)
|
||||
ln_out = ln(x)
|
||||
ln_out_cuda = ln_cuda(x_cuda)
|
||||
ln_out.backward(grad_output)
|
||||
ln_out_cuda.backward(grad_output_cuda)
|
||||
self.assertEqual(ln.weight.grad, ln_cuda.weight.grad, f"weight grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4)
|
||||
self.assertEqual(ln.bias.grad, ln_cuda.bias.grad, f"bias grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4)
|
||||
|
||||
@largeTensorTest("40GB", device="cuda")
|
||||
def test_layer_norm_large_tensor(self):
|
||||
# test for https://github.com/pytorch/pytorch/issues/136291
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user