From 0c58a017bde6b1eb2044eec3471df584994929e7 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 10 Sep 2020 21:13:49 -0700 Subject: [PATCH] [quant][eagermode][refactor] Add set/get method for quantization and fusion mappings (#43990) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43990 Allow user to register custom quantization and fusion patterns Test Plan: Imported from OSS Reviewed By: z-a-f Differential Revision: D23485344 fbshipit-source-id: 4f0174ee6d8000d83de0f73cb370e9a1941d54aa --- mypy.ini | 4 +- torch/quantization/__init__.py | 13 ++ torch/quantization/_numeric_suite.py | 28 +-- torch/quantization/fuse_modules.py | 93 +-------- torch/quantization/fuser_method_mappings.py | 101 ++++++++++ torch/quantization/fx/fusion_patterns.py | 8 +- .../quantization/fx/quantization_patterns.py | 22 +-- torch/quantization/fx/quantize.py | 6 +- torch/quantization/quantization_mappings.py | 187 ++++++++++++++++++ torch/quantization/quantize.py | 31 +-- .../testing/_internal/common_quantization.py | 14 +- 11 files changed, 357 insertions(+), 150 deletions(-) create mode 100644 torch/quantization/fuser_method_mappings.py create mode 100644 torch/quantization/quantization_mappings.py diff --git a/mypy.ini b/mypy.ini index 9adf61abb8d..0634f2bfb49 100644 --- a/mypy.ini +++ b/mypy.ini @@ -56,9 +56,6 @@ ignore_errors = True [mypy-torch.testing._internal.*] ignore_errors = True -[mypy-torch.quantization.default_mappings] -ignore_errors = True - [mypy-torch.quantization.observer] ignore_errors = True @@ -74,6 +71,7 @@ ignore_errors = True [mypy-torch.quantization._numeric_suite] ignore_errors = True + [mypy-torch.quantization.quantize_fx] ignore_errors = True diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py index fca1435971b..661a72470b2 100644 --- a/torch/quantization/__init__.py +++ b/torch/quantization/__init__.py @@ -8,6 +8,8 @@ from .stubs import * from .quant_type import * from .quantize_jit import * from .quantize_fx import * +from .quantization_mappings import * +from .fuser_method_mappings import * def default_eval_fn(model, calib_data): r""" @@ -28,6 +30,17 @@ _all__ = [ 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx', 'QuantType', # quantization type + # custom module APIs + 'register_static_quant_module_mapping', + 'get_static_quant_module_mappings', 'get_static_quant_module_class', + 'register_dynamic_quant_module_mapping', + 'get_dynamic_quant_module_mappings', + 'register_qat_module_mapping', + 'get_qat_module_mappings', + 'get_qconfig_propagation_list', + 'get_compare_output_module_list', + 'register_quantized_operator_mapping', 'get_quantized_operator', + 'register_fuser_method', 'get_fuser_method', # Sub functions for `prepare` and `swap_module` 'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module', 'default_eval_fn', 'get_observer_dict', diff --git a/torch/quantization/_numeric_suite.py b/torch/quantization/_numeric_suite.py index b6f795b3fe6..5d55093e3ea 100644 --- a/torch/quantization/_numeric_suite.py +++ b/torch/quantization/_numeric_suite.py @@ -6,25 +6,10 @@ import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd from torch.quantization import prepare -from .default_mappings import ( - _EXCLUDE_QCONFIG_PROPAGATE_LIST, - _INCLUDE_QCONFIG_PROPAGATE_LIST, - DEFAULT_DYNAMIC_MODULE_MAPPING, - DEFAULT_MODULE_MAPPING, - DEFAULT_QAT_MODULE_MAPPING, +from .quantization_mappings import ( + get_compare_output_module_list, ) - -DEFAULT_NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_ALLOWED_LIST = ( - set(DEFAULT_MODULE_MAPPING.values()) - | set(DEFAULT_QAT_MODULE_MAPPING.values()) - | set(DEFAULT_DYNAMIC_MODULE_MAPPING.values()) - | set(DEFAULT_MODULE_MAPPING.keys()) - | set(DEFAULT_QAT_MODULE_MAPPING.keys()) - | set(DEFAULT_DYNAMIC_MODULE_MAPPING.keys()) - | _INCLUDE_QCONFIG_PROPAGATE_LIST -) - _EXCLUDE_QCONFIG_PROPAGATE_LIST - NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = { nnqd.Linear, nnq.Linear, @@ -409,7 +394,7 @@ def prepare_model_outputs( float_module, q_module, Logger=OutputLogger, - allow_list=DEFAULT_NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_ALLOWED_LIST, + allow_list=None ): r"""Prepare the model by attaching the logger to both float module and quantized module if they are in the allow_list. @@ -420,6 +405,9 @@ def prepare_model_outputs( Logger: type of logger to be attached to float_module and q_module allow_list: list of module types to attach logger """ + if allow_list is None: + allow_list = get_compare_output_module_list() + qconfig_debug = torch.quantization.QConfig(activation=Logger, weight=None) float_module.qconfig = qconfig_debug prepare(float_module, inplace=True, allow_list=allow_list) @@ -437,7 +425,7 @@ def compare_model_outputs( q_model, *data, Logger=OutputLogger, - allow_list=DEFAULT_NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_ALLOWED_LIST, + allow_list=None ): r"""Compare output activations between float and quantized models at corresponding locations for the same input. Return a dict with key corresponding @@ -463,6 +451,8 @@ def compare_model_outputs( and each entry being a dictionary with two keys 'float' and 'quantized', containing the matching float and quantized activations """ + if allow_list is None: + allow_list = get_compare_output_module_list() prepare_model_outputs(float_model, q_model, Logger, allow_list) float_model(*data) q_model(*data) diff --git a/torch/quantization/fuse_modules.py b/torch/quantization/fuse_modules.py index 56aabdc1b37..3ab42f2cd22 100644 --- a/torch/quantization/fuse_modules.py +++ b/torch/quantization/fuse_modules.py @@ -2,95 +2,14 @@ from __future__ import absolute_import, division, print_function, unicode_litera import copy -import torch.nn.intrinsic.modules.fused as torch_fused import torch.nn as nn -import torch.nn.intrinsic as nni -from typing import Type, List, Optional, Union, Callable, Tuple, Dict +from .fuser_method_mappings import get_fuser_method +# for backward compatiblity +from .fuser_method_mappings import fuse_conv_bn # noqa: F401 +from .fuser_method_mappings import fuse_conv_bn_relu # noqa: F40 -def fuse_conv_bn(conv, bn): - r"""Given the conv and bn modules, fuses them and returns the fused module - - Args: - conv: Module instance of type conv2d/conv3d - bn: Spatial BN instance that needs to be fused with the conv - - Examples:: - - >>> m1 = nn.Conv2d(10, 20, 3) - >>> b1 = nn.BatchNorm2d(20) - >>> m2 = fuse_conv_bn(m1, b1) - """ - assert(conv.training == bn.training),\ - "Conv and BN both must be in the same mode (train or eval)." - - is_3d = isinstance(conv, nn.Conv3d) - - if conv.training: - assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d' - assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True' - assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True' - return nni.ConvBn3d(conv, bn) if is_3d \ - else nni.ConvBn2d(conv, bn) - else: - return nn.utils.fuse_conv_bn_eval(conv, bn) - -def fuse_conv_bn_relu(conv, bn, relu): - r"""Given the conv and bn modules, fuses them and returns the fused module - - Args: - conv: Module instance of type conv2d/conv3d - bn: Spatial BN instance that needs to be fused with the conv - - Examples:: - - >>> m1 = nn.Conv2d(10, 20, 3) - >>> b1 = nn.BatchNorm2d(20) - >>> m2 = fuse_conv_bn(m1, b1) - """ - assert(conv.training == bn.training == relu.training),\ - "Conv and BN both must be in the same mode (train or eval)." - fused_module : Optional[Type[nn.Sequential]] = None - if conv.training: - map_to_fused_module_train = { - nn.Conv2d: torch_fused.ConvBnReLU2d, - nn.Conv3d: torch_fused.ConvBnReLU3d, - } - assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm' - assert bn.affine, 'Only support fusing BatchNorm with affine set to True' - assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True' - fused_module = map_to_fused_module_train.get(type(conv)) - if fused_module is not None: - return fused_module(conv, bn, relu) - else: - raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu))) - else: - map_to_fused_module_eval = { - nn.Conv1d: torch_fused.ConvReLU1d, - nn.Conv2d: torch_fused.ConvReLU2d, - nn.Conv3d: torch_fused.ConvReLU3d, - } - fused_module = map_to_fused_module_eval[type(conv)] - if fused_module is not None: - fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) - return fused_module(fused_conv, relu) - else: - raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu))) - -OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = { - (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn, - (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, - (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn, - (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu, - (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn, - (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu, - (nn.Conv1d, nn.ReLU): nni.ConvReLU1d, - (nn.Conv2d, nn.ReLU): nni.ConvReLU2d, - (nn.Conv3d, nn.ReLU): nni.ConvReLU3d, - (nn.Linear, nn.ReLU): nni.LinearReLU, - (nn.BatchNorm2d, nn.ReLU): nni.BNReLU2d, - (nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d, -} +from typing import List, Optional # Generalization of getattr def _get_module(model, submodule_key): @@ -123,7 +42,7 @@ def fuse_known_modules(mod_list): the fused operation. The rest of the elements are set to nn.Identity() """ types = tuple(type(m) for m in mod_list) - fuser_method = OP_LIST_TO_FUSER_METHOD.get(types) + fuser_method = get_fuser_method(types) if fuser_method is None: raise NotImplementedError("Cannot fuse modules: {}".format(types)) new_mod : List[Optional[nn.Module]] = [None] * len(mod_list) diff --git a/torch/quantization/fuser_method_mappings.py b/torch/quantization/fuser_method_mappings.py new file mode 100644 index 00000000000..72ad5a7bcc7 --- /dev/null +++ b/torch/quantization/fuser_method_mappings.py @@ -0,0 +1,101 @@ +import torch.nn as nn +import torch.nn.intrinsic as nni + +from typing import Union, Callable, Tuple, Dict, Optional, Type + +def fuse_conv_bn(conv, bn): + r"""Given the conv and bn modules, fuses them and returns the fused module + + Args: + conv: Module instance of type conv2d/conv3d + bn: Spatial BN instance that needs to be fused with the conv + + Examples:: + + >>> m1 = nn.Conv2d(10, 20, 3) + >>> b1 = nn.BatchNorm2d(20) + >>> m2 = fuse_conv_bn(m1, b1) + """ + assert(conv.training == bn.training),\ + "Conv and BN both must be in the same mode (train or eval)." + + is_3d = isinstance(conv, nn.Conv3d) + + if conv.training: + assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d' + assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True' + assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True' + return nni.ConvBn3d(conv, bn) if is_3d \ + else nni.ConvBn2d(conv, bn) + else: + return nn.utils.fuse_conv_bn_eval(conv, bn) + +def fuse_conv_bn_relu(conv, bn, relu): + r"""Given the conv and bn modules, fuses them and returns the fused module + + Args: + conv: Module instance of type conv2d/conv3d + bn: Spatial BN instance that needs to be fused with the conv + + Examples:: + + >>> m1 = nn.Conv2d(10, 20, 3) + >>> b1 = nn.BatchNorm2d(20) + >>> m2 = fuse_conv_bn(m1, b1) + """ + assert(conv.training == bn.training == relu.training),\ + "Conv and BN both must be in the same mode (train or eval)." + fused_module : Optional[Type[nn.Sequential]] = None + if conv.training: + map_to_fused_module_train = { + nn.Conv2d: nni.ConvBnReLU2d, + nn.Conv3d: nni.ConvBnReLU3d, + } + assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm' + assert bn.affine, 'Only support fusing BatchNorm with affine set to True' + assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True' + fused_module = map_to_fused_module_train.get(type(conv)) + if fused_module is not None: + return fused_module(conv, bn, relu) + else: + raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu))) + else: + map_to_fused_module_eval = { + nn.Conv1d: nni.ConvReLU1d, + nn.Conv2d: nni.ConvReLU2d, + nn.Conv3d: nni.ConvReLU3d, + } + fused_module = map_to_fused_module_eval[type(conv)] + if fused_module is not None: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return fused_module(fused_conv, relu) + else: + raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu))) + +OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = { + (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn, + (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, + (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn, + (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu, + (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn, + (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu, + (nn.Conv1d, nn.ReLU): nni.ConvReLU1d, + (nn.Conv2d, nn.ReLU): nni.ConvReLU2d, + (nn.Conv3d, nn.ReLU): nni.ConvReLU3d, + (nn.Linear, nn.ReLU): nni.LinearReLU, + (nn.BatchNorm2d, nn.ReLU): nni.BNReLU2d, + (nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d, +} + +def register_fuser_method(op_list, fuser_method): + ''' Register a fuser method for a tuple of ops, will be called + during fusion step + ''' + assert isinstance(op_list, tuple), 'op list must be a tuple' + OP_LIST_TO_FUSER_METHOD[op_list] = fuser_method + +def get_fuser_method(op_list): + ''' Get fuser method for the given list of module types, + return None if fuser method does not exist + ''' + return OP_LIST_TO_FUSER_METHOD.get(op_list, None) diff --git a/torch/quantization/fx/fusion_patterns.py b/torch/quantization/fx/fusion_patterns.py index 6de0acb9428..fe5631d8548 100644 --- a/torch/quantization/fx/fusion_patterns.py +++ b/torch/quantization/fx/fusion_patterns.py @@ -3,7 +3,7 @@ from .pattern_utils import ( register_fusion_pattern, ) from .utils import _parent_name -from ..fuse_modules import OP_LIST_TO_FUSER_METHOD +from ..fuser_method_mappings import get_fuser_method # --------------------- # Fusion Patterns @@ -60,7 +60,7 @@ class ConvBNReLUFusion(): op_list.reverse() op_type_list = tuple(type(m) for m in op_list) conv_parent_name, conv_name = _parent_name(self.conv_node.target) - fuser_method = OP_LIST_TO_FUSER_METHOD.get(op_type_list, None) + fuser_method = get_fuser_method(op_type_list) if fuser_method is None: raise NotImplementedError("Cannot fuse modules: {}".format(types)) setattr(quantizer.modules[conv_parent_name], conv_name, fuser_method(*op_list)) @@ -104,8 +104,6 @@ class ModuleReLUFusion(): op_list.reverse() op_type_list = tuple(type(m) for m in op_list) module_parent_name, module_name = _parent_name(self.module_node.target) - fuser_method = OP_LIST_TO_FUSER_METHOD.get(op_type_list, None) - if fuser_method is None: - raise NotImplementedError("Cannot fuse modules: {}".format(types)) + fuser_method = get_fuser_method(op_type_list) setattr(quantizer.modules[module_parent_name], module_name, fuser_method(*op_list)) return quantizer.fused_graph.node_copy(self.module_node, load_arg) diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 4b8b16c4ded..c0a376d89e1 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -1,11 +1,11 @@ import torch -from torch.quantization.default_mappings import ( - DEFAULT_MODULE_MAPPING, - DEFAULT_OPERATOR_MAPPING, -) from torch.fx.graph import ( Node, ) +from ..quantization_mappings import ( + get_static_quant_module_class, + get_quantized_operator, +) from .pattern_utils import ( register_quant_pattern, register_dynamic_quant_pattern, @@ -181,10 +181,7 @@ class ConvRelu(QuantizeHandler): else: self.conv.activation_post_process = quantizer.activation_post_process_map[node.name] # 2. select quantized class - # TODO: make the mapping configurable? - assert type(self.conv) in DEFAULT_MODULE_MAPPING, \ - 'unhandled conv type:{}'.format(type(self.conv)) - qconv_cls = DEFAULT_MODULE_MAPPING[type(self.conv)] + qconv_cls = get_static_quant_module_class(type(self.conv)) quantized = qconv_cls.from_float(self.conv) parent_name, name = _parent_name(self.conv_node.target) setattr(quantizer.modules[parent_name], name, quantized) @@ -335,7 +332,7 @@ class BatchNorm(QuantizeHandler): self.bn[1].activation_post_process = activation_post_process else: self.bn.activation_post_process = activation_post_process - qbn_cls = DEFAULT_MODULE_MAPPING[type(self.bn)] + qbn_cls = get_static_quant_module_class(type(self.bn)) quantized = qbn_cls.from_float(self.bn) parent_name, name = _parent_name(self.bn_node.target) setattr(quantizer.modules[parent_name], name, quantized) @@ -371,7 +368,8 @@ class DefaultNode(QuantizeHandler): if node.op == 'call_module': module = quantizer.modules[node.target] module.activation_post_process = activation_post_process - quantized_module = DEFAULT_MODULE_MAPPING[type(module)].from_float(module) + quantized_module_cls = get_static_quant_module_class(type(module)) + quantized_module = quantized_module_cls.from_float(module) parent_name, name = _parent_name(node.target) setattr(quantizer.modules[parent_name], name, quantized_module) return quantizer.quantized_graph.create_node( @@ -385,7 +383,7 @@ class DefaultNode(QuantizeHandler): scale = float(scale) zero_point = int(zero_point) - quantized_op = DEFAULT_OPERATOR_MAPPING[node.target] + quantized_op = get_quantized_operator(node.target) args = load_arg(quantized=[0])(node.args) kwargs = load_arg(quantized=False)(node.kwargs) kwargs.update({'output_scale': scale, 'output_zero_point': zero_point}) @@ -405,7 +403,7 @@ class ELU(QuantizeHandler): scale, zero_point = activation_post_process.calculate_qparams() scale = float(scale) zero_point = int(zero_point) - quantized_op = DEFAULT_OPERATOR_MAPPING[node.target] + quantized_op = get_quantized_operator(node.target) args = load_arg(quantized=[0])(node.args) kwargs = load_arg(quantized=False)(node.kwargs) kwargs.update({'output_scale': scale, 'output_zero_point': zero_point}) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 7f62f3616c1..0e2af583714 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -4,8 +4,8 @@ from torch.quantization import ( convert, ) -from torch.quantization.default_mappings import ( - DEFAULT_QAT_MODULE_MAPPING, +from ..quantization_mappings import ( + get_qat_module_mappings, ) from torch.fx import ( @@ -163,7 +163,7 @@ class Quantizer: def _qat_swap_modules(self, root): - convert(root, mapping=DEFAULT_QAT_MODULE_MAPPING, inplace=True, remove_qconfig=False) + convert(root, mapping=get_qat_module_mappings(), inplace=True, remove_qconfig=False) def _generate_qconfig_map(self, root, input_graph): def get_qconfig(module): diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py new file mode 100644 index 00000000000..585b018a5b0 --- /dev/null +++ b/torch/quantization/quantization_mappings.py @@ -0,0 +1,187 @@ +import torch +from torch import nn + +import torch.nn.functional as F +import torch.nn.intrinsic as nni +import torch.nn.intrinsic.quantized as nniq +import torch.nn.intrinsic.qat as nniqat +import torch.nn.quantized as nnq +import torch.nn.quantized.dynamic as nnqd +import torch.nn.qat as nnqat + +from .stubs import QuantStub, DeQuantStub + +# Map for swapping float module to quantized ones +STATIC_QUANT_MODULE_MAPPINGS = { + nn.Linear: nnq.Linear, + nn.ReLU: nnq.ReLU, + nn.ReLU6: nnq.ReLU6, + nn.Hardswish: nnq.Hardswish, + nn.ELU: nnq.ELU, + nn.Conv1d: nnq.Conv1d, + nn.Conv2d: nnq.Conv2d, + nn.Conv3d: nnq.Conv3d, + nn.BatchNorm2d: nnq.BatchNorm2d, + nn.BatchNorm3d: nnq.BatchNorm3d, + nn.LayerNorm: nnq.LayerNorm, + nn.GroupNorm: nnq.GroupNorm, + nn.InstanceNorm1d: nnq.InstanceNorm1d, + nn.InstanceNorm2d: nnq.InstanceNorm2d, + nn.InstanceNorm3d: nnq.InstanceNorm3d, + nn.Embedding: nnq.Embedding, + nn.EmbeddingBag: nnq.EmbeddingBag, + QuantStub: nnq.Quantize, + DeQuantStub: nnq.DeQuantize, + # Wrapper Modules: + nnq.FloatFunctional: nnq.QFunctional, + # Intrinsic modules: + nni.ConvReLU1d: nniq.ConvReLU1d, + nni.ConvReLU2d: nniq.ConvReLU2d, + nni.ConvReLU3d: nniq.ConvReLU3d, + nni.LinearReLU: nniq.LinearReLU, + nni.BNReLU2d: nniq.BNReLU2d, + nni.BNReLU3d: nniq.BNReLU3d, + nniqat.ConvReLU2d: nniq.ConvReLU2d, + nniqat.LinearReLU: nniq.LinearReLU, + nniqat.ConvBn2d: nnq.Conv2d, + nniqat.ConvBnReLU2d: nniq.ConvReLU2d, + # QAT modules: + nnqat.Linear: nnq.Linear, + nnqat.Conv2d: nnq.Conv2d, +} + +# Map for swapping float module to qat modules +QAT_MODULE_MAPPINGS = { + nn.Linear: nnqat.Linear, + nn.Conv2d: nnqat.Conv2d, + # Intrinsic modules: + nni.ConvBn2d: nniqat.ConvBn2d, + nni.ConvBnReLU2d: nniqat.ConvBnReLU2d, + nni.ConvReLU2d: nniqat.ConvReLU2d, + nni.LinearReLU: nniqat.LinearReLU +} + +# Map for swapping dynamic modules +DYNAMIC_QUANT_MODULE_MAPPINGS = { + nn.Linear: nnqd.Linear, + nn.LSTM: nnqd.LSTM, + nn.LSTMCell: nnqd.LSTMCell, + nn.RNNCell: nnqd.RNNCell, + nn.GRUCell: nnqd.GRUCell, +} + +# Whitelist for propagating the qconfig +_EXCLUDE_QCONFIG_PROPAGATE_LIST = { + DeQuantStub, +} +_INCLUDE_QCONFIG_PROPAGATE_LIST = { + nn.Sequential, +} + +# mapping from floating point function or torch ops to quantized ops +FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = { + F.elu: torch._ops.ops.quantized.elu, + F.hardswish: torch._ops.ops.quantized.hardswish, + F.instance_norm: torch._ops.ops.quantized.instance_norm, + F.layer_norm: torch._ops.ops.quantized.layer_norm, +} + +def register_static_quant_module_mapping( + float_source_module_class, static_quant_target_module_class): + ''' Register a mapping from `float_source__module_class` to `static_quant_target_module_class` + `static_quant_target_module_class` must have from_float defined as a class method + The mapping is used in the convert step of post training static quantization to + convert a float module to a statically quantized module. + ''' + assert hasattr(static_quant_target_module_class, 'from_float'), 'from_float must be defined' + \ + ' in quantized module class' + STATIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = static_quant_target_module_class + +def get_static_quant_module_mappings(): + ''' Get module mapping for post training static quantization + ''' + return STATIC_QUANT_MODULE_MAPPINGS + +def get_static_quant_module_class(float_module_class): + ''' Get the statically quantized module class corresponding to + the floating point module class + ''' + static_quant_module_class = STATIC_QUANT_MODULE_MAPPINGS.get(float_module_class, None) + assert static_quant_module_class is not None, \ + 'Floating point module class {}'.format(float_module_class) + \ + ' does not have a corresponding quantized module class' + return static_quant_module_class + +def register_qat_module_mapping(float_source_module_class, qat_target_module_class): + '''Register a mapping from `float_source_module_class` to `qat_target_module_class`, + `qat_target_module_class` must have from_float defined as a class method + This mapping is used in prepare step of quantization aware training to swap + a float module to a qat module. + ''' + assert hasattr(qat_target_module_class, 'from_float'), 'from_float must be defined' + \ + ' in qat module class' + QAT_MODULE_MAPPINGS[float_source_module_class] = qat_target_module_class + +def get_qat_module_mappings(): + ''' Get module mapping for quantization aware training + ''' + return QAT_MODULE_MAPPINGS + +def register_dynamic_quant_module_class(float_source_module_class, dynamic_quant_target_module_class): + ''' Register a mapping from `float_source_module_class` to `dynamic_quant_target_module_class`, + `dynamic_quant_target_module_class` must have from_float defined as a class method + This mapping is used in convert step of post training dynamic + quantization to swap a float module to a dynamically quantized + module. + ''' + assert hasattr(dynamic_quant_target_module_class, 'from_float'), 'from_float must be defined' + \ + ' in dynamically quantized module type' + DYNAMIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = dynamic_quant_target_module_class + +def get_dynamic_quant_module_mappings(): + ''' Get module mapping for post training dynamic quantization + ''' + return DYNAMIC_QUANT_MODULE_MAPPINGS + +def get_qconfig_propagation_list(): + ''' Get the list of module types that we'll attach qconfig + attribute to in prepare + ''' + QCONFIG_PROPAGATE_MODULE_CLASS_LIST = ( + (set(STATIC_QUANT_MODULE_MAPPINGS.keys()) | + set(QAT_MODULE_MAPPINGS.keys()) | + set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) | + _INCLUDE_QCONFIG_PROPAGATE_LIST) - + _EXCLUDE_QCONFIG_PROPAGATE_LIST + ) + return QCONFIG_PROPAGATE_MODULE_CLASS_LIST + +def get_compare_output_module_list(): + ''' Get list of module class types that we will record output + in numeric suite + ''' + NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = ( + set(STATIC_QUANT_MODULE_MAPPINGS.values()) + | set(QAT_MODULE_MAPPINGS.values()) + | set(DYNAMIC_QUANT_MODULE_MAPPINGS.values()) + | set(STATIC_QUANT_MODULE_MAPPINGS.keys()) + | set(QAT_MODULE_MAPPINGS.keys()) + | set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) + | _INCLUDE_QCONFIG_PROPAGATE_LIST + ) - _EXCLUDE_QCONFIG_PROPAGATE_LIST + return NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST + +def register_quantized_operator_mapping(float_op, quantized_op): + ''' Register a mapping from `floating_point_op` (torch or functional) to `quantized_op` + This is used in convert step of fx based graph mode quantization + to convert a float op to quantized op. + ''' + FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS[float_op] = quantized_op + +def get_quantized_operator(float_op): + ''' Get the quantized operator corresponding to the float operator + ''' + quantized_op = FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None) + assert quantized_op is not None, \ + 'Operator {} does not have corresponding quantized op'.format(float_op) + return quantized_op diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index 0c491abdb28..664074149e7 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -10,10 +10,11 @@ import torch.nn.intrinsic as nni import torch.nn.quantized as nnq import torch.nn.intrinsic.qat as nniqat -from .default_mappings import (DEFAULT_DYNAMIC_MODULE_MAPPING, - DEFAULT_MODULE_MAPPING, - DEFAULT_QAT_MODULE_MAPPING, - DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST) +from .quantization_mappings import (get_dynamic_quant_module_mappings, + get_static_quant_module_mappings, + get_qat_module_mappings, + get_qconfig_propagation_list) + from .stubs import DeQuantStub, QuantWrapper from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_dynamic_qconfig @@ -37,7 +38,7 @@ def _propagate_qconfig_helper(module, qconfig_dict, allow_list=None, """ # TODO: Add test if allow_list is None: - allow_list = DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST + allow_list = get_qconfig_propagation_list() module_qconfig = qconfig_dict.get(type(module), qconfig_parent) module_qconfig = qconfig_dict.get(prefix, module_qconfig) @@ -100,7 +101,8 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No None, module is modified inplace with added observer modules and forward_hooks """ if qconfig_propagation_list is None: - qconfig_propagation_list = DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST + qconfig_propagation_list = get_qconfig_propagation_list() + # respect device affinity when adding observers if device is None: devices = get_unique_devices_(module) @@ -194,9 +196,10 @@ def prepare(model, inplace=False, allow_list=None, """ if not inplace: model = copy.deepcopy(model) - propagate_qconfig_list = allow_list - if propagate_qconfig_list is None: - propagate_qconfig_list = DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST + + qconfig_propagation_list = allow_list + if qconfig_propagation_list is None: + qconfig_propagation_list = get_qconfig_propagation_list() propagate_qconfig_(model, qconfig_dict=None) # sanity check common API misusage @@ -205,7 +208,7 @@ def prepare(model, inplace=False, allow_list=None, "passed correct configuration through `qconfig_dict` or " "by assigning the `.qconfig` attribute directly on submodules") - add_observer_(model, propagate_qconfig_list, observer_non_leaf_module_list, prehook=prehook) + add_observer_(model, qconfig_propagation_list, observer_non_leaf_module_list, prehook=prehook) return model def _remove_qconfig(module): @@ -239,7 +242,7 @@ def quantize(model, run_fn, run_args, mapping=None, inplace=False): Quantized model. """ if mapping is None: - mapping = DEFAULT_MODULE_MAPPING + mapping = get_static_quant_module_mappings() if not inplace: model = copy.deepcopy(model) model.eval() @@ -316,7 +319,7 @@ def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig))) if mapping is None: - mapping = DEFAULT_DYNAMIC_MODULE_MAPPING + mapping = get_dynamic_quant_module_mappings() if not inplace: model = copy.deepcopy(model) @@ -341,7 +344,7 @@ def prepare_qat(model, mapping=None, inplace=False): is mutated """ if mapping is None: - mapping = DEFAULT_QAT_MODULE_MAPPING + mapping = get_qat_module_mappings() if not inplace: model = copy.deepcopy(model) @@ -406,7 +409,7 @@ def _convert(module, mapping=None, inplace=False): """ if mapping is None: - mapping = DEFAULT_MODULE_MAPPING + mapping = get_static_quant_module_mappings() if not inplace: module = copy.deepcopy(module) reassign = {} diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 79259042d44..55f1fdb21f0 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -17,10 +17,10 @@ from torch.testing._internal.common_utils import TestCase from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \ default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \ propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_dynamic_qconfig -from torch.quantization.default_mappings import ( - DEFAULT_DYNAMIC_MODULE_MAPPING, - DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST, - DEFAULT_QAT_MODULE_MAPPING, +from torch.quantization.quantization_mappings import ( + get_dynamic_quant_module_mappings, + get_qconfig_propagation_list, + get_qat_module_mappings, ) # symbolic trace from torch.fx import symbolic_trace @@ -191,7 +191,7 @@ def run_ddp(rank, world_size, prepared): def convert_dynamic(module): - convert(module, DEFAULT_DYNAMIC_MODULE_MAPPING, inplace=True) + convert(module, get_dynamic_quant_module_mappings(), inplace=True) def prepare_dynamic(model, qconfig_dict=None): propagate_qconfig_(model, qconfig_dict) @@ -347,7 +347,7 @@ class QuantizationTestCase(TestCase): have observers in preperation for quantization """ if propagate_qconfig_list is None: - propagate_qconfig_list = DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST + propagate_qconfig_list = get_qconfig_propagation_list() if hasattr(module, 'qconfig') and module.qconfig is not None and \ len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential) \ and type(module) in propagate_qconfig_list: @@ -355,7 +355,7 @@ class QuantizationTestCase(TestCase): 'module: ' + str(type(module)) + ' do not have observer') # we don't need to check observers for child modules of the # qat modules - if type(module) not in DEFAULT_QAT_MODULE_MAPPING.values(): + if type(module) not in get_qat_module_mappings().values(): for child in module.children(): self.checkObservers(child)