# mypy: allow-untyped-defs from __future__ import annotations import copy import functools import typing_extensions from typing import Any, Optional, TYPE_CHECKING import torch import torch._dynamo as torchdynamo import torch.nn.functional as F from torch.ao.quantization.fake_quantize import ( FakeQuantize, FusedMovingAvgObsFakeQuantize, ) from torch.ao.quantization.observer import ( HistogramObserver, MinMaxObserver, MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver, PerChannelMinMaxObserver, PlaceholderObserver, ) from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer from torch.ao.quantization.quantizer.utils import _get_module_name_filter from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( _convert_scalars_to_attrs, OP_TO_ANNOTATOR, OperatorConfig, OperatorPatternType, propagate_annotation, QuantizationConfig, ) from torch.fx._compatibility import compatibility if TYPE_CHECKING: from collections.abc import Callable from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor from torch.fx import Node __all__ = [ "XNNPACKQuantizer", "get_symmetric_quantization_config", ] def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph: gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs) gm.graph.eliminate_dead_code() return gm.graph def _get_linear_patterns(input_size: list[int]): in_channels = input_size[-1] out_channels = 8 # hard coding but this should not matter weight = torch.ones((out_channels, in_channels)) bias = torch.ones((out_channels,)) act = torch.ones(input_size) def linear_op(act, weight, bias=None): return F.linear(act, weight, bias) pattern_w_bias = _get_dynamo_graph(linear_op, (act, weight, bias)) pattern_wo_bias = _get_dynamo_graph(linear_op, (act, weight)) return [pattern_w_bias, pattern_wo_bias] def _supported_symmetric_quantized_operators() -> dict[str, list[OperatorPatternType]]: supported_operators: dict[str, list[OperatorPatternType]] = { # Both conv and linear should be able to handle relu + hardtanh fusion since # those are clamp ops "conv2d": [ [torch.nn.Conv2d, torch.nn.ReLU], [torch.nn.Conv2d, F.relu], [F.conv2d, torch.nn.ReLU], [F.conv2d, F.relu], ], "linear": [[torch.nn.Linear], [F.linear]], "add": [[torch.add]], "adaptive_avg_pool2d": [ [torch.nn.AdaptiveAvgPool2d], [F.adaptive_avg_pool2d], ], } return copy.deepcopy(supported_operators) def _get_supported_symmetric_config_and_operators() -> list[OperatorConfig]: supported_config_and_operators: list[OperatorConfig] = [] for quantization_config in [ get_symmetric_quantization_config(), get_symmetric_quantization_config(is_qat=True), get_symmetric_quantization_config(is_per_channel=True), get_symmetric_quantization_config(is_per_channel=True, is_qat=True), ]: ops = _supported_symmetric_quantized_operators() supported_config_and_operators.extend( OperatorConfig(quantization_config, pattern_list) for pattern_list in ops.values() ) return copy.deepcopy(supported_config_and_operators) @functools.lru_cache def get_symmetric_quantization_config( is_per_channel: bool = False, is_qat: bool = False, is_dynamic: bool = False, act_qmin: int = -128, act_qmax: int = 127, weight_qmin: int = -127, weight_qmax: int = 127, ): extra_args: dict[str, Any] = {"eps": 2**-12} if is_qat: if is_dynamic: act_observer_or_fake_quant_ctr = FakeQuantize dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( averaging_constant=1 ) extra_args["observer"] = dynamic_quant_observer else: act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] else: if is_dynamic: act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] else: act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] act_quantization_spec = QuantizationSpec( dtype=torch.int8, quant_min=act_qmin, quant_max=act_qmax, qscheme=torch.per_tensor_affine, is_dynamic=is_dynamic, observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( **extra_args, ), ) weight_qscheme = ( torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric ) weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( MinMaxObserver ) if is_qat: # TODO: qat + per channel? weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize elif is_per_channel: weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver extra_args: dict[str, Any] = {"eps": 2**-12} if is_qat: if weight_qscheme == torch.per_tensor_symmetric: extra_args["observer"] = MovingAverageMinMaxObserver else: extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] weight_quantization_spec = QuantizationSpec( dtype=torch.int8, quant_min=weight_qmin, quant_max=weight_qmax, qscheme=weight_qscheme, ch_axis=0, is_dynamic=False, observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( **extra_args ), ) bias_quantization_spec = None if is_dynamic: quantization_config = QuantizationConfig( act_quantization_spec, None, weight_quantization_spec, bias_quantization_spec, is_qat, ) else: quantization_config = QuantizationConfig( act_quantization_spec, act_quantization_spec, weight_quantization_spec, bias_quantization_spec, is_qat, ) return quantization_config def _get_supported_config_and_operators() -> list[OperatorConfig]: return _get_supported_symmetric_config_and_operators() def _get_module_type_filter(tp: Callable): """Get the module_type_filter function for a given module type, the filter accepts a node and checks if the node comes from a module that has certain module type For example: node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear >> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule >> print(module_type_filter(node)) True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well) """ tp_str = tp.__module__ + "." + tp.__qualname__ def module_type_filter(n: Node) -> bool: # example: { # 'L__self___sub': ("L['self'].sub", ), # 'L__self___sub_linear': ("L['self'].sub.linear", ) # } nn_module_stack = n.meta.get("nn_module_stack", {}) types = [] for _, t in nn_module_stack.values(): # export() returns str, but older APIs (e.g. capture_pre_autograd_graph) # return type. Handle both cases. if isinstance(t, type): t = t.__module__ + "." + t.__qualname__ types.append(t) return tp_str in types return module_type_filter def _get_not_module_type_or_name_filter( tp_list: list[Callable], module_name_list: list[str] ) -> Callable[[Node], bool]: module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] def not_module_type_or_name_filter(n: Node) -> bool: return not any(f(n) for f in module_type_filters + module_name_list_filters) return not_module_type_or_name_filter @compatibility(is_backward_compatible=False) @typing_extensions.deprecated( "XNNPACKQuantizer is deprecated! Please use xnnpack quantizer in " "ExecuTorch (https://github.com/pytorch/executorch/tree/main/backends/xnnpack/quantizer) instead." ) class XNNPACKQuantizer(Quantizer): """ !!! DEPRECATED !!! XNNPACKQuantizer is a marked as deprecated. It will be removed in the future. It has been moved to executorch.backends.xnnpack.quantizer.xnnpack_quantizer.XNNPACKQuantizer. Please use the new quantizer instead. """ supported_config_and_operators = _get_supported_config_and_operators() STATIC_QAT_ONLY_OPS = [ "conv_bn_relu", "conv_bn", "conv_transpose_bn_relu", "conv_transpose_bn", ] # static quantization ops (both PTQ and QAT) # Preserve the order that fusions come before singular ops STATIC_OPS = [ "linear_relu", "linear", "conv_relu", "conv", "conv_transpose_relu", "adaptive_avg_pool2d", # TODO: move this to BoltNNQuantizer? "gru_io_only", "add_relu", "add", "mul_relu", "mul", "cat", ] DYNAMIC_OPS = [ "linear", ] def __init__(self) -> None: super().__init__() self.global_config: Optional[QuantizationConfig] = None self.operator_type_config: dict[ torch._ops.OpOverloadPacket, Optional[QuantizationConfig] ] = {} self.module_type_config: dict[Callable, Optional[QuantizationConfig]] = {} self.module_name_config: dict[str, Optional[QuantizationConfig]] = {} @classmethod def get_supported_quantization_configs(cls) -> list[QuantizationConfig]: op_configs: set[QuantizationConfig] = { spec for spec, _ in cls.supported_config_and_operators } return list(op_configs) @classmethod def get_supported_operator_for_quantization_config( cls, quantization_config: Optional[QuantizationConfig] ) -> list[OperatorPatternType]: if quantization_config is None: all_ops = [] for _, ops in cls.supported_config_and_operators: all_ops.extend(ops) return all_ops for config, ops in cls.supported_config_and_operators: # note: this assumes each entry in cls.supported_spec_and_operators # corresponds to one spec, e.g. we don't have # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)] # where the first and second entry have the same spec but did not # merge the op list if config == quantization_config: return ops return [] def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer: self.global_config = quantization_config return self def set_operator_type( self, operator_type: torch._ops.OpOverloadPacket, quantization_config: QuantizationConfig, ) -> XNNPACKQuantizer: self.operator_type_config[operator_type] = quantization_config return self def set_module_type( self, module_type: Callable, quantization_config: QuantizationConfig ): """Set quantization_config for a submodule with type: `module_type`, for example: quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator patterns in the submodule with this module type with the given `quantization_config` """ self.module_type_config[module_type] = quantization_config return self def set_module_name( self, module_name: str, quantization_config: Optional[QuantizationConfig] ): """Set quantization_config for a submodule with name: `module_name`, for example: quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator patterns in the submodule with this module name with the given `quantization_config` """ assert quantization_config is not None, ( " quantization_config == None is not supported yet" ) self.module_name_config[module_name] = quantization_config return self def transform_for_annotation( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: """Transforms scalar values to tensor attributes""" return _convert_scalars_to_attrs(model) def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: """just handling global spec for now""" # hacked for handling dynamic linear quant. will fix later. if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr] model = self._annotate_for_dynamic_quantization_config(model) else: model = self._annotate_for_static_quantization_config(model) propagate_annotation(model) return model def _annotate_all_static_patterns( self, model: torch.fx.GraphModule, quantization_config: Optional[QuantizationConfig], filter_fn: Optional[Callable[[Node], bool]] = None, ) -> torch.fx.GraphModule: # TODO: implement the support for None to be canceling out previous annotations if quantization_config is None: return model if quantization_config.is_qat: for op in self.STATIC_QAT_ONLY_OPS: OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) for op in self.STATIC_OPS: OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) return model def _annotate_all_dynamic_patterns( self, model: torch.fx.GraphModule, quantization_config: Optional[QuantizationConfig], filter_fn: Optional[Callable[[Node], bool]] = None, ) -> torch.fx.GraphModule: # TODO: implement the support for None to be canceling out previous annotations if quantization_config is None: return model for op in self.DYNAMIC_OPS: OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) return model def _annotate_for_static_quantization_config( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: module_name_list = list(self.module_name_config.keys()) for module_name, config in self.module_name_config.items(): self._annotate_all_static_patterns( model, config, _get_module_name_filter(module_name) ) tp_list = list(self.module_type_config.keys()) for module_type, config in self.module_type_config.items(): self._annotate_all_static_patterns( model, config, _get_module_type_filter(module_type) ) self._annotate_all_static_patterns( model, self.global_config, _get_not_module_type_or_name_filter(tp_list, module_name_list), ) return model def _annotate_for_dynamic_quantization_config( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: module_name_list = list(self.module_name_config.keys()) for module_name, config in self.module_name_config.items(): self._annotate_all_dynamic_patterns( model, config, _get_module_name_filter(module_name) ) tp_list = list(self.module_type_config.keys()) for module_type, config in self.module_type_config.items(): self._annotate_all_dynamic_patterns( model, config, _get_module_type_filter(module_type) ) self._annotate_all_dynamic_patterns( model, self.global_config, _get_not_module_type_or_name_filter(tp_list, module_name_list), ) return model def validate(self, model: torch.fx.GraphModule) -> None: pass @classmethod def get_supported_operators(cls) -> list[OperatorConfig]: return cls.supported_config_and_operators