From 3c7edf1ec003e4f429b770c5397f27f45d585f8d Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Sun, 22 Sep 2024 23:01:08 -0700 Subject: [PATCH] [Inductor][CPP] Fix int8 cvt half (#136353) Fix the correctness issue of https://github.com/pytorch/ao/pull/884/. The current implementation for converting between `Half/BFloat16` and `int8/uint8` incorrectly assumes that 1/4 of the int8/uint8 vector lane maps to 1/2 of the Half/BFloat16 vector lane. This assumption leads to accuracy issues after the full bit-width vectorization of the Half data type was introduced. When converting between int8 weights and the half data type, the generated code is as the following: ``` #include "/tmp/torchinductor_leslie/xw/cxww3s7wxrujoyxna7mlcjktid2uu6nntixqwm542xfkd756gl3x.h" extern "C" void kernel(const int8_t* in_ptr0, half* out_ptr0) { { for(int64_t x0=static_cast(0L); x0(2048L); x0+=static_cast(32L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x0), static_cast(32)); auto tmp1 = at::vec::convert(tmp0); tmp1.store(out_ptr0 + static_cast(x0), static_cast(32)); } } } ``` In this PR, we address the issue by changing the implementation to convert 1/2 of the int8/uint8 vector lane into a full vector lane of Half/BFloat16. **TestPlan** * AO: `python test/integration/test_integration.py -k test_int8_weight_only_quant_subclass_api` * `python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_convert_int8_to_half_vec` * Due to the CPP backend legalization pass, we are unable to create a unit test to simulate the conversion from `Half` to `int8`. Instead, we rely on a C++ test case. * `./build/bin/vec_test_all_types_AVX512 --gtest_filter="VecConvertTestsReducedFloat/*.ConvertReduced"` * `./build/bin/vec_test_all_types_AVX2 --gtest_filter="VecConvertTestsReducedFloat/*.ConvertReduced"` Pull Request resolved: https://github.com/pytorch/pytorch/pull/136353 Approved by: https://github.com/jgong5, https://github.com/jerryzh168 --- aten/src/ATen/cpu/vec/vec256/vec256_convert.h | 42 ++++++++++++++++- aten/src/ATen/cpu/vec/vec512/vec512_convert.h | 39 +++++++++++++++- aten/src/ATen/test/vec_test_all_types.cpp | 45 +++++++++++++++++++ test/inductor/test_cpu_repro.py | 20 +++++++++ 4 files changed, 142 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_convert.h b/aten/src/ATen/cpu/vec/vec256/vec256_convert.h index b0f109fc875..06eb7f82346 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_convert.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_convert.h @@ -208,8 +208,27 @@ struct VecConvert< (is_reduced_floating_point_v && is_8bit_integer_v), void>> { static inline VectorizedN apply(const VectorizedN& src) { - VectorizedN tmp_fp32 = VecConvert::apply(src); - return VecConvert::apply(tmp_fp32); + VectorizedN tmp_fp32 = VecConvert::apply(src); + return VecConvert::apply(tmp_fp32); + } +}; + +template +struct VecConvert< + dst_t, + 1, + float, + 2, + typename std::enable_if_t, + void>> { + static inline VectorizedN apply(const VectorizedN& src) { + at::vec::Vectorized vec1 = convert_float_to_int8(src[0]); + at::vec::Vectorized vec2 = convert_float_to_int8(src[1]); + __m128 lane2 = _mm256_castps256_ps128(_mm256_castsi256_ps(vec2)); + __m256 combined = _mm256_insertf128_ps(_mm256_castsi256_ps(vec1), lane2, 1); + // Shuffle [191:128] bit from combined in to [127:64] bit of result + __m256i result = _mm256_permute4x64_epi64(_mm256_castps_si256(combined), 0b11011000); + return at::vec::Vectorized(result); } }; @@ -226,6 +245,25 @@ struct VecConvert< } }; +template +struct VecConvert< + float, + 2, + src_t, + 1, + typename std::enable_if_t, + void>> { + static inline VectorizedN apply(const VectorizedN& src) { + // Shuffle [127:64] bit from src[0] in to [191:128] bit of shuffled + __m256i shuffled = _mm256_permute4x64_epi64(src[0], 0b11011000); + __m256i src2 = _mm256_castsi128_si256( + _mm_castps_si128( + _mm256_extractf128_ps(_mm256_castsi256_ps(shuffled), 1) // Extract the second 128-bit lane + ) + ); + return VectorizedN(convert_int8_to_float(src[0]), convert_int8_to_float(src2)); + } +}; template struct VecConvert< diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_convert.h b/aten/src/ATen/cpu/vec/vec512/vec512_convert.h index 78c7045fb30..cfb4ddb1373 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_convert.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_convert.h @@ -209,8 +209,25 @@ struct VecConvert< (is_reduced_floating_point_v && is_8bit_integer_v), void>> { static inline VectorizedN apply(const VectorizedN& src) { - VectorizedN tmp_fp32 = VecConvert::apply(src); - return VecConvert::apply(tmp_fp32); + VectorizedN tmp_fp32 = VecConvert::apply(src); + return VecConvert::apply(tmp_fp32); + } +}; + +template +struct VecConvert< + dst_t, + 1, + float, + 2, + typename std::enable_if_t, + void>> { + static inline VectorizedN apply(const VectorizedN& src) { + at::vec::Vectorized vec1 = convert_float_to_int8(src[0]); + at::vec::Vectorized vec2 = convert_float_to_int8(src[1]); + __m128 lane2 = _mm512_castps512_ps128(_mm512_castsi512_ps(vec2)); + __m512 result = _mm512_insertf32x4(_mm512_castsi512_ps(vec1), lane2, 1); // Insert lane2 into the second 128-bit lane + return at::vec::Vectorized(_mm512_castps_si512(result)); } }; @@ -227,6 +244,24 @@ struct VecConvert< } }; +template +struct VecConvert< + float, + 2, + src_t, + 1, + typename std::enable_if_t, + void>> { + static inline VectorizedN apply(const VectorizedN& src) { + __m512i src2 = _mm512_castsi128_si512( + _mm_castps_si128( + _mm512_extractf32x4_ps(_mm512_castsi512_ps(src[0]), 1) // Extract the second 128-bit lane + ) + ); + return VectorizedN(convert_int8_to_float(src[0]), convert_int8_to_float(src2)); + } +}; + template struct VecConvert< float, diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index a2c8da12c44..042bb56d6ff 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -71,6 +71,8 @@ namespace { template class VecConvertTests : public ::testing::Test {}; template + class VecConvertTestsReducedFloat : public ::testing::Test {}; + template class VecMaskTests : public ::testing::Test {}; using RealFloatTestedTypes = ::testing::Types; using FloatTestedTypes = ::testing::Types; @@ -121,6 +123,7 @@ namespace { TYPED_TEST_SUITE(FunctionalTests, RealFloatIntTestedTypes); TYPED_TEST_SUITE(FunctionalTestsReducedFloat, ReducedFloatTestedTypes); TYPED_TEST_SUITE(VecConvertTests, RealFloatIntTestedTypes); + TYPED_TEST_SUITE(VecConvertTestsReducedFloat, ReducedFloatTestedTypes); TYPED_TEST_SUITE(VecMaskTests, RealFloatIntTestedTypes); TYPED_TEST(Memory, UnAlignedLoadStore) { using vec = TypeParam; @@ -1663,6 +1666,48 @@ namespace { TEST_CONVERT_TO(double); #undef TEST_CONVERT_TO } + TYPED_TEST(VecConvertTestsReducedFloat, ConvertReduced) { + using vec = TypeParam; + using src_t = UholdType; + constexpr auto N = vec::size(); + #define TEST_CONVERT_TO(dst_t) \ + do { \ + CACHE_ALIGN src_t x[N]; \ + CACHE_ALIGN dst_t y[N]; \ + CACHE_ALIGN dst_t ref[N]; \ + auto seed = TestSeed(); \ + auto low = std::is_signed_v ? src_t(-100.0) : src_t(0); \ + ValueGen generator(low, src_t(100), seed); \ + for (const auto i : c10::irange(N)) { \ + x[i] = generator.get(); \ + } \ + for (const auto i : c10::irange(N)) { \ + ref[i] = static_cast(x[i]); \ + } \ + auto x_vec = vec::loadu(x); \ + auto y_vec = at::vec::convert(x_vec); \ + constexpr int num_dst_elements = \ + std::min(N, at::vec::Vectorized::size()); \ + y_vec.store(y, num_dst_elements); \ + for (const auto i : c10::irange(num_dst_elements)) { \ + ASSERT_EQ(y[i], ref[i]) \ + << "Failure Details:\nTest Seed to reproduce: " << seed \ + << " x[" << i << "]=" << x[i] << " dst_t=" #dst_t; \ + } \ + constexpr int dst_n = N / num_dst_elements; \ + auto y_vec_n = at::vec::convert( \ + at::vec::VectorizedN(x_vec)); \ + y_vec_n.store(y, N); \ + for (const auto i : c10::irange(N)) { \ + ASSERT_EQ(y[i], ref[i]) \ + << "Failure Details:\nTest Seed to reproduce: " << seed \ + << " x[" << i << "]=" << x[i] << " dst_t=" #dst_t; \ + } \ + } while (0) + TEST_CONVERT_TO(int8_t); + TEST_CONVERT_TO(uint8_t); + #undef TEST_CONVERT_TO + } #endif TYPED_TEST(VecMaskTests, MaskedLoad) { using vec = TypeParam; diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index e142b56d02f..6ce94b10b45 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -3734,6 +3734,26 @@ class CPUReproTests(TestCase): # TODO(jgong5): change to 1 with vectorized uint64 load assert metrics.generated_cpp_vec_kernel_count == 0 + def test_convert_int8_to_half_vec(self): + src_dtypes = [torch.int8, torch.uint8] + dst_dtypes = [torch.bfloat16, torch.half] + _simd_lens = [isa._bit_width for isa in cpu_vec_isa.valid_vec_isa_list()] + for src_dtype, dst_dtype, _simd_len in itertools.product( + src_dtypes, dst_dtypes, _simd_lens + ): + + def fn(x): + return x.to(dst_dtype) + + low = 0 if src_dtype == torch.uint8 else -100 + + x = torch.randint(low, 100, (32, 32), dtype=src_dtype) + with config.patch({"cpp.simdlen": _simd_len}): + torch._dynamo.reset() + metrics.reset() + self.common(fn, (x,)) + check_metrics_vec_kernel_count(1) + def test_convert_int32_to_int64_vec(self): def fn(x): return x.to(torch.int64)