[inductor] fix tensor.to(uint8) error when tensor src type is float (#157267)

The cpu inductor processes .to(torch.uint8) incorrectly, leading to numerical inconsistencies. The convert_float_to_int8 function may return incorrect results for negative inputs, such as -2.xx, when the data type is uint8_t, producing 0 instead of 255. This issue stems from the clamping logic; we should avoid converting min_val to uint8_t too early
Fixes https://github.com/pytorch/pytorch/issues/156788
@leslie-fang-intel

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157267
Approved by: https://github.com/leslie-fang-intel
This commit is contained in:
thenumberouscode 2025-07-09 07:03:38 +00:00 committed by PyTorch MergeBot
parent e3f2597b45
commit 6f23f53599
3 changed files with 73 additions and 18 deletions

View File

@ -121,27 +121,52 @@ typename std::enable_if_t<
} }
template <typename T> template <typename T>
typename std::enable_if_t< at::vec::Vectorized<T> inline convert_float_to_int8(
std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>, at::vec::Vectorized<float> src);
at::vec::Vectorized<
T>> inline convert_float_to_int8(at::vec::Vectorized<float> src) { template <>
at::vec::Vectorized<int8_t> inline convert_float_to_int8(
at::vec::Vectorized<float> src) {
// Convert from float32 to int32 with truncation // Convert from float32 to int32 with truncation
__m256i x_values_int32 = _mm256_cvttps_epi32(src); __m256i x_values_int32 = _mm256_cvttps_epi32(src);
// Convert from int32 to int16 using signed saturation // Convert from int32 to int16 using signed saturation
__m256i xy_packed_v = _mm256_packs_epi32(x_values_int32, x_values_int32); __m256i xy_packed_v = _mm256_packs_epi32(x_values_int32, x_values_int32);
constexpr auto min_val = std::numeric_limits<T>::min(); constexpr auto min_val = std::numeric_limits<int8_t>::min();
constexpr auto max_val = std::numeric_limits<T>::max(); constexpr auto max_val = std::numeric_limits<int8_t>::max();
// Convert from int16 to uint8/int8 using unsigned saturation // Convert from int16 to int8 using unsigned saturation
__m256i xyzw_clamped_v = __m256i xyzw_clamped_v = pack_saturate_and_clamp<int8_t>(
pack_saturate_and_clamp<T>(xy_packed_v, xy_packed_v, min_val, max_val); xy_packed_v, xy_packed_v, min_val, max_val);
__m256i permute_mask_v = __m256i permute_mask_v =
_mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
return _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); return _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
} }
template <>
at::vec::Vectorized<uint8_t> inline convert_float_to_int8(
at::vec::Vectorized<float> src) {
// The type of *_val should be int32_t to ensure correct clamping behavior.
constexpr auto min_val = std::numeric_limits<int32_t>::min();
constexpr auto max_val = std::numeric_limits<int32_t>::max();
__m256 float32_min_val = _mm256_set1_ps(float(min_val));
__m256 float32_max_val = _mm256_set1_ps(float(max_val));
__m256 float32_src = _mm256_max_ps(src, float32_min_val);
float32_src = _mm256_min_ps(float32_src, float32_max_val);
__m256i truncated_src = _mm256_cvttps_epi32(float32_src);
__m128i r1 = _mm256_castsi256_si128(truncated_src);
__m128i mask = _mm_setr_epi8(
0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1);
__m128i r1_shuffled = _mm_shuffle_epi8(r1, mask);
__m128i r2 = _mm256_extractf128_si256(truncated_src, 1);
__m128i r2_shuffled = _mm_shuffle_epi8(r2, mask);
__m128i result = _mm_unpacklo_epi32(r1_shuffled, r2_shuffled);
return _mm256_castsi128_si256(result);
}
template <typename T> template <typename T>
__FORCE_INLINE void QuantizeAvx2( __FORCE_INLINE void QuantizeAvx2(
const float* src, const float* src,

View File

@ -123,22 +123,24 @@ typename std::enable_if_t<
} }
template <typename T> template <typename T>
typename std::enable_if_t< at::vec::Vectorized<T> inline convert_float_to_int8(
std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>, at::vec::Vectorized<float> src);
at::vec::Vectorized<
T>> inline convert_float_to_int8(at::vec::Vectorized<float> src) { template <>
at::vec::Vectorized<int8_t> inline convert_float_to_int8(
at::vec::Vectorized<float> src) {
// Convert from float32 to int32 with truncation // Convert from float32 to int32 with truncation
__m512i x_values_int32 = _mm512_cvttps_epi32(src); __m512i x_values_int32 = _mm512_cvttps_epi32(src);
// Convert from int32 to int16 using signed saturation // Convert from int32 to int16 using signed saturation
__m512i xy_packed_v = _mm512_packs_epi32(x_values_int32, x_values_int32); __m512i xy_packed_v = _mm512_packs_epi32(x_values_int32, x_values_int32);
constexpr auto min_val = std::numeric_limits<T>::min(); constexpr auto min_val = std::numeric_limits<int8_t>::min();
constexpr auto max_val = std::numeric_limits<T>::max(); constexpr auto max_val = std::numeric_limits<int8_t>::max();
// Convert from int16 to uint8/int8 using unsigned saturation // Convert from int16 to int8 using unsigned saturation
__m512i xyzw_clamped_v = __m512i xyzw_clamped_v = pack_saturate_and_clamp<int8_t>(
pack_saturate_and_clamp<T>(xy_packed_v, xy_packed_v, min_val, max_val); xy_packed_v, xy_packed_v, min_val, max_val);
__m512i permute_mask_v = _mm512_set_epi32( __m512i permute_mask_v = _mm512_set_epi32(
0x0f, 0x0f,
0x0b, 0x0b,
@ -159,6 +161,21 @@ typename std::enable_if_t<
return _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v); return _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v);
} }
template <>
at::vec::Vectorized<uint8_t> inline convert_float_to_int8(
at::vec::Vectorized<float> src) {
// The type of *_val should be int32_t to ensure correct clamping behavior.
constexpr auto min_val = std::numeric_limits<int32_t>::min();
constexpr auto max_val = std::numeric_limits<int32_t>::max();
__m512 float32_min_val = _mm512_set1_ps(float(min_val));
__m512 float32_max_val = _mm512_set1_ps(float(max_val));
__m512 float32_src = _mm512_max_ps(src, float32_min_val);
float32_src = _mm512_min_ps(float32_src, float32_max_val);
__m512i int32_src_clamped = _mm512_cvttps_epi32(float32_src);
__m128i int8_src = _mm512_cvtepi32_epi8(int32_src_clamped);
return _mm512_castsi128_si512(int8_src);
}
template <typename T> template <typename T>
__FORCE_INLINE void QuantizeAvx512( __FORCE_INLINE void QuantizeAvx512(
const float* src, const float* src,

View File

@ -4330,6 +4330,19 @@ class CPUReproTests(TestCase):
y = torch.randint(0, 255, (3, 3), dtype=torch.uint8) y = torch.randint(0, 255, (3, 3), dtype=torch.uint8)
self.common(fn, (x, y)) self.common(fn, (x, y))
def test_float32_to_uint8(self):
# https://github.com/pytorch/pytorch/issues/156788
@torch.compile
def fn(x):
return x.to(torch.uint8)
x = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32, device="cpu")
self.assertEqual(
x.to(torch.uint8),
fn(x),
msg=f"Expected {x.to(torch.uint8)} but got {fn(x)}",
)
def test_non_contiguous_reduction_store(self): def test_non_contiguous_reduction_store(self):
# https://github.com/pytorch/pytorch/issues/113018 # https://github.com/pytorch/pytorch/issues/113018
class M(torch.nn.Module): class M(torch.nn.Module):