From 6401d1d53d13cb2d564d48a30a8cf4f952ff773d Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 1 Jul 2025 18:46:45 +0000 Subject: [PATCH] Revert "Fused RMSNorm implementation (#153666)" This reverts commit e1aee86646aa6d1b9cb9d34351e43936401c5efc. Reverted https://github.com/pytorch/pytorch/pull/153666 on behalf of https://github.com/davidberard98 due to causing build failures on main branch [GH job link](https://github.com/pytorch/pytorch/actions/runs/16007148842/job/45156382001) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/e1aee86646aa6d1b9cb9d34351e43936401c5efc) ([comment](https://github.com/pytorch/pytorch/pull/153666#issuecomment-3025146176)) --- .../functorch/BatchRulesDecompositions.cpp | 1 - .../src/ATen/native/cuda/layer_norm_kernel.cu | 590 +++++------------- aten/src/ATen/native/layer_norm.cpp | 81 +-- aten/src/ATen/native/layer_norm.h | 6 - .../src/ATen/native/mps/operations/RMSNorm.mm | 13 +- aten/src/ATen/native/native_functions.yaml | 8 +- ...asDecompTest.test_has_decomposition.expect | 1 + test/test_decomp.py | 29 +- tools/autograd/derivatives.yaml | 5 - torch/_decomp/__init__.py | 1 - torch/_decomp/decompositions.py | 75 --- torch/csrc/autograd/FunctionsManual.cpp | 189 ------ torch/csrc/autograd/FunctionsManual.h | 23 - torch/overrides.py | 1 - 14 files changed, 184 insertions(+), 839 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index d58d436c511..4b66b30b62e 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -158,7 +158,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(kron); OP_DECOMPOSE(l1_loss); m.impl("layer_norm", native::layer_norm_symint); - m.impl("_fused_rms_norm", native::rms_norm_composite); OP_DECOMPOSE2(ldexp, Tensor); OP_DECOMPOSE2(less_equal, Tensor ); OP_DECOMPOSE2(less, Tensor ); diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index f765b515cd0..bdb169e26b1 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -50,7 +50,7 @@ bool can_vectorize(const T * ptr, int alignment) { }; -template +template __global__ void RowwiseMomentsCUDAKernel( int64_t N, T_ACC eps, @@ -84,17 +84,12 @@ __global__ void RowwiseMomentsCUDAKernel( T_ACC m1; T_ACC m2; thrust::tie(m2, m1) = welford_op.project(val); - if constexpr (!rms_norm){ - mean[i] = m1; - rstd[i] = c10::cuda::compat::rsqrt(m2 + eps); - } else { - rstd[i] = c10::cuda::compat::rsqrt(m2 + m1 * m1 + eps); - } - + mean[i] = m1; + rstd[i] = c10::cuda::compat::rsqrt(m2 + eps); } } -template +template __global__ void LayerNormForwardCUDAKernel( int64_t N, const T* X, @@ -108,15 +103,11 @@ __global__ void LayerNormForwardCUDAKernel( const int64_t index = i * N + j; const T_ACC gamma_v = gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); - if constexpr (!rms_norm){ - const T_ACC beta_v = - beta == nullptr ? T_ACC(0) : static_cast(beta[j]); - Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * - static_cast(rstd[i]) * gamma_v + - beta_v; - } else { - Y[index] = (static_cast(X[index])) * static_cast(rstd[i]) * gamma_v; - } + const T_ACC beta_v = + beta == nullptr ? T_ACC(0) : static_cast(beta[j]); + Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * + static_cast(rstd[i]) * gamma_v + + beta_v; } } @@ -128,48 +119,40 @@ struct WelfordDataLN{ C10_HOST_DEVICE WelfordDataLN(float mean, float sigma2, float count): mean(mean), sigma2(sigma2), count(count) {} }; -template __device__ +template __device__ WelfordDataLN cuWelfordOnlineSum( const U val, const WelfordDataLN& curr_sum) { - if constexpr (!rms_norm){ - U delta = val - curr_sum.mean; - U new_count = curr_sum.count + 1.f; - U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster - return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; - } else{ - return {0.f, curr_sum.sigma2 + val * val, 0}; - } + U delta = val - curr_sum.mean; + U new_count = curr_sum.count + 1.f; + U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster + return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; } -template __device__ +__device__ WelfordDataLN cuWelfordCombine( const WelfordDataLN dataB, const WelfordDataLN dataA ) { - if constexpr (!rms_norm){ - using U = decltype(dataB.count); - U delta = dataB.mean - dataA.mean; - U count = dataA.count + dataB.count; - U mean, sigma2; - if (count > decltype(dataB.count){0}) { - auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division - auto nA = dataA.count * coef; - auto nB = dataB.count * coef; - mean = nA*dataA.mean + nB*dataB.mean; - sigma2 = dataA.sigma2 + dataB.sigma2 + delta * delta * dataA.count * nB; - } else { - mean = U(0); - sigma2 = U(0); - } - return {mean, sigma2, count}; + using U = decltype(dataB.count); + U delta = dataB.mean - dataA.mean; + U count = dataA.count + dataB.count; + U mean, sigma2; + if (count > decltype(dataB.count){0}) { + auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division + auto nA = dataA.count * coef; + auto nB = dataB.count * coef; + mean = nA*dataA.mean + nB*dataB.mean; + sigma2 = dataA.sigma2 + dataB.sigma2 + delta * delta * dataA.count * nB; } else { - return {0.f, dataB.sigma2 + dataA.sigma2, 0}; + mean = U(0); + sigma2 = U(0); } + return {mean, sigma2, count}; } -template +template __device__ WelfordDataLN compute_stats( const T* __restrict__ X, const int N, @@ -188,13 +171,14 @@ __device__ WelfordDataLN compute_stats( vec_t data = X_vec[i]; #pragma unroll for (int ii=0; ii < vec_size; ii++){ - wd = cuWelfordOnlineSum(static_cast(data.val[ii]), wd); + wd = cuWelfordOnlineSum(static_cast(data.val[ii]), wd); } } // intra-warp reduction for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { - WelfordDataLN wdB{WARP_SHFL_DOWN(wd.mean, offset), WARP_SHFL_DOWN(wd.sigma2, offset), WARP_SHFL_DOWN(wd.count, offset)}; - wd = cuWelfordCombine(wd, wdB); + WelfordDataLN wdB{WARP_SHFL_DOWN(wd.mean, offset), + WARP_SHFL_DOWN(wd.sigma2, offset), WARP_SHFL_DOWN(wd.count, offset)}; + wd = cuWelfordCombine(wd, wdB); } // threadIdx.x == 0 has correct values for each warp // inter-warp reductions @@ -215,7 +199,7 @@ __device__ WelfordDataLN compute_stats( WelfordDataLN wdB{meansigmabuf[2*threadIdx.y], meansigmabuf[2*threadIdx.y+1], countbuf[threadIdx.y]}; - wd = cuWelfordCombine(wd, wdB); + wd = cuWelfordCombine(wd, wdB); } __syncthreads(); } @@ -232,7 +216,7 @@ __device__ WelfordDataLN compute_stats( } -template , int> = 0> __device__ __inline__ void vectorized_layer_norm_kernel_impl( const int N, @@ -247,7 +231,7 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( //as one thread would have to write 3 consecutive floats auto i1 = blockIdx.x; const T * block_row = X + i1 * N; - WelfordDataLN wd = compute_stats(block_row, N, s_data); + WelfordDataLN wd = compute_stats(block_row, N, s_data); using vec_t = aligned_vector; const vec_t * X_vec = reinterpret_cast(block_row); @@ -270,48 +254,34 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( if (gamma_vec != nullptr && beta_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - if constexpr (!rms_norm){ - out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)) - + static_cast(beta_vec[i].val[ii]); - } else { - out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * static_cast(data.val[ii])); - } + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + + static_cast(beta_vec[i].val[ii]); } } else if (gamma_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - if constexpr (!rms_norm){ - out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)); - } else { - out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * static_cast(data.val[ii])); - } + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)); } } else if (beta_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - out.val[ii] = (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + static_cast(beta_vec[i].val[ii]); + out.val[ii] = (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + static_cast(beta_vec[i].val[ii]); } } else { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - if constexpr (!rms_norm){ - out.val[ii] = rstd_val * (static_cast(data.val[ii]) - wd.mean); - } else { - out.val[ii] = rstd_val * static_cast(data.val[ii]); - } + out.val[ii] = rstd_val * (static_cast(data.val[ii]) - wd.mean); } } Y_vec[i] = out; } if (thrx == 0) { - if constexpr (!rms_norm){ - mean[i1] = wd.mean; - } + mean[i1] = wd.mean; rstd[i1] = rstd_val; } } -template , int> = 0> __device__ __inline__ void vectorized_layer_norm_kernel_impl( const int /*N*/, @@ -326,7 +296,7 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( } //to avoid windows SFINAE errors -template +template __global__ void vectorized_layer_norm_kernel( const int N, T_ACC eps, @@ -336,11 +306,11 @@ __global__ void vectorized_layer_norm_kernel( T_ACC* mean, T_ACC* rstd, T* Y){ - vectorized_layer_norm_kernel_impl(N, eps, X, gamma, beta, mean, rstd, Y); + vectorized_layer_norm_kernel_impl(N, eps, X, gamma, beta, mean, rstd, Y); } -template +template __device__ __inline__ void compute_gI( const T* __restrict__ dY, const T* __restrict__ X, @@ -351,10 +321,7 @@ __device__ __inline__ void compute_gI( const int N, T_ACC * buf){ const auto i1 = blockIdx.x; - T_ACC mean_val = 0; - if constexpr (!rms_norm){ - mean_val = mean[i1]; - } + const T_ACC mean_val = mean[i1]; const T_ACC rstd_val = rstd[i1]; T_ACC stats_x1{0}, stats_x2{0}; constexpr int unroll = 4; @@ -370,39 +337,26 @@ __device__ __inline__ void compute_gI( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l+k]) : T_ACC(1); const auto c_h = static_cast(X_i[l+k]); const auto c_loss = static_cast(dY_i[l+k]); - if constexpr (!rms_norm){ - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; - } else { - stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; - } + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; } } for (; l < N; l ++) { const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); const auto c_h = static_cast(X_i[l]); const auto c_loss = static_cast(dY_i[l]); - if constexpr (!rms_norm){ - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; - } else { - stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; - } - } - if constexpr (!rms_norm){ - stats_x1 = cuda_utils::BlockReduceSum(stats_x1, buf); + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; } + + stats_x1 = cuda_utils::BlockReduceSum(stats_x1, buf); stats_x2 = cuda_utils::BlockReduceSum(stats_x2, buf); if (threadIdx.x == 0) { - if constexpr (!rms_norm){ - buf[0] = stats_x1; - } + buf[0] = stats_x1; buf[1] = stats_x2; } __syncthreads(); - if constexpr (!rms_norm){ - stats_x1 = buf[0]; - } + stats_x1 = buf[0]; stats_x2 = buf[1]; T_ACC fH = N; T_ACC term1 = (T_ACC(1) / fH) * rstd_val; @@ -413,20 +367,15 @@ __device__ __inline__ void compute_gI( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); T_ACC f_grad_input = fH * gamma_val * dy; - if constexpr (!rms_norm){ - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; - } else { - f_grad_input -= (x) * rstd_val * stats_x2; - } - + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; f_grad_input *= term1; dX_i[l] = f_grad_input; } } -template +template __global__ void layer_norm_grad_input_kernel( const T* __restrict__ dY, const T* __restrict__ X, @@ -438,7 +387,7 @@ __global__ void layer_norm_grad_input_kernel( alignas(sizeof(double)) extern __shared__ char s_data1[]; T_ACC * buf = reinterpret_cast(&s_data1); - compute_gI(dY, X, mean, rstd, gamma, dX, N, buf); + compute_gI(dY, X, mean, rstd, gamma, dX, N, buf); } @@ -447,7 +396,7 @@ __global__ void layer_norm_grad_input_kernel( // faster measured at PT operator level, with cases seeing a 2X speedup (where N >> M). // There are no noticeable regressions on the rest of the sizes. -template +template __global__ void layer_norm_grad_input_kernel_vectorized( const T* __restrict__ dY, const T* __restrict__ X, @@ -460,10 +409,7 @@ __global__ void layer_norm_grad_input_kernel_vectorized( T_ACC* reduce_buf = reinterpret_cast(&shared_data); const auto bIdx = blockIdx.x; - T_ACC mean_val = 0; - if constexpr (!rms_norm){ - mean_val = mean[bIdx]; - } + const T_ACC mean_val = mean[bIdx]; const T_ACC rstd_val = rstd[bIdx]; const T* X_i = X + bIdx * N; const T* dY_i = dY + bIdx * N; @@ -495,12 +441,8 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = static_cast(gamma_vec_reg.val[k]); const auto c_h = static_cast(X_i_vec_reg.val[k]); const auto c_loss = static_cast(dY_i_vec_reg.val[k]); - if constexpr (!rms_norm){ - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; - } else { - stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; - } + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; } } @@ -509,29 +451,19 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); const auto c_h = static_cast(X_i[l]); const auto c_loss = static_cast(dY_i[l]); - if constexpr (!rms_norm){ - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; - } else{ - stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; - } + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; } // Reduction in Shared Memory - if constexpr (!rms_norm){ - stats_x1 = cuda_utils::BlockReduceSum(stats_x1, reduce_buf); - } + stats_x1 = cuda_utils::BlockReduceSum(stats_x1, reduce_buf); stats_x2 = cuda_utils::BlockReduceSum(stats_x2, reduce_buf); if (threadIdx.x == 0) { - if constexpr (!rms_norm){ - reduce_buf[0] = stats_x1; - } + reduce_buf[0] = stats_x1; reduce_buf[1] = stats_x2; } __syncthreads(); - if constexpr (!rms_norm){ - stats_x1 = reduce_buf[0]; - } + stats_x1 = reduce_buf[0]; stats_x2 = reduce_buf[1]; T_ACC fH = N; @@ -553,12 +485,8 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto dy = static_cast(dY_i_vec_reg.val[k]); T_ACC f_grad_input = fH * gamma_val * dy; - if constexpr (!rms_norm){ - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; - } else { - f_grad_input -= (x) * rstd_val * stats_x2; - } + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; f_grad_input *= term1; dX_i_vec_reg.val[k] = f_grad_input; } @@ -573,19 +501,15 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); T_ACC f_grad_input = fH * gamma_val * dy; - if constexpr (!rms_norm){ - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; - } else { - f_grad_input -= (x) * rstd_val * stats_x2; - } + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; f_grad_input *= term1; dX_i[l] = f_grad_input; } } -template +template __global__ void GammaBetaBackwardSimpleCUDAKernel( int64_t M, int64_t N, @@ -601,25 +525,17 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel( T_ACC sum2 = 0; for (int64_t i = 0; i < M; ++i) { const int64_t index = i * N + j; - if constexpr (!rms_norm){ - sum1 += dg == nullptr ? T_ACC(0) - : static_cast(dY[index]) * - (static_cast(X[index]) - static_cast(mean[i])) * - static_cast(rstd[i]); - sum2 += db == nullptr ? T_ACC(0) : static_cast(dY[index]); - } else { - sum1 += dg == nullptr ? T_ACC(0) - : static_cast(dY[index]) * - (static_cast(X[index])) * static_cast(rstd[i]); - } + sum1 += dg == nullptr ? T_ACC(0) + : static_cast(dY[index]) * + (static_cast(X[index]) - static_cast(mean[i])) * + static_cast(rstd[i]); + sum2 += db == nullptr ? T_ACC(0) : static_cast(dY[index]); } if (dg != nullptr) { dg[j] = sum1; } if (db != nullptr) { - if constexpr (!rms_norm){ - db[j] = sum2; - } + db[j] = sum2; } } } @@ -629,8 +545,7 @@ unsigned int block_dim_x, unsigned int block_dim_y, unsigned int rows_per_block_y, bool check_x, -bool check_y, -bool rms_norm> +bool check_y> __device__ __forceinline__ void @@ -654,9 +569,7 @@ blockReduceGammaBetaBackwardsHelper( int64_t mean_index = M_start + threadIdx.y * rows_per_thread_y; T_ACC warp_mean = 0, warp_rstd = 0; if (lane_id < rows_per_thread_y && mean_index + lane_id < M) { - if constexpr (!rms_norm){ - warp_mean = mean[mean_index + lane_id]; - } + warp_mean = mean[mean_index + lane_id]; warp_rstd = rstd[mean_index + lane_id]; } // We do a WARP_SYNC() here because we use WARP_SHFL below to access @@ -683,14 +596,10 @@ blockReduceGammaBetaBackwardsHelper( #pragma unroll for (int i = 0; i < rows_per_thread_y; ++i) { + T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize); - if constexpr (!rms_norm){ - T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); - dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; - db_sum += dY_regs[i]; - } else{ - dg_sum += dY_regs[i] * (X_regs[i]) * rstd_reg; - } + dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; + db_sum += dY_regs[i]; } } @@ -699,8 +608,7 @@ unsigned int block_dim_x, unsigned int block_dim_y, unsigned int rows_per_block_y, bool check_x, -bool check_y, -bool rms_norm> +bool check_y> __device__ __forceinline__ void @@ -721,10 +629,10 @@ blockReduceGammaBetaBackwardsWithChecks( M_start += rows_per_block_y * gridDim.y) { int64_t M_end = M_start + rows_per_block_y - 1; if (!check_y || M_end < M) { - blockReduceGammaBetaBackwardsHelper + blockReduceGammaBetaBackwardsHelper (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { - blockReduceGammaBetaBackwardsHelper + blockReduceGammaBetaBackwardsHelper (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } } @@ -746,8 +654,7 @@ template __global__ void @@ -772,7 +679,7 @@ __launch_bounds__(block_dim_x * block_dim_y) // When N and M align perfectly with block_dim_x and block_dim_y, we // can skip boundary condition checks that waste instruction issue slots. blockReduceGammaBetaBackwardsWithChecks - + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { // In the general case we need to check boundary conditions in the M @@ -780,11 +687,11 @@ __launch_bounds__(block_dim_x * block_dim_y) // for the inner blocks. So try to avoid those checks when possible. if (blockIdx.x * block_dim_x + block_dim_x - 1 < N) { blockReduceGammaBetaBackwardsWithChecks - + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { blockReduceGammaBetaBackwardsWithChecks - + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } } @@ -799,7 +706,7 @@ __launch_bounds__(block_dim_x * block_dim_y) if (dg) { dg[thread_y * N + thread_x] = dg_sum; } - if (db && !rms_norm) { + if (db) { db[thread_y * N + thread_x] = db_sum; } } @@ -845,7 +752,7 @@ __launch_bounds__(block_dim_x * block_dim_y) if (dg) { dg[out_index] = reg_dg; } - if (db && !rms_norm) { + if (db) { db[out_index] = reg_db; } } @@ -856,8 +763,7 @@ __launch_bounds__(block_dim_x * block_dim_y) template +bool partial_reduction> void LaunchAndCheckGammaBetaBackwardKernel( bool aligned_grid, dim3 blocks, @@ -873,7 +779,7 @@ void LaunchAndCheckGammaBetaBackwardKernel( T* dgamma_data, T* dbeta_data) { if (aligned_grid) { - GammaBetaBackwardCUDAKernelTemplate + GammaBetaBackwardCUDAKernelTemplate <<>>( M, N, @@ -884,7 +790,7 @@ if (aligned_grid) { dgamma_data, dbeta_data); } else { - GammaBetaBackwardCUDAKernelTemplate + GammaBetaBackwardCUDAKernelTemplate <<>>( M, N, @@ -900,7 +806,7 @@ if (aligned_grid) { template +int rows_per_block_y> void ConfigureAndLaunchGammaBetaBackwardKernel( const T* dY_data, const T* X_data, @@ -923,16 +829,16 @@ void ConfigureAndLaunchGammaBetaBackwardKernel( if (blocks.y == 1 && threads.y == 1) { // Optimization: since there is just one thread doing all the summation, we don't need a reduction // across threads. So we set partial_reduction to true. - LaunchAndCheckGammaBetaBackwardKernel( + LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); } else { - LaunchAndCheckGammaBetaBackwardKernel( + LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); } } -template +template void LaunchGammaBetaBackwardCUDAKernel( const T* dY_data, const T* X_data, @@ -970,21 +876,19 @@ void LaunchGammaBetaBackwardCUDAKernel( dgamma_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); dgamma_blocks_ptr = dgamma_blocks.data_ptr(); } - if (dbeta->defined() && !rms_norm) { + if (dbeta->defined()) { auto options = dbeta->options(); dbeta_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); dbeta_blocks_ptr = dbeta_blocks.data_ptr(); } - LaunchAndCheckGammaBetaBackwardKernel( + LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr); if (dgamma_blocks.defined()) { *dgamma = dgamma_blocks.sum(0); } - if constexpr (!rms_norm){ - if (dbeta_blocks.defined()) { - *dbeta = dbeta_blocks.sum(0); - } + if (dbeta_blocks.defined()) { + *dbeta = dbeta_blocks.sum(0); } } else { // We are in the normal case where M is not that large. @@ -992,18 +896,18 @@ void LaunchGammaBetaBackwardCUDAKernel( // For small M it is faster to have a smaller tile, otherwise we could have idle threads. // For larger M we use a bigger tile size. if (M < 64) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else if (M < 128) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else if (M < 256) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } } } -template +template void launch_vectorized_layer_norm_kernel( int N, int64_t M, @@ -1032,7 +936,7 @@ void launch_vectorized_layer_norm_kernel( TORCH_INTERNAL_ASSERT_DEBUG_ONLY(threads.y % 2 == 0 || threads.y == 1); int nshared = threads.y > 1 ? threads.y * 3/2 *sizeof(T_ACC) : 0; - vectorized_layer_norm_kernel<<>>(N, eps, X_data, + vectorized_layer_norm_kernel<<>>(N, eps, X_data, gamma_data, beta_data, mean_data, rstd_data, Y_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -1054,7 +958,7 @@ void launch_vectorized_layer_norm_kernel( blocks.x = (remaining > blocks.x) ? blocks.x : remaining; - vectorized_layer_norm_kernel<<>>(N, eps, X_data2, + vectorized_layer_norm_kernel<<>>(N, eps, X_data2, gamma_data, beta_data, mean_data2, rstd_data2, Y_data2); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -1064,7 +968,7 @@ void launch_vectorized_layer_norm_kernel( } -template +template void LayerNormKernelImplInternal( const Tensor& X, const Tensor& gamma, @@ -1083,7 +987,7 @@ void LayerNormKernelImplInternal( const T* gamma_data = gamma.defined() ? gamma.const_data_ptr() : nullptr; const T* beta_data = beta.defined() ? beta.const_data_ptr() : nullptr; T* Y_data = Y->data_ptr(); - T_ACC* mean_data = !rms_norm ? mean->data_ptr() : nullptr; + T_ACC* mean_data = mean->data_ptr(); T_ACC* rstd_data = rstd->data_ptr(); // check if can take fast path - all tensors are properly aligned, N is less than 2^24 (to use float count), @@ -1098,14 +1002,14 @@ void LayerNormKernelImplInternal( if ((std::is_same_v || std::is_same_v || std::is_same_v) && N <= static_cast(1ULL << std::numeric_limits::digits) && N % num_vec_elems == 0 && can_vec_X && can_vec_Y && can_vec_gamma && can_vec_beta) { - launch_vectorized_layer_norm_kernel(static_cast(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data); + launch_vectorized_layer_norm_kernel(static_cast(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data); } else { cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); - RowwiseMomentsCUDAKernel + RowwiseMomentsCUDAKernel <<>>( N, eps, X_data, mean_data, rstd_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); - LayerNormForwardCUDAKernel<<>>( + LayerNormForwardCUDAKernel<<>>( N, X_data, mean_data, rstd_data, gamma_data, beta_data, Y_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1133,29 +1037,7 @@ void LayerNormKernelImpl( }); } -void RmsNormKernelImpl( - const Tensor& X, - const Tensor& gamma, - int64_t M, - int64_t N, - double eps, - Tensor* Y, - Tensor* rstd) { -AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - X.scalar_type(), - "LayerNormKernelImpl", - [&]() { - using acc_t = acc_type; - // rms_norm = true - LayerNormKernelImplInternal( - // pass in at::Tensor() for gamma and nullptr for mean, it won't be accessed with rms_norm = True - X, gamma, at::Tensor(), M, N, static_cast(eps), Y, nullptr, rstd); - }); -} - -template __device__ +template __device__ void cuLoadWriteStridedInputs( const int i1_block, const int thr_load_row_off, @@ -1173,10 +1055,7 @@ void cuLoadWriteStridedInputs( { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { - T_ACC curr_mean = 0; - if constexpr (!rms_norm){ - curr_mean = mean[i1]; - } + T_ACC curr_mean = mean[i1]; T_ACC curr_rstd = rstd[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; @@ -1201,7 +1080,7 @@ void cuLoadWriteStridedInputs( } } -template __device__ +template __device__ void cuLoadAddStridedInputs( const int i1_block, const int thr_load_row_off, @@ -1219,11 +1098,7 @@ void cuLoadAddStridedInputs( { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { - - T_ACC curr_mean = 0; - if constexpr (!rms_norm){ - curr_mean = mean[i1]; - } + T_ACC curr_mean = mean[i1]; T_ACC curr_rstd = rstd[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; @@ -1239,7 +1114,7 @@ void cuLoadAddStridedInputs( } } -template __global__ +template __global__ void cuComputePartGradGammaBeta( const T* __restrict__ dout, const T* __restrict__ input, @@ -1265,9 +1140,9 @@ void cuComputePartGradGammaBeta( T_ACC* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; // compute partial sums from strided inputs // do this to increase number of loads in flight - cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { - cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); } __syncthreads(); // inter-warp reductions @@ -1306,7 +1181,7 @@ void cuComputePartGradGammaBeta( } } -template __global__ +template __global__ void cuComputeGradGammaBeta( const T_ACC* part_grad_gamma, const T_ACC* part_grad_beta, @@ -1331,9 +1206,7 @@ void cuComputeGradGammaBeta( if (i2 < N) { for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { sum_gamma += part_grad_gamma_ptr[warp_offset*N]; - if constexpr (!rms_norm){ - sum_beta += part_grad_beta_ptr[warp_offset*N]; - } + sum_beta += part_grad_beta_ptr[warp_offset*N]; } } @@ -1351,9 +1224,7 @@ void cuComputeGradGammaBeta( if (threadIdx.y < offset) { const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; sum_gamma += buf[read_idx]; - if constexpr (!rms_norm){ - sum_beta += buf[read_idx+nbsize3]; - } + sum_beta += buf[read_idx+nbsize3]; } __syncthreads(); } @@ -1364,14 +1235,12 @@ void cuComputeGradGammaBeta( grad_gamma[i2] = sum_gamma; } if (grad_beta) { - if constexpr (!rms_norm){ - grad_beta[i2] = sum_beta; - } + grad_beta[i2] = sum_beta; } } } -template __global__ +template __global__ void cuComputeGradInput( const T* __restrict__ dout, const T* __restrict__ input, @@ -1385,10 +1254,7 @@ void cuComputeGradInput( for (int i1=blockIdx.y; i1 < M; i1 += gridDim.y) { T_ACC sum_loss1 = T_ACC(0); T_ACC sum_loss2 = T_ACC(0); - T_ACC c_mean = 0; - if constexpr (!rms_norm){ - c_mean = mean[i1]; - } + T_ACC c_mean = mean[i1]; const T_ACC c_rstd = rstd[i1]; const T* k_input = input + i1*N; const T* k_dout = dout + i1*N; @@ -1401,31 +1267,21 @@ void cuComputeGradInput( const T_ACC gamma_idx = static_cast((idx((idx((idx((idx((idx 0; mask /= 2) { - if constexpr (!rms_norm){ - sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); - } + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); } // inter-warp reductions @@ -1436,33 +1292,25 @@ void cuComputeGradInput( // upper half of warps write to shared if (threadIdx.y >= offset && threadIdx.y < 2*offset) { const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; - if constexpr (!rms_norm){ - buf[2*wrt_i] = sum_loss1; - } + buf[2*wrt_i] = sum_loss1; buf[2*wrt_i+1] = sum_loss2; } __syncthreads(); // lower half merges if (threadIdx.y < offset) { const int read_i = threadIdx.y * blockDim.x + threadIdx.x; - if constexpr (!rms_norm){ - sum_loss1 += buf[2*read_i]; - } + sum_loss1 += buf[2*read_i]; sum_loss2 += buf[2*read_i+1]; } __syncthreads(); } if (threadIdx.y == 0) { - if constexpr (!rms_norm){ - buf[2*threadIdx.x] = sum_loss1; - } + buf[2*threadIdx.x] = sum_loss1; buf[2*threadIdx.x+1] = sum_loss2; } __syncthreads(); if (threadIdx.y !=0) { - if constexpr (!rms_norm){ - sum_loss1 = buf[2*threadIdx.x]; - } + sum_loss1 = buf[2*threadIdx.x]; sum_loss2 = buf[2*threadIdx.x+1]; } } @@ -1475,12 +1323,8 @@ void cuComputeGradInput( const T_ACC c_h = static_cast(k_input[l]); const T_ACC c_loss = static_cast(k_dout[l]); T_ACC f_grad_input = fH * c_loss * gamma[l]; - if constexpr (!rms_norm){ - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; - } else { - f_grad_input -= (c_h) * c_rstd * sum_loss2; - } + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } @@ -1489,12 +1333,8 @@ void cuComputeGradInput( const T_ACC c_h = static_cast(k_input[l]); const T_ACC c_loss = static_cast(k_dout[l]); T_ACC f_grad_input = fH * c_loss; - if constexpr (!rms_norm){ - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; - } else { - f_grad_input -= (c_h) * c_rstd * sum_loss2; - } + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } @@ -1504,7 +1344,7 @@ void cuComputeGradInput( } } -template +template void LayerNormBackwardKernelImplInternal( const Tensor& dY, const Tensor& X, @@ -1518,9 +1358,7 @@ void LayerNormBackwardKernelImplInternal( Tensor* dbeta) { using T_ACC = acc_type; TORCH_CHECK(dY.numel() == M * N); - if constexpr (!rms_norm){ - TORCH_CHECK(mean.numel() == M); - } + TORCH_CHECK(mean.numel() == M); TORCH_CHECK(rstd.numel() == M); TORCH_CHECK(M <= at::cuda::getCurrentDeviceProperties()->maxGridSize[0], "M should be less than maximum CUDA grid size, \ file a support request to support bigger batches"); @@ -1546,7 +1384,7 @@ void LayerNormBackwardKernelImplInternal( threads1.y > 1 ? threads1.y*threads1.x*sizeof(T_ACC) : 0; - cuComputeGradInput<<>>( + cuComputeGradInput<<>>( dY_data, X_data, M, N, @@ -1558,7 +1396,7 @@ void LayerNormBackwardKernelImplInternal( } else { const dim3 blocks(M); int nshared = (num_threads()/warp_size) * sizeof(T_ACC); - layer_norm_grad_input_kernel<<>>(dY_data, + layer_norm_grad_input_kernel<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1572,12 +1410,13 @@ void LayerNormBackwardKernelImplInternal( const unsigned int alignment = sizeof(T) * vec_size; bool bAlignedBuffers = can_vectorize(dY_data, alignment) && can_vectorize(X_data, alignment) && can_vectorize(gamma_data, alignment) && can_vectorize(dX_data, alignment); + if (bAlignedBuffers && bTargetDataTypes && bVectorSizeMultiple) { - layer_norm_grad_input_kernel_vectorized<<>>(dY_data, + layer_norm_grad_input_kernel_vectorized<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { - layer_norm_grad_input_kernel<<>>(dY_data, + layer_norm_grad_input_kernel<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1593,7 +1432,7 @@ void LayerNormBackwardKernelImplInternal( if (M < 128) { // For small batch size, do colwise reduce directly. const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads; - GammaBetaBackwardSimpleCUDAKernel + GammaBetaBackwardSimpleCUDAKernel <<>>( M, N, @@ -1617,7 +1456,7 @@ void LayerNormBackwardKernelImplInternal( Tensor part_grad_gamma = at::empty({part_size,N}, gamma.options().dtype(part_grad_dtype)); Tensor part_grad_beta = at::native::empty_like(part_grad_gamma); - cuComputePartGradGammaBeta<<>>( + cuComputePartGradGammaBeta<<>>( dY_data, X_data, M,N, @@ -1631,7 +1470,7 @@ void LayerNormBackwardKernelImplInternal( const dim3 blocks3((N + threads3.x - 1) / threads3.x, 1, 1); const int nshared3 = threads3.x * threads3.y * sizeof(T_ACC); - cuComputeGradGammaBeta<<>>( + cuComputeGradGammaBeta<<>>( part_grad_gamma.template data_ptr(), part_grad_beta.template data_ptr(), part_size, @@ -1641,7 +1480,7 @@ void LayerNormBackwardKernelImplInternal( C10_CUDA_KERNEL_LAUNCH_CHECK(); } #else - LaunchGammaBetaBackwardCUDAKernel( + LaunchGammaBetaBackwardCUDAKernel( dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); #endif } @@ -1669,29 +1508,8 @@ void LayerNormBackwardKernelImpl( }); } -void RMSNormBackwardKernelImpl( - const Tensor& dY, - const Tensor& X, - const Tensor& rstd, - const Tensor& gamma, - int64_t M, - int64_t N, - Tensor* dX, - Tensor* dgamma) { - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - X.scalar_type(), - "LayerNormBackwardKernelImpl", - [&]() { - LayerNormBackwardKernelImplInternal( - dY.contiguous(), X, rstd, rstd, gamma, M, N, dX, dgamma, dgamma); - }); -} - } // namespace - std::tuple layer_norm_cuda( const Tensor& input, IntArrayRef normalized_shape, @@ -1820,108 +1638,6 @@ std::tuple layer_norm_backward_cuda( return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta)); } -/* RMSNorm is implemented by reusing layer_norm's kernels */ -std::tuple _fused_rms_norm_cuda( - const Tensor& input, - IntArrayRef normalized_shape, - const std::optional& weight_opt /* optional */, - std::optional eps){ - - c10::MaybeOwned weight_maybe_owned = - at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, weight); - auto M = M_N.first; - auto N = M_N.second; - auto X = input.expect_contiguous(); - auto gamma = weight.expect_contiguous(); - - double eps_val = eps.value_or(std::numeric_limits::epsilon()); - - Tensor Y = at::native::empty_like( - *X, - std::nullopt /* dtype */, - std::nullopt /* layout */, - std::nullopt /* device */, - std::nullopt /* pin_memory */, - LEGACY_CONTIGUOUS_MEMORY_FORMAT); - auto acc_type = at::toAccumulateType(input.scalar_type(), /*is_cuda=*/true); - Tensor rstd = at::empty({M}, X->options().dtype(acc_type)); - - if (M > 0) { - RmsNormKernelImpl(*X, *gamma, M, N, eps_val, &Y, &rstd); - } - - const auto input_shape = input.sizes(); - const size_t axis = input.dim() - normalized_shape.size(); - - std::vector stat_shape; - for (const auto idx: c10::irange(axis)) { - stat_shape.push_back(input_shape[idx]); - } - for ([[maybe_unused]] const auto idx : c10::irange(axis, input.dim())) { - stat_shape.push_back(1); - } - - rstd = rstd.view(stat_shape); - - return std::make_tuple(std::move(Y), std::move(rstd)); -} - - -std::tuple _fused_rms_norm_backward_cuda( - const Tensor& dY, - const Tensor& input, - IntArrayRef normalized_shape, - const Tensor& rstd, - const std::optional& weight_opt /* optional */, - std::array grad_input_mask) { - - c10::MaybeOwned weight_maybe_owned = - at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - - auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, weight); - auto M = M_N.first; - auto N = M_N.second; - auto X = input.expect_contiguous(); - auto gamma = weight.expect_contiguous(); - - Tensor dX; - Tensor dgamma; - if (grad_input_mask[0]) { - dX = at::native::empty_like( - *X, - std::nullopt /* dtype */, - std::nullopt /* layout */, - std::nullopt /* device */, - std::nullopt /* pin_memory */, - LEGACY_CONTIGUOUS_MEMORY_FORMAT); - } - if (grad_input_mask[1]) { - dgamma = M > 0 ? at::native::empty_like( - *gamma, - std::nullopt /* dtype */, - std::nullopt /* layout */, - std::nullopt /* device */, - std::nullopt /* pin_memory */, - LEGACY_CONTIGUOUS_MEMORY_FORMAT) - : at::native::zeros_like( - *gamma, - std::nullopt /* dtype */, - std::nullopt /* layout */, - std::nullopt /* device */, - std::nullopt /* pin_memory */, - LEGACY_CONTIGUOUS_MEMORY_FORMAT); - } - - if (M > 0 && N > 0) { - RMSNormBackwardKernelImpl( - dY, *X, rstd, *gamma, M, N, &dX, &dgamma); - } - return std::make_tuple(std::move(dX), std::move(dgamma)); -} - REGISTER_DISPATCH(LayerNormKernel, &LayerNormKernelImpl) REGISTER_DISPATCH(LayerNormBackwardKernel, &LayerNormBackwardKernelImpl) diff --git a/aten/src/ATen/native/layer_norm.cpp b/aten/src/ATen/native/layer_norm.cpp index 207f092a676..da6bb5fec39 100644 --- a/aten/src/ATen/native/layer_norm.cpp +++ b/aten/src/ATen/native/layer_norm.cpp @@ -261,11 +261,30 @@ std::tuple math_native_layer_norm( return outputs; } -std::tuple rms_norm_composite( +Tensor rms_norm_symint( const Tensor& input, - IntArrayRef normalized_shape, + c10::SymIntArrayRef normalized_shape, const std::optional& weight_opt /* optional */, std::optional eps) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + _check_rms_norm_inputs_symint(input, normalized_shape, weight); + +#ifdef USE_MPS + if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) { + const Tensor weight = weight_opt.value(); + const bool any_nested = input.is_nested() || weight.is_nested(); + const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad(); + const bool is_input_fp = isFloatingType(input.scalar_type()); + const bool is_weight_fp = isFloatingType(weight.scalar_type()); + + if (!(GradMode::is_enabled() && any_inputs_require_grad) && !any_nested && is_input_fp && is_weight_fp) { + auto eps_val = eps.value_or(std::numeric_limits::epsilon()); + return at::_fused_rms_norm(input.contiguous(), normalized_shape.size(), weight.contiguous(), eps_val); + } + } +#endif std::vector dims_to_reduce; for (const auto i : c10::irange(normalized_shape.size())) { @@ -302,60 +321,10 @@ std::tuple rms_norm_composite( upcasted_result = upcasted_result.mul(weight_opt.value()); } - // if nested do not make contiguous - if(input.is_nested() || (weight_opt.has_value() && weight_opt.value().is_nested())){ - return std::make_tuple(upcasted_result, rqrst_input); - } - - if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){ - return std::make_tuple(upcasted_result, rqrst_input); - } - - return std::make_tuple(upcasted_result.contiguous(), rqrst_input.contiguous()); + return upcasted_result; }); - return std::make_tuple( - std::get<0>(result).type_as(input), // Cast normalized result to original input type - std::get<1>(result) // rsqrt_val - ); + + return result.type_as(input); + } - - -Tensor rms_norm_symint( - const Tensor& input, - c10::SymIntArrayRef normalized_shape, - const std::optional& weight_opt /* optional */, - const std::optional eps) { - - c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - _check_rms_norm_inputs_symint(input, normalized_shape, weight); - - // composite fallback for channels last - if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){ - return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); - } - - // composite fallback for complex datatypes - if(input.is_complex()){ - return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); - } - - #ifdef USE_MPS - if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) { - const Tensor weight = weight_opt.value(); - const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad(); - - if (!(GradMode::is_enabled() && any_inputs_require_grad)) { - return std::get<0>(at::_fused_rms_norm(input.contiguous(), IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); - } - } - - if (input.device().type() == DeviceType::MPS){ - return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); - } - #endif - - return std::get<0>(at::_fused_rms_norm(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); -} - } // namespace at::native diff --git a/aten/src/ATen/native/layer_norm.h b/aten/src/ATen/native/layer_norm.h index 0debe942dd0..0181f35fd6e 100644 --- a/aten/src/ATen/native/layer_norm.h +++ b/aten/src/ATen/native/layer_norm.h @@ -106,12 +106,6 @@ void layer_norm_cpu_out( int64_t M, int64_t N); -std::tuple rms_norm_composite( - const Tensor& input, - IntArrayRef normalized_shape, - const std::optional& weight_opt /* optional */, - std::optional eps); - Tensor rms_norm_symint( const Tensor& input, c10::SymIntArrayRef normalized_shape, diff --git a/aten/src/ATen/native/mps/operations/RMSNorm.mm b/aten/src/ATen/native/mps/operations/RMSNorm.mm index 7948b5acd8e..71128297d5b 100644 --- a/aten/src/ATen/native/mps/operations/RMSNorm.mm +++ b/aten/src/ATen/native/mps/operations/RMSNorm.mm @@ -19,14 +19,7 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary(); #include #endif -std::tuple _fused_rms_norm_mps(const Tensor& input, - IntArrayRef normalized_shape, - const std::optional& weight_opt, - const std::optional eps) { - const Tensor weight = weight_opt.value().contiguous(); - const int64_t normalized_ndim = normalized_shape.size(); - auto eps_val = eps.value_or(std::numeric_limits::epsilon()); - +Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, const Tensor& weight, const double eps) { TORCH_CHECK(input.is_contiguous() && weight.is_contiguous(), "Expected contiguous input and weight tensors"); auto output = at::empty_like(input); const auto input_shape = input.sizes(); @@ -48,7 +41,7 @@ std::tuple _fused_rms_norm_mps(const Tensor& input, const std::string kernel = fmt::format("{}_{}", name, scalarToMetalTypeString(output)); id rms_norm_pso = lib.getPipelineStateForFunc(kernel); [computeEncoder setComputePipelineState:rms_norm_pso]; - mtl_setArgs(computeEncoder, input, weight, output, eps_val, N, 1); + mtl_setArgs(computeEncoder, input, weight, output, eps, N, 1); const auto maxThreadsPerGroup = static_cast([rms_norm_pso maxTotalThreadsPerThreadgroup]); size_t threadgroup_size = maxThreadsPerGroup; @@ -65,7 +58,7 @@ std::tuple _fused_rms_norm_mps(const Tensor& input, } }); - return std::make_tuple(output, Tensor()); + return output; } } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 0e88a20a044..0fe784626f3 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3315,15 +3315,9 @@ dispatch: CompositeImplicitAutograd: rms_norm_symint -- func: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor) +- func: _fused_rms_norm(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor dispatch: - CUDA: _fused_rms_norm_cuda MPS: _fused_rms_norm_mps - CompositeImplicitAutograd: rms_norm_composite - -- func: _fused_rms_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor rstd, Tensor? weight, bool[2] output_mask) -> (Tensor, Tensor) - dispatch: - CUDA: _fused_rms_norm_backward_cuda - func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor variants: function, method diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index a590713ad0f..042959c22cd 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -374,6 +374,7 @@ aten::_fused_adamw_.tensor_lr aten::_fused_moving_avg_obs_fq_helper aten::_fused_moving_avg_obs_fq_helper.out aten::_fused_moving_avg_obs_fq_helper_functional +aten::_fused_rms_norm aten::_fused_sdp_choice aten::_fused_sgd aten::_fused_sgd.out diff --git a/test/test_decomp.py b/test/test_decomp.py index 6869d56982c..07dcd8252c5 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -15,7 +15,7 @@ from torch._dispatch.python import enable_python_dispatcher from torch._export.utils import _is_cia_op from torch._ops import DispatchKey from torch.testing import make_tensor -from torch.testing._internal.common_cuda import SM70OrLater, tf32_off +from torch.testing._internal.common_cuda import tf32_off from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCPU, @@ -1226,33 +1226,6 @@ class DecompOneOffTests(TestCase): for o_ref, o in zip(out_ref, out): self.assertEqual(o_ref.dtype, o.dtype) - @onlyCUDA - @unittest.skipIf(not SM70OrLater, "triton") - def test_rms_norm_decomp_cuda(self, device): - @torch.compile - def rms_norm_sinh(a, b, c): - output = torch.nn.functional.rms_norm(a, b, c) - return torch.sinh(output) - - normalized_shape_arg = (3, 3, 3) - input_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True) - weight_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True) - - def forward_pass_fn(): - return rms_norm_sinh(input_tensor, normalized_shape_arg, weight_tensor) - - model_output, generated_codes = torch._inductor.utils.run_fw_bw_and_get_code( - forward_pass_fn - ) - - # check RMSNorm was fused with sinh - self.assertTrue( - "triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0] - ) - self.assertTrue( - "triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1] - ) - instantiate_device_type_tests(DecompOneOffTests, globals()) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index f0349c2484b..e2419aab268 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1267,11 +1267,6 @@ mean: not_implemented("native_layer_norm_backward mean") rstd: not_implemented("native_layer_norm_backward rstd") -- name: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor) - input, weight: "GradMode::is_enabled() || grads[1].defined() ? infinitely_differentiable_native_rms_norm_backward(grads[0], grads[1], input, normalized_shape, result1, weight, grad_input_mask) : (grads[0].defined() ? _fused_rms_norm_backward(grads[0], input, normalized_shape, result1, weight, grad_input_mask) : std::tuple())" - result0: rms_norm_jvp(input_p, input_t, weight_p, weight_t, result1, normalized_shape) - result1: rms_norm_rstd_jvp(input_p, input_t, result1, normalized_shape) - - name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), input.device().is_xpu() ? input : input.contiguous(input.device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple())" result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group) diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 8e9796d2f7c..abb94b109cc 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -418,7 +418,6 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.native_dropout_backward, aten.native_group_norm_backward, aten.native_layer_norm_backward, - aten._fused_rms_norm_backward, aten.new_empty, aten.new_full, aten.new_ones, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index d1bb2ed632c..0ff7e46f839 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1743,81 +1743,6 @@ def native_layer_norm_backward_out( return grad_input -@register_decomposition(aten._fused_rms_norm_backward.default) -def _fused_rms_norm_backward( - grad_out: Tensor, - input: Tensor, - normalized_shape: list[int], - rstd: Tensor, - weight: Optional[Tensor], - output_mask: list[bool], -) -> tuple[Optional[Tensor], Optional[Tensor]]: - input_shape = input.shape - input_ndim = input.dim() - computation_dtype = utils.get_computation_dtype(input.dtype) - - grad_out_cast = grad_out.to( - computation_dtype, memory_format=torch.contiguous_format - ) - input_cast = input.to(computation_dtype, memory_format=torch.contiguous_format) - weight_cast = ( - weight.to(computation_dtype, memory_format=torch.contiguous_format) - if weight is not None - else None - ) - assert grad_out_cast is not None - - axis = input_ndim - len(normalized_shape) - inner_dims = input_shape[axis:] - outer_dims = input_shape[:axis] - inner_dim_indices: list[int] = [] - outer_dim_indices: list[int] = [] - for i in range(input_ndim): - if i >= axis: - inner_dim_indices.append(i) - else: - outer_dim_indices.append(i) - - N = prod(inner_dims) # type: ignore[arg-type] - M = prod(outer_dims) # type: ignore[arg-type] - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious - - if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0): - return ( - input.new_zeros(input_shape) if output_mask[0] else None, - input.new_zeros(input_shape[axis:]) if output_mask[1] else None, - ) - - rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr] - if weight_cast is not None: - grad_x_hat = grad_out_cast * weight_cast - else: - grad_x_hat = grad_out_cast - - d_input: Optional[Tensor] = None - d_weight: Optional[Tensor] = None - - x_hat = input_cast * rstd - - if output_mask[0]: - sum_val = torch.sum(x_hat * grad_x_hat, dim=inner_dim_indices, keepdim=True) - d_input = (grad_x_hat - (x_hat / N) * sum_val) * rstd - - if output_mask[1] and weight_cast is not None: - d_weight_full_shape = grad_out_cast * x_hat - if len(outer_dim_indices) > 0: - d_weight = torch.sum( - d_weight_full_shape, dim=outer_dim_indices, keepdim=False - ) - else: - d_weight = d_weight_full_shape - - return ( - _maybe_cast(d_input, input.dtype), - _maybe_cast(d_weight, input.dtype), - ) - - def native_batch_norm_helper( input: Tensor, weight: Optional[Tensor], diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index af04f617260..2258917ba20 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -5022,103 +5022,6 @@ std::tuple layer_norm_double_backward( return std::tuple{gI, gG, ggO}; } -std::tuple infinitely_differentiable_native_rms_norm_backward( - const Tensor& dY, - const Tensor& drstd, - const Tensor& input, - IntArrayRef normalized_shape, - const Tensor& rstd, - const std::optional& weight_opt, - std::array grad_input_mask) { - c10::MaybeOwned weight_maybe_owned = - at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - - const auto input_shape = input.sizes(); - const auto input_ndim = input.dim(); - const int normalized_ndim = normalized_shape.size(); - const int axis = input_ndim - normalized_ndim; - - int64_t N_rms = 1; - for (int i = 0; i < normalized_ndim; ++i) { - N_rms *= input_shape[axis + i]; - } - - Tensor dX; - Tensor dgamma; - - std::vector rstd_view_shape = rstd.sizes().vec(); - for (int i = 0; - i < std::max(static_cast(normalized_ndim - rstd.dim()), 0); - ++i) { - rstd_view_shape.push_back(1); - } - Tensor rstd_broadcast = rstd.view(rstd_view_shape); - Tensor rstd_pow3 = rstd_broadcast.pow(3); - Tensor grad_x_hat; - - if (dY.defined()) { - if (weight.defined()) { - grad_x_hat = dY * weight; - } else { - grad_x_hat = dY; - } - } - - if (grad_input_mask[0]) { - Tensor dX_from_dY_path; - Tensor dX_from_drstd_path; - - std::vector inner_sum_dims; - inner_sum_dims.reserve(normalized_ndim); - for (int i = 0; i < normalized_ndim; ++i) { - inner_sum_dims.push_back(axis + i); - } - - if (dY.defined() && grad_x_hat.defined()) { - Tensor sum_input_times_grad_x_hat = - sum(input * grad_x_hat, inner_sum_dims, /*keepdim=*/true); - dX_from_dY_path = rstd_broadcast * grad_x_hat - - (input * rstd_pow3 / static_cast(N_rms)) * - sum_input_times_grad_x_hat; - } - - if (drstd.defined()) { - Tensor drstd_broadcast = drstd.view(rstd_view_shape); - dX_from_drstd_path = - -(input * rstd_pow3 / static_cast(N_rms)) * drstd_broadcast; - } - - if (dX_from_dY_path.defined() && dX_from_drstd_path.defined()) { - dX = dX_from_dY_path + dX_from_drstd_path; - } else if (dX_from_dY_path.defined()) { - dX = dX_from_dY_path; - } else if (dX_from_drstd_path.defined()) { - dX = dX_from_drstd_path; - } - } - - if (grad_input_mask[1] && weight.defined()) { - if (dY.defined()) { - Tensor x_hat = input * rstd_broadcast; - Tensor dgamma_full_shape = dY * x_hat; - - if (axis > 0) { - std::vector outer_sum_dims; - outer_sum_dims.reserve(axis); - for (int i = 0; i < axis; ++i) { - outer_sum_dims.push_back(i); - } - dgamma = sum(dgamma_full_shape, outer_sum_dims, /*keepdim=*/false); - } else { - dgamma = dgamma_full_shape; - } - } - } - - return std::make_tuple(dX, dgamma); -} - std::tuple infinitely_differentiable_native_group_norm_backward( const Tensor& dY, @@ -6473,98 +6376,6 @@ Tensor layer_norm_jvp( bias_t.defined() ? bias_t.view(view_size_affine) : bias_t); } -Tensor rms_norm_jvp( - const Tensor& input_p, - const Tensor& input_t, - const Tensor& weight_p, - const Tensor& weight_t, - const Tensor& saved_rstd, - IntArrayRef normalized_shape) { - auto dims = std::vector{}; - auto view_size = input_t.sizes().vec(); - auto view_size_affine = input_t.sizes().vec(); - - int64_t numel = 1; - for (const auto i : c10::irange(view_size.size())) { - if (i < view_size.size() - normalized_shape.size()) { - view_size_affine[i] = 1; - } else { - numel *= input_t.size(static_cast(i)); - view_size[i] = 1; - dims.push_back(static_cast(i)); - } - } - - auto rstd_p = saved_rstd.view(view_size); - - Tensor rstd_t; - if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || - input_t._is_zerotensor()) { - rstd_t = -rstd_p.pow(3) * (input_t) * (input_p); - } else { - rstd_t = input_t * input_p; - rstd_t *= -rstd_p.pow(3); - } - rstd_t = rstd_t.sum(dims, true); - rstd_t /= numel; - - Tensor result_t; - if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || - input_t._is_zerotensor()) { - result_t = (input_t)*rstd_p + (input_p)*rstd_t; - } else { - result_t = input_t * rstd_p; - auto temp = input_p * rstd_t; - result_t += temp; - } - - std::optional result_p = std::nullopt; - if (weight_p.defined()) { - result_p = std::optional(input_p * rstd_p); - } - - return _affine_jvp( - result_p, - result_t, - weight_p.defined() ? weight_p.view(view_size_affine) : weight_p, - weight_t.defined() ? weight_t.view(view_size_affine) : weight_t, - Tensor()); -} - -Tensor rms_norm_rstd_jvp( - const Tensor& input_p, - const Tensor& input_t, - const Tensor& saved_rstd, - IntArrayRef normalized_shape) { - auto dims = std::vector{}; - auto view_size = input_t.sizes().vec(); - auto view_size_affine = input_t.sizes().vec(); - - int64_t numel = 1; - for (const auto i : c10::irange(view_size.size())) { - if (i < view_size.size() - normalized_shape.size()) { - view_size_affine[i] = 1; - } else { - numel *= input_t.size(static_cast(i)); - view_size[i] = 1; - dims.push_back(static_cast(i)); - } - } - - auto rstd_p = saved_rstd.view(view_size); - Tensor rstd_t; - if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || - input_t._is_zerotensor()) { - rstd_t = -rstd_p.pow(3) * (input_t) * (input_p); - } else { - rstd_t = input_t * input_p; - rstd_t *= -rstd_p.pow(3); - } - rstd_t = rstd_t.sum(dims, true); - rstd_t /= numel; - return rstd_t; -} - Tensor group_norm_jvp( const Tensor& input_p, const Tensor& input_t, diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index c49c614c7ec..1bbad0ae92d 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -828,15 +828,6 @@ std::tuple layer_norm_double_backward( c10::SymIntArrayRef normalized_shape, std::array output_mask); -std::tuple infinitely_differentiable_native_rms_norm_backward( - const Tensor& dY, - const Tensor& drstd, - const Tensor& input, - IntArrayRef normalized_shape, - const Tensor& rstd, - const std::optional& weight_opt, - std::array grad_input_mask); - std::tuple householder_product_backward( const Tensor& grad, const Tensor& result, @@ -976,20 +967,6 @@ Tensor layer_norm_jvp( const Tensor& saved_invstd, c10::SymIntArrayRef normalized_shape); -Tensor rms_norm_jvp( - const Tensor& input_p, - const Tensor& input_t, - const Tensor& weight_p, - const Tensor& weight_t, - const Tensor& saved_rstd, - IntArrayRef normalized_shape); - -Tensor rms_norm_rstd_jvp( - const Tensor& input_p, - const Tensor& input_t, - const Tensor& saved_rstd, - IntArrayRef normalized_shape); - Tensor group_norm_jvp( const Tensor& input_p, const Tensor& input_t, diff --git a/torch/overrides.py b/torch/overrides.py index 28a58445cdc..562141ff1cf 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -820,7 +820,6 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1, torch.native_dropout: lambda input, p, train: -1, torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1, - torch._fused_rms_norm: lambda input, normalized_shape, weight=None, eps=1e-05: -1, torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1, torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1, torch.native_channel_shuffle: lambda input, groups: -1,