diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index ae762e1def3..8b283c417b7 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI) if(USE_CUDA) # To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build. # If you want to integrate a kernel from FBGEMM into torch, you have to add it here. - set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*") + set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*") file(GLOB_RECURSE fbgemm_genai_native_cuda_cu "${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu" "${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu") diff --git a/aten/src/ATen/native/cuda/ScaledBlas.cpp b/aten/src/ATen/native/cuda/ScaledBlas.cpp index 4af8ad8493a..f073df5d07c 100644 --- a/aten/src/ATen/native/cuda/ScaledBlas.cpp +++ b/aten/src/ATen/native/cuda/ScaledBlas.cpp @@ -59,22 +59,6 @@ // forward declare class cublasCommonArgs; -namespace fbgemm_gpu { - -// NOTE(slayton58): FBGemm_GPU kernels come from within the FBGemm repo. -// To update supported ops means a submodule bump, which is.. painful. Instead, we -// can simply forward-declare the methods we want to use.. Works at least as a short-term -// thing, but should still be fixed somewhere/somehow. -at::Tensor f4f4bf16( - at::Tensor, - at::Tensor, - at::Tensor, - at::Tensor, - std::optional, - bool use_mx); - -} // namespace fbgemm_gpu - using at::blas::ScalingType; using at::blas::SwizzleType; @@ -1013,47 +997,26 @@ _scaled_mxfp4_mxfp4( const std::optional& bias, const c10::ScalarType out_dtype, Tensor& out) { -#if !defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI) - TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only"); +#ifndef USE_ROCM + TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only"); #endif // Restrictions: // A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32 TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()); - // Packed FP4 format means actual-K = 2 * reported-K -- adjust - auto K_multiplier = 2; -#ifdef USE_ROCM - // AMD - auto scale_a_elems = ceil_div(K_multiplier * mat_a.size(0), 32) * mat_a.size(1); - auto scale_b_elems = ceil_div(K_multiplier * mat_b.size(1), 32) * mat_b.size(0); -#else - // NVIDIA - auto scale_a_elems = round_up(mat_a.size(0), 128) * round_up(ceil_div(K_multiplier * mat_a.size(1), 32), 4); - auto scale_b_elems = round_up(mat_b.size(1), 128) * round_up(ceil_div(K_multiplier * mat_b.size(0), 32), 4); -#endif + auto scale_a_elems = ceil_div(2 * mat_a.size(0), 32) * mat_a.size(1); + auto scale_b_elems = ceil_div(2 * mat_b.size(1), 32) * mat_b.size(0); TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(), "For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel()); TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(), "For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel()); -#ifdef USE_ROCM - // AMD - TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)"); - TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)"); -#else - // NVIDIA - TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format"); - TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format"); -#endif - TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(), "For Blockwise scaling both scales should be contiguous"); TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype); -#ifdef USE_ROCM - // AMD auto scaling_choice_a = ScalingType::BlockWise1x32; auto scaling_choice_b = ScalingType::BlockWise1x32; @@ -1068,29 +1031,11 @@ _scaled_mxfp4_mxfp4( TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 || out.scalar_type() == ScalarType::Half, "Block-wise scaling only supports BFloat16 or Half output types"); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later"); #endif return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); -#else - // NVIDIA - // NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor, - // but we have one we need to use. Two clear options are to copy into - // our output (slow), or use a move-assignment-operator (faster). - // However, the compiler can complain about the explicit move preventing - // copy elision because the return from f4f4bf16 is a temporary object. - // So we don't explicitly move, and trust the compiler here... - // In the longer term this should be fixed on the FBGemm side. - out = fbgemm_gpu::f4f4bf16( - mat_a, - mat_b.transpose(-2, -1), - scale_a, - scale_b, - std::nullopt, /* global_scale */ - true /* use_mx */ - ); - - return out; -#endif } Tensor& @@ -1215,20 +1160,17 @@ _scaled_mm_cuda_v2_out( mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")"); } - // Handle fp4 packed-K dimension - int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1; - TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1], " but got ", bias->numel()); TORCH_CHECK_VALUE( - K_multiplier * mat_a.sizes()[1] % 16 == 0, + mat_a.sizes()[1] % 16 == 0, "Expected trailing dimension of mat1 to be divisible by 16 ", "but got mat1 shape: (", mat_a.sizes()[0], "x", - K_multiplier * mat_a.sizes()[1], + mat_a.sizes()[1], ")."); - TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x", + TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x", mat_b.sizes()[1], ") must be divisible by 16"); // TODO(slayton): Existing checks, not sure if they should really be here. diff --git a/test/test_scaled_matmul_cuda.py b/test/test_scaled_matmul_cuda.py index 95eb22d84cb..258c7570c77 100644 --- a/test/test_scaled_matmul_cuda.py +++ b/test/test_scaled_matmul_cuda.py @@ -144,36 +144,42 @@ def infer_scale_swizzle(mat, scale): ] == math.ceil(mat.shape[1] // 128): return ScalingType.BlockWise128x128, SwizzleType.NO_SWIZZLE - # if we're checking for nvfp4, need to adjust for packed-K - K_multiplier = 2 if mat.dtype == torch.float4_e2m1fn_x2 else 1 # NVFP4 if ( (scale.numel() - == round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 16), 4) + == round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4) or scale.numel() - == round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 16), 4)) + == round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4)) and mat.dtype == torch.float4_e2m1fn_x2 and scale.dtype == torch.float8_e4m3fn ): return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4 - # MX formats + # MXFP4 w/o swizzle + if ( + (scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1] + or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0]) + and mat.dtype == torch.float4_e2m1fn_x2 + and scale.dtype == torch.float8_e8m0fnu + ): + return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE + if not torch.version.hip: - # MX w/swizzle (NVIDIA) + # MXFP8 w/ swizzle if ( (scale.numel() - == round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 32), 4) + == round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4) or scale.numel() - == round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 32), 4)) + == round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)) and scale.dtype == torch.float8_e8m0fnu ): return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4 else: - # MX w/o swizzle (AMD) + # MXFP8 w/o swizzle if ( - (scale.numel() == math.ceil(mat.shape[0] // 32) * K_multiplier * mat.shape[1] - or scale.numel() == math.ceil(K_multiplier * mat.shape[1] // 32) * mat.shape[0]) + (scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1] + or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0]) and scale.dtype == torch.float8_e8m0fnu ): return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE @@ -1601,7 +1607,7 @@ class TestFP8Matmul(TestCase): (127, 96, 1024), (1025, 128, 96) ], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}") - @parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"]) + @parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"]) def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None: if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum: raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping") @@ -1615,12 +1621,8 @@ class TestFP8Matmul(TestCase): if not (M % 16 == 0 and K % 128 == 0 and N % 16 == 0): raise unittest.SkipTest("M and N must be multiples of 16 and K must be multiple of 128 on ROCm, skipping") - fp4_scaling_dtype = torch.float8_e8m0fnu if recipe == "mxfp4" else torch.float8_e4m3fn - BLOCK_SIZE = 16 if recipe == "nvfp4" else 32 - - if K % BLOCK_SIZE != 0: - raise unittest.SkipTest(f"K ({K}) must be divisible by BLOCK_SIZE ({BLOCK_SIZE}), skipping") - + fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn + BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32) require_exact_match = True approx_match_sqnr_target = 22.0 @@ -1798,7 +1800,7 @@ class TestFP8Matmul(TestCase): B = B.clamp(min=min_val, max=max_val) B = _bfloat16_to_float4_e2m1fn_x2(B) - approx_match_sqnr_target = 15 if recipe == "mxfp4" else 15.8 + approx_match_sqnr_target = 15 if torch.version.hip else 15.8 C_ref = A_ref @ B_ref.t()