mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[mxfp8 torch._scaled_grouped_mm] fix meta registration for 3d tensor (#162765)
Meta registration checks for torch._scaled_grouped_mm has a bug for 3d "B" tensors. Namely, the scale shape for such a tensor should be 2d with shape (G, blocked_K * blocked_N), but it currently enforces an expected 3d shape of (G, blocked_K, blocked_N).
See Blas.cpp for correct validation logic [here](8e217a9f6d/aten/src/ATen/native/cuda/Blas.cpp (L1622)).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162765
Approved by: https://github.com/ngimel
This commit is contained in:
parent
e8eeb06034
commit
872ed60679
|
|
@ -7547,18 +7547,18 @@ def _meta_grouped_mm_common(
|
|||
# scale sizes at compile time.
|
||||
if is_mxfp8:
|
||||
torch._check(
|
||||
mat.ndim == scale.ndim,
|
||||
lambda: f"For MXFP8, scale should have same number of dimensions as target tensor, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950
|
||||
scale.ndim == mat.ndim - 1,
|
||||
lambda: f"For MXFP8, 3d tensor should have 2d scales, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950
|
||||
)
|
||||
# TODO: This logic only holds for RHS tensor in 2d-3d case.
|
||||
# We'll need to update it to handle LHS 3d tensor in 3d-2d and 3d-3d cases.
|
||||
G, K, N = scale.shape
|
||||
G, K, N = mat.shape
|
||||
block_size = 32
|
||||
blocked_K = round_up(K / block_size, 4)
|
||||
blocked_N = round_up(N, 128)
|
||||
torch._check(
|
||||
mat.shape[-2] == blocked_K and mat.shape[-1] == blocked_N,
|
||||
lambda: f"For MXFP8, expected mat.shape={mat.shape} to have scale shape of ({G},{blocked_K},{blocked_N}), but got {scale.shape}", # noqa: B950
|
||||
scale.shape[0] == G and scale.shape[1] == blocked_K * blocked_N,
|
||||
lambda: f"For MXFP8, expected mat.shape={mat.shape} to have scale shape of ({G},{blocked_K * blocked_N}), but got {scale.shape}", # noqa: B950
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user