mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] Use simdgroup_size constexpr (#157751)
Instead of every shader defining it separately, move it to `c10/metal/common.h` Pull Request resolved: https://github.com/pytorch/pytorch/pull/157751 Approved by: https://github.com/Skylion007, https://github.com/dcci ghstack dependencies: #157746
This commit is contained in:
parent
0b73f7c871
commit
39a8f66d59
|
|
@ -1,6 +1,8 @@
|
||||||
|
#include <c10/metal/common.h>
|
||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
using c10::metal::simdgroup_size;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
kernel void layer_norm_single_row(
|
kernel void layer_norm_single_row(
|
||||||
|
|
@ -18,7 +20,6 @@ kernel void layer_norm_single_row(
|
||||||
uint tid [[thread_position_in_threadgroup]],
|
uint tid [[thread_position_in_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simdgroup_id [[simdgroup_index_in_threadgroup]]) {
|
uint simdgroup_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
constexpr int SIMD_SIZE = 32;
|
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
|
|
||||||
// each threadgroup handles one full “row” of length axis_size
|
// each threadgroup handles one full “row” of length axis_size
|
||||||
|
|
@ -52,8 +53,8 @@ kernel void layer_norm_single_row(
|
||||||
}
|
}
|
||||||
|
|
||||||
// threadgroup‐wide reduction
|
// threadgroup‐wide reduction
|
||||||
threadgroup float local_sums[SIMD_SIZE];
|
threadgroup float local_sums[simdgroup_size];
|
||||||
threadgroup float local_sums_sq[SIMD_SIZE];
|
threadgroup float local_sums_sq[simdgroup_size];
|
||||||
threadgroup float tg_mean[1];
|
threadgroup float tg_mean[1];
|
||||||
threadgroup float tg_inv_std[1];
|
threadgroup float tg_inv_std[1];
|
||||||
|
|
||||||
|
|
@ -142,7 +143,6 @@ kernel void layer_norm_looped(
|
||||||
uint lsize [[threads_per_threadgroup]],
|
uint lsize [[threads_per_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simdgroup_id [[simdgroup_index_in_threadgroup]]) {
|
uint simdgroup_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
constexpr int SIMD_SIZE = 32;
|
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
|
|
||||||
uint row_offset = tg_id * axis_size;
|
uint row_offset = tg_id * axis_size;
|
||||||
|
|
@ -178,8 +178,8 @@ kernel void layer_norm_looped(
|
||||||
partial_sum = simd_sum(partial_sum);
|
partial_sum = simd_sum(partial_sum);
|
||||||
partial_sum_sq = simd_sum(partial_sum_sq);
|
partial_sum_sq = simd_sum(partial_sum_sq);
|
||||||
|
|
||||||
threadgroup float local_sums[SIMD_SIZE];
|
threadgroup float local_sums[simdgroup_size];
|
||||||
threadgroup float local_sums_sq[SIMD_SIZE];
|
threadgroup float local_sums_sq[simdgroup_size];
|
||||||
threadgroup float tg_mean[1];
|
threadgroup float tg_mean[1];
|
||||||
threadgroup float tg_inv_std[1];
|
threadgroup float tg_inv_std[1];
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,11 +2,13 @@
|
||||||
// https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/rms_norm.metal
|
// https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/rms_norm.metal
|
||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <c10/metal/common.h>
|
||||||
#include <metal_common>
|
#include <metal_common>
|
||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
using c10::metal::simdgroup_size;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
[[kernel]] void rms_single_row(
|
[[kernel]] void rms_single_row(
|
||||||
|
|
@ -20,11 +22,10 @@ template <typename T>
|
||||||
uint lid [[thread_position_in_threadgroup]],
|
uint lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
constexpr int SIMD_SIZE = 32;
|
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
|
|
||||||
threadgroup float local_inv_mean[1];
|
threadgroup float local_inv_mean[1];
|
||||||
threadgroup float local_sums[SIMD_SIZE];
|
threadgroup float local_sums[simdgroup_size];
|
||||||
|
|
||||||
float acc = 0;
|
float acc = 0;
|
||||||
x += gid * size_t(axis_size) + lid * N_READS;
|
x += gid * size_t(axis_size) + lid * N_READS;
|
||||||
|
|
@ -92,10 +93,9 @@ template <typename T>
|
||||||
uint lsize [[threads_per_threadgroup]],
|
uint lsize [[threads_per_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
constexpr int SIMD_SIZE = 32;
|
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
threadgroup float local_inv_mean[1];
|
threadgroup float local_inv_mean[1];
|
||||||
threadgroup float local_sums[SIMD_SIZE];
|
threadgroup float local_sums[simdgroup_size];
|
||||||
|
|
||||||
float acc = 0;
|
float acc = 0;
|
||||||
x += gid * size_t(axis_size) + lid * N_READS;
|
x += gid * size_t(axis_size) + lid * N_READS;
|
||||||
|
|
|
||||||
|
|
@ -398,6 +398,8 @@ REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool);
|
||||||
|
|
||||||
#else // __METAL_VERSION__ >= 310
|
#else // __METAL_VERSION__ >= 310
|
||||||
|
|
||||||
|
C10_METAL_CONSTEXPR auto simd_size = c10::metal::simdgroup_size;
|
||||||
|
|
||||||
// The reminder of this file contains cummin and cummax implementations adapted
|
// The reminder of this file contains cummin and cummax implementations adapted
|
||||||
// from MLX:
|
// from MLX:
|
||||||
// https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/scan.h
|
// https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/scan.h
|
||||||
|
|
@ -710,7 +712,6 @@ kernel void scan_innermost_dim(
|
||||||
uint3 lsize [[threads_per_threadgroup]],
|
uint3 lsize [[threads_per_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
constexpr int simd_size = 32;
|
|
||||||
Op op;
|
Op op;
|
||||||
|
|
||||||
// Position the pointers
|
// Position the pointers
|
||||||
|
|
@ -808,7 +809,6 @@ kernel void scan_outer_dim(
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
constexpr int simd_size = 32;
|
|
||||||
constexpr int BM = 32;
|
constexpr int BM = 32;
|
||||||
constexpr int BN = 32;
|
constexpr int BN = 32;
|
||||||
constexpr int BN_pad = 32 + 16 / sizeof(T);
|
constexpr int BN_pad = 32 + 16 / sizeof(T);
|
||||||
|
|
@ -907,7 +907,6 @@ kernel void scan_with_indices_innermost_dim(
|
||||||
uint3 lsize [[threads_per_threadgroup]],
|
uint3 lsize [[threads_per_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
constexpr int simd_size = 32;
|
|
||||||
Op op;
|
Op op;
|
||||||
using pair_t = typename Op::pair_t;
|
using pair_t = typename Op::pair_t;
|
||||||
|
|
||||||
|
|
@ -999,7 +998,6 @@ kernel void scan_with_indices_outer_dim(
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
constexpr int simd_size = 32;
|
|
||||||
constexpr int BM = 32;
|
constexpr int BM = 32;
|
||||||
constexpr int BN = 32;
|
constexpr int BN = 32;
|
||||||
constexpr int BN_pad = 32 + 16 / sizeof(T);
|
constexpr int BN_pad = 32 + 16 / sizeof(T);
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,8 @@
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
namespace metal {
|
namespace metal {
|
||||||
C10_METAL_CONSTEXPR unsigned max_ndim = 16;
|
C10_METAL_CONSTEXPR unsigned max_ndim = 16;
|
||||||
|
C10_METAL_CONSTEXPR unsigned simdgroup_size = 32;
|
||||||
|
|
||||||
#ifdef __METAL__
|
#ifdef __METAL__
|
||||||
template <typename T, unsigned N>
|
template <typename T, unsigned N>
|
||||||
using array = ::metal::array<T, N>;
|
using array = ::metal::array<T, N>;
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,6 @@
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
namespace metal {
|
namespace metal {
|
||||||
|
|
||||||
constant constexpr ushort simdgroup_size = 32;
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline ::metal::enable_if_t<!::metal::is_same_v<T, long>, T> simd_sum(T val) {
|
inline ::metal::enable_if_t<!::metal::is_same_v<T, long>, T> simd_sum(T val) {
|
||||||
return ::metal::simd_sum(val);
|
return ::metal::simd_sum(val);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user