diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 9feb30c2194..3ce2c24c18e 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -508,7 +508,6 @@ __global__ void layer_norm_grad_input_kernel_vectorized( } } - template __global__ void GammaBetaBackwardSimpleCUDAKernel( int64_t M, @@ -540,191 +539,365 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel( } } -// This implementation gets called if M and N divide with 32. This case should -// be the most common. We can then make better use of warp level intrinsics -// to improve performance. - -template -__global__ void GammaBetaBackwardCUDAKernel_32x32( +template +__device__ +__forceinline__ +void +blockReduceGammaBetaBackwardsHelper( + int64_t M_start, int64_t M, int64_t N, - const T* dY, - const T* X, - const T_ACC* mean, - const T_ACC* rstd, - T* dg, - T* db) { - alignas(sizeof(double)) extern __shared__ char s_data1[]; - T_ACC* s_data_typed = reinterpret_cast(&s_data1); - T_ACC* s_dg; - T_ACC* s_db; + const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ mean, + const T_ACC* __restrict__ rstd, + T* __restrict__ dg, + T* __restrict__ db, + T_ACC &dg_sum, + T_ACC &db_sum +) { + constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y; + int64_t thread_x = blockIdx.x * block_dim_x + threadIdx.x; + + int lane_id = (threadIdx.y * blockDim.x + threadIdx.x) & (kWarpSize - 1); + int64_t mean_index = M_start + threadIdx.y * rows_per_thread_y; + T_ACC warp_mean = 0, warp_rstd = 0; + if (lane_id < rows_per_thread_y && mean_index + lane_id < M) { + warp_mean = mean[mean_index + lane_id]; + warp_rstd = rstd[mean_index + lane_id]; + } + // We do a WARP_SYNC() here because we use WARP_SHFL below to access + // warp_mean and warp_rstd. + WARP_SYNC(); + + T_ACC dY_regs[rows_per_thread_y] = {0}; + T_ACC X_regs[rows_per_thread_y] = {0}; + #pragma unroll + for (int i = 0; i < rows_per_thread_y; ++i) { + int64_t current_y = M_start + threadIdx.y * rows_per_thread_y + i; + bool active = true; + if (check_x && thread_x >= N) { + active = false; + } + if (check_y && current_y >= M) { + active = false; + } + if (active) { + dY_regs[i] = dY[current_y * N + thread_x]; + X_regs[i] = X[current_y * N + thread_x]; + } + } + + #pragma unroll + for (int i = 0; i < rows_per_thread_y; ++i) { + T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); + T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize); + dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; + db_sum += dY_regs[i]; + } +} + +template +__device__ +__forceinline__ +void +blockReduceGammaBetaBackwardsWithChecks( + int64_t M, + int64_t N, + const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ mean, + const T_ACC* __restrict__ rstd, + T* __restrict__ dg, + T* __restrict__ db, + T_ACC &dg_sum, + T_ACC &db_sum +) { + for (int64_t M_start = blockIdx.y * rows_per_block_y; + M_start < M; + M_start += rows_per_block_y * gridDim.y) { + int64_t M_end = M_start + rows_per_block_y - 1; + if (!check_y || M_end < M) { + blockReduceGammaBetaBackwardsHelper + (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); + } else { + blockReduceGammaBetaBackwardsHelper + (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); + } + } +} + +// block_dim_x is the number of threads in the x dimension per block. +// block_dim_y is the number of threads in the y dimension per block. +// rows_per_block_y is the size of the tile (number of data elements) +// in the y dimension per block. +// partial_reduction indicates whether we need to reduce across threads +// or not. If set to true, we will not reduce across threads. This can +// be faster in the M >> N case but requires another kernel to do a full +// final reduction. +// aligned_grid means the data size is a multiple of tile size. In that +// case we don't need to check for boundary conditions which can provide +// a further speedup by not needing instructions to check for edge cases +// and not needing predicate registers. +template +__global__ +void +__launch_bounds__(block_dim_x * block_dim_y) + GammaBetaBackwardCUDAKernelTemplate( + int64_t M, + int64_t N, + const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ mean, + const T_ACC* __restrict__ rstd, + T* __restrict__ dg, + T* __restrict__ db) { + // This assert is a compile-time check only. + constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y; + static_assert(rows_per_thread_y <= kWarpSize); T_ACC dg_sum = 0; T_ACC db_sum = 0; - const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; + if (aligned_grid) { + // When N and M align perfectly with block_dim_x and block_dim_y, we + // can skip boundary condition checks that waste instruction issue slots. + blockReduceGammaBetaBackwardsWithChecks + + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); + } else { + // In the general case we need to check boundary conditions in the M + // dimension. However, we can still avoid boundary checks in the N dimension + // for the inner blocks. So try to avoid those checks when possible. + if (blockIdx.x * block_dim_x + block_dim_x - 1 < N) { + blockReduceGammaBetaBackwardsWithChecks + + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); + } else { + blockReduceGammaBetaBackwardsWithChecks + + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); + } + } - if (j < N) { - constexpr int unroll_factor = 8; - int laneId = threadIdx.x & (C10_WARP_SIZE - 1); + int64_t thread_x = ((int64_t)blockIdx.x) * block_dim_x + threadIdx.x; - T_ACC mean_reg, mean_reg_tmp; - T_ACC rstd_reg, rstd_reg_tmp; - T dY_reg; - T X_reg; - - // Main loop - int bcounter; - for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); - bcounter++) { - int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; - - if (laneId < unroll_factor) { - mean_reg_tmp = mean[offset + laneId]; - rstd_reg_tmp = rstd[offset + laneId]; + // When partial_reduction is requested, we don't reduce within a block. + // We also don't reduce if we are only a single block in the y dimension. + if (partial_reduction || (blockDim.y == 1 && gridDim.y == 1)) { + if (aligned_grid || thread_x < N) { + int64_t thread_y = ((int64_t)blockIdx.y) * blockDim.y + threadIdx.y; + if (dg) { + dg[thread_y * N + thread_x] = dg_sum; } - WARP_SYNC(); - - #pragma unroll - for (int ii = 0; ii < unroll_factor; ++ii) { - dY_reg = dY[(offset + ii) * N + j]; - X_reg = X[(offset + ii) * N + j]; - mean_reg = WARP_SHFL(mean_reg_tmp, ii, kWarpSize); - rstd_reg = WARP_SHFL(rstd_reg_tmp, ii, kWarpSize); - dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; - db_sum += dY_reg; + if (db) { + db[thread_y * N + thread_x] = db_sum; } } - - // Remainder loop - int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; - for (int ii = 0; ii < unroll_factor; ii++) { - if ((offset + ii) < M) { - mean_reg = mean[offset + ii]; - rstd_reg = rstd[offset + ii]; - dY_reg = dY[(offset + ii) * N + j]; - X_reg = X[(offset + ii) * N + j]; - dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; - db_sum += dY_reg; - } - } - - // This kernel uses a block of (C10_WARP_SIZE x C10_WARP_SIZE) and - // gets called when M; N divide by 32. We can use warp shuffles - // for the final reduction step. This removes 4 shmem loads and - // stores with their corresponding __syncthreads() - - // This greatly reduces bank conflicts at the expense of a little - // extra shared memory. It does not impact occupancy - int padded_bx = (1 + blockDim.x); - + } else { + // The caller requested a full reduction so we must reduce across + // warps using shared memory and warp shuffles. + static_assert(rows_per_thread_y <= C10_WARP_SIZE); + alignas(sizeof(double)) extern __shared__ char s_data1[]; + T_ACC* s_data_typed = reinterpret_cast(&s_data1); + T_ACC* s_dg; + T_ACC* s_db; + int padded_bx = (block_dim_x + 1); + // Transpose dg and db. s_dg = s_data_typed; - s_db = s_data_typed + (padded_bx * blockDim.y); + s_db = s_data_typed + (padded_bx * block_dim_y); s_dg[threadIdx.y * padded_bx + threadIdx.x] = dg_sum; s_db[threadIdx.y * padded_bx + threadIdx.x] = db_sum; __syncthreads(); // Load transposed so that a warp holds an entire column - T_ACC reg_dg = s_dg[threadIdx.x * padded_bx + threadIdx.y]; - T_ACC reg_db = s_db[threadIdx.x * padded_bx + threadIdx.y]; - for (unsigned delta = C10_WARP_SIZE >> 1; delta >= 1; delta >>= 1) { - reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize); - reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize); - } - - if (threadIdx.x == 0) { - const int64_t j = blockIdx.x * blockDim.x + threadIdx.y; - if (dg) { - dg[j] = reg_dg; + // Because block_dim_x != block_dim_y in the general case, we need + // some code to handle the general case. + static_assert(block_dim_x * block_dim_y % C10_WARP_SIZE == 0); + constexpr int warps_available_to_reduce = block_dim_x * block_dim_y / C10_WARP_SIZE; + int thread_id = threadIdx.y * block_dim_x + threadIdx.x; + int warp_id = thread_id / C10_WARP_SIZE; + int lane_id = thread_id & (C10_WARP_SIZE - 1); + #pragma unroll + for (int i = warp_id; i < block_dim_x; i += warps_available_to_reduce) { + T_ACC reg_db, reg_dg; + if (lane_id < block_dim_y) { + reg_dg = s_dg[lane_id * padded_bx + i]; + reg_db = s_db[lane_id * padded_bx + i]; } - if (db) { - db[j] = reg_db; + #pragma unroll + for (unsigned delta = block_dim_y >> 1; delta >= 1; delta >>= 1) { + reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize); + reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize); + } + // Reduce is done. Now write it out to global memory. + int64_t out_index = ((int64_t)blockIdx.x) * block_dim_x + i; + if (threadIdx.x == 0 && (aligned_grid || out_index < N)) { + if (dg) { + dg[out_index] = reg_dg; + } + if (db) { + db[out_index] = reg_db; + } } } } } -template -__global__ void GammaBetaBackwardCUDAKernel( +template +void LaunchAndCheckGammaBetaBackwardKernel( + bool aligned_grid, + dim3 blocks, + dim3 threads, + size_t shmem_sz, + cudaStream_t cuda_stream, + const T* dY_data, + const T* X_data, + const T_ACC* mean_data, + const T_ACC* rstd_data, + int64_t M, + int64_t N, + T* dgamma_data, + T* dbeta_data) { +if (aligned_grid) { + GammaBetaBackwardCUDAKernelTemplate + <<>>( + M, + N, + dY_data, + X_data, + mean_data, + rstd_data, + dgamma_data, + dbeta_data); + } else { + GammaBetaBackwardCUDAKernelTemplate + <<>>( + M, + N, + dY_data, + X_data, + mean_data, + rstd_data, + dgamma_data, + dbeta_data); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void ConfigureAndLaunchGammaBetaBackwardKernel( + const T* dY_data, + const T* X_data, + const T_ACC* mean_data, + const T_ACC* rstd_data, int64_t M, int64_t N, - const T* dY, - const T* X, - const T_ACC* mean, - const T_ACC* rstd, - T* dg, - T* db) { - alignas(sizeof(double)) extern __shared__ char s_data1[]; - T_ACC* s_data_typed = reinterpret_cast(&s_data1); - T_ACC* s_dg; - T_ACC* s_db; + Tensor* dgamma, + Tensor* dbeta, + cudaStream_t cuda_stream) { + T* dgamma_data = + dgamma->defined() ? dgamma->template data_ptr() : nullptr; + T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr() : nullptr; + bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0); + dim3 threads{block_dim_x, block_dim_y}; + dim3 blocks; + blocks.x = (N + block_dim_x - 1) / block_dim_x; + blocks.y = 1; + size_t shmem_sz = (block_dim_x + 1) * block_dim_y * sizeof(T_ACC) * 2; + if (blocks.y == 1 && threads.y == 1) { + // Optimization: since there is just one thread doing all the summation, we don't need a reduction + // across threads. So we set partial_reduction to true. + LaunchAndCheckGammaBetaBackwardKernel( + aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); + } else { + LaunchAndCheckGammaBetaBackwardKernel( + aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); + } - const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; +} - T_ACC dg_sum = 0; - T_ACC db_sum = 0; - - if (j < N) { - constexpr int unroll_factor = 8; - - T_ACC mean_reg; - T_ACC rstd_reg; - T dY_reg; - T X_reg; - - // Main Loop - int bcounter; - for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); bcounter++){ - int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; - - #pragma unroll - for (int ii = 0; ii < unroll_factor; ++ii) { - dY_reg = dY[(offset + ii) * N + j]; - X_reg = X[(offset + ii) * N + j]; - mean_reg = mean[offset + ii]; - rstd_reg = rstd[offset + ii]; - dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; - db_sum += dY_reg; - } +template +void LaunchGammaBetaBackwardCUDAKernel( + const T* dY_data, + const T* X_data, + const T_ACC* mean_data, + const T_ACC* rstd_data, + int64_t M, + int64_t N, + Tensor* dgamma, + Tensor* dbeta, + cudaStream_t cuda_stream) { + constexpr int block_dim_x = 32; + const int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + if (M > 64 * 1024 && N / block_dim_x < sm_count / 2) { + // We have a situation where M >> N and N is small. + // In this case we can speed up the computation by parallelizing in the M dimension. + // We launch multiple blocks in the y-dimension, and compute partial sums for the + // gradient in the first pass. Then we do a .sum(0) to do a final reduction. + // Although we launch 2 kernels, we can get up to a 10x speedup for large M. + constexpr int block_dim_y = 1; + constexpr int rows_per_block_y = 32; + bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0); + dim3 threads{block_dim_x, block_dim_y}; + dim3 blocks; + blocks.x = (N + block_dim_x - 1) / block_dim_x; + // int rows_per_block = my_gamma_beta_unroll_factor * + blocks.y = (M + rows_per_block_y - 1) / rows_per_block_y; + constexpr int max_grid_size = 64 * 1024 / 2; + blocks.y = std::min(max_grid_size / blocks.x, blocks.y); + Tensor dgamma_blocks; + Tensor dbeta_blocks; + T * dgamma_blocks_ptr = nullptr; + T * dbeta_blocks_ptr = nullptr; + if (dgamma->defined()) { + auto options = dgamma->options(); + dgamma_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); + dgamma_blocks_ptr = dgamma_blocks.data_ptr(); } - - // Remainder loop - int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; - for (int ii = 0; ii < unroll_factor; ii++ ){ - if ((offset + ii) < M) { - dY_reg = dY[(offset + ii) * N + j ]; - X_reg = X[(offset + ii) * N + j]; - mean_reg = mean[offset + ii]; - rstd_reg = rstd[offset + ii]; - dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; - db_sum += dY_reg; - } + if (dbeta->defined()) { + auto options = dbeta->options(); + dbeta_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); + dbeta_blocks_ptr = dbeta_blocks.data_ptr(); } + LaunchAndCheckGammaBetaBackwardKernel( + aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr); - // Do the final reduction in shared memory - s_dg = s_data_typed; - s_db = s_data_typed + blockDim.x * blockDim.y; - s_dg[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum; - s_db[threadIdx.y * blockDim.x + threadIdx.x] = db_sum; - __syncthreads(); - - for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { - if (threadIdx.y < offset) { - s_dg[threadIdx.y * blockDim.x + threadIdx.x] += - s_dg[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; - s_db[threadIdx.y * blockDim.x + threadIdx.x] += - s_db[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; - } - __syncthreads(); - } - - if (threadIdx.y == 0) { - if (dg) { - dg[j] = s_dg[threadIdx.x]; - } - if (db) { - db[j] = s_db[threadIdx.x]; - } + *dgamma = dgamma_blocks.sum(0); + *dbeta = dbeta_blocks.sum(0); + } else { + // We are in the normal case where M is not that large. + // We can change the tile shape (which is the last template parameter) in accordance with M. + // For small M it is faster to have a smaller tile, otherwise we could have idle threads. + // For larger M we use a bigger tile size. + if (M < 64) { + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + } else if (M < 128) { + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + } else if (M < 256) { + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + } else { + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } } } @@ -1250,6 +1423,7 @@ void LayerNormBackwardKernelImplInternal( dgamma->defined() ? dgamma->template data_ptr() : nullptr; T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr() : nullptr; +#if defined(USE_ROCM) if (M < 128) { // For small batch size, do colwise reduce directly. const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads; @@ -1265,7 +1439,6 @@ void LayerNormBackwardKernelImplInternal( dbeta_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { -#if defined(USE_ROCM) // For small batch size, do colwise reduce directly. const int part_size = warp_size; const dim3 threads2(warp_size, 4, 1); @@ -1300,47 +1473,11 @@ void LayerNormBackwardKernelImplInternal( dgamma_data, dbeta_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); -#else - if ((M % kWarpSize == 0) && (N % kWarpSize == 0)) { - // This implementation relies on warp primitives and requires that M and N divide - // exactly to warp size. - dim3 threads{kWarpSize, kWarpSize}; - int blocks = (N + threads.x - 1) / threads.x; - - // If M and N divide by warp_size, we can use warp shuffles for the final reduction. - // That requires transposing values in shared memory, so we apply a padding to - // reduce bank conflicts. - - size_t shmem_sz = 2 * sizeof(T_ACC) * (threads.x + 1) * threads.y; - GammaBetaBackwardCUDAKernel_32x32 - <<>>( - M, - N, - dY_data, - X_data, - mean_data, - rstd_data, - dgamma_data, - dbeta_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } else { - dim3 threads{16, 32}; - int blocks = (N + threads.x - 1) / threads.x; - size_t shmem_sz = 2 * sizeof(T_ACC) * threads.x * threads.y; - GammaBetaBackwardCUDAKernel - <<>>( - M, - N, - dY_data, - X_data, - mean_data, - rstd_data, - dgamma_data, - dbeta_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } -#endif } +#else + LaunchGammaBetaBackwardCUDAKernel( + dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); +#endif } } diff --git a/test/test_nn.py b/test/test_nn.py index 30fe71b4162..72c440ca5ec 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -7195,6 +7195,26 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""") ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False) self.assertEqual(ln.forward(x), torch.zeros_like(x)) + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + def test_layer_norm_backwards_eps(self): + dtype = torch.float + m_x_n_list = [(3, 3), (5, 5), (11, 11), (55, 55), + (32, 32), (1024, 32), (1024, 1024), + (33, 33), (1025, 33), (1025, 1025)] + for m, n in m_x_n_list: + x = torch.randn((m, n), dtype=dtype, requires_grad=True) + grad_output = torch.rand_like(x) + x_cuda = x.clone().detach().to("cuda").requires_grad_() + grad_output_cuda = grad_output.clone().detach().to("cuda") + ln = nn.LayerNorm(n, dtype=dtype) + ln_cuda = nn.LayerNorm(n, device="cuda", dtype=dtype) + ln_out = ln(x) + ln_out_cuda = ln_cuda(x_cuda) + ln_out.backward(grad_output) + ln_out_cuda.backward(grad_output_cuda) + self.assertEqual(ln.weight.grad, ln_cuda.weight.grad, f"weight grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4) + self.assertEqual(ln.bias.grad, ln_cuda.bias.grad, f"bias grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4) + @largeTensorTest("40GB", device="cuda") def test_layer_norm_large_tensor(self): # test for https://github.com/pytorch/pytorch/issues/136291