[ROCm] Add scaled_mm v2 support. (#165528)

Add mx fp4 support in Blas.cpp.
Updated the scale_kernel_dispatch array and ScaledGemmImplementation enum to include MXFP4 support.
Modify the tests under test_scaled_matmul_cuda accordingly.

PYTORCH_TEST_WITH_ROCM=1 python test/test_scaled_matmul_cuda.py -v -k test_blockwise
115 test passed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165528
Approved by: https://github.com/jeffdaily
This commit is contained in:
Jagadish Krishnamoorthy 2025-10-16 18:36:37 +00:00 committed by PyTorch MergeBot
parent 86fd4fc23e
commit 7669ac9402
2 changed files with 123 additions and 11 deletions

View File

@ -1759,6 +1759,7 @@ enum class ScaledGemmImplementation {
MXFP8_MXFP8 = 6,
NVFP4_NVFP4 = 7,
NVFP4_NVFP4_SINGLE_SCALE = 8,
MXFP4_MXFP4 = 9,
};
/**
@ -1955,10 +1956,39 @@ bool check_mxfp8_recipe(c10::ScalarType type_a,
return true;
}
/**
* Both inputs must be fp4
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
*/
bool check_mxfp4_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp4
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
return false;
}
// 1 scales, 1 recipes for each input
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x32, e8m0} for A & B
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
return true;
}
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
using namespace std::placeholders;
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8> scale_kernel_dispatch = {{
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 9> scale_kernel_dispatch = {{
{ "tensorwise_tensorwise", check_tensorwise_recipe, ScaledGemmImplementation::TENSORWISE_TENSORWISE },
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
{ "block_1x128_128x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise128x128, _1, _2, _3, _4, _5, _6),
@ -1969,7 +1999,8 @@ std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8>
ScaledGemmImplementation::BLOCK_1x128_1x128},
{ "nvfp4_nvfp4", check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4},
{ "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8},
{ "mxfp4_mxfp4", check_mxfp4_recipe, ScaledGemmImplementation::MXFP4_MXFP4}}};
Tensor&
_scaled_tensorwise_tensorwise(
@ -2187,15 +2218,22 @@ _scaled_mxfp8_mxfp8(
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
#ifdef USE_ROCM
auto scale_a_elems = ceil_div<int64_t>(mat_a.size(0), 32) * mat_a.size(1);
auto scale_b_elems = ceil_div<int64_t>(mat_b.size(1), 32) * mat_b.size(0);
#else
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_a.size(1), 32), 4);
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_b.size(0), 32), 4);
#endif
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
#ifndef USE_ROCM
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
#endif
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
"For Blockwise scaling both scales should be contiguous");
@ -2225,6 +2263,56 @@ _scaled_mxfp8_mxfp8(
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
}
Tensor&
_scaled_mxfp4_mxfp4(
const Tensor& mat_a, const Tensor& mat_b,
const Tensor& scale_a, const SwizzleType swizzle_a,
const Tensor& scale_b, const SwizzleType swizzle_b,
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
Tensor& out) {
#ifndef USE_ROCM
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only");
#endif
// Restrictions:
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1);
auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0);
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
"For Blockwise scaling both scales should be contiguous");
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
auto scaling_choice_a = ScalingType::BlockWise1x32;
auto scaling_choice_b = ScalingType::BlockWise1x32;
#if ROCM_VERSION >= 70000
TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
TORCH_CHECK_VALUE(mat_a.size(0) % 32 == 0 && mat_a.size(1) % 32 == 0 &&
mat_b.size(0) % 32 == 0 && mat_b.size(1) % 32 == 0,
"Matrix dimensions must be multiples of 32 for block-wise scaling");
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
out.scalar_type() == ScalarType::Half,
"Block-wise scaling only supports BFloat16 or Half output types");
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
#endif
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
}
Tensor&
_scaled_nvfp4_nvfp4(
const Tensor& mat_a, const Tensor& mat_b,
@ -2468,6 +2556,8 @@ _scaled_mm_cuda_v2_out(
TORCH_CHECK_NOT_IMPLEMENTED(false, "Only single-scale NVFP4 currently supported");
} else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE) {
return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, true /* single_scale */, out);
} else if (gemm_impl == ScaledGemmImplementation::MXFP4_MXFP4) {
return _scaled_mxfp4_mxfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
} else {
TORCH_CHECK_VALUE(false, "Invalid state - found an implementation, but not really");
}

View File

@ -152,15 +152,34 @@ def infer_scale_swizzle(mat, scale):
):
return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
# MX
# MXFP4 w/o swizzle
if (
scale.numel()
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
or scale.numel()
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)
scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1]
or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0]
and mat.dtype == torch.float4_e2m1fn_x2
and scale.dtype == torch.float8_e8m0fnu
):
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
if not torch.version.hip:
# MXFP8 w/ swizzle
if (
scale.numel()
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
or scale.numel()
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)
and scale.dtype == torch.float8_e8m0fnu
):
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
else:
# MXFP8 w/o swizzle
if (
scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1]
or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0]
and scale.dtype == torch.float8_e8m0fnu
):
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
return None, None
@ -1489,7 +1508,7 @@ class TestFP8Matmul(TestCase):
assert sqnr.item() > approx_match_sqnr_target
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM or IS_WINDOWS, mx_skip_msg)
@parametrize("recipe", ["mxfp8", "nvfp4"])
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None:
M, K, N = (1024, 512, 2048)
BLOCK_SIZE_K = 16 if recipe == "nvfp4" else 32
@ -1503,7 +1522,7 @@ class TestFP8Matmul(TestCase):
if recipe == "mxfp8":
x_lowp = x.to(e4m3_type)
y_lowp = y.to(e4m3_type).t()
else: # nvfp4
else: # nvfp4 #mxfp4
x_lowp = _bfloat16_to_float4_e2m1fn_x2(x.bfloat16())
y_lowp = _bfloat16_to_float4_e2m1fn_x2(y.bfloat16()).t()
@ -1517,7 +1536,10 @@ class TestFP8Matmul(TestCase):
if recipe == "nvfp4"
else ScalingType.BlockWise1x32
)
swizzle = SwizzleType.SWIZZLE_32_4_4
if torch.version.hip:
swizzle = SwizzleType.NO_SWIZZLE
else:
swizzle = SwizzleType.SWIZZLE_32_4_4
# Test wrong scale tensor size for scale_a with correct dtype
with self.assertRaisesRegex(