Support fp8 quantization (#123161)

This commit enables float8_e5m2 and float8_e4m3fn dtypes in fx quantization and PT2E.

Motivation for using fp8 quantization instead of int8:
- it works better to run inference with the same datatype the model was trained with,
- fp8 can handle outliers better, which is one of the problems in LLMs activations.

The numerical recipe we want to use it for is fp8 inference:
- bgemms/gemms running in float8_e4m3fn,
- Per-Tensor-Quantization/Scaling,
- amax observer for measurement with input_backoff and weight_backoff.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123161
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
Amadeusz Skrzypczak 2024-04-23 13:35:24 +00:00 committed by PyTorch MergeBot
parent f8f6c460cd
commit 107f944f22
6 changed files with 65 additions and 34 deletions

View File

@ -49,7 +49,11 @@ from torch.testing._internal.common_quantization import (
skipIfNoQNNPACK,
TestHelperModules,
)
from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
TemporaryFileName,
)
@skipIfNoQNNPACK
@ -1175,14 +1179,15 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
self.assertIsNot(observers[0], observers[2])
self.assertIsNot(observers[1], observers[2])
def test_int16(self):
class Int16ActQuantizer(Quantizer):
@parametrize("dtype", (torch.int16, torch.float8_e5m2, torch.float8_e4m3fn))
def test_quantization_dtype(self, dtype):
class DtypeActQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
# using int32 to simulate int16
int16_qspec = QuantizationSpec(
dtype=torch.int16,
quant_min=-(2**15),
quant_max=2**15 - 1,
info_fun = torch.iinfo if dtype == torch.int16 else torch.finfo
activate_qspec = QuantizationSpec(
dtype=dtype,
quant_min=int(info_fun(dtype).min),
quant_max=int(info_fun(dtype).max),
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_observer,
@ -1196,10 +1201,10 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
observer_or_fake_quant_ctr=observer.default_weight_observer,
)
quantization_config = QuantizationConfig(
input_activation=int16_qspec,
input_activation=activate_qspec,
weight=int8_qspec,
bias=None,
output_activation=int16_qspec,
output_activation=activate_qspec,
)
OP_TO_ANNOTATOR["conv"](model, quantization_config)
@ -1214,7 +1219,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
def forward(self, x):
return self.conv(x)
quantizer = Int16ActQuantizer()
quantizer = DtypeActQuantizer()
node_occurrence = {
# one for input of the first conv, one for output for the first conv
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
@ -1230,7 +1235,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
self._test_quantizer(
M().eval(),
example_inputs,
Int16ActQuantizer(),
quantizer,
node_occurrence,
node_list,
)
@ -2248,3 +2253,6 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
node_occurrence,
node_list,
)
instantiate_parametrized_tests(TestQuantizePT2E)

View File

@ -10,12 +10,11 @@ from torch.library import impl, Library
# name is not too long
quantized_decomposed_lib = Library("quantized_decomposed", "DEF")
_DTYPE_TO_QVALUE_BOUNDS = {
torch.uint8: (0, 255),
torch.int8: (-128, 127),
torch.int16: (-(2**15), 2**15 - 1),
torch.int32: (-(2**31), 2**31 - 1),
}
_INTEGER_DTYPES = [torch.uint8, torch.int8, torch.int16, torch.int32]
_FLOAT_DTYPES = [torch.float8_e5m2, torch.float8_e4m3fn]
_DTYPE_TO_QVALUE_BOUNDS = {k : (torch.iinfo(k).min, torch.iinfo(k).max) for k in _INTEGER_DTYPES}
_DTYPE_TO_QVALUE_BOUNDS.update({k : (int(torch.finfo(k).min), int(torch.finfo(k).max)) for k in _FLOAT_DTYPES})
# Helper to check the passed in quant min and max are valid for the dtype
def _quant_min_max_bounds_check(quant_min, quant_max, dtype):

View File

@ -84,6 +84,18 @@ __all__ = [
"convert_weighted_module",
]
SUPPORTED_QDTYPES = [
torch.quint8,
torch.qint8,
torch.qint32,
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.float8_e5m2,
torch.float8_e4m3fn,
]
_QSCHEME_TO_CHOOSE_QPARAMS_OP = {
torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor,
torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor,
@ -136,8 +148,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
if hasattr(activation_post_process, "is_dynamic"):
is_dynamic = activation_post_process.is_dynamic # type: ignore[assignment]
if dtype in [torch.quint8, torch.qint8, torch.qint32, torch.uint8, torch.int8, torch.int16, torch.int32] and \
(not is_dynamic):
if dtype in SUPPORTED_QDTYPES and (not is_dynamic):
# TODO: probably should cleanup this condition check, it's hard
# to reason about this if and the following elif
@ -372,7 +383,7 @@ def _replace_observer_with_quantize_dequantize_node(
if hasattr(activation_post_process, "is_dynamic"):
is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment]
if dtype in [torch.quint8, torch.qint8, torch.qint32] and \
if dtype in [torch.quint8, torch.qint8, torch.qint32, torch.float8_e5m2, torch.float8_e4m3fn] and \
(not is_dynamic):
# TODO: probably should cleanup this condition check, it's hard
# to reason about this if and the following elif
@ -477,15 +488,7 @@ def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool:
is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment]
return (
(dtype in [
torch.quint8,
torch.qint8,
torch.qint32,
torch.uint8,
torch.int8,
torch.int16,
torch.int32
] and (not is_dynamic)) or # type: ignore[return-value]
(dtype in SUPPORTED_QDTYPES and (not is_dynamic)) or # type: ignore[return-value]
is_dynamic or
dtype == torch.float16
)

View File

@ -138,7 +138,9 @@ _OBS_DTYPE_LIST = [
torch.uint8,
torch.int8,
torch.int16,
torch.int32
torch.int32,
torch.float8_e5m2,
torch.float8_e4m3fn,
]
_DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float)

View File

@ -244,6 +244,8 @@ class UniformQuantizationObserverBase(ObserverBase):
torch.uint8,
torch.int16,
torch.int32,
torch.float8_e5m2,
torch.float8_e4m3fn,
)
assert self.dtype in _ALLOWED_DTYPES, f"Default Observer only works for {_ALLOWED_DTYPES} data type"

View File

@ -151,6 +151,8 @@ def to_underlying_dtype(qdtype):
torch.int8: torch.int8,
torch.int16: torch.int16,
torch.int32: torch.int32,
torch.float8_e5m2: torch.float8_e5m2,
torch.float8_e4m3fn: torch.float8_e4m3fn,
}
assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + str(qdtype)
return DTYPE_MAPPING[qdtype]
@ -231,7 +233,9 @@ def activation_is_statically_quantized(qconfig):
torch.uint8,
torch.int8,
torch.int16,
torch.int32
torch.int32,
torch.float8_e5m2,
torch.float8_e4m3fn,
]
and (not activation_is_dynamically_quantized(qconfig))
)
@ -269,7 +273,9 @@ def weight_is_quantized(qconfig):
torch.uint8,
torch.int8,
torch.int16,
torch.int32
torch.int32,
torch.float8_e5m2,
torch.float8_e4m3fn,
]
def weight_is_statically_quantized(qconfig):
@ -305,7 +311,18 @@ def get_quant_type(qconfig):
assert qconfig is not None
activation = qconfig.activation()
weight = qconfig.weight()
static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32, torch.uint8, torch.int8, torch.int16, torch.int32]
static_dtypes = [
torch.quint8,
torch.qint8,
torch.quint4x2,
torch.qint32,
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.float8_e5m2,
torch.float8_e4m3fn
]
if weight.dtype in static_dtypes:
if hasattr(activation, 'is_dynamic') and activation.is_dynamic:
return QuantType.DYNAMIC