From 6f23f53599629a47d6e097b2a027048658a142d4 Mon Sep 17 00:00:00 2001 From: thenumberouscode Date: Wed, 9 Jul 2025 07:03:38 +0000 Subject: [PATCH] [inductor] fix tensor.to(uint8) error when tensor src type is float (#157267) The cpu inductor processes .to(torch.uint8) incorrectly, leading to numerical inconsistencies. The convert_float_to_int8 function may return incorrect results for negative inputs, such as -2.xx, when the data type is uint8_t, producing 0 instead of 255. This issue stems from the clamping logic; we should avoid converting min_val to uint8_t too early Fixes https://github.com/pytorch/pytorch/issues/156788 @leslie-fang-intel Pull Request resolved: https://github.com/pytorch/pytorch/pull/157267 Approved by: https://github.com/leslie-fang-intel --- aten/src/ATen/cpu/vec/vec256/vec256_qint.h | 43 +++++++++++++++++----- aten/src/ATen/cpu/vec/vec512/vec512_qint.h | 35 +++++++++++++----- test/inductor/test_cpu_repro.py | 13 +++++++ 3 files changed, 73 insertions(+), 18 deletions(-) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h index 399cf206ab0..c9a379ddd78 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h @@ -121,27 +121,52 @@ typename std::enable_if_t< } template -typename std::enable_if_t< - std::is_same_v || std::is_same_v, - at::vec::Vectorized< - T>> inline convert_float_to_int8(at::vec::Vectorized src) { +at::vec::Vectorized inline convert_float_to_int8( + at::vec::Vectorized src); + +template <> +at::vec::Vectorized inline convert_float_to_int8( + at::vec::Vectorized src) { // Convert from float32 to int32 with truncation __m256i x_values_int32 = _mm256_cvttps_epi32(src); // Convert from int32 to int16 using signed saturation __m256i xy_packed_v = _mm256_packs_epi32(x_values_int32, x_values_int32); - constexpr auto min_val = std::numeric_limits::min(); - constexpr auto max_val = std::numeric_limits::max(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); - // Convert from int16 to uint8/int8 using unsigned saturation - __m256i xyzw_clamped_v = - pack_saturate_and_clamp(xy_packed_v, xy_packed_v, min_val, max_val); + // Convert from int16 to int8 using unsigned saturation + __m256i xyzw_clamped_v = pack_saturate_and_clamp( + xy_packed_v, xy_packed_v, min_val, max_val); __m256i permute_mask_v = _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); return _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); } +template <> +at::vec::Vectorized inline convert_float_to_int8( + at::vec::Vectorized src) { + // The type of *_val should be int32_t to ensure correct clamping behavior. + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + __m256 float32_min_val = _mm256_set1_ps(float(min_val)); + __m256 float32_max_val = _mm256_set1_ps(float(max_val)); + __m256 float32_src = _mm256_max_ps(src, float32_min_val); + float32_src = _mm256_min_ps(float32_src, float32_max_val); + __m256i truncated_src = _mm256_cvttps_epi32(float32_src); + + __m128i r1 = _mm256_castsi256_si128(truncated_src); + __m128i mask = _mm_setr_epi8( + 0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); + __m128i r1_shuffled = _mm_shuffle_epi8(r1, mask); + __m128i r2 = _mm256_extractf128_si256(truncated_src, 1); + __m128i r2_shuffled = _mm_shuffle_epi8(r2, mask); + __m128i result = _mm_unpacklo_epi32(r1_shuffled, r2_shuffled); + + return _mm256_castsi128_si256(result); +} + template __FORCE_INLINE void QuantizeAvx2( const float* src, diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h index 09789acaae1..cb73e97adef 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h @@ -123,22 +123,24 @@ typename std::enable_if_t< } template -typename std::enable_if_t< - std::is_same_v || std::is_same_v, - at::vec::Vectorized< - T>> inline convert_float_to_int8(at::vec::Vectorized src) { +at::vec::Vectorized inline convert_float_to_int8( + at::vec::Vectorized src); + +template <> +at::vec::Vectorized inline convert_float_to_int8( + at::vec::Vectorized src) { // Convert from float32 to int32 with truncation __m512i x_values_int32 = _mm512_cvttps_epi32(src); // Convert from int32 to int16 using signed saturation __m512i xy_packed_v = _mm512_packs_epi32(x_values_int32, x_values_int32); - constexpr auto min_val = std::numeric_limits::min(); - constexpr auto max_val = std::numeric_limits::max(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); - // Convert from int16 to uint8/int8 using unsigned saturation - __m512i xyzw_clamped_v = - pack_saturate_and_clamp(xy_packed_v, xy_packed_v, min_val, max_val); + // Convert from int16 to int8 using unsigned saturation + __m512i xyzw_clamped_v = pack_saturate_and_clamp( + xy_packed_v, xy_packed_v, min_val, max_val); __m512i permute_mask_v = _mm512_set_epi32( 0x0f, 0x0b, @@ -159,6 +161,21 @@ typename std::enable_if_t< return _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v); } +template <> +at::vec::Vectorized inline convert_float_to_int8( + at::vec::Vectorized src) { + // The type of *_val should be int32_t to ensure correct clamping behavior. + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + __m512 float32_min_val = _mm512_set1_ps(float(min_val)); + __m512 float32_max_val = _mm512_set1_ps(float(max_val)); + __m512 float32_src = _mm512_max_ps(src, float32_min_val); + float32_src = _mm512_min_ps(float32_src, float32_max_val); + __m512i int32_src_clamped = _mm512_cvttps_epi32(float32_src); + __m128i int8_src = _mm512_cvtepi32_epi8(int32_src_clamped); + return _mm512_castsi128_si512(int8_src); +} + template __FORCE_INLINE void QuantizeAvx512( const float* src, diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index b6a46176c27..50001c24fd0 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -4330,6 +4330,19 @@ class CPUReproTests(TestCase): y = torch.randint(0, 255, (3, 3), dtype=torch.uint8) self.common(fn, (x, y)) + def test_float32_to_uint8(self): + # https://github.com/pytorch/pytorch/issues/156788 + @torch.compile + def fn(x): + return x.to(torch.uint8) + + x = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32, device="cpu") + self.assertEqual( + x.to(torch.uint8), + fn(x), + msg=f"Expected {x.to(torch.uint8)} but got {fn(x)}", + ) + def test_non_contiguous_reduction_store(self): # https://github.com/pytorch/pytorch/issues/113018 class M(torch.nn.Module):