[BE] Use simdgroup_size constexpr (#157751)

Instead of every shader defining it separately, move it to `c10/metal/common.h`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157751
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #157746
This commit is contained in:
Nikita Shulga 2025-07-07 17:14:13 -07:00 committed by PyTorch MergeBot
parent 0b73f7c871
commit 39a8f66d59
5 changed files with 15 additions and 17 deletions

View File

@ -1,6 +1,8 @@
#include <c10/metal/common.h>
#include <metal_simdgroup>
#include <metal_stdlib>
using namespace metal;
using c10::metal::simdgroup_size;
template <typename T>
kernel void layer_norm_single_row(
@ -18,7 +20,6 @@ kernel void layer_norm_single_row(
uint tid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simdgroup_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
constexpr int N_READS = 4;
// each threadgroup handles one full “row” of length axis_size
@ -52,8 +53,8 @@ kernel void layer_norm_single_row(
}
// threadgroupwide reduction
threadgroup float local_sums[SIMD_SIZE];
threadgroup float local_sums_sq[SIMD_SIZE];
threadgroup float local_sums[simdgroup_size];
threadgroup float local_sums_sq[simdgroup_size];
threadgroup float tg_mean[1];
threadgroup float tg_inv_std[1];
@ -142,7 +143,6 @@ kernel void layer_norm_looped(
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simdgroup_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
constexpr int N_READS = 4;
uint row_offset = tg_id * axis_size;
@ -178,8 +178,8 @@ kernel void layer_norm_looped(
partial_sum = simd_sum(partial_sum);
partial_sum_sq = simd_sum(partial_sum_sq);
threadgroup float local_sums[SIMD_SIZE];
threadgroup float local_sums_sq[SIMD_SIZE];
threadgroup float local_sums[simdgroup_size];
threadgroup float local_sums_sq[simdgroup_size];
threadgroup float tg_mean[1];
threadgroup float tg_inv_std[1];
@ -291,4 +291,4 @@ kernel void layer_norm_looped(
instantiate_layer_norm(float) instantiate_layer_norm(half)
#if __METAL_VERSION__ >= 310
instantiate_layer_norm(bfloat)
#endif
#endif

View File

@ -2,11 +2,13 @@
// https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/rms_norm.metal
// Copyright © 2024 Apple Inc.
#include <c10/metal/common.h>
#include <metal_common>
#include <metal_simdgroup>
#include <metal_stdlib>
using namespace metal;
using c10::metal::simdgroup_size;
template <typename T>
[[kernel]] void rms_single_row(
@ -20,11 +22,10 @@ template <typename T>
uint lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
constexpr int N_READS = 4;
threadgroup float local_inv_mean[1];
threadgroup float local_sums[SIMD_SIZE];
threadgroup float local_sums[simdgroup_size];
float acc = 0;
x += gid * size_t(axis_size) + lid * N_READS;
@ -92,10 +93,9 @@ template <typename T>
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
constexpr int N_READS = 4;
threadgroup float local_inv_mean[1];
threadgroup float local_sums[SIMD_SIZE];
threadgroup float local_sums[simdgroup_size];
float acc = 0;
x += gid * size_t(axis_size) + lid * N_READS;

View File

@ -398,6 +398,8 @@ REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool);
#else // __METAL_VERSION__ >= 310
C10_METAL_CONSTEXPR auto simd_size = c10::metal::simdgroup_size;
// The reminder of this file contains cummin and cummax implementations adapted
// from MLX:
// https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/scan.h
@ -710,7 +712,6 @@ kernel void scan_innermost_dim(
uint3 lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int simd_size = 32;
Op op;
// Position the pointers
@ -808,7 +809,6 @@ kernel void scan_outer_dim(
uint3 lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int simd_size = 32;
constexpr int BM = 32;
constexpr int BN = 32;
constexpr int BN_pad = 32 + 16 / sizeof(T);
@ -907,7 +907,6 @@ kernel void scan_with_indices_innermost_dim(
uint3 lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int simd_size = 32;
Op op;
using pair_t = typename Op::pair_t;
@ -999,7 +998,6 @@ kernel void scan_with_indices_outer_dim(
uint3 lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int simd_size = 32;
constexpr int BM = 32;
constexpr int BN = 32;
constexpr int BN_pad = 32 + 16 / sizeof(T);

View File

@ -39,6 +39,8 @@
namespace c10 {
namespace metal {
C10_METAL_CONSTEXPR unsigned max_ndim = 16;
C10_METAL_CONSTEXPR unsigned simdgroup_size = 32;
#ifdef __METAL__
template <typename T, unsigned N>
using array = ::metal::array<T, N>;

View File

@ -6,8 +6,6 @@
namespace c10 {
namespace metal {
constant constexpr ushort simdgroup_size = 32;
template <typename T>
inline ::metal::enable_if_t<!::metal::is_same_v<T, long>, T> simd_sum(T val) {
return ::metal::simd_sum(val);