Revert "[cuda] Add new faster gammabeta backward kernel (#148605)"

This reverts commit 114d404b07.

Reverted https://github.com/pytorch/pytorch/pull/148605 on behalf of https://github.com/drisspg due to See https://github.com/pytorch/pytorch/issues/150266#issuecomment-2773907902 for more details ([comment](https://github.com/pytorch/pytorch/pull/148605#issuecomment-2773928838))
This commit is contained in:
PyTorch MergeBot 2025-04-02 23:14:11 +00:00
parent de15ef0ee8
commit 61a1f09b5b
2 changed files with 202 additions and 358 deletions

View File

@ -508,6 +508,7 @@ __global__ void layer_norm_grad_input_kernel_vectorized(
} }
} }
template <typename T, typename T_ACC> template <typename T, typename T_ACC>
__global__ void GammaBetaBackwardSimpleCUDAKernel( __global__ void GammaBetaBackwardSimpleCUDAKernel(
int64_t M, int64_t M,
@ -539,364 +540,191 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel(
} }
} }
template <typename T, typename T_ACC, // This implementation gets called if M and N divide with 32. This case should
unsigned int block_dim_x, // be the most common. We can then make better use of warp level intrinsics
unsigned int block_dim_y, // to improve performance.
unsigned int rows_per_block_y,
bool check_x, template <typename T, typename T_ACC>
bool check_y> __global__ void GammaBetaBackwardCUDAKernel_32x32(
__device__
__forceinline__
void
blockReduceGammaBetaBackwardsHelper(
int64_t M_start,
int64_t M, int64_t M,
int64_t N, int64_t N,
const T* __restrict__ dY, const T* dY,
const T* __restrict__ X, const T* X,
const T_ACC* __restrict__ mean, const T_ACC* mean,
const T_ACC* __restrict__ rstd, const T_ACC* rstd,
T* __restrict__ dg, T* dg,
T* __restrict__ db, T* db) {
T_ACC &dg_sum,
T_ACC &db_sum
) {
constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y;
int64_t thread_x = blockIdx.x * block_dim_x + threadIdx.x;
int lane_id = (threadIdx.y * blockDim.x + threadIdx.x) & (kWarpSize - 1);
int64_t mean_index = M_start + threadIdx.y * rows_per_thread_y;
T_ACC warp_mean = 0, warp_rstd = 0;
if (lane_id < rows_per_thread_y && mean_index + lane_id < M) {
warp_mean = mean[mean_index + lane_id];
warp_rstd = rstd[mean_index + lane_id];
}
// We do a WARP_SYNC() here because we use WARP_SHFL below to access
// warp_mean and warp_rstd.
WARP_SYNC();
T_ACC dY_regs[rows_per_thread_y] = {0};
T_ACC X_regs[rows_per_thread_y] = {0};
#pragma unroll
for (int i = 0; i < rows_per_thread_y; ++i) {
int64_t current_y = M_start + threadIdx.y * rows_per_thread_y + i;
bool active = true;
if (check_x && thread_x >= N) {
active = false;
}
if (check_y && current_y >= M) {
active = false;
}
if (active) {
dY_regs[i] = dY[current_y * N + thread_x];
X_regs[i] = X[current_y * N + thread_x];
}
}
#pragma unroll
for (int i = 0; i < rows_per_thread_y; ++i) {
T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize);
T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize);
dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg;
db_sum += dY_regs[i];
}
}
template <typename T, typename T_ACC,
unsigned int block_dim_x,
unsigned int block_dim_y,
unsigned int rows_per_block_y,
bool check_x,
bool check_y>
__device__
__forceinline__
void
blockReduceGammaBetaBackwardsWithChecks(
int64_t M,
int64_t N,
const T* __restrict__ dY,
const T* __restrict__ X,
const T_ACC* __restrict__ mean,
const T_ACC* __restrict__ rstd,
T* __restrict__ dg,
T* __restrict__ db,
T_ACC &dg_sum,
T_ACC &db_sum
) {
for (int64_t M_start = blockIdx.y * rows_per_block_y;
M_start < M;
M_start += rows_per_block_y * gridDim.y) {
int64_t M_end = M_start + rows_per_block_y - 1;
if (!check_y || M_end < M) {
blockReduceGammaBetaBackwardsHelper<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, check_x, false>
(M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
} else {
blockReduceGammaBetaBackwardsHelper<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, check_x, true>
(M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
}
}
}
// block_dim_x is the number of threads in the x dimension per block.
// block_dim_y is the number of threads in the y dimension per block.
// rows_per_block_y is the size of the tile (number of data elements)
// in the y dimension per block.
// partial_reduction indicates whether we need to reduce across threads
// or not. If set to true, we will not reduce across threads. This can
// be faster in the M >> N case but requires another kernel to do a full
// final reduction.
// aligned_grid means the data size is a multiple of tile size. In that
// case we don't need to check for boundary conditions which can provide
// a further speedup by not needing instructions to check for edge cases
// and not needing predicate registers.
template <typename T, typename T_ACC,
unsigned int block_dim_x, unsigned int block_dim_y,
unsigned int rows_per_block_y,
bool partial_reduction,
bool aligned_grid
>
__global__
void
GammaBetaBackwardCUDAKernelTemplate(
int64_t M,
int64_t N,
const T* __restrict__ dY,
const T* __restrict__ X,
const T_ACC* __restrict__ mean,
const T_ACC* __restrict__ rstd,
T* __restrict__ dg,
T* __restrict__ db) {
// This assert is a compile-time check only.
constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y;
static_assert(rows_per_thread_y <= kWarpSize);
T_ACC dg_sum = 0;
T_ACC db_sum = 0;
if (aligned_grid) {
// When N and M align perfectly with block_dim_x and block_dim_y, we
// can skip boundary condition checks that waste instruction issue slots.
blockReduceGammaBetaBackwardsWithChecks
<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, false, false>
(M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
} else {
// In the general case we need to check boundary conditions in the M
// dimension. However, we can still avoid boundary checks in the N dimension
// for the inner blocks. So try to avoid those checks when possible.
if (blockIdx.x * block_dim_x + block_dim_x - 1 < N) {
blockReduceGammaBetaBackwardsWithChecks
<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, false, true>
(M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
} else {
blockReduceGammaBetaBackwardsWithChecks
<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, true, true>
(M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
}
}
int64_t thread_x = ((int64_t)blockIdx.x) * block_dim_x + threadIdx.x;
// When partial_reduction is requested, we don't reduce within a block.
// We also don't reduce if we are only a single block in the y dimension.
if (partial_reduction || (blockDim.y == 1 && gridDim.y == 1)) {
if (aligned_grid || thread_x < N) {
int64_t thread_y = ((int64_t)blockIdx.y) * blockDim.y + threadIdx.y;
if (dg) {
dg[thread_y * N + thread_x] = dg_sum;
}
if (db) {
db[thread_y * N + thread_x] = db_sum;
}
}
} else {
// The caller requested a full reduction so we must reduce across
// warps using shared memory and warp shuffles.
static_assert(rows_per_thread_y <= C10_WARP_SIZE);
alignas(sizeof(double)) extern __shared__ char s_data1[]; alignas(sizeof(double)) extern __shared__ char s_data1[];
T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1); T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
T_ACC* s_dg; T_ACC* s_dg;
T_ACC* s_db; T_ACC* s_db;
int padded_bx = (block_dim_x + 1);
// Transpose dg and db. T_ACC dg_sum = 0;
T_ACC db_sum = 0;
const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
if (j < N) {
constexpr int unroll_factor = 8;
int laneId = threadIdx.x & (C10_WARP_SIZE - 1);
T_ACC mean_reg, mean_reg_tmp;
T_ACC rstd_reg, rstd_reg_tmp;
T dY_reg;
T X_reg;
// Main loop
int bcounter;
for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor);
bcounter++) {
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
if (laneId < unroll_factor) {
mean_reg_tmp = mean[offset + laneId];
rstd_reg_tmp = rstd[offset + laneId];
}
WARP_SYNC();
#pragma unroll
for (int ii = 0; ii < unroll_factor; ++ii) {
dY_reg = dY[(offset + ii) * N + j];
X_reg = X[(offset + ii) * N + j];
mean_reg = WARP_SHFL(mean_reg_tmp, ii, kWarpSize);
rstd_reg = WARP_SHFL(rstd_reg_tmp, ii, kWarpSize);
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
db_sum += dY_reg;
}
}
// Remainder loop
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
for (int ii = 0; ii < unroll_factor; ii++) {
if ((offset + ii) < M) {
mean_reg = mean[offset + ii];
rstd_reg = rstd[offset + ii];
dY_reg = dY[(offset + ii) * N + j];
X_reg = X[(offset + ii) * N + j];
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
db_sum += dY_reg;
}
}
// This kernel uses a block of (C10_WARP_SIZE x C10_WARP_SIZE) and
// gets called when M; N divide by 32. We can use warp shuffles
// for the final reduction step. This removes 4 shmem loads and
// stores with their corresponding __syncthreads()
// This greatly reduces bank conflicts at the expense of a little
// extra shared memory. It does not impact occupancy
int padded_bx = (1 + blockDim.x);
s_dg = s_data_typed; s_dg = s_data_typed;
s_db = s_data_typed + (padded_bx * block_dim_y); s_db = s_data_typed + (padded_bx * blockDim.y);
s_dg[threadIdx.y * padded_bx + threadIdx.x] = dg_sum; s_dg[threadIdx.y * padded_bx + threadIdx.x] = dg_sum;
s_db[threadIdx.y * padded_bx + threadIdx.x] = db_sum; s_db[threadIdx.y * padded_bx + threadIdx.x] = db_sum;
__syncthreads(); __syncthreads();
// Load transposed so that a warp holds an entire column // Load transposed so that a warp holds an entire column
// Because block_dim_x != block_dim_y in the general case, we need T_ACC reg_dg = s_dg[threadIdx.x * padded_bx + threadIdx.y];
// some code to handle the general case. T_ACC reg_db = s_db[threadIdx.x * padded_bx + threadIdx.y];
static_assert(block_dim_x * block_dim_y % C10_WARP_SIZE == 0); for (unsigned delta = C10_WARP_SIZE >> 1; delta >= 1; delta >>= 1) {
constexpr int warps_available_to_reduce = block_dim_x * block_dim_y / C10_WARP_SIZE;
int thread_id = threadIdx.y * block_dim_x + threadIdx.x;
int warp_id = thread_id / C10_WARP_SIZE;
int lane_id = thread_id & (C10_WARP_SIZE - 1);
#pragma unroll
for (int i = warp_id; i < block_dim_x; i += warps_available_to_reduce) {
T_ACC reg_db, reg_dg;
if (lane_id < block_dim_y) {
reg_dg = s_dg[lane_id * padded_bx + i];
reg_db = s_db[lane_id * padded_bx + i];
}
#pragma unroll
for (unsigned delta = block_dim_y >> 1; delta >= 1; delta >>= 1) {
reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize); reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize);
reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize); reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize);
} }
// Reduce is done. Now write it out to global memory.
int64_t out_index = ((int64_t)blockIdx.x) * block_dim_x + i; if (threadIdx.x == 0) {
if (threadIdx.x == 0 && (aligned_grid || out_index < N)) { const int64_t j = blockIdx.x * blockDim.x + threadIdx.y;
if (dg) { if (dg) {
dg[out_index] = reg_dg; dg[j] = reg_dg;
} }
if (db) { if (db) {
db[out_index] = reg_db; db[j] = reg_db;
}
} }
} }
} }
} }
template<typename T, typename T_ACC, template <typename T, typename T_ACC>
int block_dim_x, int block_dim_y, __global__ void GammaBetaBackwardCUDAKernel(
int rows_per_block_y,
bool partial_reduction>
void LaunchAndCheckGammaBetaBackwardKernel(
bool aligned_grid,
dim3 blocks,
dim3 threads,
size_t shmem_sz,
cudaStream_t cuda_stream,
const T* dY_data,
const T* X_data,
const T_ACC* mean_data,
const T_ACC* rstd_data,
int64_t M, int64_t M,
int64_t N, int64_t N,
T* dgamma_data, const T* dY,
T* dbeta_data) { const T* X,
if (aligned_grid) { const T_ACC* mean,
GammaBetaBackwardCUDAKernelTemplate<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, partial_reduction, true> const T_ACC* rstd,
<<<blocks, threads, shmem_sz, cuda_stream>>>( T* dg,
M, T* db) {
N, alignas(sizeof(double)) extern __shared__ char s_data1[];
dY_data, T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
X_data, T_ACC* s_dg;
mean_data, T_ACC* s_db;
rstd_data,
dgamma_data,
dbeta_data);
} else {
GammaBetaBackwardCUDAKernelTemplate<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, partial_reduction, false>
<<<blocks, threads, shmem_sz, cuda_stream>>>(
M,
N,
dY_data,
X_data,
mean_data,
rstd_data,
dgamma_data,
dbeta_data);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template<typename T, typename T_ACC, const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
int block_dim_x, int block_dim_y,
int rows_per_block_y> T_ACC dg_sum = 0;
void ConfigureAndLaunchGammaBetaBackwardKernel( T_ACC db_sum = 0;
const T* dY_data,
const T* X_data, if (j < N) {
const T_ACC* mean_data, constexpr int unroll_factor = 8;
const T_ACC* rstd_data,
int64_t M, T_ACC mean_reg;
int64_t N, T_ACC rstd_reg;
Tensor* dgamma, T dY_reg;
Tensor* dbeta, T X_reg;
cudaStream_t cuda_stream) {
T* dgamma_data = // Main Loop
dgamma->defined() ? dgamma->template data_ptr<T>() : nullptr; int bcounter;
T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T>() : nullptr; for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); bcounter++){
bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0); int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
dim3 threads{block_dim_x, block_dim_y};
dim3 blocks; #pragma unroll
blocks.x = (N + block_dim_x - 1) / block_dim_x; for (int ii = 0; ii < unroll_factor; ++ii) {
blocks.y = 1; dY_reg = dY[(offset + ii) * N + j];
size_t shmem_sz = (block_dim_x + 1) * block_dim_y * sizeof(T_ACC) * 2; X_reg = X[(offset + ii) * N + j];
if (blocks.y == 1 && threads.y == 1) { mean_reg = mean[offset + ii];
// Optimization: since there is just one thread doing all the summation, we don't need a reduction rstd_reg = rstd[offset + ii];
// across threads. So we set partial_reduction to true. dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
LaunchAndCheckGammaBetaBackwardKernel<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, true>( db_sum += dY_reg;
aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); }
} else {
LaunchAndCheckGammaBetaBackwardKernel<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, false>(
aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data);
} }
} // Remainder loop
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
template<typename T, typename T_ACC> for (int ii = 0; ii < unroll_factor; ii++ ){
void LaunchGammaBetaBackwardCUDAKernel( if ((offset + ii) < M) {
const T* dY_data, dY_reg = dY[(offset + ii) * N + j ];
const T* X_data, X_reg = X[(offset + ii) * N + j];
const T_ACC* mean_data, mean_reg = mean[offset + ii];
const T_ACC* rstd_data, rstd_reg = rstd[offset + ii];
int64_t M, dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
int64_t N, db_sum += dY_reg;
Tensor* dgamma,
Tensor* dbeta,
cudaStream_t cuda_stream) {
constexpr int block_dim_x = 32;
const int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
if (M > 64 * 1024 && N / block_dim_x < sm_count / 2) {
// We have a situation where M >> N and N is small.
// In this case we can speed up the computation by parallelizing in the M dimension.
// We launch multiple blocks in the y-dimension, and compute partial sums for the
// gradient in the first pass. Then we do a .sum(0) to do a final reduction.
// Although we launch 2 kernels, we can get up to a 10x speedup for large M.
constexpr int block_dim_y = 1;
constexpr int rows_per_block_y = 32;
bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0);
dim3 threads{block_dim_x, block_dim_y};
dim3 blocks;
blocks.x = (N + block_dim_x - 1) / block_dim_x;
// int rows_per_block = my_gamma_beta_unroll_factor *
blocks.y = (M + rows_per_block_y - 1) / rows_per_block_y;
constexpr int max_grid_size = 64 * 1024 / 2;
blocks.y = std::min<unsigned int>(max_grid_size / blocks.x, blocks.y);
Tensor dgamma_blocks;
Tensor dbeta_blocks;
T * dgamma_blocks_ptr = nullptr;
T * dbeta_blocks_ptr = nullptr;
if (dgamma->defined()) {
auto options = dgamma->options();
dgamma_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options);
dgamma_blocks_ptr = dgamma_blocks.data_ptr<T>();
} }
if (dbeta->defined()) {
auto options = dbeta->options();
dbeta_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options);
dbeta_blocks_ptr = dbeta_blocks.data_ptr<T>();
} }
LaunchAndCheckGammaBetaBackwardKernel<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, true>(
aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr);
*dgamma = dgamma_blocks.sum(0); // Do the final reduction in shared memory
*dbeta = dbeta_blocks.sum(0); s_dg = s_data_typed;
} else { s_db = s_data_typed + blockDim.x * blockDim.y;
// We are in the normal case where M is not that large. s_dg[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum;
// We can change the tile shape (which is the last template parameter) in accordance with M. s_db[threadIdx.y * blockDim.x + threadIdx.x] = db_sum;
// For small M it is faster to have a smaller tile, otherwise we could have idle threads. __syncthreads();
// For larger M we use a bigger tile size.
if (M < 64) { for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 1, 8>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); if (threadIdx.y < offset) {
} else if (M < 128) { s_dg[threadIdx.y * blockDim.x + threadIdx.x] +=
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 8, 64>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); s_dg[(threadIdx.y + offset) * blockDim.x + threadIdx.x];
} else if (M < 256) { s_db[threadIdx.y * blockDim.x + threadIdx.x] +=
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 16, 128>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); s_db[(threadIdx.y + offset) * blockDim.x + threadIdx.x];
} else { }
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 32, 256>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); __syncthreads();
}
if (threadIdx.y == 0) {
if (dg) {
dg[j] = s_dg[threadIdx.x];
}
if (db) {
db[j] = s_db[threadIdx.x];
}
} }
} }
} }
@ -1422,7 +1250,6 @@ void LayerNormBackwardKernelImplInternal(
dgamma->defined() ? dgamma->template data_ptr<T>() : nullptr; dgamma->defined() ? dgamma->template data_ptr<T>() : nullptr;
T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T>() : nullptr; T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T>() : nullptr;
#if defined(USE_ROCM)
if (M < 128) { if (M < 128) {
// For small batch size, do colwise reduce directly. // For small batch size, do colwise reduce directly.
const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads; const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads;
@ -1438,6 +1265,7 @@ void LayerNormBackwardKernelImplInternal(
dbeta_data); dbeta_data);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
} else { } else {
#if defined(USE_ROCM)
// For small batch size, do colwise reduce directly. // For small batch size, do colwise reduce directly.
const int part_size = warp_size; const int part_size = warp_size;
const dim3 threads2(warp_size, 4, 1); const dim3 threads2(warp_size, 4, 1);
@ -1472,12 +1300,48 @@ void LayerNormBackwardKernelImplInternal(
dgamma_data, dgamma_data,
dbeta_data); dbeta_data);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
}
#else #else
LaunchGammaBetaBackwardCUDAKernel( if ((M % kWarpSize == 0) && (N % kWarpSize == 0)) {
dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); // This implementation relies on warp primitives and requires that M and N divide
// exactly to warp size.
dim3 threads{kWarpSize, kWarpSize};
int blocks = (N + threads.x - 1) / threads.x;
// If M and N divide by warp_size, we can use warp shuffles for the final reduction.
// That requires transposing values in shared memory, so we apply a padding to
// reduce bank conflicts.
size_t shmem_sz = 2 * sizeof(T_ACC) * (threads.x + 1) * threads.y;
GammaBetaBackwardCUDAKernel_32x32<T, T_ACC>
<<<blocks, threads, shmem_sz, cuda_stream>>>(
M,
N,
dY_data,
X_data,
mean_data,
rstd_data,
dgamma_data,
dbeta_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
dim3 threads{16, 32};
int blocks = (N + threads.x - 1) / threads.x;
size_t shmem_sz = 2 * sizeof(T_ACC) * threads.x * threads.y;
GammaBetaBackwardCUDAKernel<T, T_ACC>
<<<blocks, threads, shmem_sz, cuda_stream>>>(
M,
N,
dY_data,
X_data,
mean_data,
rstd_data,
dgamma_data,
dbeta_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
#endif #endif
} }
}
} }
void LayerNormBackwardKernelImpl( void LayerNormBackwardKernelImpl(

View File

@ -7195,26 +7195,6 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False) ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False)
self.assertEqual(ln.forward(x), torch.zeros_like(x)) self.assertEqual(ln.forward(x), torch.zeros_like(x))
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_layer_norm_backwards_eps(self):
dtype = torch.float
m_x_n_list = [(3, 3), (5, 5), (11, 11), (55, 55),
(32, 32), (1024, 32), (1024, 1024),
(33, 33), (1025, 33), (1025, 1025)]
for m, n in m_x_n_list:
x = torch.randn((m, n), dtype=dtype, requires_grad=True)
grad_output = torch.rand_like(x)
x_cuda = x.clone().detach().to("cuda").requires_grad_()
grad_output_cuda = grad_output.clone().detach().to("cuda")
ln = nn.LayerNorm(n, dtype=dtype)
ln_cuda = nn.LayerNorm(n, device="cuda", dtype=dtype)
ln_out = ln(x)
ln_out_cuda = ln_cuda(x_cuda)
ln_out.backward(grad_output)
ln_out_cuda.backward(grad_output_cuda)
self.assertEqual(ln.weight.grad, ln_cuda.weight.grad, f"weight grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4)
self.assertEqual(ln.bias.grad, ln_cuda.bias.grad, f"bias grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4)
@largeTensorTest("40GB", device="cuda") @largeTensorTest("40GB", device="cuda")
def test_layer_norm_large_tensor(self): def test_layer_norm_large_tensor(self):
# test for https://github.com/pytorch/pytorch/issues/136291 # test for https://github.com/pytorch/pytorch/issues/136291