enable torch.compile for torch._scaled_mm nvfp4 recipe (#150462)

Summary:

Updates the meta registration for `torch._scaled_mm` to work for the
nvfp4 recipe.

Test Plan:

```bash
pytest test/test_matmul_cuda.py -s -k test_blockwise_nvfp4
```

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150462
Approved by: https://github.com/eellison
This commit is contained in:
vasiliy 2025-04-01 13:19:59 -07:00 committed by PyTorch MergeBot
parent ee97299961
commit c974b5322a
3 changed files with 53 additions and 8 deletions

View File

@ -1397,6 +1397,35 @@ class TestFP8MatmulCuda(TestCase):
)
torch.testing.assert_close(C, C_ref, atol=0, rtol=0)
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
def test_blockwise_nvfp4_compile(self) -> None:
device = "cuda"
M, K, N = 128, 128, 128
BLOCK_SIZE = 16
A_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
B_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
C_ref = A_ref @ B_ref.t()
compiled_scaled_mm = torch.compile(torch._scaled_mm, backend="inductor")
# C = torch._scaled_mm(
C = compiled_scaled_mm(
A,
B.t(),
A_scale,
B_scale,
out_dtype=torch.bfloat16,
use_fast_accum=False,
)
torch.testing.assert_close(C, C_ref, atol=0, rtol=0)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")

View File

@ -6182,12 +6182,13 @@ def meta_scaled_mm(
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
):
def is_fp8_type(dtype):
def is_fp8_or_fp4_type(dtype):
return dtype in (
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
torch.float4_e2m1fn_x2,
)
torch._check(
@ -6195,8 +6196,8 @@ def meta_scaled_mm(
lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}",
)
torch._check(
is_fp8_type(self.dtype) and is_fp8_type(mat2.dtype),
lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}",
is_fp8_or_fp4_type(self.dtype) and is_fp8_or_fp4_type(mat2.dtype),
lambda: f"Expected both inputs to be fp8 or fp4 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}",
)
if device_hint(self) == "cuda":
@ -6232,18 +6233,32 @@ def meta_scaled_mm(
m, _k = self.shape
n = mat2.size(1)
is_blockwise_scaling = (
scale_a.dtype == torch.float8_e8m0fnu
and scale_b.dtype == torch.float8_e8m0fnu
) or (
scale_a.dtype == torch.float8_e4m3fn
and scale_b.dtype == torch.float8_e4m3fn
)
if scale_a.numel() == 1 and scale_b.numel() == 1:
# tensorwise scaling
torch._check(
scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32,
lambda: "For tensorwise scaling, both scale_a and scale_b must be float (fp32) tensors.",
)
elif (
scale_a.dtype == torch.float8_e8m0fnu
and scale_b.dtype == torch.float8_e8m0fnu
):
elif is_blockwise_scaling:
# blockwise scaling
block_size_k = 32
if scale_a.dtype == torch.float8_e4m3fn:
# NVIDIA's nvfp4 recipe:
# * block size is 16 elements packed (32 unpacked)
# * _k needs to be translated to the unpacked version
block_size_k = 16
_k = _k * 2
else:
block_size_k = 32
block_size_mn = 128
def ceil_div(a, b):

View File

@ -222,6 +222,7 @@ dtype_abbrs = {
torch.float8_e4m3fnuz: "f8e4m3fnuz",
torch.float8_e5m2fnuz: "f8e5m2fnuz",
torch.float8_e8m0fnu: "f8e8m0fnu",
torch.float4_e2m1fn_x2: "f4e2m1fnx2",
torch.complex32: "c32",
torch.complex64: "c64",
torch.complex128: "c128",