From 39a8f66d5939e892bcb07ef97462af47d3201491 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 7 Jul 2025 17:14:13 -0700 Subject: [PATCH] [BE] Use `simdgroup_size` constexpr (#157751) Instead of every shader defining it separately, move it to `c10/metal/common.h` Pull Request resolved: https://github.com/pytorch/pytorch/pull/157751 Approved by: https://github.com/Skylion007, https://github.com/dcci ghstack dependencies: #157746 --- aten/src/ATen/native/mps/kernels/LayerNorm.metal | 14 +++++++------- aten/src/ATen/native/mps/kernels/RMSNorm.metal | 8 ++++---- aten/src/ATen/native/mps/kernels/ScanKernel.metal | 6 ++---- c10/metal/common.h | 2 ++ c10/metal/reduction_utils.h | 2 -- 5 files changed, 15 insertions(+), 17 deletions(-) diff --git a/aten/src/ATen/native/mps/kernels/LayerNorm.metal b/aten/src/ATen/native/mps/kernels/LayerNorm.metal index eea3ff10cf5..1ca9f916c2c 100644 --- a/aten/src/ATen/native/mps/kernels/LayerNorm.metal +++ b/aten/src/ATen/native/mps/kernels/LayerNorm.metal @@ -1,6 +1,8 @@ +#include #include #include using namespace metal; +using c10::metal::simdgroup_size; template kernel void layer_norm_single_row( @@ -18,7 +20,6 @@ kernel void layer_norm_single_row( uint tid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simdgroup_id [[simdgroup_index_in_threadgroup]]) { - constexpr int SIMD_SIZE = 32; constexpr int N_READS = 4; // each threadgroup handles one full “row” of length axis_size @@ -52,8 +53,8 @@ kernel void layer_norm_single_row( } // threadgroup‐wide reduction - threadgroup float local_sums[SIMD_SIZE]; - threadgroup float local_sums_sq[SIMD_SIZE]; + threadgroup float local_sums[simdgroup_size]; + threadgroup float local_sums_sq[simdgroup_size]; threadgroup float tg_mean[1]; threadgroup float tg_inv_std[1]; @@ -142,7 +143,6 @@ kernel void layer_norm_looped( uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simdgroup_id [[simdgroup_index_in_threadgroup]]) { - constexpr int SIMD_SIZE = 32; constexpr int N_READS = 4; uint row_offset = tg_id * axis_size; @@ -178,8 +178,8 @@ kernel void layer_norm_looped( partial_sum = simd_sum(partial_sum); partial_sum_sq = simd_sum(partial_sum_sq); - threadgroup float local_sums[SIMD_SIZE]; - threadgroup float local_sums_sq[SIMD_SIZE]; + threadgroup float local_sums[simdgroup_size]; + threadgroup float local_sums_sq[simdgroup_size]; threadgroup float tg_mean[1]; threadgroup float tg_inv_std[1]; @@ -291,4 +291,4 @@ kernel void layer_norm_looped( instantiate_layer_norm(float) instantiate_layer_norm(half) #if __METAL_VERSION__ >= 310 instantiate_layer_norm(bfloat) -#endif \ No newline at end of file +#endif diff --git a/aten/src/ATen/native/mps/kernels/RMSNorm.metal b/aten/src/ATen/native/mps/kernels/RMSNorm.metal index 681231d2aaa..f66dcb035df 100644 --- a/aten/src/ATen/native/mps/kernels/RMSNorm.metal +++ b/aten/src/ATen/native/mps/kernels/RMSNorm.metal @@ -2,11 +2,13 @@ // https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/rms_norm.metal // Copyright © 2024 Apple Inc. +#include #include #include #include using namespace metal; +using c10::metal::simdgroup_size; template [[kernel]] void rms_single_row( @@ -20,11 +22,10 @@ template uint lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - constexpr int SIMD_SIZE = 32; constexpr int N_READS = 4; threadgroup float local_inv_mean[1]; - threadgroup float local_sums[SIMD_SIZE]; + threadgroup float local_sums[simdgroup_size]; float acc = 0; x += gid * size_t(axis_size) + lid * N_READS; @@ -92,10 +93,9 @@ template uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - constexpr int SIMD_SIZE = 32; constexpr int N_READS = 4; threadgroup float local_inv_mean[1]; - threadgroup float local_sums[SIMD_SIZE]; + threadgroup float local_sums[simdgroup_size]; float acc = 0; x += gid * size_t(axis_size) + lid * N_READS; diff --git a/aten/src/ATen/native/mps/kernels/ScanKernel.metal b/aten/src/ATen/native/mps/kernels/ScanKernel.metal index c12fdb33cd7..e6d739cac13 100644 --- a/aten/src/ATen/native/mps/kernels/ScanKernel.metal +++ b/aten/src/ATen/native/mps/kernels/ScanKernel.metal @@ -398,6 +398,8 @@ REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool); #else // __METAL_VERSION__ >= 310 +C10_METAL_CONSTEXPR auto simd_size = c10::metal::simdgroup_size; + // The reminder of this file contains cummin and cummax implementations adapted // from MLX: // https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/scan.h @@ -710,7 +712,6 @@ kernel void scan_innermost_dim( uint3 lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - constexpr int simd_size = 32; Op op; // Position the pointers @@ -808,7 +809,6 @@ kernel void scan_outer_dim( uint3 lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - constexpr int simd_size = 32; constexpr int BM = 32; constexpr int BN = 32; constexpr int BN_pad = 32 + 16 / sizeof(T); @@ -907,7 +907,6 @@ kernel void scan_with_indices_innermost_dim( uint3 lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - constexpr int simd_size = 32; Op op; using pair_t = typename Op::pair_t; @@ -999,7 +998,6 @@ kernel void scan_with_indices_outer_dim( uint3 lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - constexpr int simd_size = 32; constexpr int BM = 32; constexpr int BN = 32; constexpr int BN_pad = 32 + 16 / sizeof(T); diff --git a/c10/metal/common.h b/c10/metal/common.h index a8274dd8fb4..b953dec9025 100644 --- a/c10/metal/common.h +++ b/c10/metal/common.h @@ -39,6 +39,8 @@ namespace c10 { namespace metal { C10_METAL_CONSTEXPR unsigned max_ndim = 16; +C10_METAL_CONSTEXPR unsigned simdgroup_size = 32; + #ifdef __METAL__ template using array = ::metal::array; diff --git a/c10/metal/reduction_utils.h b/c10/metal/reduction_utils.h index 7e16aa1569a..2d9834c34ee 100644 --- a/c10/metal/reduction_utils.h +++ b/c10/metal/reduction_utils.h @@ -6,8 +6,6 @@ namespace c10 { namespace metal { -constant constexpr ushort simdgroup_size = 32; - template inline ::metal::enable_if_t, T> simd_sum(T val) { return ::metal::simd_sum(val);