mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Add CUDA MXFP4 scaled mm support via. FBGEMM (#166526)"
This reverts commit e3ae0594d1.
Reverted https://github.com/pytorch/pytorch/pull/166526 on behalf of https://github.com/atalman due to Failing internal test ([comment](https://github.com/pytorch/pytorch/pull/166526#issuecomment-3474907536))
This commit is contained in:
parent
d97144d31e
commit
93a70c717a
|
|
@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
|
|||
if(USE_CUDA)
|
||||
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
|
||||
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*")
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
|
||||
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
|
||||
|
|
|
|||
|
|
@ -59,22 +59,6 @@
|
|||
// forward declare
|
||||
class cublasCommonArgs;
|
||||
|
||||
namespace fbgemm_gpu {
|
||||
|
||||
// NOTE(slayton58): FBGemm_GPU kernels come from <fbgemm_gpu/torch_ops.h> within the FBGemm repo.
|
||||
// To update supported ops means a submodule bump, which is.. painful. Instead, we
|
||||
// can simply forward-declare the methods we want to use.. Works at least as a short-term
|
||||
// thing, but should still be fixed somewhere/somehow.
|
||||
at::Tensor f4f4bf16(
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
std::optional<at::Tensor>,
|
||||
bool use_mx);
|
||||
|
||||
} // namespace fbgemm_gpu
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
|
|
@ -1013,47 +997,26 @@ _scaled_mxfp4_mxfp4(
|
|||
const std::optional<Tensor>& bias,
|
||||
const c10::ScalarType out_dtype,
|
||||
Tensor& out) {
|
||||
#if !defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI)
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
|
||||
#ifndef USE_ROCM
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only");
|
||||
#endif
|
||||
// Restrictions:
|
||||
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
|
||||
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
|
||||
// Packed FP4 format means actual-K = 2 * reported-K -- adjust
|
||||
auto K_multiplier = 2;
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
auto scale_a_elems = ceil_div<int64_t>(K_multiplier * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(K_multiplier * mat_b.size(1), 32) * mat_b.size(0);
|
||||
#else
|
||||
// NVIDIA
|
||||
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_a.size(1), 32), 4);
|
||||
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_b.size(0), 32), 4);
|
||||
#endif
|
||||
auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0);
|
||||
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
|
||||
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
|
||||
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
|
||||
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)");
|
||||
#else
|
||||
// NVIDIA
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
|
||||
"For Blockwise scaling both scales should be contiguous");
|
||||
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x32;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x32;
|
||||
|
||||
|
|
@ -1068,29 +1031,11 @@ _scaled_mxfp4_mxfp4(
|
|||
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
|
||||
out.scalar_type() == ScalarType::Half,
|
||||
"Block-wise scaling only supports BFloat16 or Half output types");
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
|
||||
#endif
|
||||
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
#else
|
||||
// NVIDIA
|
||||
// NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor,
|
||||
// but we have one we need to use. Two clear options are to copy into
|
||||
// our output (slow), or use a move-assignment-operator (faster).
|
||||
// However, the compiler can complain about the explicit move preventing
|
||||
// copy elision because the return from f4f4bf16 is a temporary object.
|
||||
// So we don't explicitly move, and trust the compiler here...
|
||||
// In the longer term this should be fixed on the FBGemm side.
|
||||
out = fbgemm_gpu::f4f4bf16(
|
||||
mat_a,
|
||||
mat_b.transpose(-2, -1),
|
||||
scale_a,
|
||||
scale_b,
|
||||
std::nullopt, /* global_scale */
|
||||
true /* use_mx */
|
||||
);
|
||||
|
||||
return out;
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
|
|
@ -1215,20 +1160,17 @@ _scaled_mm_cuda_v2_out(
|
|||
mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")");
|
||||
}
|
||||
|
||||
// Handle fp4 packed-K dimension
|
||||
int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1;
|
||||
|
||||
TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1],
|
||||
" but got ", bias->numel());
|
||||
TORCH_CHECK_VALUE(
|
||||
K_multiplier * mat_a.sizes()[1] % 16 == 0,
|
||||
mat_a.sizes()[1] % 16 == 0,
|
||||
"Expected trailing dimension of mat1 to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes()[0],
|
||||
"x",
|
||||
K_multiplier * mat_a.sizes()[1],
|
||||
mat_a.sizes()[1],
|
||||
").");
|
||||
TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
|
||||
TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
|
||||
mat_b.sizes()[1], ") must be divisible by 16");
|
||||
|
||||
// TODO(slayton): Existing checks, not sure if they should really be here.
|
||||
|
|
|
|||
|
|
@ -144,36 +144,42 @@ def infer_scale_swizzle(mat, scale):
|
|||
] == math.ceil(mat.shape[1] // 128):
|
||||
return ScalingType.BlockWise128x128, SwizzleType.NO_SWIZZLE
|
||||
|
||||
# if we're checking for nvfp4, need to adjust for packed-K
|
||||
K_multiplier = 2 if mat.dtype == torch.float4_e2m1fn_x2 else 1
|
||||
# NVFP4
|
||||
if (
|
||||
(scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 16), 4)
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 16), 4))
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4))
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e4m3fn
|
||||
):
|
||||
return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
# MX formats
|
||||
# MXFP4 w/o swizzle
|
||||
if (
|
||||
(scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0])
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
|
||||
if not torch.version.hip:
|
||||
# MX w/swizzle (NVIDIA)
|
||||
# MXFP8 w/ swizzle
|
||||
if (
|
||||
(scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 32), 4)
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 32), 4))
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4))
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
else:
|
||||
# MX w/o swizzle (AMD)
|
||||
# MXFP8 w/o swizzle
|
||||
if (
|
||||
(scale.numel() == math.ceil(mat.shape[0] // 32) * K_multiplier * mat.shape[1]
|
||||
or scale.numel() == math.ceil(K_multiplier * mat.shape[1] // 32) * mat.shape[0])
|
||||
(scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0])
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
|
|
@ -1601,7 +1607,7 @@ class TestFP8Matmul(TestCase):
|
|||
(127, 96, 1024),
|
||||
(1025, 128, 96)
|
||||
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
|
||||
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
|
||||
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
|
||||
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
|
||||
|
|
@ -1615,12 +1621,8 @@ class TestFP8Matmul(TestCase):
|
|||
if not (M % 16 == 0 and K % 128 == 0 and N % 16 == 0):
|
||||
raise unittest.SkipTest("M and N must be multiples of 16 and K must be multiple of 128 on ROCm, skipping")
|
||||
|
||||
fp4_scaling_dtype = torch.float8_e8m0fnu if recipe == "mxfp4" else torch.float8_e4m3fn
|
||||
BLOCK_SIZE = 16 if recipe == "nvfp4" else 32
|
||||
|
||||
if K % BLOCK_SIZE != 0:
|
||||
raise unittest.SkipTest(f"K ({K}) must be divisible by BLOCK_SIZE ({BLOCK_SIZE}), skipping")
|
||||
|
||||
fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
|
||||
BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32)
|
||||
require_exact_match = True
|
||||
approx_match_sqnr_target = 22.0
|
||||
|
||||
|
|
@ -1798,7 +1800,7 @@ class TestFP8Matmul(TestCase):
|
|||
B = B.clamp(min=min_val, max=max_val)
|
||||
B = _bfloat16_to_float4_e2m1fn_x2(B)
|
||||
|
||||
approx_match_sqnr_target = 15 if recipe == "mxfp4" else 15.8
|
||||
approx_match_sqnr_target = 15 if torch.version.hip else 15.8
|
||||
|
||||
C_ref = A_ref @ B_ref.t()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user