mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor][CPP] Fix int8 cvt half (#136353)
Fix the correctness issue of https://github.com/pytorch/ao/pull/884/. The current implementation for converting between `Half/BFloat16` and `int8/uint8` incorrectly assumes that 1/4 of the int8/uint8 vector lane maps to 1/2 of the Half/BFloat16 vector lane. This assumption leads to accuracy issues after the full bit-width vectorization of the Half data type was introduced. When converting between int8 weights and the half data type, the generated code is as the following: ``` #include "/tmp/torchinductor_leslie/xw/cxww3s7wxrujoyxna7mlcjktid2uu6nntixqwm542xfkd756gl3x.h" extern "C" void kernel(const int8_t* in_ptr0, half* out_ptr0) { { for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2048L); x0+=static_cast<int64_t>(32L)) { auto tmp0 = at::vec::Vectorized<int8_t>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32)); auto tmp1 = at::vec::convert<half>(tmp0); tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32)); } } } ``` In this PR, we address the issue by changing the implementation to convert 1/2 of the int8/uint8 vector lane into a full vector lane of Half/BFloat16. **TestPlan** * AO: `python test/integration/test_integration.py -k test_int8_weight_only_quant_subclass_api` * `python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_convert_int8_to_half_vec` * Due to the CPP backend legalization pass, we are unable to create a unit test to simulate the conversion from `Half` to `int8`. Instead, we rely on a C++ test case. * `./build/bin/vec_test_all_types_AVX512 --gtest_filter="VecConvertTestsReducedFloat/*.ConvertReduced"` * `./build/bin/vec_test_all_types_AVX2 --gtest_filter="VecConvertTestsReducedFloat/*.ConvertReduced"` Pull Request resolved: https://github.com/pytorch/pytorch/pull/136353 Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
parent
8225e7706e
commit
3c7edf1ec0
|
|
@ -208,8 +208,27 @@ struct VecConvert<
|
|||
(is_reduced_floating_point_v<src_t> && is_8bit_integer_v<dst_t>),
|
||||
void>> {
|
||||
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<src_t, 1>& src) {
|
||||
VectorizedN<float, 1> tmp_fp32 = VecConvert<float, 1, src_t, 1>::apply(src);
|
||||
return VecConvert<dst_t, 1, float, 1>::apply(tmp_fp32);
|
||||
VectorizedN<float, 2> tmp_fp32 = VecConvert<float, 2, src_t, 1>::apply(src);
|
||||
return VecConvert<dst_t, 1, float, 2>::apply(tmp_fp32);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dst_t>
|
||||
struct VecConvert<
|
||||
dst_t,
|
||||
1,
|
||||
float,
|
||||
2,
|
||||
typename std::enable_if_t<is_8bit_integer_v<dst_t>,
|
||||
void>> {
|
||||
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<float, 2>& src) {
|
||||
at::vec::Vectorized<dst_t> vec1 = convert_float_to_int8<dst_t>(src[0]);
|
||||
at::vec::Vectorized<dst_t> vec2 = convert_float_to_int8<dst_t>(src[1]);
|
||||
__m128 lane2 = _mm256_castps256_ps128(_mm256_castsi256_ps(vec2));
|
||||
__m256 combined = _mm256_insertf128_ps(_mm256_castsi256_ps(vec1), lane2, 1);
|
||||
// Shuffle [191:128] bit from combined in to [127:64] bit of result
|
||||
__m256i result = _mm256_permute4x64_epi64(_mm256_castps_si256(combined), 0b11011000);
|
||||
return at::vec::Vectorized<dst_t>(result);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -226,6 +245,25 @@ struct VecConvert<
|
|||
}
|
||||
};
|
||||
|
||||
template <typename src_t>
|
||||
struct VecConvert<
|
||||
float,
|
||||
2,
|
||||
src_t,
|
||||
1,
|
||||
typename std::enable_if_t<is_8bit_integer_v<src_t>,
|
||||
void>> {
|
||||
static inline VectorizedN<float, 2> apply(const VectorizedN<src_t, 1>& src) {
|
||||
// Shuffle [127:64] bit from src[0] in to [191:128] bit of shuffled
|
||||
__m256i shuffled = _mm256_permute4x64_epi64(src[0], 0b11011000);
|
||||
__m256i src2 = _mm256_castsi128_si256(
|
||||
_mm_castps_si128(
|
||||
_mm256_extractf128_ps(_mm256_castsi256_ps(shuffled), 1) // Extract the second 128-bit lane
|
||||
)
|
||||
);
|
||||
return VectorizedN<float, 2>(convert_int8_to_float<src_t>(src[0]), convert_int8_to_float<src_t>(src2));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dst_t>
|
||||
struct VecConvert<
|
||||
|
|
|
|||
|
|
@ -209,8 +209,25 @@ struct VecConvert<
|
|||
(is_reduced_floating_point_v<src_t> && is_8bit_integer_v<dst_t>),
|
||||
void>> {
|
||||
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<src_t, 1>& src) {
|
||||
VectorizedN<float, 1> tmp_fp32 = VecConvert<float, 1, src_t, 1>::apply(src);
|
||||
return VecConvert<dst_t, 1, float, 1>::apply(tmp_fp32);
|
||||
VectorizedN<float, 2> tmp_fp32 = VecConvert<float, 2, src_t, 1>::apply(src);
|
||||
return VecConvert<dst_t, 1, float, 2>::apply(tmp_fp32);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dst_t>
|
||||
struct VecConvert<
|
||||
dst_t,
|
||||
1,
|
||||
float,
|
||||
2,
|
||||
typename std::enable_if_t<is_8bit_integer_v<dst_t>,
|
||||
void>> {
|
||||
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<float, 2>& src) {
|
||||
at::vec::Vectorized<dst_t> vec1 = convert_float_to_int8<dst_t>(src[0]);
|
||||
at::vec::Vectorized<dst_t> vec2 = convert_float_to_int8<dst_t>(src[1]);
|
||||
__m128 lane2 = _mm512_castps512_ps128(_mm512_castsi512_ps(vec2));
|
||||
__m512 result = _mm512_insertf32x4(_mm512_castsi512_ps(vec1), lane2, 1); // Insert lane2 into the second 128-bit lane
|
||||
return at::vec::Vectorized<dst_t>(_mm512_castps_si512(result));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -227,6 +244,24 @@ struct VecConvert<
|
|||
}
|
||||
};
|
||||
|
||||
template <typename src_t>
|
||||
struct VecConvert<
|
||||
float,
|
||||
2,
|
||||
src_t,
|
||||
1,
|
||||
typename std::enable_if_t<is_8bit_integer_v<src_t>,
|
||||
void>> {
|
||||
static inline VectorizedN<float, 2> apply(const VectorizedN<src_t, 1>& src) {
|
||||
__m512i src2 = _mm512_castsi128_si512(
|
||||
_mm_castps_si128(
|
||||
_mm512_extractf32x4_ps(_mm512_castsi512_ps(src[0]), 1) // Extract the second 128-bit lane
|
||||
)
|
||||
);
|
||||
return VectorizedN<float, 2>(convert_int8_to_float<src_t>(src[0]), convert_int8_to_float<src_t>(src2));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename src_t>
|
||||
struct VecConvert<
|
||||
float,
|
||||
|
|
|
|||
|
|
@ -71,6 +71,8 @@ namespace {
|
|||
template <typename T>
|
||||
class VecConvertTests : public ::testing::Test {};
|
||||
template <typename T>
|
||||
class VecConvertTestsReducedFloat : public ::testing::Test {};
|
||||
template <typename T>
|
||||
class VecMaskTests : public ::testing::Test {};
|
||||
using RealFloatTestedTypes = ::testing::Types<vfloat, vdouble>;
|
||||
using FloatTestedTypes = ::testing::Types<vfloat, vdouble, vcomplex, vcomplexDbl>;
|
||||
|
|
@ -121,6 +123,7 @@ namespace {
|
|||
TYPED_TEST_SUITE(FunctionalTests, RealFloatIntTestedTypes);
|
||||
TYPED_TEST_SUITE(FunctionalTestsReducedFloat, ReducedFloatTestedTypes);
|
||||
TYPED_TEST_SUITE(VecConvertTests, RealFloatIntTestedTypes);
|
||||
TYPED_TEST_SUITE(VecConvertTestsReducedFloat, ReducedFloatTestedTypes);
|
||||
TYPED_TEST_SUITE(VecMaskTests, RealFloatIntTestedTypes);
|
||||
TYPED_TEST(Memory, UnAlignedLoadStore) {
|
||||
using vec = TypeParam;
|
||||
|
|
@ -1663,6 +1666,48 @@ namespace {
|
|||
TEST_CONVERT_TO(double);
|
||||
#undef TEST_CONVERT_TO
|
||||
}
|
||||
TYPED_TEST(VecConvertTestsReducedFloat, ConvertReduced) {
|
||||
using vec = TypeParam;
|
||||
using src_t = UholdType<TypeParam>;
|
||||
constexpr auto N = vec::size();
|
||||
#define TEST_CONVERT_TO(dst_t) \
|
||||
do { \
|
||||
CACHE_ALIGN src_t x[N]; \
|
||||
CACHE_ALIGN dst_t y[N]; \
|
||||
CACHE_ALIGN dst_t ref[N]; \
|
||||
auto seed = TestSeed(); \
|
||||
auto low = std::is_signed_v<dst_t> ? src_t(-100.0) : src_t(0); \
|
||||
ValueGen<src_t> generator(low, src_t(100), seed); \
|
||||
for (const auto i : c10::irange(N)) { \
|
||||
x[i] = generator.get(); \
|
||||
} \
|
||||
for (const auto i : c10::irange(N)) { \
|
||||
ref[i] = static_cast<dst_t>(x[i]); \
|
||||
} \
|
||||
auto x_vec = vec::loadu(x); \
|
||||
auto y_vec = at::vec::convert<dst_t>(x_vec); \
|
||||
constexpr int num_dst_elements = \
|
||||
std::min(N, at::vec::Vectorized<dst_t>::size()); \
|
||||
y_vec.store(y, num_dst_elements); \
|
||||
for (const auto i : c10::irange(num_dst_elements)) { \
|
||||
ASSERT_EQ(y[i], ref[i]) \
|
||||
<< "Failure Details:\nTest Seed to reproduce: " << seed \
|
||||
<< " x[" << i << "]=" << x[i] << " dst_t=" #dst_t; \
|
||||
} \
|
||||
constexpr int dst_n = N / num_dst_elements; \
|
||||
auto y_vec_n = at::vec::convert<dst_t, dst_n, src_t, 1>( \
|
||||
at::vec::VectorizedN<src_t, 1>(x_vec)); \
|
||||
y_vec_n.store(y, N); \
|
||||
for (const auto i : c10::irange(N)) { \
|
||||
ASSERT_EQ(y[i], ref[i]) \
|
||||
<< "Failure Details:\nTest Seed to reproduce: " << seed \
|
||||
<< " x[" << i << "]=" << x[i] << " dst_t=" #dst_t; \
|
||||
} \
|
||||
} while (0)
|
||||
TEST_CONVERT_TO(int8_t);
|
||||
TEST_CONVERT_TO(uint8_t);
|
||||
#undef TEST_CONVERT_TO
|
||||
}
|
||||
#endif
|
||||
TYPED_TEST(VecMaskTests, MaskedLoad) {
|
||||
using vec = TypeParam;
|
||||
|
|
|
|||
|
|
@ -3734,6 +3734,26 @@ class CPUReproTests(TestCase):
|
|||
# TODO(jgong5): change to 1 with vectorized uint64 load
|
||||
assert metrics.generated_cpp_vec_kernel_count == 0
|
||||
|
||||
def test_convert_int8_to_half_vec(self):
|
||||
src_dtypes = [torch.int8, torch.uint8]
|
||||
dst_dtypes = [torch.bfloat16, torch.half]
|
||||
_simd_lens = [isa._bit_width for isa in cpu_vec_isa.valid_vec_isa_list()]
|
||||
for src_dtype, dst_dtype, _simd_len in itertools.product(
|
||||
src_dtypes, dst_dtypes, _simd_lens
|
||||
):
|
||||
|
||||
def fn(x):
|
||||
return x.to(dst_dtype)
|
||||
|
||||
low = 0 if src_dtype == torch.uint8 else -100
|
||||
|
||||
x = torch.randint(low, 100, (32, 32), dtype=src_dtype)
|
||||
with config.patch({"cpp.simdlen": _simd_len}):
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
self.common(fn, (x,))
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
def test_convert_int32_to_int64_vec(self):
|
||||
def fn(x):
|
||||
return x.to(torch.int64)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user