Revert "[BE][MPS] Refactor core matmul logic into matmul_core (#155969)"

This reverts commit 769d754ab2.

Reverted https://github.com/pytorch/pytorch/pull/155969 on behalf of https://github.com/atalman due to need to revert eigen test ([comment](https://github.com/pytorch/pytorch/pull/155969#issuecomment-2992502683))
This commit is contained in:
PyTorch MergeBot 2025-06-20 18:40:38 +00:00
parent 96d082d06b
commit d309cd1d50

View File

@ -7,31 +7,36 @@ using namespace metal;
constant uint TILE_DIM = 16;
template <typename T>
inline c10::metal::opmath_t<T> matmul_inner(
constant T* mat1Data,
constant T* mat2Data,
constant array<ulong2, 3>& strides,
constant uint3& sizes,
threadgroup T A_tile[TILE_DIM][TILE_DIM],
threadgroup T B_tile[TILE_DIM][TILE_DIM],
uint2 tid,
uint2 thread_id) {
kernel void matmul(
constant T* mat1Data [[buffer(0)]],
constant T* mat2Data [[buffer(1)]],
device T* outputData [[buffer(2)]],
constant array<ulong2, 3>& strides [[buffer(3)]],
constant uint3& sizes [[buffer(4)]],
uint2 tid [[thread_position_in_threadgroup]],
uint2 group_id [[threadgroup_position_in_grid]]) {
uint col = group_id.x * TILE_DIM + tid.x;
uint row = group_id.y * TILE_DIM + tid.y;
c10::metal::opmath_t<T> sum = 0;
threadgroup T A_tile[TILE_DIM][TILE_DIM];
threadgroup T B_tile[TILE_DIM][TILE_DIM];
uint numTiles = (sizes.y + TILE_DIM - 1) / TILE_DIM;
for (uint t = 0; t < numTiles; t++) {
uint tiledCol = t * TILE_DIM + tid.x;
if (thread_id.y < sizes.x && tiledCol < sizes.y) {
if (row < sizes.x && tiledCol < sizes.y) {
A_tile[tid.y][tid.x] =
mat1Data[thread_id.y * strides[0].x + tiledCol * strides[0].y];
mat1Data[row * strides[0].x + tiledCol * strides[0].y];
} else {
A_tile[tid.y][tid.x] = 0;
}
uint tiledRow = t * TILE_DIM + tid.y;
if (tiledRow < sizes.y && thread_id.x < sizes.z) {
if (tiledRow < sizes.y && col < sizes.z) {
B_tile[tid.y][tid.x] =
mat2Data[tiledRow * strides[1].x + thread_id.x * strides[1].y];
mat2Data[tiledRow * strides[1].x + col * strides[1].y];
} else {
B_tile[tid.y][tid.x] = 0;
}
@ -45,26 +50,8 @@ inline c10::metal::opmath_t<T> matmul_inner(
threadgroup_barrier(mem_flags::mem_threadgroup);
}
return sum;
}
template <typename T>
kernel void matmul(
constant T* mat1Data [[buffer(0)]],
constant T* mat2Data [[buffer(1)]],
device T* outputData [[buffer(2)]],
constant array<ulong2, 3>& strides [[buffer(3)]],
constant uint3& sizes [[buffer(4)]],
uint2 tid [[thread_position_in_threadgroup]],
uint2 thread_id [[thread_position_in_grid]]) {
threadgroup T A_tile[TILE_DIM][TILE_DIM];
threadgroup T B_tile[TILE_DIM][TILE_DIM];
auto sum = matmul_inner(
mat1Data, mat2Data, strides, sizes, A_tile, B_tile, tid, thread_id);
if (thread_id.y < sizes.x && thread_id.x < sizes.z) {
outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] =
static_cast<T>(sum);
if (row < sizes.x && col < sizes.z) {
outputData[row * strides[2].x + col * strides[2].y] = static_cast<T>(sum);
}
}