mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59953 The following modifications were made to the equalization observers due to design changes: - [InputEqualizationObserver] Replaced `calculate_qparams()` with `calculate_scaled_minmax()` since we will need to return the scaled min/max values to update the following input quantization observer - [WeightEqualizationObserver] We no longer need a row observer since this will be taken care of by the following weight quantization observer - [WeightEqualizationObserver] Following the previous comment, we no longer need to calculate the scaled qparam values. Instead, we will use the equalization scale to later scale the weights and the qparams will be taken care of by the weight quantization observer. Test Plan: `python test/test_quantization.py TestEqualizeFx.test_input_weight_eq_observer` Imported from OSS Reviewed By: supriyar Differential Revision: D29135332 fbshipit-source-id: be7e468273c8b62fc183b1e1ec50f6bd6d8cf831
136 lines
4.9 KiB
Python
136 lines
4.9 KiB
Python
"""
|
|
Utils shared by different modes of quantization (eager/graph)
|
|
"""
|
|
import warnings
|
|
|
|
import torch
|
|
from .quant_type import QuantType, quant_type_to_str
|
|
|
|
def get_combined_dict(default_dict, additional_dict):
|
|
d = default_dict.copy()
|
|
d.update(additional_dict)
|
|
return d
|
|
|
|
def is_per_tensor(qscheme):
|
|
return qscheme == torch.per_tensor_affine or \
|
|
qscheme == torch.per_tensor_symmetric
|
|
|
|
def is_per_channel(qscheme):
|
|
return qscheme in [torch.per_channel_affine,
|
|
torch.per_channel_affine_float_qparams,
|
|
torch.per_channel_symmetric]
|
|
|
|
def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig):
|
|
""" Get the observed/quantized custom module class that we need
|
|
to swap `custom_module` to
|
|
Input:
|
|
custom_module: input, can be an instance of either a float or observed custom module
|
|
custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping
|
|
qconfig: qconfig configured for the custom module
|
|
|
|
Output:
|
|
corresponding observed/quantized custom module class for input custom module instance
|
|
"""
|
|
quant_type = get_quant_type(qconfig)
|
|
quant_type_str = quant_type_to_str(quant_type)
|
|
class_mapping = custom_module_class_mapping.get(quant_type_str, {})
|
|
assert type(custom_module) in class_mapping, "did not find corresponding observed " \
|
|
"module class for {} in mapping: {}".format(type(custom_module), class_mapping)
|
|
return class_mapping[type(custom_module)]
|
|
|
|
def activation_dtype(qconfig):
|
|
assert qconfig is not None
|
|
activation = qconfig.activation()
|
|
return activation.dtype
|
|
|
|
def weight_dtype(qconfig):
|
|
assert qconfig is not None
|
|
weight = qconfig.weight()
|
|
return weight.dtype
|
|
|
|
def activation_is_statically_quantized(qconfig):
|
|
""" Given a qconfig, decide if the activation needs to be
|
|
quantized or not, this includes quantizing to quint8, qint8 and float16
|
|
"""
|
|
return activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16]
|
|
|
|
def activation_is_int8_quantized(qconfig):
|
|
""" Given a qconfig, decide if the activation needs to be
|
|
quantized to int8 or not, this includes quantizing to quint8, qint8
|
|
"""
|
|
return activation_dtype(qconfig) in [torch.quint8, torch.qint8]
|
|
|
|
def weight_is_quantized(qconfig):
|
|
""" Given a qconfig, decide if the weight needs to be
|
|
quantized or not
|
|
"""
|
|
return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16]
|
|
|
|
def weight_is_statically_quantized(qconfig):
|
|
""" Given a qconfig, decide if the weight needs to be statically
|
|
quantized or not
|
|
"""
|
|
return weight_dtype(qconfig) in [torch.quint8, torch.qint8]
|
|
|
|
def get_qconfig_dtypes(qconfig):
|
|
r""" returns the qconfig tuple for qconfig:
|
|
(activation_dtype, weight_dtype, activation_compute_dtype)
|
|
"""
|
|
assert qconfig is not None
|
|
activation = qconfig.activation()
|
|
weight = qconfig.weight()
|
|
compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None
|
|
return (activation.dtype, weight.dtype, compute_dtype)
|
|
|
|
def get_quant_type(qconfig):
|
|
assert qconfig is not None
|
|
activation = qconfig.activation()
|
|
weight = qconfig.weight()
|
|
static_dtypes = [torch.quint8, torch.qint8]
|
|
if weight.dtype in static_dtypes:
|
|
if activation.dtype in static_dtypes:
|
|
return QuantType.STATIC
|
|
elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes:
|
|
return QuantType.DYNAMIC
|
|
else:
|
|
return QuantType.WEIGHT_ONLY
|
|
|
|
if weight.dtype == torch.float16:
|
|
if activation.dtype == torch.float:
|
|
return QuantType.DYNAMIC
|
|
elif activation.dtype == torch.float16:
|
|
return QuantType.STATIC
|
|
|
|
raise Exception("Unrecognized dtype combination in get_quant_type: activation({}),"
|
|
"weight({})".format(activation.dtype, weight.dtype))
|
|
|
|
def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool:
|
|
""" Checks if the given minimum and maximum values are valid, meaning that
|
|
they exist and the min value is less than the max value.
|
|
"""
|
|
if min_val.numel() == 0 or max_val.numel() == 0:
|
|
warnings.warn(
|
|
"must run observer before calling calculate_qparams. " +
|
|
"Returning default values."
|
|
)
|
|
return False
|
|
|
|
if min_val.dim() == 0 or max_val.dim() == 0:
|
|
if min_val == float("inf") and max_val == float("-inf"):
|
|
warnings.warn(
|
|
"must run observer before calling calculate_qparams. " +
|
|
"Returning default values."
|
|
)
|
|
|
|
return False
|
|
|
|
assert min_val <= max_val, "min {} should be less than max {}".format(
|
|
min_val, max_val
|
|
)
|
|
else:
|
|
assert torch.all(
|
|
min_val <= max_val
|
|
), "min {} should be less than max {}".format(min_val, max_val)
|
|
|
|
return True
|