[mxfp8 torch._scaled_grouped_mm] fix meta registration for 3d tensor (#162765)

Meta registration checks for torch._scaled_grouped_mm has a bug for 3d "B" tensors. Namely, the scale shape for such a tensor should be 2d with shape (G, blocked_K * blocked_N), but it currently enforces an expected 3d shape of (G, blocked_K, blocked_N).

See Blas.cpp for correct validation logic [here](8e217a9f6d/aten/src/ATen/native/cuda/Blas.cpp (L1622)).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162765
Approved by: https://github.com/ngimel
This commit is contained in:
Daniel Vega-Myhre 2025-09-12 03:51:49 +00:00 committed by PyTorch MergeBot
parent e8eeb06034
commit 872ed60679

View File

@ -7547,18 +7547,18 @@ def _meta_grouped_mm_common(
# scale sizes at compile time.
if is_mxfp8:
torch._check(
mat.ndim == scale.ndim,
lambda: f"For MXFP8, scale should have same number of dimensions as target tensor, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950
scale.ndim == mat.ndim - 1,
lambda: f"For MXFP8, 3d tensor should have 2d scales, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950
)
# TODO: This logic only holds for RHS tensor in 2d-3d case.
# We'll need to update it to handle LHS 3d tensor in 3d-2d and 3d-3d cases.
G, K, N = scale.shape
G, K, N = mat.shape
block_size = 32
blocked_K = round_up(K / block_size, 4)
blocked_N = round_up(N, 128)
torch._check(
mat.shape[-2] == blocked_K and mat.shape[-1] == blocked_N,
lambda: f"For MXFP8, expected mat.shape={mat.shape} to have scale shape of ({G},{blocked_K},{blocked_N}), but got {scale.shape}", # noqa: B950
scale.shape[0] == G and scale.shape[1] == blocked_K * blocked_N,
lambda: f"For MXFP8, expected mat.shape={mat.shape} to have scale shape of ({G},{blocked_K * blocked_N}), but got {scale.shape}", # noqa: B950
)
else:
torch._check(