[Inductor][CPP] Enable vectorized fp8 E5M2 quant dequant (#153365)

**Summary**
This PR enables the vectorization codegen with Inductor CPP backend for `FP8_E5M2` `quant` from `float32` and `dequant` to `float32`.

**Test Plan**
```
python test/inductor/test_cpu_repro.py -k test_dequant_quant_lowering_fp8_e5m2
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153365
Approved by: https://github.com/jansel, https://github.com/jgong5
ghstack dependencies: #152417, #152418, #153364
This commit is contained in:
leslie-fang-intel 2025-05-21 22:33:49 -07:00 committed by PyTorch MergeBot
parent 84b657d0b5
commit 7ba6fb69e6
3 changed files with 35 additions and 2 deletions

View File

@ -312,6 +312,28 @@ struct VecConvert<float, 1, Float8_e4m3fn, 1> {
} }
}; };
template <>
struct VecConvert<Float8_e5m2, 1, float, 1> {
static inline VectorizedN<Float8_e5m2, 1> apply(
const VectorizedN<float, 1>& src_n) {
at::vec::Vectorized<float> src = src_n[0];
__m128i res128 = cvtfp32_fp8e5m2(src);
return at::vec::Vectorized<Float8_e5m2>(_mm512_castsi128_si512(res128));
}
};
template <>
struct VecConvert<float, 1, Float8_e5m2, 1> {
static inline VectorizedN<float, 1> apply(
const VectorizedN<Float8_e5m2, 1>& src_n) {
// cvt first 16x8 bits from Float8_e5m2 to float
at::vec::Vectorized<Float8_e5m2> src = src_n[0];
__m512 result;
cvtfp8e5m2_fp32(_mm512_castsi512_si128(src), result);
return at::vec::Vectorized<float>(result);
}
};
#endif #endif
} // namespace CPU_CAPABILITY } // namespace CPU_CAPABILITY

View File

@ -1418,10 +1418,15 @@ class CPUReproTests(TestCase):
use_quant_list = [False, True] use_quant_list = [False, True]
use_tensor_overload_list = [False, True] use_tensor_overload_list = [False, True]
assert dtype in [torch.uint8, torch.int8, torch.float8_e4m3fn] assert dtype in [
torch.uint8,
torch.int8,
torch.float8_e4m3fn,
torch.float8_e5m2,
]
quant_min = 0 if dtype == torch.uint8 else -128 quant_min = 0 if dtype == torch.uint8 else -128
quant_max = 255 if dtype == torch.uint8 else 127 quant_max = 255 if dtype == torch.uint8 else 127
if dtype == torch.float8_e4m3fn: if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
quant_min = int(torch.finfo(dtype).min) quant_min = int(torch.finfo(dtype).min)
quant_max = int(torch.finfo(dtype).max) quant_max = int(torch.finfo(dtype).max)
use_tensor_overload_list = [ use_tensor_overload_list = [
@ -1486,6 +1491,10 @@ class CPUReproTests(TestCase):
def test_dequant_quant_lowering_fp8_e4m3(self): def test_dequant_quant_lowering_fp8_e4m3(self):
self._test_dequant_quant_lowering_helper(torch.float8_e4m3fn) self._test_dequant_quant_lowering_helper(torch.float8_e4m3fn)
@requires_vectorization
def test_dequant_quant_lowering_fp8_e5m2(self):
self._test_dequant_quant_lowering_helper(torch.float8_e5m2)
def _test_dequant_maxpool2d_lowering_helper(self, dtype): def _test_dequant_maxpool2d_lowering_helper(self, dtype):
def fn(x, scale, zero_point, quant_min, quant_max, dtype): def fn(x, scale, zero_point, quant_min, quant_max, dtype):
x = torch.ops.quantized_decomposed.dequantize_per_tensor( x = torch.ops.quantized_decomposed.dequantize_per_tensor(

View File

@ -155,6 +155,7 @@ VECTORIZABLE_DTYPES: list[torch.dtype] = [
torch.int32, torch.int32,
torch.int64, torch.int64,
torch.float8_e4m3fn, torch.float8_e4m3fn,
torch.float8_e5m2,
] ]
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [ MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
@ -1609,6 +1610,7 @@ class CppVecOverrides(CppOverrides):
torch.int32, torch.int32,
torch.int64, torch.int64,
torch.float8_e4m3fn, torch.float8_e4m3fn,
torch.float8_e5m2,
], f"{__name__} does not support {dtype}" ], f"{__name__} does not support {dtype}"
assert isinstance(x, CppCSEVariable) assert isinstance(x, CppCSEVariable)
src_dtype = x.dtype src_dtype = x.dtype