""" Utils shared by different modes of quantization (eager/graph) """ import warnings import functools import torch from torch.ao.quantization.quant_type import QuantType, quant_type_to_str from typing import Tuple, Any, Union, Callable # Type for fusion patterns, it can be more complicated than the following actually, # see pattern.md for docs # TODO: not sure if typing supports recursive data types Pattern = Union[Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any] # TODO: maybe rename this to MatchInputNode class MatchAllNode: """ A node pattern that matches all nodes, used in defining fusion patterns in FX Graph Mode Quantization """ pass module_type_list = { torch.nn.ReLU, torch.nn.ReLU6, torch.nn.AdaptiveAvgPool1d, torch.nn.AdaptiveAvgPool2d, torch.nn.AdaptiveAvgPool3d, torch.nn.AvgPool1d, torch.nn.AvgPool2d, torch.nn.AvgPool3d, torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.Identity, torch.nn.Hardsigmoid, torch.nn.Sigmoid, torch.nn.Tanh, } func_list = { torch.nn.functional.adaptive_avg_pool1d, torch.nn.functional.adaptive_avg_pool2d, torch.nn.functional.adaptive_avg_pool3d, torch.nn.functional.elu, torch.nn.functional.hardswish, torch.nn.functional.instance_norm, torch.nn.functional.layer_norm, torch.nn.functional.leaky_relu, torch.nn.functional.silu, torch.nn.functional.mish, torch.nn.functional.dropout, torch.nn.functional.max_pool1d, torch.nn.functional.max_pool2d, torch.nn.functional.max_pool3d, torch.nn.functional.relu, torch.nn.functional.hardtanh, torch.nn.functional.hardtanh_, torch.nn.functional.hardsigmoid, torch.nn.functional.sigmoid, torch.transpose, torch.repeat_interleave, torch.sigmoid, torch.squeeze, torch.stack, torch.sum, torch.tanh, torch.unsqueeze, torch.cat, } method_list = { torch.mean, 'relu', 'relu_', 'contiguous', 'detach', 'detach_', 'hardsigmoid', 'hardsigmoid_', 'permute', 'repeat', 'repeat_interleave', 'reshape', 'resize_', 'shape', 'sigmoid', 'sigmoid_', 'size', 'squeeze', 'squeeze_', 'tanh', 'tanh_', 'transpose', 'unsqueeze', 'unsqueeze_', 'view', } def check_node(node, modules): # TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py is_call_function = node.op == "call_function" and node.target in func_list is_call_method = node.op == "call_method" and node.target in method_list is_call_module = node.op == "call_module" and type(modules[str(node.target)]) in module_type_list return is_call_function, is_call_method, is_call_module def get_combined_dict(default_dict, additional_dict): d = default_dict.copy() d.update(additional_dict) return d def is_per_tensor(qscheme): return qscheme == torch.per_tensor_affine or \ qscheme == torch.per_tensor_symmetric def is_per_channel(qscheme): return qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric] def getattr_from_fqn(obj: Any, fqn: str) -> Any: """ Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz. """ return functools.reduce(getattr, fqn.split("."), obj) def get_qparam_dict(observer_or_fake_quant): qscheme = observer_or_fake_quant.qscheme if hasattr(observer_or_fake_quant, "qscheme") else None dtype = observer_or_fake_quant.dtype qparams = {"qscheme": qscheme, "dtype": dtype} if not qscheme: return qparams if is_per_tensor(qscheme): qscheme = torch.per_tensor_affine elif is_per_channel(qscheme): # change symmetric to affine since we do not have symmetric # quantized Tensor if qscheme == torch.per_channel_symmetric: qscheme = torch.per_channel_affine qparams["axis"] = observer_or_fake_quant.ch_axis else: raise RuntimeError(f"Unrecognized qscheme: {qscheme}") # update qscheme, since we don't have symmetric quant qscheme # in quantized Tensor qparams["qscheme"] = qscheme scale, zero_point = observer_or_fake_quant.calculate_qparams() qparams["scale"] = scale qparams["zero_point"] = zero_point return qparams def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig): """ Get the observed/quantized custom module class that we need to swap `custom_module` to Input: custom_module: input, can be an instance of either a float or observed custom module custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping qconfig: qconfig configured for the custom module Output: corresponding observed/quantized custom module class for input custom module instance """ quant_type = get_quant_type(qconfig) quant_type_str = quant_type_to_str(quant_type) class_mapping = custom_module_class_mapping.get(quant_type_str, {}) assert type(custom_module) in class_mapping, "did not find corresponding observed " \ "module class for {} in mapping: {}".format(type(custom_module), class_mapping) return class_mapping[type(custom_module)] def activation_dtype(qconfig): assert qconfig is not None activation = qconfig.activation() return activation.dtype def weight_dtype(qconfig): assert qconfig is not None weight = qconfig.weight() return weight.dtype def activation_is_statically_quantized(qconfig): """ Given a qconfig, decide if the activation needs to be quantized or not, this includes quantizing to quint8, qint8 and float16 """ return activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16] def activation_is_dynamically_quantized(qconfig): """ Given a qconfig, decide if the activation needs to be dynamically quantized or not, this includes dynamically quantizing to quint8, qint8 and float16 """ activation_dtype, _, activation_compute_dtype = \ get_qconfig_dtypes(qconfig) return activation_dtype == torch.float and \ activation_compute_dtype in [torch.quint8, torch.qint8, torch.float16] def activation_is_int8_quantized(qconfig): """ Given a qconfig, decide if the activation needs to be quantized to int8 or not, this includes quantizing to quint8, qint8 """ return activation_dtype(qconfig) in [torch.quint8, torch.qint8] def activation_is_int32_quantized(qconfig): """ Given a qconfig, decide if the activation needs to be quantized to int32 or not """ return activation_dtype(qconfig) == torch.qint32 def weight_is_quantized(qconfig): """ Given a qconfig, decide if the weight needs to be quantized or not """ return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16, torch.quint4x2] def weight_is_statically_quantized(qconfig): """ Given a qconfig, decide if the weight needs to be statically quantized or not """ return weight_dtype(qconfig) in [torch.quint8, torch.qint8] def op_is_int8_dynamically_quantized(qconfig) -> bool: """ Given a qconfig, returns True if this op is using int8 dynamic quantization """ activation_dtype, weight_dtype, activation_compute_dtype = \ get_qconfig_dtypes(qconfig) return ( activation_dtype is torch.float and # for now, the lines below assume fbgemm or qnnpack weight_dtype is torch.qint8 and activation_compute_dtype is torch.quint8 ) def get_qconfig_dtypes(qconfig): r""" returns the qconfig tuple for qconfig: (activation_dtype, weight_dtype, activation_compute_dtype) """ assert qconfig is not None activation = qconfig.activation() weight = qconfig.weight() compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None return (activation.dtype, weight.dtype, compute_dtype) def get_quant_type(qconfig): assert qconfig is not None activation = qconfig.activation() weight = qconfig.weight() static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2] if weight.dtype in static_dtypes: if activation.dtype in static_dtypes: return QuantType.STATIC elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes: return QuantType.DYNAMIC else: return QuantType.WEIGHT_ONLY if weight.dtype == torch.float16: if activation.dtype == torch.float: return QuantType.DYNAMIC elif activation.dtype == torch.float16: return QuantType.STATIC raise Exception("Unrecognized dtype combination in get_quant_type: activation({})," "weight({})".format(activation.dtype, weight.dtype)) def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool: """ Checks if the given minimum and maximum values are valid, meaning that they exist and the min value is less than the max value. """ if min_val.numel() == 0 or max_val.numel() == 0: warnings.warn( "must run observer before calling calculate_qparams. " + "Returning default values." ) return False if min_val.dim() == 0 or max_val.dim() == 0: if min_val == float("inf") and max_val == float("-inf"): warnings.warn( "must run observer before calling calculate_qparams. " + "Returning default values." ) return False assert min_val <= max_val, "min {} should be less than max {}".format( min_val, max_val ) else: assert torch.all( min_val <= max_val ), "min {} should be less than max {}".format(min_val, max_val) return True def calculate_qmin_qmax(quant_min: int, quant_max: int, has_customized_qrange: bool, dtype: torch.dtype, reduce_range: bool) -> Tuple[int, int]: r"""Calculates actual qmin and qmax based on the quantization range, observer datatype and if range is reduced. """ if has_customized_qrange: # This initialization here is to be resolve TorchScript compilation issues and allow # using of refinement to decouple initial_qmin and initial_qmax from quantization range. # The actual values of initial_qmin and initial_qmax will be reset below. if dtype == torch.qint32: initial_quant_min, initial_quant_max = 0, 2**31 - 1 else: initial_quant_min, initial_quant_max = 0, 255 # The following assignment of self.qmin and self.qmax to the local variables and the if check refine the # attribute from Optional valid integers for use, based on TorchScript's requirements. custom_quant_min, custom_quant_max = quant_min, quant_max if custom_quant_min is not None and custom_quant_max is not None: initial_quant_min, initial_quant_max = ( custom_quant_min, custom_quant_max, ) qrange_len = initial_quant_max - initial_quant_min + 1 if dtype == torch.qint8: assert ( 0 < qrange_len <= 256 ), "quantization range should be positive and not exceed the maximum bit range (=256)." elif dtype == torch.qint32: assert ( 0 < qrange_len <= 2**31 ), "quantization range should be positive and not exceed the maximum bit range (=4294967296)." if dtype == torch.qint8: quant_min, quant_max = -qrange_len // 2, qrange_len // 2 - 1 else: quant_min, quant_max = 0, qrange_len - 1 if reduce_range: quant_min, quant_max = quant_min // 2, quant_max // 2 else: # Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used. if dtype == torch.qint8: if reduce_range: quant_min, quant_max = -64, 63 else: quant_min, quant_max = -128, 127 elif dtype == torch.quint8: if reduce_range: quant_min, quant_max = 0, 127 else: quant_min, quant_max = 0, 255 elif dtype == torch.qint32: quant_min, quant_max = -1 * (2 ** 31), (2 ** 31) - 1 else: quant_min, quant_max = 0, 15 return quant_min, quant_max def _parent_name(target): """ Turn 'foo.bar' into ['foo', 'bar'] """ r = target.rsplit('.', 1) if len(r) == 1: return '', r[0] else: return r[0], r[1]