diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_convert.h b/aten/src/ATen/cpu/vec/vec512/vec512_convert.h index c7f15c5ab1a..a4adc222fa1 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_convert.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_convert.h @@ -312,6 +312,28 @@ struct VecConvert { } }; +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src_n) { + at::vec::Vectorized src = src_n[0]; + __m128i res128 = cvtfp32_fp8e5m2(src); + return at::vec::Vectorized(_mm512_castsi128_si512(res128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src_n) { + // cvt first 16x8 bits from Float8_e5m2 to float + at::vec::Vectorized src = src_n[0]; + __m512 result; + cvtfp8e5m2_fp32(_mm512_castsi512_si128(src), result); + return at::vec::Vectorized(result); + } +}; + #endif } // namespace CPU_CAPABILITY diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 9bb3b84f2e4..11a6c8739bf 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -1418,10 +1418,15 @@ class CPUReproTests(TestCase): use_quant_list = [False, True] use_tensor_overload_list = [False, True] - assert dtype in [torch.uint8, torch.int8, torch.float8_e4m3fn] + assert dtype in [ + torch.uint8, + torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2, + ] quant_min = 0 if dtype == torch.uint8 else -128 quant_max = 255 if dtype == torch.uint8 else 127 - if dtype == torch.float8_e4m3fn: + if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: quant_min = int(torch.finfo(dtype).min) quant_max = int(torch.finfo(dtype).max) use_tensor_overload_list = [ @@ -1486,6 +1491,10 @@ class CPUReproTests(TestCase): def test_dequant_quant_lowering_fp8_e4m3(self): self._test_dequant_quant_lowering_helper(torch.float8_e4m3fn) + @requires_vectorization + def test_dequant_quant_lowering_fp8_e5m2(self): + self._test_dequant_quant_lowering_helper(torch.float8_e5m2) + def _test_dequant_maxpool2d_lowering_helper(self, dtype): def fn(x, scale, zero_point, quant_min, quant_max, dtype): x = torch.ops.quantized_decomposed.dequantize_per_tensor( diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 95ce24369f4..ee6a728aabe 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -155,6 +155,7 @@ VECTORIZABLE_DTYPES: list[torch.dtype] = [ torch.int32, torch.int64, torch.float8_e4m3fn, + torch.float8_e5m2, ] MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [ @@ -1609,6 +1610,7 @@ class CppVecOverrides(CppOverrides): torch.int32, torch.int64, torch.float8_e4m3fn, + torch.float8_e5m2, ], f"{__name__} does not support {dtype}" assert isinstance(x, CppCSEVariable) src_dtype = x.dtype