[Inductor][ATen][FP8] Relax stride check for block-wise scaling when scaling dimension is 1 (#163829)

Summary: Relax stride check for block-wise scaling (1x128, 128x128) when a dimension of the scaling factor is 1. When the scaling tensor has a dimension of size 1, the stride is effectively "meaningless" to PyTorch, i.e. PyTorch decides to replace its stride with a default of `[1, 1]`. However, the old stride check required the stride to match one of the scaling dimensions. Here, we relax the stride check when the effective stride is 1 in order to allow for cases in which `K <= 128` and `N <= 128`.

Test Plan:
```
pytest -s -v test/test_matmul_cuda.py::TestFP8MatmulCUDA::test_scaled_mm_vs_emulated_block_wise_float32_lhs_block_1_rhs_block_128_cuda   2>&1 | tee ~/personal/stride_check.log
```

Differential Revision: D83023706

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163829
Approved by: https://github.com/lw, https://github.com/eqy
This commit is contained in:
Janani Sriram 2025-09-29 17:28:23 +00:00 committed by PyTorch MergeBot
parent 6b473c90cf
commit e2c894c97d
2 changed files with 30 additions and 9 deletions

View File

@ -1124,6 +1124,17 @@ bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
&& scale.is_contiguous());
}
bool check_size_stride(const at::Tensor& scale, int dim, int size, int stride) {
// For Blockwise1x128 and Blockwise128x128,
// when the scale tensor has a dimension of size 1, the stride is effectively
// "meaningless", i.e. PyTorch decides to use a stride of 1. Thus, the regular
// stride check fails. Here, we relax the stride check when the effective
// stride is 1.
return (
scale.size(dim) == size && (size <= 1 || scale.stride(dim) == stride));
}
// 1x16 blocks for packed nvfp4 data and fp8_e4m3fn scales
bool is_blockwise_1x16_scaling(const at::Tensor& t, const at::Tensor& scale) {
// Multiply t.size(1) by 2 to adjust for fp4x2 packing
@ -1149,15 +1160,24 @@ bool is_blockwise_1x32_scaling(const at::Tensor& t, const at::Tensor& scale) {
}
bool is_blockwise_1x128_scaling(const at::Tensor& t, const at::Tensor& scale) {
return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2
&& scale.size(0) == t.size(0) && scale.size(1) == ceil_div<int64_t>(t.size(1), 128)
&& scale.stride(0) == 1 && scale.stride(1) == t.size(0));
return (
isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat &&
scale.dim() == 2 && check_size_stride(scale, 0, t.size(0), 1) &&
check_size_stride(
scale, 1, ceil_div<int64_t>(t.size(1), 128), t.size(0)));
}
bool is_blockwise_128x128_scaling(const at::Tensor& t, const at::Tensor& scale) {
return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2
&& scale.size(0) == ceil_div<int64_t>(t.size(0), 128) && scale.size(1) == ceil_div<int64_t>(t.size(1), 128)
&& scale.stride(0) == round_up<int64_t>(ceil_div<int64_t>(t.size(1), 128), 4) && scale.stride(1) == 1);
return (
isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat &&
scale.dim() == 2 &&
check_size_stride(
scale,
0,
ceil_div<int64_t>(t.size(0), 128),
round_up<int64_t>(ceil_div<int64_t>(t.size(1), 128), 4)) &&
check_size_stride(
scale, 1, ceil_div<int64_t>(t.size(1), 128), 1));
}
bool is_desired_scaling(const at::Tensor& t, const at::Tensor& scale, ScalingType desired_scaling) {

View File

@ -1579,11 +1579,12 @@ class TestFP8Matmul(TestCase):
)
@parametrize("output_dtype", [torch.bfloat16, torch.float32])
@parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)])
def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block):
@parametrize("M,N,K", [(256, 768, 512), (256, 128, 256), (256, 256, 128)])
def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block, M, N, K):
torch.manual_seed(42)
x = torch.randn(256, 512, device="cuda", dtype=output_dtype).pow(3)
y = torch.randn(768, 512, device="cuda", dtype=output_dtype).pow(3)
x = torch.randn(M, K, device="cuda", dtype=output_dtype).pow(3)
y = torch.randn(N, K, device="cuda", dtype=output_dtype).pow(3)
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)