pytorch/torch/quantization/utils.py
Angela Yi c0b7c59e55 [quant] Equalization Observer modifications (#59953)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59953

The following modifications were made to the equalization
observers due to design changes:
- [InputEqualizationObserver] Replaced `calculate_qparams()` with
`calculate_scaled_minmax()` since we will need to return the scaled
min/max values to update the following input quantization observer
- [WeightEqualizationObserver] We no longer need a row observer since
this will be taken care of by the following weight quantization observer
- [WeightEqualizationObserver] Following the previous comment, we no
longer need to calculate the scaled qparam values. Instead, we will use
the equalization scale to later scale the weights and the qparams will
be taken care of by the weight quantization observer.

Test Plan:
`python test/test_quantization.py
TestEqualizeFx.test_input_weight_eq_observer`

Imported from OSS

Reviewed By: supriyar

Differential Revision: D29135332

fbshipit-source-id: be7e468273c8b62fc183b1e1ec50f6bd6d8cf831
2021-06-16 22:32:30 -07:00

136 lines
4.9 KiB
Python

"""
Utils shared by different modes of quantization (eager/graph)
"""
import warnings
import torch
from .quant_type import QuantType, quant_type_to_str
def get_combined_dict(default_dict, additional_dict):
d = default_dict.copy()
d.update(additional_dict)
return d
def is_per_tensor(qscheme):
return qscheme == torch.per_tensor_affine or \
qscheme == torch.per_tensor_symmetric
def is_per_channel(qscheme):
return qscheme in [torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
torch.per_channel_symmetric]
def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig):
""" Get the observed/quantized custom module class that we need
to swap `custom_module` to
Input:
custom_module: input, can be an instance of either a float or observed custom module
custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping
qconfig: qconfig configured for the custom module
Output:
corresponding observed/quantized custom module class for input custom module instance
"""
quant_type = get_quant_type(qconfig)
quant_type_str = quant_type_to_str(quant_type)
class_mapping = custom_module_class_mapping.get(quant_type_str, {})
assert type(custom_module) in class_mapping, "did not find corresponding observed " \
"module class for {} in mapping: {}".format(type(custom_module), class_mapping)
return class_mapping[type(custom_module)]
def activation_dtype(qconfig):
assert qconfig is not None
activation = qconfig.activation()
return activation.dtype
def weight_dtype(qconfig):
assert qconfig is not None
weight = qconfig.weight()
return weight.dtype
def activation_is_statically_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
quantized or not, this includes quantizing to quint8, qint8 and float16
"""
return activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16]
def activation_is_int8_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
quantized to int8 or not, this includes quantizing to quint8, qint8
"""
return activation_dtype(qconfig) in [torch.quint8, torch.qint8]
def weight_is_quantized(qconfig):
""" Given a qconfig, decide if the weight needs to be
quantized or not
"""
return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16]
def weight_is_statically_quantized(qconfig):
""" Given a qconfig, decide if the weight needs to be statically
quantized or not
"""
return weight_dtype(qconfig) in [torch.quint8, torch.qint8]
def get_qconfig_dtypes(qconfig):
r""" returns the qconfig tuple for qconfig:
(activation_dtype, weight_dtype, activation_compute_dtype)
"""
assert qconfig is not None
activation = qconfig.activation()
weight = qconfig.weight()
compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None
return (activation.dtype, weight.dtype, compute_dtype)
def get_quant_type(qconfig):
assert qconfig is not None
activation = qconfig.activation()
weight = qconfig.weight()
static_dtypes = [torch.quint8, torch.qint8]
if weight.dtype in static_dtypes:
if activation.dtype in static_dtypes:
return QuantType.STATIC
elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes:
return QuantType.DYNAMIC
else:
return QuantType.WEIGHT_ONLY
if weight.dtype == torch.float16:
if activation.dtype == torch.float:
return QuantType.DYNAMIC
elif activation.dtype == torch.float16:
return QuantType.STATIC
raise Exception("Unrecognized dtype combination in get_quant_type: activation({}),"
"weight({})".format(activation.dtype, weight.dtype))
def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool:
""" Checks if the given minimum and maximum values are valid, meaning that
they exist and the min value is less than the max value.
"""
if min_val.numel() == 0 or max_val.numel() == 0:
warnings.warn(
"must run observer before calling calculate_qparams. " +
"Returning default values."
)
return False
if min_val.dim() == 0 or max_val.dim() == 0:
if min_val == float("inf") and max_val == float("-inf"):
warnings.warn(
"must run observer before calling calculate_qparams. " +
"Returning default values."
)
return False
assert min_val <= max_val, "min {} should be less than max {}".format(
min_val, max_val
)
else:
assert torch.all(
min_val <= max_val
), "min {} should be less than max {}".format(min_val, max_val)
return True