mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
enable torch.compile for torch._scaled_mm nvfp4 recipe (#150462)
Summary: Updates the meta registration for `torch._scaled_mm` to work for the nvfp4 recipe. Test Plan: ```bash pytest test/test_matmul_cuda.py -s -k test_blockwise_nvfp4 ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/150462 Approved by: https://github.com/eellison
This commit is contained in:
parent
ee97299961
commit
c974b5322a
|
|
@ -1397,6 +1397,35 @@ class TestFP8MatmulCuda(TestCase):
|
|||
)
|
||||
torch.testing.assert_close(C, C_ref, atol=0, rtol=0)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
|
||||
def test_blockwise_nvfp4_compile(self) -> None:
|
||||
|
||||
device = "cuda"
|
||||
M, K, N = 128, 128, 128
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
A_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
|
||||
B_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
|
||||
|
||||
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
|
||||
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
|
||||
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
|
||||
C_ref = A_ref @ B_ref.t()
|
||||
|
||||
compiled_scaled_mm = torch.compile(torch._scaled_mm, backend="inductor")
|
||||
# C = torch._scaled_mm(
|
||||
C = compiled_scaled_mm(
|
||||
A,
|
||||
B.t(),
|
||||
A_scale,
|
||||
B_scale,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=False,
|
||||
)
|
||||
torch.testing.assert_close(C, C_ref, atol=0, rtol=0)
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
|
||||
|
|
|
|||
|
|
@ -6182,12 +6182,13 @@ def meta_scaled_mm(
|
|||
out_dtype: Optional[torch.dtype] = None,
|
||||
use_fast_accum: bool = False,
|
||||
):
|
||||
def is_fp8_type(dtype):
|
||||
def is_fp8_or_fp4_type(dtype):
|
||||
return dtype in (
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e5m2fnuz,
|
||||
torch.float4_e2m1fn_x2,
|
||||
)
|
||||
|
||||
torch._check(
|
||||
|
|
@ -6195,8 +6196,8 @@ def meta_scaled_mm(
|
|||
lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}",
|
||||
)
|
||||
torch._check(
|
||||
is_fp8_type(self.dtype) and is_fp8_type(mat2.dtype),
|
||||
lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}",
|
||||
is_fp8_or_fp4_type(self.dtype) and is_fp8_or_fp4_type(mat2.dtype),
|
||||
lambda: f"Expected both inputs to be fp8 or fp4 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}",
|
||||
)
|
||||
|
||||
if device_hint(self) == "cuda":
|
||||
|
|
@ -6232,18 +6233,32 @@ def meta_scaled_mm(
|
|||
m, _k = self.shape
|
||||
n = mat2.size(1)
|
||||
|
||||
is_blockwise_scaling = (
|
||||
scale_a.dtype == torch.float8_e8m0fnu
|
||||
and scale_b.dtype == torch.float8_e8m0fnu
|
||||
) or (
|
||||
scale_a.dtype == torch.float8_e4m3fn
|
||||
and scale_b.dtype == torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
if scale_a.numel() == 1 and scale_b.numel() == 1:
|
||||
# tensorwise scaling
|
||||
torch._check(
|
||||
scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32,
|
||||
lambda: "For tensorwise scaling, both scale_a and scale_b must be float (fp32) tensors.",
|
||||
)
|
||||
elif (
|
||||
scale_a.dtype == torch.float8_e8m0fnu
|
||||
and scale_b.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
elif is_blockwise_scaling:
|
||||
# blockwise scaling
|
||||
block_size_k = 32
|
||||
|
||||
if scale_a.dtype == torch.float8_e4m3fn:
|
||||
# NVIDIA's nvfp4 recipe:
|
||||
# * block size is 16 elements packed (32 unpacked)
|
||||
# * _k needs to be translated to the unpacked version
|
||||
block_size_k = 16
|
||||
_k = _k * 2
|
||||
else:
|
||||
block_size_k = 32
|
||||
|
||||
block_size_mn = 128
|
||||
|
||||
def ceil_div(a, b):
|
||||
|
|
|
|||
|
|
@ -222,6 +222,7 @@ dtype_abbrs = {
|
|||
torch.float8_e4m3fnuz: "f8e4m3fnuz",
|
||||
torch.float8_e5m2fnuz: "f8e5m2fnuz",
|
||||
torch.float8_e8m0fnu: "f8e8m0fnu",
|
||||
torch.float4_e2m1fn_x2: "f4e2m1fnx2",
|
||||
torch.complex32: "c32",
|
||||
torch.complex64: "c64",
|
||||
torch.complex128: "c128",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user