[quant][eagermode][refactor] Add set/get method for quantization and fusion mappings (#43990)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43990

Allow user to register custom quantization and fusion patterns

Test Plan: Imported from OSS

Reviewed By: z-a-f

Differential Revision: D23485344

fbshipit-source-id: 4f0174ee6d8000d83de0f73cb370e9a1941d54aa
This commit is contained in:
Jerry Zhang 2020-09-10 21:13:49 -07:00 committed by Facebook GitHub Bot
parent f7278473d3
commit 0c58a017bd
11 changed files with 357 additions and 150 deletions

View File

@ -56,9 +56,6 @@ ignore_errors = True
[mypy-torch.testing._internal.*] [mypy-torch.testing._internal.*]
ignore_errors = True ignore_errors = True
[mypy-torch.quantization.default_mappings]
ignore_errors = True
[mypy-torch.quantization.observer] [mypy-torch.quantization.observer]
ignore_errors = True ignore_errors = True
@ -74,6 +71,7 @@ ignore_errors = True
[mypy-torch.quantization._numeric_suite] [mypy-torch.quantization._numeric_suite]
ignore_errors = True ignore_errors = True
[mypy-torch.quantization.quantize_fx] [mypy-torch.quantization.quantize_fx]
ignore_errors = True ignore_errors = True

View File

@ -8,6 +8,8 @@ from .stubs import *
from .quant_type import * from .quant_type import *
from .quantize_jit import * from .quantize_jit import *
from .quantize_fx import * from .quantize_fx import *
from .quantization_mappings import *
from .fuser_method_mappings import *
def default_eval_fn(model, calib_data): def default_eval_fn(model, calib_data):
r""" r"""
@ -28,6 +30,17 @@ _all__ = [
'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx
'prepare_fx', 'prepare_dynamic_fx', 'convert_fx', 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',
'QuantType', # quantization type 'QuantType', # quantization type
# custom module APIs
'register_static_quant_module_mapping',
'get_static_quant_module_mappings', 'get_static_quant_module_class',
'register_dynamic_quant_module_mapping',
'get_dynamic_quant_module_mappings',
'register_qat_module_mapping',
'get_qat_module_mappings',
'get_qconfig_propagation_list',
'get_compare_output_module_list',
'register_quantized_operator_mapping', 'get_quantized_operator',
'register_fuser_method', 'get_fuser_method',
# Sub functions for `prepare` and `swap_module` # Sub functions for `prepare` and `swap_module`
'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module', 'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module',
'default_eval_fn', 'get_observer_dict', 'default_eval_fn', 'get_observer_dict',

View File

@ -6,25 +6,10 @@ import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd import torch.nn.quantized.dynamic as nnqd
from torch.quantization import prepare from torch.quantization import prepare
from .default_mappings import ( from .quantization_mappings import (
_EXCLUDE_QCONFIG_PROPAGATE_LIST, get_compare_output_module_list,
_INCLUDE_QCONFIG_PROPAGATE_LIST,
DEFAULT_DYNAMIC_MODULE_MAPPING,
DEFAULT_MODULE_MAPPING,
DEFAULT_QAT_MODULE_MAPPING,
) )
DEFAULT_NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_ALLOWED_LIST = (
set(DEFAULT_MODULE_MAPPING.values())
| set(DEFAULT_QAT_MODULE_MAPPING.values())
| set(DEFAULT_DYNAMIC_MODULE_MAPPING.values())
| set(DEFAULT_MODULE_MAPPING.keys())
| set(DEFAULT_QAT_MODULE_MAPPING.keys())
| set(DEFAULT_DYNAMIC_MODULE_MAPPING.keys())
| _INCLUDE_QCONFIG_PROPAGATE_LIST
) - _EXCLUDE_QCONFIG_PROPAGATE_LIST
NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = { NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
nnqd.Linear, nnqd.Linear,
nnq.Linear, nnq.Linear,
@ -409,7 +394,7 @@ def prepare_model_outputs(
float_module, float_module,
q_module, q_module,
Logger=OutputLogger, Logger=OutputLogger,
allow_list=DEFAULT_NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_ALLOWED_LIST, allow_list=None
): ):
r"""Prepare the model by attaching the logger to both float module r"""Prepare the model by attaching the logger to both float module
and quantized module if they are in the allow_list. and quantized module if they are in the allow_list.
@ -420,6 +405,9 @@ def prepare_model_outputs(
Logger: type of logger to be attached to float_module and q_module Logger: type of logger to be attached to float_module and q_module
allow_list: list of module types to attach logger allow_list: list of module types to attach logger
""" """
if allow_list is None:
allow_list = get_compare_output_module_list()
qconfig_debug = torch.quantization.QConfig(activation=Logger, weight=None) qconfig_debug = torch.quantization.QConfig(activation=Logger, weight=None)
float_module.qconfig = qconfig_debug float_module.qconfig = qconfig_debug
prepare(float_module, inplace=True, allow_list=allow_list) prepare(float_module, inplace=True, allow_list=allow_list)
@ -437,7 +425,7 @@ def compare_model_outputs(
q_model, q_model,
*data, *data,
Logger=OutputLogger, Logger=OutputLogger,
allow_list=DEFAULT_NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_ALLOWED_LIST, allow_list=None
): ):
r"""Compare output activations between float and quantized models at r"""Compare output activations between float and quantized models at
corresponding locations for the same input. Return a dict with key corresponding corresponding locations for the same input. Return a dict with key corresponding
@ -463,6 +451,8 @@ def compare_model_outputs(
and each entry being a dictionary with two keys 'float' and 'quantized', and each entry being a dictionary with two keys 'float' and 'quantized',
containing the matching float and quantized activations containing the matching float and quantized activations
""" """
if allow_list is None:
allow_list = get_compare_output_module_list()
prepare_model_outputs(float_model, q_model, Logger, allow_list) prepare_model_outputs(float_model, q_model, Logger, allow_list)
float_model(*data) float_model(*data)
q_model(*data) q_model(*data)

View File

@ -2,95 +2,14 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import copy import copy
import torch.nn.intrinsic.modules.fused as torch_fused
import torch.nn as nn import torch.nn as nn
import torch.nn.intrinsic as nni
from typing import Type, List, Optional, Union, Callable, Tuple, Dict from .fuser_method_mappings import get_fuser_method
# for backward compatiblity
from .fuser_method_mappings import fuse_conv_bn # noqa: F401
from .fuser_method_mappings import fuse_conv_bn_relu # noqa: F40
def fuse_conv_bn(conv, bn): from typing import List, Optional
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert(conv.training == bn.training),\
"Conv and BN both must be in the same mode (train or eval)."
is_3d = isinstance(conv, nn.Conv3d)
if conv.training:
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
return nni.ConvBn3d(conv, bn) if is_3d \
else nni.ConvBn2d(conv, bn)
else:
return nn.utils.fuse_conv_bn_eval(conv, bn)
def fuse_conv_bn_relu(conv, bn, relu):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert(conv.training == bn.training == relu.training),\
"Conv and BN both must be in the same mode (train or eval)."
fused_module : Optional[Type[nn.Sequential]] = None
if conv.training:
map_to_fused_module_train = {
nn.Conv2d: torch_fused.ConvBnReLU2d,
nn.Conv3d: torch_fused.ConvBnReLU3d,
}
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
fused_module = map_to_fused_module_train.get(type(conv))
if fused_module is not None:
return fused_module(conv, bn, relu)
else:
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
else:
map_to_fused_module_eval = {
nn.Conv1d: torch_fused.ConvReLU1d,
nn.Conv2d: torch_fused.ConvReLU2d,
nn.Conv3d: torch_fused.ConvReLU3d,
}
fused_module = map_to_fused_module_eval[type(conv)]
if fused_module is not None:
fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
return fused_module(fused_conv, relu)
else:
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
(nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
(nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
(nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv1d, nn.ReLU): nni.ConvReLU1d,
(nn.Conv2d, nn.ReLU): nni.ConvReLU2d,
(nn.Conv3d, nn.ReLU): nni.ConvReLU3d,
(nn.Linear, nn.ReLU): nni.LinearReLU,
(nn.BatchNorm2d, nn.ReLU): nni.BNReLU2d,
(nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d,
}
# Generalization of getattr # Generalization of getattr
def _get_module(model, submodule_key): def _get_module(model, submodule_key):
@ -123,7 +42,7 @@ def fuse_known_modules(mod_list):
the fused operation. The rest of the elements are set to nn.Identity() the fused operation. The rest of the elements are set to nn.Identity()
""" """
types = tuple(type(m) for m in mod_list) types = tuple(type(m) for m in mod_list)
fuser_method = OP_LIST_TO_FUSER_METHOD.get(types) fuser_method = get_fuser_method(types)
if fuser_method is None: if fuser_method is None:
raise NotImplementedError("Cannot fuse modules: {}".format(types)) raise NotImplementedError("Cannot fuse modules: {}".format(types))
new_mod : List[Optional[nn.Module]] = [None] * len(mod_list) new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)

View File

@ -0,0 +1,101 @@
import torch.nn as nn
import torch.nn.intrinsic as nni
from typing import Union, Callable, Tuple, Dict, Optional, Type
def fuse_conv_bn(conv, bn):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert(conv.training == bn.training),\
"Conv and BN both must be in the same mode (train or eval)."
is_3d = isinstance(conv, nn.Conv3d)
if conv.training:
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
return nni.ConvBn3d(conv, bn) if is_3d \
else nni.ConvBn2d(conv, bn)
else:
return nn.utils.fuse_conv_bn_eval(conv, bn)
def fuse_conv_bn_relu(conv, bn, relu):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert(conv.training == bn.training == relu.training),\
"Conv and BN both must be in the same mode (train or eval)."
fused_module : Optional[Type[nn.Sequential]] = None
if conv.training:
map_to_fused_module_train = {
nn.Conv2d: nni.ConvBnReLU2d,
nn.Conv3d: nni.ConvBnReLU3d,
}
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
fused_module = map_to_fused_module_train.get(type(conv))
if fused_module is not None:
return fused_module(conv, bn, relu)
else:
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
else:
map_to_fused_module_eval = {
nn.Conv1d: nni.ConvReLU1d,
nn.Conv2d: nni.ConvReLU2d,
nn.Conv3d: nni.ConvReLU3d,
}
fused_module = map_to_fused_module_eval[type(conv)]
if fused_module is not None:
fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
return fused_module(fused_conv, relu)
else:
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
(nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
(nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
(nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv1d, nn.ReLU): nni.ConvReLU1d,
(nn.Conv2d, nn.ReLU): nni.ConvReLU2d,
(nn.Conv3d, nn.ReLU): nni.ConvReLU3d,
(nn.Linear, nn.ReLU): nni.LinearReLU,
(nn.BatchNorm2d, nn.ReLU): nni.BNReLU2d,
(nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d,
}
def register_fuser_method(op_list, fuser_method):
''' Register a fuser method for a tuple of ops, will be called
during fusion step
'''
assert isinstance(op_list, tuple), 'op list must be a tuple'
OP_LIST_TO_FUSER_METHOD[op_list] = fuser_method
def get_fuser_method(op_list):
''' Get fuser method for the given list of module types,
return None if fuser method does not exist
'''
return OP_LIST_TO_FUSER_METHOD.get(op_list, None)

View File

@ -3,7 +3,7 @@ from .pattern_utils import (
register_fusion_pattern, register_fusion_pattern,
) )
from .utils import _parent_name from .utils import _parent_name
from ..fuse_modules import OP_LIST_TO_FUSER_METHOD from ..fuser_method_mappings import get_fuser_method
# --------------------- # ---------------------
# Fusion Patterns # Fusion Patterns
@ -60,7 +60,7 @@ class ConvBNReLUFusion():
op_list.reverse() op_list.reverse()
op_type_list = tuple(type(m) for m in op_list) op_type_list = tuple(type(m) for m in op_list)
conv_parent_name, conv_name = _parent_name(self.conv_node.target) conv_parent_name, conv_name = _parent_name(self.conv_node.target)
fuser_method = OP_LIST_TO_FUSER_METHOD.get(op_type_list, None) fuser_method = get_fuser_method(op_type_list)
if fuser_method is None: if fuser_method is None:
raise NotImplementedError("Cannot fuse modules: {}".format(types)) raise NotImplementedError("Cannot fuse modules: {}".format(types))
setattr(quantizer.modules[conv_parent_name], conv_name, fuser_method(*op_list)) setattr(quantizer.modules[conv_parent_name], conv_name, fuser_method(*op_list))
@ -104,8 +104,6 @@ class ModuleReLUFusion():
op_list.reverse() op_list.reverse()
op_type_list = tuple(type(m) for m in op_list) op_type_list = tuple(type(m) for m in op_list)
module_parent_name, module_name = _parent_name(self.module_node.target) module_parent_name, module_name = _parent_name(self.module_node.target)
fuser_method = OP_LIST_TO_FUSER_METHOD.get(op_type_list, None) fuser_method = get_fuser_method(op_type_list)
if fuser_method is None:
raise NotImplementedError("Cannot fuse modules: {}".format(types))
setattr(quantizer.modules[module_parent_name], module_name, fuser_method(*op_list)) setattr(quantizer.modules[module_parent_name], module_name, fuser_method(*op_list))
return quantizer.fused_graph.node_copy(self.module_node, load_arg) return quantizer.fused_graph.node_copy(self.module_node, load_arg)

View File

@ -1,11 +1,11 @@
import torch import torch
from torch.quantization.default_mappings import (
DEFAULT_MODULE_MAPPING,
DEFAULT_OPERATOR_MAPPING,
)
from torch.fx.graph import ( from torch.fx.graph import (
Node, Node,
) )
from ..quantization_mappings import (
get_static_quant_module_class,
get_quantized_operator,
)
from .pattern_utils import ( from .pattern_utils import (
register_quant_pattern, register_quant_pattern,
register_dynamic_quant_pattern, register_dynamic_quant_pattern,
@ -181,10 +181,7 @@ class ConvRelu(QuantizeHandler):
else: else:
self.conv.activation_post_process = quantizer.activation_post_process_map[node.name] self.conv.activation_post_process = quantizer.activation_post_process_map[node.name]
# 2. select quantized class # 2. select quantized class
# TODO: make the mapping configurable? qconv_cls = get_static_quant_module_class(type(self.conv))
assert type(self.conv) in DEFAULT_MODULE_MAPPING, \
'unhandled conv type:{}'.format(type(self.conv))
qconv_cls = DEFAULT_MODULE_MAPPING[type(self.conv)]
quantized = qconv_cls.from_float(self.conv) quantized = qconv_cls.from_float(self.conv)
parent_name, name = _parent_name(self.conv_node.target) parent_name, name = _parent_name(self.conv_node.target)
setattr(quantizer.modules[parent_name], name, quantized) setattr(quantizer.modules[parent_name], name, quantized)
@ -335,7 +332,7 @@ class BatchNorm(QuantizeHandler):
self.bn[1].activation_post_process = activation_post_process self.bn[1].activation_post_process = activation_post_process
else: else:
self.bn.activation_post_process = activation_post_process self.bn.activation_post_process = activation_post_process
qbn_cls = DEFAULT_MODULE_MAPPING[type(self.bn)] qbn_cls = get_static_quant_module_class(type(self.bn))
quantized = qbn_cls.from_float(self.bn) quantized = qbn_cls.from_float(self.bn)
parent_name, name = _parent_name(self.bn_node.target) parent_name, name = _parent_name(self.bn_node.target)
setattr(quantizer.modules[parent_name], name, quantized) setattr(quantizer.modules[parent_name], name, quantized)
@ -371,7 +368,8 @@ class DefaultNode(QuantizeHandler):
if node.op == 'call_module': if node.op == 'call_module':
module = quantizer.modules[node.target] module = quantizer.modules[node.target]
module.activation_post_process = activation_post_process module.activation_post_process = activation_post_process
quantized_module = DEFAULT_MODULE_MAPPING[type(module)].from_float(module) quantized_module_cls = get_static_quant_module_class(type(module))
quantized_module = quantized_module_cls.from_float(module)
parent_name, name = _parent_name(node.target) parent_name, name = _parent_name(node.target)
setattr(quantizer.modules[parent_name], name, quantized_module) setattr(quantizer.modules[parent_name], name, quantized_module)
return quantizer.quantized_graph.create_node( return quantizer.quantized_graph.create_node(
@ -385,7 +383,7 @@ class DefaultNode(QuantizeHandler):
scale = float(scale) scale = float(scale)
zero_point = int(zero_point) zero_point = int(zero_point)
quantized_op = DEFAULT_OPERATOR_MAPPING[node.target] quantized_op = get_quantized_operator(node.target)
args = load_arg(quantized=[0])(node.args) args = load_arg(quantized=[0])(node.args)
kwargs = load_arg(quantized=False)(node.kwargs) kwargs = load_arg(quantized=False)(node.kwargs)
kwargs.update({'output_scale': scale, 'output_zero_point': zero_point}) kwargs.update({'output_scale': scale, 'output_zero_point': zero_point})
@ -405,7 +403,7 @@ class ELU(QuantizeHandler):
scale, zero_point = activation_post_process.calculate_qparams() scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale) scale = float(scale)
zero_point = int(zero_point) zero_point = int(zero_point)
quantized_op = DEFAULT_OPERATOR_MAPPING[node.target] quantized_op = get_quantized_operator(node.target)
args = load_arg(quantized=[0])(node.args) args = load_arg(quantized=[0])(node.args)
kwargs = load_arg(quantized=False)(node.kwargs) kwargs = load_arg(quantized=False)(node.kwargs)
kwargs.update({'output_scale': scale, 'output_zero_point': zero_point}) kwargs.update({'output_scale': scale, 'output_zero_point': zero_point})

View File

@ -4,8 +4,8 @@ from torch.quantization import (
convert, convert,
) )
from torch.quantization.default_mappings import ( from ..quantization_mappings import (
DEFAULT_QAT_MODULE_MAPPING, get_qat_module_mappings,
) )
from torch.fx import ( from torch.fx import (
@ -163,7 +163,7 @@ class Quantizer:
def _qat_swap_modules(self, root): def _qat_swap_modules(self, root):
convert(root, mapping=DEFAULT_QAT_MODULE_MAPPING, inplace=True, remove_qconfig=False) convert(root, mapping=get_qat_module_mappings(), inplace=True, remove_qconfig=False)
def _generate_qconfig_map(self, root, input_graph): def _generate_qconfig_map(self, root, input_graph):
def get_qconfig(module): def get_qconfig(module):

View File

@ -0,0 +1,187 @@
import torch
from torch import nn
import torch.nn.functional as F
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.qat as nniqat
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
import torch.nn.qat as nnqat
from .stubs import QuantStub, DeQuantStub
# Map for swapping float module to quantized ones
STATIC_QUANT_MODULE_MAPPINGS = {
nn.Linear: nnq.Linear,
nn.ReLU: nnq.ReLU,
nn.ReLU6: nnq.ReLU6,
nn.Hardswish: nnq.Hardswish,
nn.ELU: nnq.ELU,
nn.Conv1d: nnq.Conv1d,
nn.Conv2d: nnq.Conv2d,
nn.Conv3d: nnq.Conv3d,
nn.BatchNorm2d: nnq.BatchNorm2d,
nn.BatchNorm3d: nnq.BatchNorm3d,
nn.LayerNorm: nnq.LayerNorm,
nn.GroupNorm: nnq.GroupNorm,
nn.InstanceNorm1d: nnq.InstanceNorm1d,
nn.InstanceNorm2d: nnq.InstanceNorm2d,
nn.InstanceNorm3d: nnq.InstanceNorm3d,
nn.Embedding: nnq.Embedding,
nn.EmbeddingBag: nnq.EmbeddingBag,
QuantStub: nnq.Quantize,
DeQuantStub: nnq.DeQuantize,
# Wrapper Modules:
nnq.FloatFunctional: nnq.QFunctional,
# Intrinsic modules:
nni.ConvReLU1d: nniq.ConvReLU1d,
nni.ConvReLU2d: nniq.ConvReLU2d,
nni.ConvReLU3d: nniq.ConvReLU3d,
nni.LinearReLU: nniq.LinearReLU,
nni.BNReLU2d: nniq.BNReLU2d,
nni.BNReLU3d: nniq.BNReLU3d,
nniqat.ConvReLU2d: nniq.ConvReLU2d,
nniqat.LinearReLU: nniq.LinearReLU,
nniqat.ConvBn2d: nnq.Conv2d,
nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
# QAT modules:
nnqat.Linear: nnq.Linear,
nnqat.Conv2d: nnq.Conv2d,
}
# Map for swapping float module to qat modules
QAT_MODULE_MAPPINGS = {
nn.Linear: nnqat.Linear,
nn.Conv2d: nnqat.Conv2d,
# Intrinsic modules:
nni.ConvBn2d: nniqat.ConvBn2d,
nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
nni.ConvReLU2d: nniqat.ConvReLU2d,
nni.LinearReLU: nniqat.LinearReLU
}
# Map for swapping dynamic modules
DYNAMIC_QUANT_MODULE_MAPPINGS = {
nn.Linear: nnqd.Linear,
nn.LSTM: nnqd.LSTM,
nn.LSTMCell: nnqd.LSTMCell,
nn.RNNCell: nnqd.RNNCell,
nn.GRUCell: nnqd.GRUCell,
}
# Whitelist for propagating the qconfig
_EXCLUDE_QCONFIG_PROPAGATE_LIST = {
DeQuantStub,
}
_INCLUDE_QCONFIG_PROPAGATE_LIST = {
nn.Sequential,
}
# mapping from floating point function or torch ops to quantized ops
FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = {
F.elu: torch._ops.ops.quantized.elu,
F.hardswish: torch._ops.ops.quantized.hardswish,
F.instance_norm: torch._ops.ops.quantized.instance_norm,
F.layer_norm: torch._ops.ops.quantized.layer_norm,
}
def register_static_quant_module_mapping(
float_source_module_class, static_quant_target_module_class):
''' Register a mapping from `float_source__module_class` to `static_quant_target_module_class`
`static_quant_target_module_class` must have from_float defined as a class method
The mapping is used in the convert step of post training static quantization to
convert a float module to a statically quantized module.
'''
assert hasattr(static_quant_target_module_class, 'from_float'), 'from_float must be defined' + \
' in quantized module class'
STATIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = static_quant_target_module_class
def get_static_quant_module_mappings():
''' Get module mapping for post training static quantization
'''
return STATIC_QUANT_MODULE_MAPPINGS
def get_static_quant_module_class(float_module_class):
''' Get the statically quantized module class corresponding to
the floating point module class
'''
static_quant_module_class = STATIC_QUANT_MODULE_MAPPINGS.get(float_module_class, None)
assert static_quant_module_class is not None, \
'Floating point module class {}'.format(float_module_class) + \
' does not have a corresponding quantized module class'
return static_quant_module_class
def register_qat_module_mapping(float_source_module_class, qat_target_module_class):
'''Register a mapping from `float_source_module_class` to `qat_target_module_class`,
`qat_target_module_class` must have from_float defined as a class method
This mapping is used in prepare step of quantization aware training to swap
a float module to a qat module.
'''
assert hasattr(qat_target_module_class, 'from_float'), 'from_float must be defined' + \
' in qat module class'
QAT_MODULE_MAPPINGS[float_source_module_class] = qat_target_module_class
def get_qat_module_mappings():
''' Get module mapping for quantization aware training
'''
return QAT_MODULE_MAPPINGS
def register_dynamic_quant_module_class(float_source_module_class, dynamic_quant_target_module_class):
''' Register a mapping from `float_source_module_class` to `dynamic_quant_target_module_class`,
`dynamic_quant_target_module_class` must have from_float defined as a class method
This mapping is used in convert step of post training dynamic
quantization to swap a float module to a dynamically quantized
module.
'''
assert hasattr(dynamic_quant_target_module_class, 'from_float'), 'from_float must be defined' + \
' in dynamically quantized module type'
DYNAMIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = dynamic_quant_target_module_class
def get_dynamic_quant_module_mappings():
''' Get module mapping for post training dynamic quantization
'''
return DYNAMIC_QUANT_MODULE_MAPPINGS
def get_qconfig_propagation_list():
''' Get the list of module types that we'll attach qconfig
attribute to in prepare
'''
QCONFIG_PROPAGATE_MODULE_CLASS_LIST = (
(set(STATIC_QUANT_MODULE_MAPPINGS.keys()) |
set(QAT_MODULE_MAPPINGS.keys()) |
set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) |
_INCLUDE_QCONFIG_PROPAGATE_LIST) -
_EXCLUDE_QCONFIG_PROPAGATE_LIST
)
return QCONFIG_PROPAGATE_MODULE_CLASS_LIST
def get_compare_output_module_list():
''' Get list of module class types that we will record output
in numeric suite
'''
NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = (
set(STATIC_QUANT_MODULE_MAPPINGS.values())
| set(QAT_MODULE_MAPPINGS.values())
| set(DYNAMIC_QUANT_MODULE_MAPPINGS.values())
| set(STATIC_QUANT_MODULE_MAPPINGS.keys())
| set(QAT_MODULE_MAPPINGS.keys())
| set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
| _INCLUDE_QCONFIG_PROPAGATE_LIST
) - _EXCLUDE_QCONFIG_PROPAGATE_LIST
return NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST
def register_quantized_operator_mapping(float_op, quantized_op):
''' Register a mapping from `floating_point_op` (torch or functional) to `quantized_op`
This is used in convert step of fx based graph mode quantization
to convert a float op to quantized op.
'''
FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS[float_op] = quantized_op
def get_quantized_operator(float_op):
''' Get the quantized operator corresponding to the float operator
'''
quantized_op = FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
assert quantized_op is not None, \
'Operator {} does not have corresponding quantized op'.format(float_op)
return quantized_op

View File

@ -10,10 +10,11 @@ import torch.nn.intrinsic as nni
import torch.nn.quantized as nnq import torch.nn.quantized as nnq
import torch.nn.intrinsic.qat as nniqat import torch.nn.intrinsic.qat as nniqat
from .default_mappings import (DEFAULT_DYNAMIC_MODULE_MAPPING, from .quantization_mappings import (get_dynamic_quant_module_mappings,
DEFAULT_MODULE_MAPPING, get_static_quant_module_mappings,
DEFAULT_QAT_MODULE_MAPPING, get_qat_module_mappings,
DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST) get_qconfig_propagation_list)
from .stubs import DeQuantStub, QuantWrapper from .stubs import DeQuantStub, QuantWrapper
from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_dynamic_qconfig from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_dynamic_qconfig
@ -37,7 +38,7 @@ def _propagate_qconfig_helper(module, qconfig_dict, allow_list=None,
""" """
# TODO: Add test # TODO: Add test
if allow_list is None: if allow_list is None:
allow_list = DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST allow_list = get_qconfig_propagation_list()
module_qconfig = qconfig_dict.get(type(module), qconfig_parent) module_qconfig = qconfig_dict.get(type(module), qconfig_parent)
module_qconfig = qconfig_dict.get(prefix, module_qconfig) module_qconfig = qconfig_dict.get(prefix, module_qconfig)
@ -100,7 +101,8 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No
None, module is modified inplace with added observer modules and forward_hooks None, module is modified inplace with added observer modules and forward_hooks
""" """
if qconfig_propagation_list is None: if qconfig_propagation_list is None:
qconfig_propagation_list = DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST qconfig_propagation_list = get_qconfig_propagation_list()
# respect device affinity when adding observers # respect device affinity when adding observers
if device is None: if device is None:
devices = get_unique_devices_(module) devices = get_unique_devices_(module)
@ -194,9 +196,10 @@ def prepare(model, inplace=False, allow_list=None,
""" """
if not inplace: if not inplace:
model = copy.deepcopy(model) model = copy.deepcopy(model)
propagate_qconfig_list = allow_list
if propagate_qconfig_list is None: qconfig_propagation_list = allow_list
propagate_qconfig_list = DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST if qconfig_propagation_list is None:
qconfig_propagation_list = get_qconfig_propagation_list()
propagate_qconfig_(model, qconfig_dict=None) propagate_qconfig_(model, qconfig_dict=None)
# sanity check common API misusage # sanity check common API misusage
@ -205,7 +208,7 @@ def prepare(model, inplace=False, allow_list=None,
"passed correct configuration through `qconfig_dict` or " "passed correct configuration through `qconfig_dict` or "
"by assigning the `.qconfig` attribute directly on submodules") "by assigning the `.qconfig` attribute directly on submodules")
add_observer_(model, propagate_qconfig_list, observer_non_leaf_module_list, prehook=prehook) add_observer_(model, qconfig_propagation_list, observer_non_leaf_module_list, prehook=prehook)
return model return model
def _remove_qconfig(module): def _remove_qconfig(module):
@ -239,7 +242,7 @@ def quantize(model, run_fn, run_args, mapping=None, inplace=False):
Quantized model. Quantized model.
""" """
if mapping is None: if mapping is None:
mapping = DEFAULT_MODULE_MAPPING mapping = get_static_quant_module_mappings()
if not inplace: if not inplace:
model = copy.deepcopy(model) model = copy.deepcopy(model)
model.eval() model.eval()
@ -316,7 +319,7 @@ def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8,
qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig))) qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig)))
if mapping is None: if mapping is None:
mapping = DEFAULT_DYNAMIC_MODULE_MAPPING mapping = get_dynamic_quant_module_mappings()
if not inplace: if not inplace:
model = copy.deepcopy(model) model = copy.deepcopy(model)
@ -341,7 +344,7 @@ def prepare_qat(model, mapping=None, inplace=False):
is mutated is mutated
""" """
if mapping is None: if mapping is None:
mapping = DEFAULT_QAT_MODULE_MAPPING mapping = get_qat_module_mappings()
if not inplace: if not inplace:
model = copy.deepcopy(model) model = copy.deepcopy(model)
@ -406,7 +409,7 @@ def _convert(module, mapping=None, inplace=False):
""" """
if mapping is None: if mapping is None:
mapping = DEFAULT_MODULE_MAPPING mapping = get_static_quant_module_mappings()
if not inplace: if not inplace:
module = copy.deepcopy(module) module = copy.deepcopy(module)
reassign = {} reassign = {}

View File

@ -17,10 +17,10 @@ from torch.testing._internal.common_utils import TestCase
from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \ from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \
default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \ default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \
propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_dynamic_qconfig propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_dynamic_qconfig
from torch.quantization.default_mappings import ( from torch.quantization.quantization_mappings import (
DEFAULT_DYNAMIC_MODULE_MAPPING, get_dynamic_quant_module_mappings,
DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST, get_qconfig_propagation_list,
DEFAULT_QAT_MODULE_MAPPING, get_qat_module_mappings,
) )
# symbolic trace # symbolic trace
from torch.fx import symbolic_trace from torch.fx import symbolic_trace
@ -191,7 +191,7 @@ def run_ddp(rank, world_size, prepared):
def convert_dynamic(module): def convert_dynamic(module):
convert(module, DEFAULT_DYNAMIC_MODULE_MAPPING, inplace=True) convert(module, get_dynamic_quant_module_mappings(), inplace=True)
def prepare_dynamic(model, qconfig_dict=None): def prepare_dynamic(model, qconfig_dict=None):
propagate_qconfig_(model, qconfig_dict) propagate_qconfig_(model, qconfig_dict)
@ -347,7 +347,7 @@ class QuantizationTestCase(TestCase):
have observers in preperation for quantization have observers in preperation for quantization
""" """
if propagate_qconfig_list is None: if propagate_qconfig_list is None:
propagate_qconfig_list = DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST propagate_qconfig_list = get_qconfig_propagation_list()
if hasattr(module, 'qconfig') and module.qconfig is not None and \ if hasattr(module, 'qconfig') and module.qconfig is not None and \
len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential) \ len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential) \
and type(module) in propagate_qconfig_list: and type(module) in propagate_qconfig_list:
@ -355,7 +355,7 @@ class QuantizationTestCase(TestCase):
'module: ' + str(type(module)) + ' do not have observer') 'module: ' + str(type(module)) + ' do not have observer')
# we don't need to check observers for child modules of the # we don't need to check observers for child modules of the
# qat modules # qat modules
if type(module) not in DEFAULT_QAT_MODULE_MAPPING.values(): if type(module) not in get_qat_module_mappings().values():
for child in module.children(): for child in module.children():
self.checkObservers(child) self.checkObservers(child)