mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor][CPP] Enable vectorized fp8 E5M2 quant dequant (#153365)
**Summary** This PR enables the vectorization codegen with Inductor CPP backend for `FP8_E5M2` `quant` from `float32` and `dequant` to `float32`. **Test Plan** ``` python test/inductor/test_cpu_repro.py -k test_dequant_quant_lowering_fp8_e5m2 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/153365 Approved by: https://github.com/jansel, https://github.com/jgong5 ghstack dependencies: #152417, #152418, #153364
This commit is contained in:
parent
84b657d0b5
commit
7ba6fb69e6
|
|
@ -312,6 +312,28 @@ struct VecConvert<float, 1, Float8_e4m3fn, 1> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct VecConvert<Float8_e5m2, 1, float, 1> {
|
||||||
|
static inline VectorizedN<Float8_e5m2, 1> apply(
|
||||||
|
const VectorizedN<float, 1>& src_n) {
|
||||||
|
at::vec::Vectorized<float> src = src_n[0];
|
||||||
|
__m128i res128 = cvtfp32_fp8e5m2(src);
|
||||||
|
return at::vec::Vectorized<Float8_e5m2>(_mm512_castsi128_si512(res128));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct VecConvert<float, 1, Float8_e5m2, 1> {
|
||||||
|
static inline VectorizedN<float, 1> apply(
|
||||||
|
const VectorizedN<Float8_e5m2, 1>& src_n) {
|
||||||
|
// cvt first 16x8 bits from Float8_e5m2 to float
|
||||||
|
at::vec::Vectorized<Float8_e5m2> src = src_n[0];
|
||||||
|
__m512 result;
|
||||||
|
cvtfp8e5m2_fp32(_mm512_castsi512_si128(src), result);
|
||||||
|
return at::vec::Vectorized<float>(result);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
} // namespace CPU_CAPABILITY
|
} // namespace CPU_CAPABILITY
|
||||||
|
|
|
||||||
|
|
@ -1418,10 +1418,15 @@ class CPUReproTests(TestCase):
|
||||||
use_quant_list = [False, True]
|
use_quant_list = [False, True]
|
||||||
use_tensor_overload_list = [False, True]
|
use_tensor_overload_list = [False, True]
|
||||||
|
|
||||||
assert dtype in [torch.uint8, torch.int8, torch.float8_e4m3fn]
|
assert dtype in [
|
||||||
|
torch.uint8,
|
||||||
|
torch.int8,
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
torch.float8_e5m2,
|
||||||
|
]
|
||||||
quant_min = 0 if dtype == torch.uint8 else -128
|
quant_min = 0 if dtype == torch.uint8 else -128
|
||||||
quant_max = 255 if dtype == torch.uint8 else 127
|
quant_max = 255 if dtype == torch.uint8 else 127
|
||||||
if dtype == torch.float8_e4m3fn:
|
if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||||
quant_min = int(torch.finfo(dtype).min)
|
quant_min = int(torch.finfo(dtype).min)
|
||||||
quant_max = int(torch.finfo(dtype).max)
|
quant_max = int(torch.finfo(dtype).max)
|
||||||
use_tensor_overload_list = [
|
use_tensor_overload_list = [
|
||||||
|
|
@ -1486,6 +1491,10 @@ class CPUReproTests(TestCase):
|
||||||
def test_dequant_quant_lowering_fp8_e4m3(self):
|
def test_dequant_quant_lowering_fp8_e4m3(self):
|
||||||
self._test_dequant_quant_lowering_helper(torch.float8_e4m3fn)
|
self._test_dequant_quant_lowering_helper(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
@requires_vectorization
|
||||||
|
def test_dequant_quant_lowering_fp8_e5m2(self):
|
||||||
|
self._test_dequant_quant_lowering_helper(torch.float8_e5m2)
|
||||||
|
|
||||||
def _test_dequant_maxpool2d_lowering_helper(self, dtype):
|
def _test_dequant_maxpool2d_lowering_helper(self, dtype):
|
||||||
def fn(x, scale, zero_point, quant_min, quant_max, dtype):
|
def fn(x, scale, zero_point, quant_min, quant_max, dtype):
|
||||||
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||||
|
|
|
||||||
|
|
@ -155,6 +155,7 @@ VECTORIZABLE_DTYPES: list[torch.dtype] = [
|
||||||
torch.int32,
|
torch.int32,
|
||||||
torch.int64,
|
torch.int64,
|
||||||
torch.float8_e4m3fn,
|
torch.float8_e4m3fn,
|
||||||
|
torch.float8_e5m2,
|
||||||
]
|
]
|
||||||
|
|
||||||
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
|
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
|
||||||
|
|
@ -1609,6 +1610,7 @@ class CppVecOverrides(CppOverrides):
|
||||||
torch.int32,
|
torch.int32,
|
||||||
torch.int64,
|
torch.int64,
|
||||||
torch.float8_e4m3fn,
|
torch.float8_e4m3fn,
|
||||||
|
torch.float8_e5m2,
|
||||||
], f"{__name__} does not support {dtype}"
|
], f"{__name__} does not support {dtype}"
|
||||||
assert isinstance(x, CppCSEVariable)
|
assert isinstance(x, CppCSEVariable)
|
||||||
src_dtype = x.dtype
|
src_dtype = x.dtype
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user