mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor][CPP] Enable vectorized fp8 E5M2 quant dequant (#153365)
**Summary** This PR enables the vectorization codegen with Inductor CPP backend for `FP8_E5M2` `quant` from `float32` and `dequant` to `float32`. **Test Plan** ``` python test/inductor/test_cpu_repro.py -k test_dequant_quant_lowering_fp8_e5m2 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/153365 Approved by: https://github.com/jansel, https://github.com/jgong5 ghstack dependencies: #152417, #152418, #153364
This commit is contained in:
parent
84b657d0b5
commit
7ba6fb69e6
|
|
@ -312,6 +312,28 @@ struct VecConvert<float, 1, Float8_e4m3fn, 1> {
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecConvert<Float8_e5m2, 1, float, 1> {
|
||||
static inline VectorizedN<Float8_e5m2, 1> apply(
|
||||
const VectorizedN<float, 1>& src_n) {
|
||||
at::vec::Vectorized<float> src = src_n[0];
|
||||
__m128i res128 = cvtfp32_fp8e5m2(src);
|
||||
return at::vec::Vectorized<Float8_e5m2>(_mm512_castsi128_si512(res128));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecConvert<float, 1, Float8_e5m2, 1> {
|
||||
static inline VectorizedN<float, 1> apply(
|
||||
const VectorizedN<Float8_e5m2, 1>& src_n) {
|
||||
// cvt first 16x8 bits from Float8_e5m2 to float
|
||||
at::vec::Vectorized<Float8_e5m2> src = src_n[0];
|
||||
__m512 result;
|
||||
cvtfp8e5m2_fp32(_mm512_castsi512_si128(src), result);
|
||||
return at::vec::Vectorized<float>(result);
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
|
|
|
|||
|
|
@ -1418,10 +1418,15 @@ class CPUReproTests(TestCase):
|
|||
use_quant_list = [False, True]
|
||||
use_tensor_overload_list = [False, True]
|
||||
|
||||
assert dtype in [torch.uint8, torch.int8, torch.float8_e4m3fn]
|
||||
assert dtype in [
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
]
|
||||
quant_min = 0 if dtype == torch.uint8 else -128
|
||||
quant_max = 255 if dtype == torch.uint8 else 127
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
quant_min = int(torch.finfo(dtype).min)
|
||||
quant_max = int(torch.finfo(dtype).max)
|
||||
use_tensor_overload_list = [
|
||||
|
|
@ -1486,6 +1491,10 @@ class CPUReproTests(TestCase):
|
|||
def test_dequant_quant_lowering_fp8_e4m3(self):
|
||||
self._test_dequant_quant_lowering_helper(torch.float8_e4m3fn)
|
||||
|
||||
@requires_vectorization
|
||||
def test_dequant_quant_lowering_fp8_e5m2(self):
|
||||
self._test_dequant_quant_lowering_helper(torch.float8_e5m2)
|
||||
|
||||
def _test_dequant_maxpool2d_lowering_helper(self, dtype):
|
||||
def fn(x, scale, zero_point, quant_min, quant_max, dtype):
|
||||
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
|
|
|
|||
|
|
@ -155,6 +155,7 @@ VECTORIZABLE_DTYPES: list[torch.dtype] = [
|
|||
torch.int32,
|
||||
torch.int64,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
]
|
||||
|
||||
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
|
||||
|
|
@ -1609,6 +1610,7 @@ class CppVecOverrides(CppOverrides):
|
|||
torch.int32,
|
||||
torch.int64,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
], f"{__name__} does not support {dtype}"
|
||||
assert isinstance(x, CppCSEVariable)
|
||||
src_dtype = x.dtype
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user