mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support fp8 quantization (#123161)
This commit enables float8_e5m2 and float8_e4m3fn dtypes in fx quantization and PT2E. Motivation for using fp8 quantization instead of int8: - it works better to run inference with the same datatype the model was trained with, - fp8 can handle outliers better, which is one of the problems in LLMs activations. The numerical recipe we want to use it for is fp8 inference: - bgemms/gemms running in float8_e4m3fn, - Per-Tensor-Quantization/Scaling, - amax observer for measurement with input_backoff and weight_backoff. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123161 Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
parent
f8f6c460cd
commit
107f944f22
|
|
@ -49,7 +49,11 @@ from torch.testing._internal.common_quantization import (
|
|||
skipIfNoQNNPACK,
|
||||
TestHelperModules,
|
||||
)
|
||||
from torch.testing._internal.common_utils import TemporaryFileName
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
TemporaryFileName,
|
||||
)
|
||||
|
||||
|
||||
@skipIfNoQNNPACK
|
||||
|
|
@ -1175,14 +1179,15 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
self.assertIsNot(observers[0], observers[2])
|
||||
self.assertIsNot(observers[1], observers[2])
|
||||
|
||||
def test_int16(self):
|
||||
class Int16ActQuantizer(Quantizer):
|
||||
@parametrize("dtype", (torch.int16, torch.float8_e5m2, torch.float8_e4m3fn))
|
||||
def test_quantization_dtype(self, dtype):
|
||||
class DtypeActQuantizer(Quantizer):
|
||||
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
# using int32 to simulate int16
|
||||
int16_qspec = QuantizationSpec(
|
||||
dtype=torch.int16,
|
||||
quant_min=-(2**15),
|
||||
quant_max=2**15 - 1,
|
||||
info_fun = torch.iinfo if dtype == torch.int16 else torch.finfo
|
||||
activate_qspec = QuantizationSpec(
|
||||
dtype=dtype,
|
||||
quant_min=int(info_fun(dtype).min),
|
||||
quant_max=int(info_fun(dtype).max),
|
||||
qscheme=torch.per_tensor_affine,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=observer.default_observer,
|
||||
|
|
@ -1196,10 +1201,10 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
observer_or_fake_quant_ctr=observer.default_weight_observer,
|
||||
)
|
||||
quantization_config = QuantizationConfig(
|
||||
input_activation=int16_qspec,
|
||||
input_activation=activate_qspec,
|
||||
weight=int8_qspec,
|
||||
bias=None,
|
||||
output_activation=int16_qspec,
|
||||
output_activation=activate_qspec,
|
||||
)
|
||||
OP_TO_ANNOTATOR["conv"](model, quantization_config)
|
||||
|
||||
|
|
@ -1214,7 +1219,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
quantizer = Int16ActQuantizer()
|
||||
quantizer = DtypeActQuantizer()
|
||||
node_occurrence = {
|
||||
# one for input of the first conv, one for output for the first conv
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
|
|
@ -1230,7 +1235,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
self._test_quantizer(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
Int16ActQuantizer(),
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)
|
||||
|
|
@ -2248,3 +2253,6 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
node_occurrence,
|
||||
node_list,
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestQuantizePT2E)
|
||||
|
|
|
|||
|
|
@ -10,12 +10,11 @@ from torch.library import impl, Library
|
|||
# name is not too long
|
||||
quantized_decomposed_lib = Library("quantized_decomposed", "DEF")
|
||||
|
||||
_DTYPE_TO_QVALUE_BOUNDS = {
|
||||
torch.uint8: (0, 255),
|
||||
torch.int8: (-128, 127),
|
||||
torch.int16: (-(2**15), 2**15 - 1),
|
||||
torch.int32: (-(2**31), 2**31 - 1),
|
||||
}
|
||||
_INTEGER_DTYPES = [torch.uint8, torch.int8, torch.int16, torch.int32]
|
||||
_FLOAT_DTYPES = [torch.float8_e5m2, torch.float8_e4m3fn]
|
||||
|
||||
_DTYPE_TO_QVALUE_BOUNDS = {k : (torch.iinfo(k).min, torch.iinfo(k).max) for k in _INTEGER_DTYPES}
|
||||
_DTYPE_TO_QVALUE_BOUNDS.update({k : (int(torch.finfo(k).min), int(torch.finfo(k).max)) for k in _FLOAT_DTYPES})
|
||||
|
||||
# Helper to check the passed in quant min and max are valid for the dtype
|
||||
def _quant_min_max_bounds_check(quant_min, quant_max, dtype):
|
||||
|
|
|
|||
|
|
@ -84,6 +84,18 @@ __all__ = [
|
|||
"convert_weighted_module",
|
||||
]
|
||||
|
||||
SUPPORTED_QDTYPES = [
|
||||
torch.quint8,
|
||||
torch.qint8,
|
||||
torch.qint32,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
]
|
||||
|
||||
_QSCHEME_TO_CHOOSE_QPARAMS_OP = {
|
||||
torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor,
|
||||
torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor,
|
||||
|
|
@ -136,8 +148,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
|
|||
if hasattr(activation_post_process, "is_dynamic"):
|
||||
is_dynamic = activation_post_process.is_dynamic # type: ignore[assignment]
|
||||
|
||||
if dtype in [torch.quint8, torch.qint8, torch.qint32, torch.uint8, torch.int8, torch.int16, torch.int32] and \
|
||||
(not is_dynamic):
|
||||
if dtype in SUPPORTED_QDTYPES and (not is_dynamic):
|
||||
# TODO: probably should cleanup this condition check, it's hard
|
||||
# to reason about this if and the following elif
|
||||
|
||||
|
|
@ -372,7 +383,7 @@ def _replace_observer_with_quantize_dequantize_node(
|
|||
if hasattr(activation_post_process, "is_dynamic"):
|
||||
is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment]
|
||||
|
||||
if dtype in [torch.quint8, torch.qint8, torch.qint32] and \
|
||||
if dtype in [torch.quint8, torch.qint8, torch.qint32, torch.float8_e5m2, torch.float8_e4m3fn] and \
|
||||
(not is_dynamic):
|
||||
# TODO: probably should cleanup this condition check, it's hard
|
||||
# to reason about this if and the following elif
|
||||
|
|
@ -477,15 +488,7 @@ def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool:
|
|||
is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment]
|
||||
|
||||
return (
|
||||
(dtype in [
|
||||
torch.quint8,
|
||||
torch.qint8,
|
||||
torch.qint32,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32
|
||||
] and (not is_dynamic)) or # type: ignore[return-value]
|
||||
(dtype in SUPPORTED_QDTYPES and (not is_dynamic)) or # type: ignore[return-value]
|
||||
is_dynamic or
|
||||
dtype == torch.float16
|
||||
)
|
||||
|
|
|
|||
|
|
@ -138,7 +138,9 @@ _OBS_DTYPE_LIST = [
|
|||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32
|
||||
torch.int32,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
]
|
||||
|
||||
_DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float)
|
||||
|
|
|
|||
|
|
@ -244,6 +244,8 @@ class UniformQuantizationObserverBase(ObserverBase):
|
|||
torch.uint8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
)
|
||||
|
||||
assert self.dtype in _ALLOWED_DTYPES, f"Default Observer only works for {_ALLOWED_DTYPES} data type"
|
||||
|
|
|
|||
|
|
@ -151,6 +151,8 @@ def to_underlying_dtype(qdtype):
|
|||
torch.int8: torch.int8,
|
||||
torch.int16: torch.int16,
|
||||
torch.int32: torch.int32,
|
||||
torch.float8_e5m2: torch.float8_e5m2,
|
||||
torch.float8_e4m3fn: torch.float8_e4m3fn,
|
||||
}
|
||||
assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + str(qdtype)
|
||||
return DTYPE_MAPPING[qdtype]
|
||||
|
|
@ -231,7 +233,9 @@ def activation_is_statically_quantized(qconfig):
|
|||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32
|
||||
torch.int32,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
]
|
||||
and (not activation_is_dynamically_quantized(qconfig))
|
||||
)
|
||||
|
|
@ -269,7 +273,9 @@ def weight_is_quantized(qconfig):
|
|||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32
|
||||
torch.int32,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
]
|
||||
|
||||
def weight_is_statically_quantized(qconfig):
|
||||
|
|
@ -305,7 +311,18 @@ def get_quant_type(qconfig):
|
|||
assert qconfig is not None
|
||||
activation = qconfig.activation()
|
||||
weight = qconfig.weight()
|
||||
static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32, torch.uint8, torch.int8, torch.int16, torch.int32]
|
||||
static_dtypes = [
|
||||
torch.quint8,
|
||||
torch.qint8,
|
||||
torch.quint4x2,
|
||||
torch.qint32,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn
|
||||
]
|
||||
if weight.dtype in static_dtypes:
|
||||
if hasattr(activation, 'is_dynamic') and activation.is_dynamic:
|
||||
return QuantType.DYNAMIC
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user