pytorch/torch/ao/quantization/utils.py
Vasiliy Kuznetsov 727debb18e dbr quant: enable reference module support for torch.qint32 (#73493)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73493

This PR enables basic support for reference modules in DBR quant.
For now, the support is limited to:
1. modules that have reference versions defined only (no functions)
2. torch.qint32 dtype only

Currently, the reference module logic is enabled whenever dtype is
torch.qint32. This is done because this is needed the earliest for
the first use case. A future PR will support more dtypes and also
add the `is_reference` flag to the API.

Test Plan:
```
python test/test_quantization.py TestQuantizeDBR.test_conv_int32_reference_model
```

Reviewed By: jerryzh168

Differential Revision: D34520759

Pulled By: vkuzo

fbshipit-source-id: 363db715315c5c7c20962a1818330ce288948778
(cherry picked from commit 6ccdfe2889c252211f191edc49f4147f66e803a4)
2022-03-04 17:35:31 +00:00

352 lines
12 KiB
Python

"""
Utils shared by different modes of quantization (eager/graph)
"""
import warnings
import functools
import torch
from torch.ao.quantization.quant_type import QuantType, quant_type_to_str
from typing import Tuple, Any, Union, Callable
# Type for fusion patterns, it can be more complicated than the following actually,
# see pattern.md for docs
# TODO: not sure if typing supports recursive data types
Pattern = Union[Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any]
# TODO: maybe rename this to MatchInputNode
class MatchAllNode:
""" A node pattern that matches all nodes, used in defining
fusion patterns in FX Graph Mode Quantization
"""
pass
module_type_list = {
torch.nn.ReLU,
torch.nn.ReLU6,
torch.nn.AdaptiveAvgPool1d,
torch.nn.AdaptiveAvgPool2d,
torch.nn.AdaptiveAvgPool3d,
torch.nn.AvgPool1d,
torch.nn.AvgPool2d,
torch.nn.AvgPool3d,
torch.nn.MaxPool1d,
torch.nn.MaxPool2d,
torch.nn.MaxPool3d,
torch.nn.Identity,
torch.nn.Hardsigmoid,
torch.nn.Sigmoid,
torch.nn.Tanh,
}
func_list = {
torch.nn.functional.adaptive_avg_pool1d,
torch.nn.functional.adaptive_avg_pool2d,
torch.nn.functional.adaptive_avg_pool3d,
torch.nn.functional.elu,
torch.nn.functional.hardswish,
torch.nn.functional.instance_norm,
torch.nn.functional.layer_norm,
torch.nn.functional.leaky_relu,
torch.nn.functional.silu,
torch.nn.functional.mish,
torch.nn.functional.dropout,
torch.nn.functional.max_pool1d,
torch.nn.functional.max_pool2d,
torch.nn.functional.max_pool3d,
torch.nn.functional.relu,
torch.nn.functional.hardtanh,
torch.nn.functional.hardtanh_,
torch.nn.functional.hardsigmoid,
torch.nn.functional.sigmoid,
torch.transpose,
torch.repeat_interleave,
torch.sigmoid,
torch.squeeze,
torch.stack,
torch.sum,
torch.tanh,
torch.unsqueeze,
torch.cat,
}
method_list = {
torch.mean,
'relu',
'relu_',
'contiguous',
'detach',
'detach_',
'hardsigmoid',
'hardsigmoid_',
'permute',
'repeat',
'repeat_interleave',
'reshape',
'resize_',
'shape',
'sigmoid',
'sigmoid_',
'size',
'squeeze',
'squeeze_',
'tanh',
'tanh_',
'transpose',
'unsqueeze',
'unsqueeze_',
'view',
}
def check_node(node, modules):
# TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py
is_call_function = node.op == "call_function" and node.target in func_list
is_call_method = node.op == "call_method" and node.target in method_list
is_call_module = node.op == "call_module" and type(modules[str(node.target)]) in module_type_list
return is_call_function, is_call_method, is_call_module
def get_combined_dict(default_dict, additional_dict):
d = default_dict.copy()
d.update(additional_dict)
return d
def is_per_tensor(qscheme):
return qscheme == torch.per_tensor_affine or \
qscheme == torch.per_tensor_symmetric
def is_per_channel(qscheme):
return qscheme in [torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
torch.per_channel_symmetric]
def getattr_from_fqn(obj: Any, fqn: str) -> Any:
"""
Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz.
"""
return functools.reduce(getattr, fqn.split("."), obj)
def get_qparam_dict(observer_or_fake_quant):
qscheme = observer_or_fake_quant.qscheme if hasattr(observer_or_fake_quant, "qscheme") else None
dtype = observer_or_fake_quant.dtype
qparams = {"qscheme": qscheme, "dtype": dtype}
if not qscheme:
return qparams
if is_per_tensor(qscheme):
qscheme = torch.per_tensor_affine
elif is_per_channel(qscheme):
# change symmetric to affine since we do not have symmetric
# quantized Tensor
if qscheme == torch.per_channel_symmetric:
qscheme = torch.per_channel_affine
qparams["axis"] = observer_or_fake_quant.ch_axis
else:
raise RuntimeError(f"Unrecognized qscheme: {qscheme}")
# update qscheme, since we don't have symmetric quant qscheme
# in quantized Tensor
qparams["qscheme"] = qscheme
scale, zero_point = observer_or_fake_quant.calculate_qparams()
qparams["scale"] = scale
qparams["zero_point"] = zero_point
return qparams
def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig):
""" Get the observed/quantized custom module class that we need
to swap `custom_module` to
Input:
custom_module: input, can be an instance of either a float or observed custom module
custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping
qconfig: qconfig configured for the custom module
Output:
corresponding observed/quantized custom module class for input custom module instance
"""
quant_type = get_quant_type(qconfig)
quant_type_str = quant_type_to_str(quant_type)
class_mapping = custom_module_class_mapping.get(quant_type_str, {})
assert type(custom_module) in class_mapping, "did not find corresponding observed " \
"module class for {} in mapping: {}".format(type(custom_module), class_mapping)
return class_mapping[type(custom_module)]
def activation_dtype(qconfig):
assert qconfig is not None
activation = qconfig.activation()
return activation.dtype
def weight_dtype(qconfig):
assert qconfig is not None
weight = qconfig.weight()
return weight.dtype
def activation_is_statically_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
quantized or not, this includes quantizing to quint8, qint8 and float16
"""
return activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16]
def activation_is_int8_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
quantized to int8 or not, this includes quantizing to quint8, qint8
"""
return activation_dtype(qconfig) in [torch.quint8, torch.qint8]
def activation_is_int32_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
quantized to int32 or not
"""
return activation_dtype(qconfig) == torch.qint32
def weight_is_quantized(qconfig):
""" Given a qconfig, decide if the weight needs to be
quantized or not
"""
return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16]
def weight_is_statically_quantized(qconfig):
""" Given a qconfig, decide if the weight needs to be statically
quantized or not
"""
return weight_dtype(qconfig) in [torch.quint8, torch.qint8]
def op_is_int8_dynamically_quantized(qconfig) -> bool:
""" Given a qconfig, returns True if this op is using int8 dynamic
quantization
"""
activation_dtype, weight_dtype, activation_compute_dtype = \
get_qconfig_dtypes(qconfig)
return (
activation_dtype is torch.float and
# for now, the lines below assume fbgemm or qnnpack
weight_dtype is torch.qint8 and
activation_compute_dtype is torch.quint8
)
def get_qconfig_dtypes(qconfig):
r""" returns the qconfig tuple for qconfig:
(activation_dtype, weight_dtype, activation_compute_dtype)
"""
assert qconfig is not None
activation = qconfig.activation()
weight = qconfig.weight()
compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None
return (activation.dtype, weight.dtype, compute_dtype)
def get_quant_type(qconfig):
assert qconfig is not None
activation = qconfig.activation()
weight = qconfig.weight()
static_dtypes = [torch.quint8, torch.qint8]
if weight.dtype in static_dtypes:
if activation.dtype in static_dtypes:
return QuantType.STATIC
elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes:
return QuantType.DYNAMIC
else:
return QuantType.WEIGHT_ONLY
if weight.dtype == torch.float16:
if activation.dtype == torch.float:
return QuantType.DYNAMIC
elif activation.dtype == torch.float16:
return QuantType.STATIC
raise Exception("Unrecognized dtype combination in get_quant_type: activation({}),"
"weight({})".format(activation.dtype, weight.dtype))
def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool:
""" Checks if the given minimum and maximum values are valid, meaning that
they exist and the min value is less than the max value.
"""
if min_val.numel() == 0 or max_val.numel() == 0:
warnings.warn(
"must run observer before calling calculate_qparams. " +
"Returning default values."
)
return False
if min_val.dim() == 0 or max_val.dim() == 0:
if min_val == float("inf") and max_val == float("-inf"):
warnings.warn(
"must run observer before calling calculate_qparams. " +
"Returning default values."
)
return False
assert min_val <= max_val, "min {} should be less than max {}".format(
min_val, max_val
)
else:
assert torch.all(
min_val <= max_val
), "min {} should be less than max {}".format(min_val, max_val)
return True
def calculate_qmin_qmax(quant_min: int, quant_max: int, has_customized_qrange: bool, dtype: torch.dtype,
reduce_range: bool) -> Tuple[int, int]:
r"""Calculates actual qmin and qmax based on the quantization range,
observer datatype and if range is reduced.
"""
if has_customized_qrange:
# This initialization here is to be resolve TorchScript compilation issues and allow
# using of refinement to decouple initial_qmin and initial_qmax from quantization range.
# The actual values of initial_qmin and initial_qmax will be reset below.
if dtype == torch.qint32:
initial_quant_min, initial_quant_max = 0, 2**31 - 1
else:
initial_quant_min, initial_quant_max = 0, 255
# The following assignment of self.qmin and self.qmax to the local variables and the if check refine the
# attribute from Optional valid integers for use, based on TorchScript's requirements.
custom_quant_min, custom_quant_max = quant_min, quant_max
if custom_quant_min is not None and custom_quant_max is not None:
initial_quant_min, initial_quant_max = (
custom_quant_min,
custom_quant_max,
)
qrange_len = initial_quant_max - initial_quant_min + 1
if dtype == torch.qint8:
assert (
0 < qrange_len <= 256
), "quantization range should be positive and not exceed the maximum bit range (=256)."
elif dtype == torch.qint32:
assert (
0 < qrange_len <= 2**31
), "quantization range should be positive and not exceed the maximum bit range (=4294967296)."
if dtype == torch.qint8:
quant_min, quant_max = -qrange_len // 2, qrange_len // 2 - 1
else:
quant_min, quant_max = 0, qrange_len - 1
if reduce_range:
quant_min, quant_max = quant_min // 2, quant_max // 2
else:
# Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used.
if dtype == torch.qint8:
if reduce_range:
quant_min, quant_max = -64, 63
else:
quant_min, quant_max = -128, 127
elif dtype == torch.quint8:
if reduce_range:
quant_min, quant_max = 0, 127
else:
quant_min, quant_max = 0, 255
elif dtype == torch.qint32:
quant_min, quant_max = -1 * (2 ** 31), (2 ** 31) - 1
else:
quant_min, quant_max = 0, 15
return quant_min, quant_max
def _parent_name(target):
"""
Turn 'foo.bar' into ['foo', 'bar']
"""
r = target.rsplit('.', 1)
if len(r) == 1:
return '', r[0]
else:
return r[0], r[1]