mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[BE][MPS] Refactor core matmul logic into matmul_core (#155969)"
This reverts commit 769d754ab2.
Reverted https://github.com/pytorch/pytorch/pull/155969 on behalf of https://github.com/atalman due to need to revert eigen test ([comment](https://github.com/pytorch/pytorch/pull/155969#issuecomment-2992502683))
This commit is contained in:
parent
96d082d06b
commit
d309cd1d50
|
|
@ -7,31 +7,36 @@ using namespace metal;
|
|||
constant uint TILE_DIM = 16;
|
||||
|
||||
template <typename T>
|
||||
inline c10::metal::opmath_t<T> matmul_inner(
|
||||
constant T* mat1Data,
|
||||
constant T* mat2Data,
|
||||
constant array<ulong2, 3>& strides,
|
||||
constant uint3& sizes,
|
||||
threadgroup T A_tile[TILE_DIM][TILE_DIM],
|
||||
threadgroup T B_tile[TILE_DIM][TILE_DIM],
|
||||
uint2 tid,
|
||||
uint2 thread_id) {
|
||||
kernel void matmul(
|
||||
constant T* mat1Data [[buffer(0)]],
|
||||
constant T* mat2Data [[buffer(1)]],
|
||||
device T* outputData [[buffer(2)]],
|
||||
constant array<ulong2, 3>& strides [[buffer(3)]],
|
||||
constant uint3& sizes [[buffer(4)]],
|
||||
uint2 tid [[thread_position_in_threadgroup]],
|
||||
uint2 group_id [[threadgroup_position_in_grid]]) {
|
||||
uint col = group_id.x * TILE_DIM + tid.x;
|
||||
uint row = group_id.y * TILE_DIM + tid.y;
|
||||
|
||||
c10::metal::opmath_t<T> sum = 0;
|
||||
|
||||
threadgroup T A_tile[TILE_DIM][TILE_DIM];
|
||||
threadgroup T B_tile[TILE_DIM][TILE_DIM];
|
||||
|
||||
uint numTiles = (sizes.y + TILE_DIM - 1) / TILE_DIM;
|
||||
for (uint t = 0; t < numTiles; t++) {
|
||||
uint tiledCol = t * TILE_DIM + tid.x;
|
||||
if (thread_id.y < sizes.x && tiledCol < sizes.y) {
|
||||
if (row < sizes.x && tiledCol < sizes.y) {
|
||||
A_tile[tid.y][tid.x] =
|
||||
mat1Data[thread_id.y * strides[0].x + tiledCol * strides[0].y];
|
||||
mat1Data[row * strides[0].x + tiledCol * strides[0].y];
|
||||
} else {
|
||||
A_tile[tid.y][tid.x] = 0;
|
||||
}
|
||||
|
||||
uint tiledRow = t * TILE_DIM + tid.y;
|
||||
if (tiledRow < sizes.y && thread_id.x < sizes.z) {
|
||||
if (tiledRow < sizes.y && col < sizes.z) {
|
||||
B_tile[tid.y][tid.x] =
|
||||
mat2Data[tiledRow * strides[1].x + thread_id.x * strides[1].y];
|
||||
mat2Data[tiledRow * strides[1].x + col * strides[1].y];
|
||||
} else {
|
||||
B_tile[tid.y][tid.x] = 0;
|
||||
}
|
||||
|
|
@ -45,26 +50,8 @@ inline c10::metal::opmath_t<T> matmul_inner(
|
|||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
kernel void matmul(
|
||||
constant T* mat1Data [[buffer(0)]],
|
||||
constant T* mat2Data [[buffer(1)]],
|
||||
device T* outputData [[buffer(2)]],
|
||||
constant array<ulong2, 3>& strides [[buffer(3)]],
|
||||
constant uint3& sizes [[buffer(4)]],
|
||||
uint2 tid [[thread_position_in_threadgroup]],
|
||||
uint2 thread_id [[thread_position_in_grid]]) {
|
||||
threadgroup T A_tile[TILE_DIM][TILE_DIM];
|
||||
threadgroup T B_tile[TILE_DIM][TILE_DIM];
|
||||
|
||||
auto sum = matmul_inner(
|
||||
mat1Data, mat2Data, strides, sizes, A_tile, B_tile, tid, thread_id);
|
||||
if (thread_id.y < sizes.x && thread_id.x < sizes.z) {
|
||||
outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] =
|
||||
static_cast<T>(sum);
|
||||
if (row < sizes.x && col < sizes.z) {
|
||||
outputData[row * strides[2].x + col * strides[2].y] = static_cast<T>(sum);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user