mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] fix tensor.to(uint8) error when tensor src type is float (#157267)
The cpu inductor processes .to(torch.uint8) incorrectly, leading to numerical inconsistencies. The convert_float_to_int8 function may return incorrect results for negative inputs, such as -2.xx, when the data type is uint8_t, producing 0 instead of 255. This issue stems from the clamping logic; we should avoid converting min_val to uint8_t too early Fixes https://github.com/pytorch/pytorch/issues/156788 @leslie-fang-intel Pull Request resolved: https://github.com/pytorch/pytorch/pull/157267 Approved by: https://github.com/leslie-fang-intel
This commit is contained in:
parent
e3f2597b45
commit
6f23f53599
|
|
@ -121,27 +121,52 @@ typename std::enable_if_t<
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if_t<
|
||||
std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>,
|
||||
at::vec::Vectorized<
|
||||
T>> inline convert_float_to_int8(at::vec::Vectorized<float> src) {
|
||||
at::vec::Vectorized<T> inline convert_float_to_int8(
|
||||
at::vec::Vectorized<float> src);
|
||||
|
||||
template <>
|
||||
at::vec::Vectorized<int8_t> inline convert_float_to_int8(
|
||||
at::vec::Vectorized<float> src) {
|
||||
// Convert from float32 to int32 with truncation
|
||||
__m256i x_values_int32 = _mm256_cvttps_epi32(src);
|
||||
|
||||
// Convert from int32 to int16 using signed saturation
|
||||
__m256i xy_packed_v = _mm256_packs_epi32(x_values_int32, x_values_int32);
|
||||
|
||||
constexpr auto min_val = std::numeric_limits<T>::min();
|
||||
constexpr auto max_val = std::numeric_limits<T>::max();
|
||||
constexpr auto min_val = std::numeric_limits<int8_t>::min();
|
||||
constexpr auto max_val = std::numeric_limits<int8_t>::max();
|
||||
|
||||
// Convert from int16 to uint8/int8 using unsigned saturation
|
||||
__m256i xyzw_clamped_v =
|
||||
pack_saturate_and_clamp<T>(xy_packed_v, xy_packed_v, min_val, max_val);
|
||||
// Convert from int16 to int8 using unsigned saturation
|
||||
__m256i xyzw_clamped_v = pack_saturate_and_clamp<int8_t>(
|
||||
xy_packed_v, xy_packed_v, min_val, max_val);
|
||||
__m256i permute_mask_v =
|
||||
_mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
|
||||
return _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
|
||||
}
|
||||
|
||||
template <>
|
||||
at::vec::Vectorized<uint8_t> inline convert_float_to_int8(
|
||||
at::vec::Vectorized<float> src) {
|
||||
// The type of *_val should be int32_t to ensure correct clamping behavior.
|
||||
constexpr auto min_val = std::numeric_limits<int32_t>::min();
|
||||
constexpr auto max_val = std::numeric_limits<int32_t>::max();
|
||||
__m256 float32_min_val = _mm256_set1_ps(float(min_val));
|
||||
__m256 float32_max_val = _mm256_set1_ps(float(max_val));
|
||||
__m256 float32_src = _mm256_max_ps(src, float32_min_val);
|
||||
float32_src = _mm256_min_ps(float32_src, float32_max_val);
|
||||
__m256i truncated_src = _mm256_cvttps_epi32(float32_src);
|
||||
|
||||
__m128i r1 = _mm256_castsi256_si128(truncated_src);
|
||||
__m128i mask = _mm_setr_epi8(
|
||||
0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1);
|
||||
__m128i r1_shuffled = _mm_shuffle_epi8(r1, mask);
|
||||
__m128i r2 = _mm256_extractf128_si256(truncated_src, 1);
|
||||
__m128i r2_shuffled = _mm_shuffle_epi8(r2, mask);
|
||||
__m128i result = _mm_unpacklo_epi32(r1_shuffled, r2_shuffled);
|
||||
|
||||
return _mm256_castsi128_si256(result);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__FORCE_INLINE void QuantizeAvx2(
|
||||
const float* src,
|
||||
|
|
|
|||
|
|
@ -123,22 +123,24 @@ typename std::enable_if_t<
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if_t<
|
||||
std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>,
|
||||
at::vec::Vectorized<
|
||||
T>> inline convert_float_to_int8(at::vec::Vectorized<float> src) {
|
||||
at::vec::Vectorized<T> inline convert_float_to_int8(
|
||||
at::vec::Vectorized<float> src);
|
||||
|
||||
template <>
|
||||
at::vec::Vectorized<int8_t> inline convert_float_to_int8(
|
||||
at::vec::Vectorized<float> src) {
|
||||
// Convert from float32 to int32 with truncation
|
||||
__m512i x_values_int32 = _mm512_cvttps_epi32(src);
|
||||
|
||||
// Convert from int32 to int16 using signed saturation
|
||||
__m512i xy_packed_v = _mm512_packs_epi32(x_values_int32, x_values_int32);
|
||||
|
||||
constexpr auto min_val = std::numeric_limits<T>::min();
|
||||
constexpr auto max_val = std::numeric_limits<T>::max();
|
||||
constexpr auto min_val = std::numeric_limits<int8_t>::min();
|
||||
constexpr auto max_val = std::numeric_limits<int8_t>::max();
|
||||
|
||||
// Convert from int16 to uint8/int8 using unsigned saturation
|
||||
__m512i xyzw_clamped_v =
|
||||
pack_saturate_and_clamp<T>(xy_packed_v, xy_packed_v, min_val, max_val);
|
||||
// Convert from int16 to int8 using unsigned saturation
|
||||
__m512i xyzw_clamped_v = pack_saturate_and_clamp<int8_t>(
|
||||
xy_packed_v, xy_packed_v, min_val, max_val);
|
||||
__m512i permute_mask_v = _mm512_set_epi32(
|
||||
0x0f,
|
||||
0x0b,
|
||||
|
|
@ -159,6 +161,21 @@ typename std::enable_if_t<
|
|||
return _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v);
|
||||
}
|
||||
|
||||
template <>
|
||||
at::vec::Vectorized<uint8_t> inline convert_float_to_int8(
|
||||
at::vec::Vectorized<float> src) {
|
||||
// The type of *_val should be int32_t to ensure correct clamping behavior.
|
||||
constexpr auto min_val = std::numeric_limits<int32_t>::min();
|
||||
constexpr auto max_val = std::numeric_limits<int32_t>::max();
|
||||
__m512 float32_min_val = _mm512_set1_ps(float(min_val));
|
||||
__m512 float32_max_val = _mm512_set1_ps(float(max_val));
|
||||
__m512 float32_src = _mm512_max_ps(src, float32_min_val);
|
||||
float32_src = _mm512_min_ps(float32_src, float32_max_val);
|
||||
__m512i int32_src_clamped = _mm512_cvttps_epi32(float32_src);
|
||||
__m128i int8_src = _mm512_cvtepi32_epi8(int32_src_clamped);
|
||||
return _mm512_castsi128_si512(int8_src);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__FORCE_INLINE void QuantizeAvx512(
|
||||
const float* src,
|
||||
|
|
|
|||
|
|
@ -4330,6 +4330,19 @@ class CPUReproTests(TestCase):
|
|||
y = torch.randint(0, 255, (3, 3), dtype=torch.uint8)
|
||||
self.common(fn, (x, y))
|
||||
|
||||
def test_float32_to_uint8(self):
|
||||
# https://github.com/pytorch/pytorch/issues/156788
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
return x.to(torch.uint8)
|
||||
|
||||
x = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32, device="cpu")
|
||||
self.assertEqual(
|
||||
x.to(torch.uint8),
|
||||
fn(x),
|
||||
msg=f"Expected {x.to(torch.uint8)} but got {fn(x)}",
|
||||
)
|
||||
|
||||
def test_non_contiguous_reduction_store(self):
|
||||
# https://github.com/pytorch/pytorch/issues/113018
|
||||
class M(torch.nn.Module):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user