mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73863 This PR fully aligns the convert function with the design: https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md and simplifies the implementation of convert function by always produce a reference quantized model (with reference patterns) first, and then lower the model to a quantized model that is runnable with PyTorch native backend (fbgemm/qnnpack). This PR makes the convert.py much easier to understand than the previous implementation, and we are able to remove majority of code in quantization_patterns.py as well (in followup PRs). Test Plan: ``` python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps python test/test_quantization.py TestFXNumericSuiteCoreAPIs python test/test_quantization.py TestFXNumericSuiteCoreAPIsModels ``` and other internal/oss regression tests Imported from OSS Reviewed By: andrewor14 Differential Revision: D34778506 fbshipit-source-id: 0678b66addf736039a8749b352f6f569caca962b (cherry picked from commit 33ec9caf23f3ab373d827117efbd9db0668b2437)
362 lines
12 KiB
Python
362 lines
12 KiB
Python
"""
|
|
Utils shared by different modes of quantization (eager/graph)
|
|
"""
|
|
import warnings
|
|
import functools
|
|
import torch
|
|
from torch.ao.quantization.quant_type import QuantType, quant_type_to_str
|
|
from typing import Tuple, Any, Union, Callable
|
|
|
|
# Type for fusion patterns, it can be more complicated than the following actually,
|
|
# see pattern.md for docs
|
|
# TODO: not sure if typing supports recursive data types
|
|
Pattern = Union[Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any]
|
|
|
|
# TODO: maybe rename this to MatchInputNode
|
|
class MatchAllNode:
|
|
""" A node pattern that matches all nodes, used in defining
|
|
fusion patterns in FX Graph Mode Quantization
|
|
"""
|
|
pass
|
|
|
|
module_type_list = {
|
|
torch.nn.ReLU,
|
|
torch.nn.ReLU6,
|
|
torch.nn.AdaptiveAvgPool1d,
|
|
torch.nn.AdaptiveAvgPool2d,
|
|
torch.nn.AdaptiveAvgPool3d,
|
|
torch.nn.AvgPool1d,
|
|
torch.nn.AvgPool2d,
|
|
torch.nn.AvgPool3d,
|
|
torch.nn.MaxPool1d,
|
|
torch.nn.MaxPool2d,
|
|
torch.nn.MaxPool3d,
|
|
torch.nn.Identity,
|
|
torch.nn.Hardsigmoid,
|
|
torch.nn.Sigmoid,
|
|
torch.nn.Tanh,
|
|
}
|
|
func_list = {
|
|
torch.nn.functional.adaptive_avg_pool1d,
|
|
torch.nn.functional.adaptive_avg_pool2d,
|
|
torch.nn.functional.adaptive_avg_pool3d,
|
|
torch.nn.functional.elu,
|
|
torch.nn.functional.hardswish,
|
|
torch.nn.functional.instance_norm,
|
|
torch.nn.functional.layer_norm,
|
|
torch.nn.functional.leaky_relu,
|
|
torch.nn.functional.silu,
|
|
torch.nn.functional.mish,
|
|
torch.nn.functional.dropout,
|
|
torch.nn.functional.max_pool1d,
|
|
torch.nn.functional.max_pool2d,
|
|
torch.nn.functional.max_pool3d,
|
|
torch.nn.functional.relu,
|
|
torch.nn.functional.hardtanh,
|
|
torch.nn.functional.hardtanh_,
|
|
torch.nn.functional.hardsigmoid,
|
|
torch.nn.functional.sigmoid,
|
|
torch.transpose,
|
|
torch.repeat_interleave,
|
|
torch.sigmoid,
|
|
torch.squeeze,
|
|
torch.stack,
|
|
torch.sum,
|
|
torch.tanh,
|
|
torch.unsqueeze,
|
|
torch.cat,
|
|
}
|
|
method_list = {
|
|
torch.mean,
|
|
'relu',
|
|
'relu_',
|
|
'contiguous',
|
|
'detach',
|
|
'detach_',
|
|
'hardsigmoid',
|
|
'hardsigmoid_',
|
|
'permute',
|
|
'repeat',
|
|
'repeat_interleave',
|
|
'reshape',
|
|
'resize_',
|
|
'shape',
|
|
'sigmoid',
|
|
'sigmoid_',
|
|
'size',
|
|
'squeeze',
|
|
'squeeze_',
|
|
'tanh',
|
|
'tanh_',
|
|
'transpose',
|
|
'unsqueeze',
|
|
'unsqueeze_',
|
|
'view',
|
|
}
|
|
|
|
def check_node(node, modules):
|
|
# TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py
|
|
is_call_function = node.op == "call_function" and node.target in func_list
|
|
is_call_method = node.op == "call_method" and node.target in method_list
|
|
is_call_module = node.op == "call_module" and type(modules[str(node.target)]) in module_type_list
|
|
return is_call_function, is_call_method, is_call_module
|
|
|
|
def get_combined_dict(default_dict, additional_dict):
|
|
d = default_dict.copy()
|
|
d.update(additional_dict)
|
|
return d
|
|
|
|
def is_per_tensor(qscheme):
|
|
return qscheme == torch.per_tensor_affine or \
|
|
qscheme == torch.per_tensor_symmetric
|
|
|
|
def is_per_channel(qscheme):
|
|
return qscheme in [torch.per_channel_affine,
|
|
torch.per_channel_affine_float_qparams,
|
|
torch.per_channel_symmetric]
|
|
|
|
def getattr_from_fqn(obj: Any, fqn: str) -> Any:
|
|
"""
|
|
Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz.
|
|
"""
|
|
return functools.reduce(getattr, fqn.split("."), obj)
|
|
|
|
def get_qparam_dict(observer_or_fake_quant):
|
|
qscheme = observer_or_fake_quant.qscheme if hasattr(observer_or_fake_quant, "qscheme") else None
|
|
dtype = observer_or_fake_quant.dtype
|
|
qparams = {"qscheme": qscheme, "dtype": dtype}
|
|
|
|
if not qscheme:
|
|
return qparams
|
|
|
|
if is_per_tensor(qscheme):
|
|
qscheme = torch.per_tensor_affine
|
|
elif is_per_channel(qscheme):
|
|
# change symmetric to affine since we do not have symmetric
|
|
# quantized Tensor
|
|
if qscheme == torch.per_channel_symmetric:
|
|
qscheme = torch.per_channel_affine
|
|
qparams["axis"] = observer_or_fake_quant.ch_axis
|
|
else:
|
|
raise RuntimeError(f"Unrecognized qscheme: {qscheme}")
|
|
# update qscheme, since we don't have symmetric quant qscheme
|
|
# in quantized Tensor
|
|
qparams["qscheme"] = qscheme
|
|
|
|
scale, zero_point = observer_or_fake_quant.calculate_qparams()
|
|
qparams["scale"] = scale
|
|
qparams["zero_point"] = zero_point
|
|
|
|
return qparams
|
|
|
|
|
|
def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig):
|
|
""" Get the observed/quantized custom module class that we need
|
|
to swap `custom_module` to
|
|
Input:
|
|
custom_module: input, can be an instance of either a float or observed custom module
|
|
custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping
|
|
qconfig: qconfig configured for the custom module
|
|
|
|
Output:
|
|
corresponding observed/quantized custom module class for input custom module instance
|
|
"""
|
|
quant_type = get_quant_type(qconfig)
|
|
quant_type_str = quant_type_to_str(quant_type)
|
|
class_mapping = custom_module_class_mapping.get(quant_type_str, {})
|
|
assert type(custom_module) in class_mapping, "did not find corresponding observed " \
|
|
"module class for {} in mapping: {}".format(type(custom_module), class_mapping)
|
|
return class_mapping[type(custom_module)]
|
|
|
|
def activation_dtype(qconfig):
|
|
assert qconfig is not None
|
|
activation = qconfig.activation()
|
|
return activation.dtype
|
|
|
|
def weight_dtype(qconfig):
|
|
assert qconfig is not None
|
|
weight = qconfig.weight()
|
|
return weight.dtype
|
|
|
|
def activation_is_statically_quantized(qconfig):
|
|
""" Given a qconfig, decide if the activation needs to be
|
|
quantized or not, this includes quantizing to quint8, qint8 and float16
|
|
"""
|
|
return activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16]
|
|
|
|
def activation_is_dynamically_quantized(qconfig):
|
|
""" Given a qconfig, decide if the activation needs to be
|
|
dynamically quantized or not, this includes dynamically quantizing to
|
|
quint8, qint8 and float16
|
|
"""
|
|
activation_dtype, _, activation_compute_dtype = \
|
|
get_qconfig_dtypes(qconfig)
|
|
return activation_dtype == torch.float and \
|
|
activation_compute_dtype in [torch.quint8, torch.qint8, torch.float16]
|
|
|
|
def activation_is_int8_quantized(qconfig):
|
|
""" Given a qconfig, decide if the activation needs to be
|
|
quantized to int8 or not, this includes quantizing to quint8, qint8
|
|
"""
|
|
return activation_dtype(qconfig) in [torch.quint8, torch.qint8]
|
|
|
|
def activation_is_int32_quantized(qconfig):
|
|
""" Given a qconfig, decide if the activation needs to be
|
|
quantized to int32 or not
|
|
"""
|
|
return activation_dtype(qconfig) == torch.qint32
|
|
|
|
def weight_is_quantized(qconfig):
|
|
""" Given a qconfig, decide if the weight needs to be
|
|
quantized or not
|
|
"""
|
|
return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16, torch.quint4x2]
|
|
|
|
def weight_is_statically_quantized(qconfig):
|
|
""" Given a qconfig, decide if the weight needs to be statically
|
|
quantized or not
|
|
"""
|
|
return weight_dtype(qconfig) in [torch.quint8, torch.qint8]
|
|
|
|
def op_is_int8_dynamically_quantized(qconfig) -> bool:
|
|
""" Given a qconfig, returns True if this op is using int8 dynamic
|
|
quantization
|
|
"""
|
|
activation_dtype, weight_dtype, activation_compute_dtype = \
|
|
get_qconfig_dtypes(qconfig)
|
|
return (
|
|
activation_dtype is torch.float and
|
|
# for now, the lines below assume fbgemm or qnnpack
|
|
weight_dtype is torch.qint8 and
|
|
activation_compute_dtype is torch.quint8
|
|
)
|
|
|
|
def get_qconfig_dtypes(qconfig):
|
|
r""" returns the qconfig tuple for qconfig:
|
|
(activation_dtype, weight_dtype, activation_compute_dtype)
|
|
"""
|
|
assert qconfig is not None
|
|
activation = qconfig.activation()
|
|
weight = qconfig.weight()
|
|
compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None
|
|
return (activation.dtype, weight.dtype, compute_dtype)
|
|
|
|
def get_quant_type(qconfig):
|
|
assert qconfig is not None
|
|
activation = qconfig.activation()
|
|
weight = qconfig.weight()
|
|
static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2]
|
|
if weight.dtype in static_dtypes:
|
|
if activation.dtype in static_dtypes:
|
|
return QuantType.STATIC
|
|
elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes:
|
|
return QuantType.DYNAMIC
|
|
else:
|
|
return QuantType.WEIGHT_ONLY
|
|
|
|
if weight.dtype == torch.float16:
|
|
if activation.dtype == torch.float:
|
|
return QuantType.DYNAMIC
|
|
elif activation.dtype == torch.float16:
|
|
return QuantType.STATIC
|
|
|
|
raise Exception("Unrecognized dtype combination in get_quant_type: activation({}),"
|
|
"weight({})".format(activation.dtype, weight.dtype))
|
|
|
|
def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool:
|
|
""" Checks if the given minimum and maximum values are valid, meaning that
|
|
they exist and the min value is less than the max value.
|
|
"""
|
|
if min_val.numel() == 0 or max_val.numel() == 0:
|
|
warnings.warn(
|
|
"must run observer before calling calculate_qparams. " +
|
|
"Returning default values."
|
|
)
|
|
return False
|
|
|
|
if min_val.dim() == 0 or max_val.dim() == 0:
|
|
if min_val == float("inf") and max_val == float("-inf"):
|
|
warnings.warn(
|
|
"must run observer before calling calculate_qparams. " +
|
|
"Returning default values."
|
|
)
|
|
|
|
return False
|
|
|
|
assert min_val <= max_val, "min {} should be less than max {}".format(
|
|
min_val, max_val
|
|
)
|
|
else:
|
|
assert torch.all(
|
|
min_val <= max_val
|
|
), "min {} should be less than max {}".format(min_val, max_val)
|
|
|
|
return True
|
|
|
|
|
|
def calculate_qmin_qmax(quant_min: int, quant_max: int, has_customized_qrange: bool, dtype: torch.dtype,
|
|
reduce_range: bool) -> Tuple[int, int]:
|
|
r"""Calculates actual qmin and qmax based on the quantization range,
|
|
observer datatype and if range is reduced.
|
|
"""
|
|
if has_customized_qrange:
|
|
# This initialization here is to be resolve TorchScript compilation issues and allow
|
|
# using of refinement to decouple initial_qmin and initial_qmax from quantization range.
|
|
# The actual values of initial_qmin and initial_qmax will be reset below.
|
|
if dtype == torch.qint32:
|
|
initial_quant_min, initial_quant_max = 0, 2**31 - 1
|
|
else:
|
|
initial_quant_min, initial_quant_max = 0, 255
|
|
# The following assignment of self.qmin and self.qmax to the local variables and the if check refine the
|
|
# attribute from Optional valid integers for use, based on TorchScript's requirements.
|
|
custom_quant_min, custom_quant_max = quant_min, quant_max
|
|
if custom_quant_min is not None and custom_quant_max is not None:
|
|
initial_quant_min, initial_quant_max = (
|
|
custom_quant_min,
|
|
custom_quant_max,
|
|
)
|
|
|
|
qrange_len = initial_quant_max - initial_quant_min + 1
|
|
if dtype == torch.qint8:
|
|
assert (
|
|
0 < qrange_len <= 256
|
|
), "quantization range should be positive and not exceed the maximum bit range (=256)."
|
|
elif dtype == torch.qint32:
|
|
assert (
|
|
0 < qrange_len <= 2**31
|
|
), "quantization range should be positive and not exceed the maximum bit range (=4294967296)."
|
|
if dtype == torch.qint8:
|
|
quant_min, quant_max = -qrange_len // 2, qrange_len // 2 - 1
|
|
else:
|
|
quant_min, quant_max = 0, qrange_len - 1
|
|
if reduce_range:
|
|
quant_min, quant_max = quant_min // 2, quant_max // 2
|
|
else:
|
|
# Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used.
|
|
if dtype == torch.qint8:
|
|
if reduce_range:
|
|
quant_min, quant_max = -64, 63
|
|
else:
|
|
quant_min, quant_max = -128, 127
|
|
elif dtype == torch.quint8:
|
|
if reduce_range:
|
|
quant_min, quant_max = 0, 127
|
|
else:
|
|
quant_min, quant_max = 0, 255
|
|
elif dtype == torch.qint32:
|
|
quant_min, quant_max = -1 * (2 ** 31), (2 ** 31) - 1
|
|
else:
|
|
quant_min, quant_max = 0, 15
|
|
return quant_min, quant_max
|
|
|
|
|
|
def _parent_name(target):
|
|
"""
|
|
Turn 'foo.bar' into ['foo', 'bar']
|
|
"""
|
|
r = target.rsplit('.', 1)
|
|
if len(r) == 1:
|
|
return '', r[0]
|
|
else:
|
|
return r[0], r[1]
|