[Inductor][CPP] Fix int8 cvt half (#136353)

Fix the correctness issue of https://github.com/pytorch/ao/pull/884/. The current implementation for converting between `Half/BFloat16` and `int8/uint8` incorrectly assumes that 1/4 of the int8/uint8 vector lane maps to 1/2 of the Half/BFloat16 vector lane. This assumption leads to accuracy issues after the full bit-width vectorization of the Half data type was introduced. When converting between int8 weights and the half data type, the generated code is as the following:
```
#include "/tmp/torchinductor_leslie/xw/cxww3s7wxrujoyxna7mlcjktid2uu6nntixqwm542xfkd756gl3x.h"
extern "C"  void kernel(const int8_t* in_ptr0,
                       half* out_ptr0)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2048L); x0+=static_cast<int64_t>(32L))
        {
            auto tmp0 = at::vec::Vectorized<int8_t>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
            auto tmp1 = at::vec::convert<half>(tmp0);
            tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
        }
    }
}
```

In this PR, we address the issue by changing the implementation to convert 1/2 of the int8/uint8 vector lane into a full vector lane of Half/BFloat16.

**TestPlan**
* AO: `python test/integration/test_integration.py -k test_int8_weight_only_quant_subclass_api`
* `python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_convert_int8_to_half_vec`
* Due to the CPP backend legalization pass, we are unable to create a unit test to simulate the conversion from `Half` to `int8`. Instead, we rely on a C++ test case.
  * `./build/bin/vec_test_all_types_AVX512 --gtest_filter="VecConvertTestsReducedFloat/*.ConvertReduced"`
  * `./build/bin/vec_test_all_types_AVX2 --gtest_filter="VecConvertTestsReducedFloat/*.ConvertReduced"`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136353
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
leslie-fang-intel 2024-09-22 23:01:08 -07:00 committed by PyTorch MergeBot
parent 8225e7706e
commit 3c7edf1ec0
4 changed files with 142 additions and 4 deletions

View File

@ -208,8 +208,27 @@ struct VecConvert<
(is_reduced_floating_point_v<src_t> && is_8bit_integer_v<dst_t>),
void>> {
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<src_t, 1>& src) {
VectorizedN<float, 1> tmp_fp32 = VecConvert<float, 1, src_t, 1>::apply(src);
return VecConvert<dst_t, 1, float, 1>::apply(tmp_fp32);
VectorizedN<float, 2> tmp_fp32 = VecConvert<float, 2, src_t, 1>::apply(src);
return VecConvert<dst_t, 1, float, 2>::apply(tmp_fp32);
}
};
template <typename dst_t>
struct VecConvert<
dst_t,
1,
float,
2,
typename std::enable_if_t<is_8bit_integer_v<dst_t>,
void>> {
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<float, 2>& src) {
at::vec::Vectorized<dst_t> vec1 = convert_float_to_int8<dst_t>(src[0]);
at::vec::Vectorized<dst_t> vec2 = convert_float_to_int8<dst_t>(src[1]);
__m128 lane2 = _mm256_castps256_ps128(_mm256_castsi256_ps(vec2));
__m256 combined = _mm256_insertf128_ps(_mm256_castsi256_ps(vec1), lane2, 1);
// Shuffle [191:128] bit from combined in to [127:64] bit of result
__m256i result = _mm256_permute4x64_epi64(_mm256_castps_si256(combined), 0b11011000);
return at::vec::Vectorized<dst_t>(result);
}
};
@ -226,6 +245,25 @@ struct VecConvert<
}
};
template <typename src_t>
struct VecConvert<
float,
2,
src_t,
1,
typename std::enable_if_t<is_8bit_integer_v<src_t>,
void>> {
static inline VectorizedN<float, 2> apply(const VectorizedN<src_t, 1>& src) {
// Shuffle [127:64] bit from src[0] in to [191:128] bit of shuffled
__m256i shuffled = _mm256_permute4x64_epi64(src[0], 0b11011000);
__m256i src2 = _mm256_castsi128_si256(
_mm_castps_si128(
_mm256_extractf128_ps(_mm256_castsi256_ps(shuffled), 1) // Extract the second 128-bit lane
)
);
return VectorizedN<float, 2>(convert_int8_to_float<src_t>(src[0]), convert_int8_to_float<src_t>(src2));
}
};
template <typename dst_t>
struct VecConvert<

View File

@ -209,8 +209,25 @@ struct VecConvert<
(is_reduced_floating_point_v<src_t> && is_8bit_integer_v<dst_t>),
void>> {
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<src_t, 1>& src) {
VectorizedN<float, 1> tmp_fp32 = VecConvert<float, 1, src_t, 1>::apply(src);
return VecConvert<dst_t, 1, float, 1>::apply(tmp_fp32);
VectorizedN<float, 2> tmp_fp32 = VecConvert<float, 2, src_t, 1>::apply(src);
return VecConvert<dst_t, 1, float, 2>::apply(tmp_fp32);
}
};
template <typename dst_t>
struct VecConvert<
dst_t,
1,
float,
2,
typename std::enable_if_t<is_8bit_integer_v<dst_t>,
void>> {
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<float, 2>& src) {
at::vec::Vectorized<dst_t> vec1 = convert_float_to_int8<dst_t>(src[0]);
at::vec::Vectorized<dst_t> vec2 = convert_float_to_int8<dst_t>(src[1]);
__m128 lane2 = _mm512_castps512_ps128(_mm512_castsi512_ps(vec2));
__m512 result = _mm512_insertf32x4(_mm512_castsi512_ps(vec1), lane2, 1); // Insert lane2 into the second 128-bit lane
return at::vec::Vectorized<dst_t>(_mm512_castps_si512(result));
}
};
@ -227,6 +244,24 @@ struct VecConvert<
}
};
template <typename src_t>
struct VecConvert<
float,
2,
src_t,
1,
typename std::enable_if_t<is_8bit_integer_v<src_t>,
void>> {
static inline VectorizedN<float, 2> apply(const VectorizedN<src_t, 1>& src) {
__m512i src2 = _mm512_castsi128_si512(
_mm_castps_si128(
_mm512_extractf32x4_ps(_mm512_castsi512_ps(src[0]), 1) // Extract the second 128-bit lane
)
);
return VectorizedN<float, 2>(convert_int8_to_float<src_t>(src[0]), convert_int8_to_float<src_t>(src2));
}
};
template <typename src_t>
struct VecConvert<
float,

View File

@ -71,6 +71,8 @@ namespace {
template <typename T>
class VecConvertTests : public ::testing::Test {};
template <typename T>
class VecConvertTestsReducedFloat : public ::testing::Test {};
template <typename T>
class VecMaskTests : public ::testing::Test {};
using RealFloatTestedTypes = ::testing::Types<vfloat, vdouble>;
using FloatTestedTypes = ::testing::Types<vfloat, vdouble, vcomplex, vcomplexDbl>;
@ -121,6 +123,7 @@ namespace {
TYPED_TEST_SUITE(FunctionalTests, RealFloatIntTestedTypes);
TYPED_TEST_SUITE(FunctionalTestsReducedFloat, ReducedFloatTestedTypes);
TYPED_TEST_SUITE(VecConvertTests, RealFloatIntTestedTypes);
TYPED_TEST_SUITE(VecConvertTestsReducedFloat, ReducedFloatTestedTypes);
TYPED_TEST_SUITE(VecMaskTests, RealFloatIntTestedTypes);
TYPED_TEST(Memory, UnAlignedLoadStore) {
using vec = TypeParam;
@ -1663,6 +1666,48 @@ namespace {
TEST_CONVERT_TO(double);
#undef TEST_CONVERT_TO
}
TYPED_TEST(VecConvertTestsReducedFloat, ConvertReduced) {
using vec = TypeParam;
using src_t = UholdType<TypeParam>;
constexpr auto N = vec::size();
#define TEST_CONVERT_TO(dst_t) \
do { \
CACHE_ALIGN src_t x[N]; \
CACHE_ALIGN dst_t y[N]; \
CACHE_ALIGN dst_t ref[N]; \
auto seed = TestSeed(); \
auto low = std::is_signed_v<dst_t> ? src_t(-100.0) : src_t(0); \
ValueGen<src_t> generator(low, src_t(100), seed); \
for (const auto i : c10::irange(N)) { \
x[i] = generator.get(); \
} \
for (const auto i : c10::irange(N)) { \
ref[i] = static_cast<dst_t>(x[i]); \
} \
auto x_vec = vec::loadu(x); \
auto y_vec = at::vec::convert<dst_t>(x_vec); \
constexpr int num_dst_elements = \
std::min(N, at::vec::Vectorized<dst_t>::size()); \
y_vec.store(y, num_dst_elements); \
for (const auto i : c10::irange(num_dst_elements)) { \
ASSERT_EQ(y[i], ref[i]) \
<< "Failure Details:\nTest Seed to reproduce: " << seed \
<< " x[" << i << "]=" << x[i] << " dst_t=" #dst_t; \
} \
constexpr int dst_n = N / num_dst_elements; \
auto y_vec_n = at::vec::convert<dst_t, dst_n, src_t, 1>( \
at::vec::VectorizedN<src_t, 1>(x_vec)); \
y_vec_n.store(y, N); \
for (const auto i : c10::irange(N)) { \
ASSERT_EQ(y[i], ref[i]) \
<< "Failure Details:\nTest Seed to reproduce: " << seed \
<< " x[" << i << "]=" << x[i] << " dst_t=" #dst_t; \
} \
} while (0)
TEST_CONVERT_TO(int8_t);
TEST_CONVERT_TO(uint8_t);
#undef TEST_CONVERT_TO
}
#endif
TYPED_TEST(VecMaskTests, MaskedLoad) {
using vec = TypeParam;

View File

@ -3734,6 +3734,26 @@ class CPUReproTests(TestCase):
# TODO(jgong5): change to 1 with vectorized uint64 load
assert metrics.generated_cpp_vec_kernel_count == 0
def test_convert_int8_to_half_vec(self):
src_dtypes = [torch.int8, torch.uint8]
dst_dtypes = [torch.bfloat16, torch.half]
_simd_lens = [isa._bit_width for isa in cpu_vec_isa.valid_vec_isa_list()]
for src_dtype, dst_dtype, _simd_len in itertools.product(
src_dtypes, dst_dtypes, _simd_lens
):
def fn(x):
return x.to(dst_dtype)
low = 0 if src_dtype == torch.uint8 else -100
x = torch.randint(low, 100, (32, 32), dtype=src_dtype)
with config.patch({"cpp.simdlen": _simd_len}):
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x,))
check_metrics_vec_kernel_count(1)
def test_convert_int32_to_int64_vec(self):
def fn(x):
return x.to(torch.int64)