mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[BE] Use simdgroup_size constexpr (#157751)
Instead of every shader defining it separately, move it to `c10/metal/common.h` Pull Request resolved: https://github.com/pytorch/pytorch/pull/157751 Approved by: https://github.com/Skylion007, https://github.com/dcci ghstack dependencies: #157746
This commit is contained in:
parent
0b73f7c871
commit
39a8f66d59
|
|
@ -1,6 +1,8 @@
|
|||
#include <c10/metal/common.h>
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
using c10::metal::simdgroup_size;
|
||||
|
||||
template <typename T>
|
||||
kernel void layer_norm_single_row(
|
||||
|
|
@ -18,7 +20,6 @@ kernel void layer_norm_single_row(
|
|||
uint tid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simdgroup_id [[simdgroup_index_in_threadgroup]]) {
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
constexpr int N_READS = 4;
|
||||
|
||||
// each threadgroup handles one full “row” of length axis_size
|
||||
|
|
@ -52,8 +53,8 @@ kernel void layer_norm_single_row(
|
|||
}
|
||||
|
||||
// threadgroup‐wide reduction
|
||||
threadgroup float local_sums[SIMD_SIZE];
|
||||
threadgroup float local_sums_sq[SIMD_SIZE];
|
||||
threadgroup float local_sums[simdgroup_size];
|
||||
threadgroup float local_sums_sq[simdgroup_size];
|
||||
threadgroup float tg_mean[1];
|
||||
threadgroup float tg_inv_std[1];
|
||||
|
||||
|
|
@ -142,7 +143,6 @@ kernel void layer_norm_looped(
|
|||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simdgroup_id [[simdgroup_index_in_threadgroup]]) {
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
constexpr int N_READS = 4;
|
||||
|
||||
uint row_offset = tg_id * axis_size;
|
||||
|
|
@ -178,8 +178,8 @@ kernel void layer_norm_looped(
|
|||
partial_sum = simd_sum(partial_sum);
|
||||
partial_sum_sq = simd_sum(partial_sum_sq);
|
||||
|
||||
threadgroup float local_sums[SIMD_SIZE];
|
||||
threadgroup float local_sums_sq[SIMD_SIZE];
|
||||
threadgroup float local_sums[simdgroup_size];
|
||||
threadgroup float local_sums_sq[simdgroup_size];
|
||||
threadgroup float tg_mean[1];
|
||||
threadgroup float tg_inv_std[1];
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,13 @@
|
|||
// https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/rms_norm.metal
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <c10/metal/common.h>
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
using c10::metal::simdgroup_size;
|
||||
|
||||
template <typename T>
|
||||
[[kernel]] void rms_single_row(
|
||||
|
|
@ -20,11 +22,10 @@ template <typename T>
|
|||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
constexpr int N_READS = 4;
|
||||
|
||||
threadgroup float local_inv_mean[1];
|
||||
threadgroup float local_sums[SIMD_SIZE];
|
||||
threadgroup float local_sums[simdgroup_size];
|
||||
|
||||
float acc = 0;
|
||||
x += gid * size_t(axis_size) + lid * N_READS;
|
||||
|
|
@ -92,10 +93,9 @@ template <typename T>
|
|||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
constexpr int N_READS = 4;
|
||||
threadgroup float local_inv_mean[1];
|
||||
threadgroup float local_sums[SIMD_SIZE];
|
||||
threadgroup float local_sums[simdgroup_size];
|
||||
|
||||
float acc = 0;
|
||||
x += gid * size_t(axis_size) + lid * N_READS;
|
||||
|
|
|
|||
|
|
@ -398,6 +398,8 @@ REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool);
|
|||
|
||||
#else // __METAL_VERSION__ >= 310
|
||||
|
||||
C10_METAL_CONSTEXPR auto simd_size = c10::metal::simdgroup_size;
|
||||
|
||||
// The reminder of this file contains cummin and cummax implementations adapted
|
||||
// from MLX:
|
||||
// https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/scan.h
|
||||
|
|
@ -710,7 +712,6 @@ kernel void scan_innermost_dim(
|
|||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
constexpr int simd_size = 32;
|
||||
Op op;
|
||||
|
||||
// Position the pointers
|
||||
|
|
@ -808,7 +809,6 @@ kernel void scan_outer_dim(
|
|||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
constexpr int simd_size = 32;
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
constexpr int BN_pad = 32 + 16 / sizeof(T);
|
||||
|
|
@ -907,7 +907,6 @@ kernel void scan_with_indices_innermost_dim(
|
|||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
constexpr int simd_size = 32;
|
||||
Op op;
|
||||
using pair_t = typename Op::pair_t;
|
||||
|
||||
|
|
@ -999,7 +998,6 @@ kernel void scan_with_indices_outer_dim(
|
|||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
constexpr int simd_size = 32;
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
constexpr int BN_pad = 32 + 16 / sizeof(T);
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@
|
|||
namespace c10 {
|
||||
namespace metal {
|
||||
C10_METAL_CONSTEXPR unsigned max_ndim = 16;
|
||||
C10_METAL_CONSTEXPR unsigned simdgroup_size = 32;
|
||||
|
||||
#ifdef __METAL__
|
||||
template <typename T, unsigned N>
|
||||
using array = ::metal::array<T, N>;
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@
|
|||
namespace c10 {
|
||||
namespace metal {
|
||||
|
||||
constant constexpr ushort simdgroup_size = 32;
|
||||
|
||||
template <typename T>
|
||||
inline ::metal::enable_if_t<!::metal::is_same_v<T, long>, T> simd_sum(T val) {
|
||||
return ::metal::simd_sum(val);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user