mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor][ATen][FP8] Relax stride check for block-wise scaling when scaling dimension is 1 (#163829)
Summary: Relax stride check for block-wise scaling (1x128, 128x128) when a dimension of the scaling factor is 1. When the scaling tensor has a dimension of size 1, the stride is effectively "meaningless" to PyTorch, i.e. PyTorch decides to replace its stride with a default of `[1, 1]`. However, the old stride check required the stride to match one of the scaling dimensions. Here, we relax the stride check when the effective stride is 1 in order to allow for cases in which `K <= 128` and `N <= 128`. Test Plan: ``` pytest -s -v test/test_matmul_cuda.py::TestFP8MatmulCUDA::test_scaled_mm_vs_emulated_block_wise_float32_lhs_block_1_rhs_block_128_cuda 2>&1 | tee ~/personal/stride_check.log ``` Differential Revision: D83023706 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163829 Approved by: https://github.com/lw, https://github.com/eqy
This commit is contained in:
parent
6b473c90cf
commit
e2c894c97d
|
|
@ -1124,6 +1124,17 @@ bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
|||
&& scale.is_contiguous());
|
||||
}
|
||||
|
||||
bool check_size_stride(const at::Tensor& scale, int dim, int size, int stride) {
|
||||
// For Blockwise1x128 and Blockwise128x128,
|
||||
// when the scale tensor has a dimension of size 1, the stride is effectively
|
||||
// "meaningless", i.e. PyTorch decides to use a stride of 1. Thus, the regular
|
||||
// stride check fails. Here, we relax the stride check when the effective
|
||||
// stride is 1.
|
||||
|
||||
return (
|
||||
scale.size(dim) == size && (size <= 1 || scale.stride(dim) == stride));
|
||||
}
|
||||
|
||||
// 1x16 blocks for packed nvfp4 data and fp8_e4m3fn scales
|
||||
bool is_blockwise_1x16_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
// Multiply t.size(1) by 2 to adjust for fp4x2 packing
|
||||
|
|
@ -1149,15 +1160,24 @@ bool is_blockwise_1x32_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
|||
}
|
||||
|
||||
bool is_blockwise_1x128_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2
|
||||
&& scale.size(0) == t.size(0) && scale.size(1) == ceil_div<int64_t>(t.size(1), 128)
|
||||
&& scale.stride(0) == 1 && scale.stride(1) == t.size(0));
|
||||
return (
|
||||
isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat &&
|
||||
scale.dim() == 2 && check_size_stride(scale, 0, t.size(0), 1) &&
|
||||
check_size_stride(
|
||||
scale, 1, ceil_div<int64_t>(t.size(1), 128), t.size(0)));
|
||||
}
|
||||
|
||||
bool is_blockwise_128x128_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2
|
||||
&& scale.size(0) == ceil_div<int64_t>(t.size(0), 128) && scale.size(1) == ceil_div<int64_t>(t.size(1), 128)
|
||||
&& scale.stride(0) == round_up<int64_t>(ceil_div<int64_t>(t.size(1), 128), 4) && scale.stride(1) == 1);
|
||||
return (
|
||||
isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat &&
|
||||
scale.dim() == 2 &&
|
||||
check_size_stride(
|
||||
scale,
|
||||
0,
|
||||
ceil_div<int64_t>(t.size(0), 128),
|
||||
round_up<int64_t>(ceil_div<int64_t>(t.size(1), 128), 4)) &&
|
||||
check_size_stride(
|
||||
scale, 1, ceil_div<int64_t>(t.size(1), 128), 1));
|
||||
}
|
||||
|
||||
bool is_desired_scaling(const at::Tensor& t, const at::Tensor& scale, ScalingType desired_scaling) {
|
||||
|
|
|
|||
|
|
@ -1579,11 +1579,12 @@ class TestFP8Matmul(TestCase):
|
|||
)
|
||||
@parametrize("output_dtype", [torch.bfloat16, torch.float32])
|
||||
@parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)])
|
||||
def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block):
|
||||
@parametrize("M,N,K", [(256, 768, 512), (256, 128, 256), (256, 256, 128)])
|
||||
def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block, M, N, K):
|
||||
torch.manual_seed(42)
|
||||
|
||||
x = torch.randn(256, 512, device="cuda", dtype=output_dtype).pow(3)
|
||||
y = torch.randn(768, 512, device="cuda", dtype=output_dtype).pow(3)
|
||||
x = torch.randn(M, K, device="cuda", dtype=output_dtype).pow(3)
|
||||
y = torch.randn(N, K, device="cuda", dtype=output_dtype).pow(3)
|
||||
|
||||
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
|
||||
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user