[BE] Use simdgroup_size constexpr (#157751)

Instead of every shader defining it separately, move it to `c10/metal/common.h`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157751
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #157746
This commit is contained in:
Nikita Shulga 2025-07-07 17:14:13 -07:00 committed by PyTorch MergeBot
parent 0b73f7c871
commit 39a8f66d59
5 changed files with 15 additions and 17 deletions

View File

@ -1,6 +1,8 @@
#include <c10/metal/common.h>
#include <metal_simdgroup> #include <metal_simdgroup>
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
using c10::metal::simdgroup_size;
template <typename T> template <typename T>
kernel void layer_norm_single_row( kernel void layer_norm_single_row(
@ -18,7 +20,6 @@ kernel void layer_norm_single_row(
uint tid [[thread_position_in_threadgroup]], uint tid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simdgroup_id [[simdgroup_index_in_threadgroup]]) { uint simdgroup_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
constexpr int N_READS = 4; constexpr int N_READS = 4;
// each threadgroup handles one full “row” of length axis_size // each threadgroup handles one full “row” of length axis_size
@ -52,8 +53,8 @@ kernel void layer_norm_single_row(
} }
// threadgroupwide reduction // threadgroupwide reduction
threadgroup float local_sums[SIMD_SIZE]; threadgroup float local_sums[simdgroup_size];
threadgroup float local_sums_sq[SIMD_SIZE]; threadgroup float local_sums_sq[simdgroup_size];
threadgroup float tg_mean[1]; threadgroup float tg_mean[1];
threadgroup float tg_inv_std[1]; threadgroup float tg_inv_std[1];
@ -142,7 +143,6 @@ kernel void layer_norm_looped(
uint lsize [[threads_per_threadgroup]], uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simdgroup_id [[simdgroup_index_in_threadgroup]]) { uint simdgroup_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
constexpr int N_READS = 4; constexpr int N_READS = 4;
uint row_offset = tg_id * axis_size; uint row_offset = tg_id * axis_size;
@ -178,8 +178,8 @@ kernel void layer_norm_looped(
partial_sum = simd_sum(partial_sum); partial_sum = simd_sum(partial_sum);
partial_sum_sq = simd_sum(partial_sum_sq); partial_sum_sq = simd_sum(partial_sum_sq);
threadgroup float local_sums[SIMD_SIZE]; threadgroup float local_sums[simdgroup_size];
threadgroup float local_sums_sq[SIMD_SIZE]; threadgroup float local_sums_sq[simdgroup_size];
threadgroup float tg_mean[1]; threadgroup float tg_mean[1];
threadgroup float tg_inv_std[1]; threadgroup float tg_inv_std[1];

View File

@ -2,11 +2,13 @@
// https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/rms_norm.metal // https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/rms_norm.metal
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <c10/metal/common.h>
#include <metal_common> #include <metal_common>
#include <metal_simdgroup> #include <metal_simdgroup>
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
using c10::metal::simdgroup_size;
template <typename T> template <typename T>
[[kernel]] void rms_single_row( [[kernel]] void rms_single_row(
@ -20,11 +22,10 @@ template <typename T>
uint lid [[thread_position_in_threadgroup]], uint lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
constexpr int N_READS = 4; constexpr int N_READS = 4;
threadgroup float local_inv_mean[1]; threadgroup float local_inv_mean[1];
threadgroup float local_sums[SIMD_SIZE]; threadgroup float local_sums[simdgroup_size];
float acc = 0; float acc = 0;
x += gid * size_t(axis_size) + lid * N_READS; x += gid * size_t(axis_size) + lid * N_READS;
@ -92,10 +93,9 @@ template <typename T>
uint lsize [[threads_per_threadgroup]], uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
constexpr int N_READS = 4; constexpr int N_READS = 4;
threadgroup float local_inv_mean[1]; threadgroup float local_inv_mean[1];
threadgroup float local_sums[SIMD_SIZE]; threadgroup float local_sums[simdgroup_size];
float acc = 0; float acc = 0;
x += gid * size_t(axis_size) + lid * N_READS; x += gid * size_t(axis_size) + lid * N_READS;

View File

@ -398,6 +398,8 @@ REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool);
#else // __METAL_VERSION__ >= 310 #else // __METAL_VERSION__ >= 310
C10_METAL_CONSTEXPR auto simd_size = c10::metal::simdgroup_size;
// The reminder of this file contains cummin and cummax implementations adapted // The reminder of this file contains cummin and cummax implementations adapted
// from MLX: // from MLX:
// https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/scan.h // https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/scan.h
@ -710,7 +712,6 @@ kernel void scan_innermost_dim(
uint3 lsize [[threads_per_threadgroup]], uint3 lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int simd_size = 32;
Op op; Op op;
// Position the pointers // Position the pointers
@ -808,7 +809,6 @@ kernel void scan_outer_dim(
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int simd_size = 32;
constexpr int BM = 32; constexpr int BM = 32;
constexpr int BN = 32; constexpr int BN = 32;
constexpr int BN_pad = 32 + 16 / sizeof(T); constexpr int BN_pad = 32 + 16 / sizeof(T);
@ -907,7 +907,6 @@ kernel void scan_with_indices_innermost_dim(
uint3 lsize [[threads_per_threadgroup]], uint3 lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int simd_size = 32;
Op op; Op op;
using pair_t = typename Op::pair_t; using pair_t = typename Op::pair_t;
@ -999,7 +998,6 @@ kernel void scan_with_indices_outer_dim(
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int simd_size = 32;
constexpr int BM = 32; constexpr int BM = 32;
constexpr int BN = 32; constexpr int BN = 32;
constexpr int BN_pad = 32 + 16 / sizeof(T); constexpr int BN_pad = 32 + 16 / sizeof(T);

View File

@ -39,6 +39,8 @@
namespace c10 { namespace c10 {
namespace metal { namespace metal {
C10_METAL_CONSTEXPR unsigned max_ndim = 16; C10_METAL_CONSTEXPR unsigned max_ndim = 16;
C10_METAL_CONSTEXPR unsigned simdgroup_size = 32;
#ifdef __METAL__ #ifdef __METAL__
template <typename T, unsigned N> template <typename T, unsigned N>
using array = ::metal::array<T, N>; using array = ::metal::array<T, N>;

View File

@ -6,8 +6,6 @@
namespace c10 { namespace c10 {
namespace metal { namespace metal {
constant constexpr ushort simdgroup_size = 32;
template <typename T> template <typename T>
inline ::metal::enable_if_t<!::metal::is_same_v<T, long>, T> simd_sum(T val) { inline ::metal::enable_if_t<!::metal::is_same_v<T, long>, T> simd_sum(T val) {
return ::metal::simd_sum(val); return ::metal::simd_sum(val);