[quant][eagermode][refactor] Add set/get method for quantization and fusion mappings (#43990)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43990

Allow user to register custom quantization and fusion patterns

Test Plan: Imported from OSS

Reviewed By: z-a-f

Differential Revision: D23485344

fbshipit-source-id: 4f0174ee6d8000d83de0f73cb370e9a1941d54aa
This commit is contained in:
Jerry Zhang 2020-09-10 21:13:49 -07:00 committed by Facebook GitHub Bot
parent f7278473d3
commit 0c58a017bd
11 changed files with 357 additions and 150 deletions

View File

@ -56,9 +56,6 @@ ignore_errors = True
[mypy-torch.testing._internal.*]
ignore_errors = True
[mypy-torch.quantization.default_mappings]
ignore_errors = True
[mypy-torch.quantization.observer]
ignore_errors = True
@ -74,6 +71,7 @@ ignore_errors = True
[mypy-torch.quantization._numeric_suite]
ignore_errors = True
[mypy-torch.quantization.quantize_fx]
ignore_errors = True

View File

@ -8,6 +8,8 @@ from .stubs import *
from .quant_type import *
from .quantize_jit import *
from .quantize_fx import *
from .quantization_mappings import *
from .fuser_method_mappings import *
def default_eval_fn(model, calib_data):
r"""
@ -28,6 +30,17 @@ _all__ = [
'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx
'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',
'QuantType', # quantization type
# custom module APIs
'register_static_quant_module_mapping',
'get_static_quant_module_mappings', 'get_static_quant_module_class',
'register_dynamic_quant_module_mapping',
'get_dynamic_quant_module_mappings',
'register_qat_module_mapping',
'get_qat_module_mappings',
'get_qconfig_propagation_list',
'get_compare_output_module_list',
'register_quantized_operator_mapping', 'get_quantized_operator',
'register_fuser_method', 'get_fuser_method',
# Sub functions for `prepare` and `swap_module`
'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module',
'default_eval_fn', 'get_observer_dict',

View File

@ -6,25 +6,10 @@ import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
from torch.quantization import prepare
from .default_mappings import (
_EXCLUDE_QCONFIG_PROPAGATE_LIST,
_INCLUDE_QCONFIG_PROPAGATE_LIST,
DEFAULT_DYNAMIC_MODULE_MAPPING,
DEFAULT_MODULE_MAPPING,
DEFAULT_QAT_MODULE_MAPPING,
from .quantization_mappings import (
get_compare_output_module_list,
)
DEFAULT_NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_ALLOWED_LIST = (
set(DEFAULT_MODULE_MAPPING.values())
| set(DEFAULT_QAT_MODULE_MAPPING.values())
| set(DEFAULT_DYNAMIC_MODULE_MAPPING.values())
| set(DEFAULT_MODULE_MAPPING.keys())
| set(DEFAULT_QAT_MODULE_MAPPING.keys())
| set(DEFAULT_DYNAMIC_MODULE_MAPPING.keys())
| _INCLUDE_QCONFIG_PROPAGATE_LIST
) - _EXCLUDE_QCONFIG_PROPAGATE_LIST
NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
nnqd.Linear,
nnq.Linear,
@ -409,7 +394,7 @@ def prepare_model_outputs(
float_module,
q_module,
Logger=OutputLogger,
allow_list=DEFAULT_NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_ALLOWED_LIST,
allow_list=None
):
r"""Prepare the model by attaching the logger to both float module
and quantized module if they are in the allow_list.
@ -420,6 +405,9 @@ def prepare_model_outputs(
Logger: type of logger to be attached to float_module and q_module
allow_list: list of module types to attach logger
"""
if allow_list is None:
allow_list = get_compare_output_module_list()
qconfig_debug = torch.quantization.QConfig(activation=Logger, weight=None)
float_module.qconfig = qconfig_debug
prepare(float_module, inplace=True, allow_list=allow_list)
@ -437,7 +425,7 @@ def compare_model_outputs(
q_model,
*data,
Logger=OutputLogger,
allow_list=DEFAULT_NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_ALLOWED_LIST,
allow_list=None
):
r"""Compare output activations between float and quantized models at
corresponding locations for the same input. Return a dict with key corresponding
@ -463,6 +451,8 @@ def compare_model_outputs(
and each entry being a dictionary with two keys 'float' and 'quantized',
containing the matching float and quantized activations
"""
if allow_list is None:
allow_list = get_compare_output_module_list()
prepare_model_outputs(float_model, q_model, Logger, allow_list)
float_model(*data)
q_model(*data)

View File

@ -2,95 +2,14 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import copy
import torch.nn.intrinsic.modules.fused as torch_fused
import torch.nn as nn
import torch.nn.intrinsic as nni
from typing import Type, List, Optional, Union, Callable, Tuple, Dict
from .fuser_method_mappings import get_fuser_method
# for backward compatiblity
from .fuser_method_mappings import fuse_conv_bn # noqa: F401
from .fuser_method_mappings import fuse_conv_bn_relu # noqa: F40
def fuse_conv_bn(conv, bn):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert(conv.training == bn.training),\
"Conv and BN both must be in the same mode (train or eval)."
is_3d = isinstance(conv, nn.Conv3d)
if conv.training:
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
return nni.ConvBn3d(conv, bn) if is_3d \
else nni.ConvBn2d(conv, bn)
else:
return nn.utils.fuse_conv_bn_eval(conv, bn)
def fuse_conv_bn_relu(conv, bn, relu):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert(conv.training == bn.training == relu.training),\
"Conv and BN both must be in the same mode (train or eval)."
fused_module : Optional[Type[nn.Sequential]] = None
if conv.training:
map_to_fused_module_train = {
nn.Conv2d: torch_fused.ConvBnReLU2d,
nn.Conv3d: torch_fused.ConvBnReLU3d,
}
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
fused_module = map_to_fused_module_train.get(type(conv))
if fused_module is not None:
return fused_module(conv, bn, relu)
else:
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
else:
map_to_fused_module_eval = {
nn.Conv1d: torch_fused.ConvReLU1d,
nn.Conv2d: torch_fused.ConvReLU2d,
nn.Conv3d: torch_fused.ConvReLU3d,
}
fused_module = map_to_fused_module_eval[type(conv)]
if fused_module is not None:
fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
return fused_module(fused_conv, relu)
else:
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
(nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
(nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
(nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv1d, nn.ReLU): nni.ConvReLU1d,
(nn.Conv2d, nn.ReLU): nni.ConvReLU2d,
(nn.Conv3d, nn.ReLU): nni.ConvReLU3d,
(nn.Linear, nn.ReLU): nni.LinearReLU,
(nn.BatchNorm2d, nn.ReLU): nni.BNReLU2d,
(nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d,
}
from typing import List, Optional
# Generalization of getattr
def _get_module(model, submodule_key):
@ -123,7 +42,7 @@ def fuse_known_modules(mod_list):
the fused operation. The rest of the elements are set to nn.Identity()
"""
types = tuple(type(m) for m in mod_list)
fuser_method = OP_LIST_TO_FUSER_METHOD.get(types)
fuser_method = get_fuser_method(types)
if fuser_method is None:
raise NotImplementedError("Cannot fuse modules: {}".format(types))
new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)

View File

@ -0,0 +1,101 @@
import torch.nn as nn
import torch.nn.intrinsic as nni
from typing import Union, Callable, Tuple, Dict, Optional, Type
def fuse_conv_bn(conv, bn):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert(conv.training == bn.training),\
"Conv and BN both must be in the same mode (train or eval)."
is_3d = isinstance(conv, nn.Conv3d)
if conv.training:
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
return nni.ConvBn3d(conv, bn) if is_3d \
else nni.ConvBn2d(conv, bn)
else:
return nn.utils.fuse_conv_bn_eval(conv, bn)
def fuse_conv_bn_relu(conv, bn, relu):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert(conv.training == bn.training == relu.training),\
"Conv and BN both must be in the same mode (train or eval)."
fused_module : Optional[Type[nn.Sequential]] = None
if conv.training:
map_to_fused_module_train = {
nn.Conv2d: nni.ConvBnReLU2d,
nn.Conv3d: nni.ConvBnReLU3d,
}
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
fused_module = map_to_fused_module_train.get(type(conv))
if fused_module is not None:
return fused_module(conv, bn, relu)
else:
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
else:
map_to_fused_module_eval = {
nn.Conv1d: nni.ConvReLU1d,
nn.Conv2d: nni.ConvReLU2d,
nn.Conv3d: nni.ConvReLU3d,
}
fused_module = map_to_fused_module_eval[type(conv)]
if fused_module is not None:
fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
return fused_module(fused_conv, relu)
else:
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
(nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
(nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
(nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv1d, nn.ReLU): nni.ConvReLU1d,
(nn.Conv2d, nn.ReLU): nni.ConvReLU2d,
(nn.Conv3d, nn.ReLU): nni.ConvReLU3d,
(nn.Linear, nn.ReLU): nni.LinearReLU,
(nn.BatchNorm2d, nn.ReLU): nni.BNReLU2d,
(nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d,
}
def register_fuser_method(op_list, fuser_method):
''' Register a fuser method for a tuple of ops, will be called
during fusion step
'''
assert isinstance(op_list, tuple), 'op list must be a tuple'
OP_LIST_TO_FUSER_METHOD[op_list] = fuser_method
def get_fuser_method(op_list):
''' Get fuser method for the given list of module types,
return None if fuser method does not exist
'''
return OP_LIST_TO_FUSER_METHOD.get(op_list, None)

View File

@ -3,7 +3,7 @@ from .pattern_utils import (
register_fusion_pattern,
)
from .utils import _parent_name
from ..fuse_modules import OP_LIST_TO_FUSER_METHOD
from ..fuser_method_mappings import get_fuser_method
# ---------------------
# Fusion Patterns
@ -60,7 +60,7 @@ class ConvBNReLUFusion():
op_list.reverse()
op_type_list = tuple(type(m) for m in op_list)
conv_parent_name, conv_name = _parent_name(self.conv_node.target)
fuser_method = OP_LIST_TO_FUSER_METHOD.get(op_type_list, None)
fuser_method = get_fuser_method(op_type_list)
if fuser_method is None:
raise NotImplementedError("Cannot fuse modules: {}".format(types))
setattr(quantizer.modules[conv_parent_name], conv_name, fuser_method(*op_list))
@ -104,8 +104,6 @@ class ModuleReLUFusion():
op_list.reverse()
op_type_list = tuple(type(m) for m in op_list)
module_parent_name, module_name = _parent_name(self.module_node.target)
fuser_method = OP_LIST_TO_FUSER_METHOD.get(op_type_list, None)
if fuser_method is None:
raise NotImplementedError("Cannot fuse modules: {}".format(types))
fuser_method = get_fuser_method(op_type_list)
setattr(quantizer.modules[module_parent_name], module_name, fuser_method(*op_list))
return quantizer.fused_graph.node_copy(self.module_node, load_arg)

View File

@ -1,11 +1,11 @@
import torch
from torch.quantization.default_mappings import (
DEFAULT_MODULE_MAPPING,
DEFAULT_OPERATOR_MAPPING,
)
from torch.fx.graph import (
Node,
)
from ..quantization_mappings import (
get_static_quant_module_class,
get_quantized_operator,
)
from .pattern_utils import (
register_quant_pattern,
register_dynamic_quant_pattern,
@ -181,10 +181,7 @@ class ConvRelu(QuantizeHandler):
else:
self.conv.activation_post_process = quantizer.activation_post_process_map[node.name]
# 2. select quantized class
# TODO: make the mapping configurable?
assert type(self.conv) in DEFAULT_MODULE_MAPPING, \
'unhandled conv type:{}'.format(type(self.conv))
qconv_cls = DEFAULT_MODULE_MAPPING[type(self.conv)]
qconv_cls = get_static_quant_module_class(type(self.conv))
quantized = qconv_cls.from_float(self.conv)
parent_name, name = _parent_name(self.conv_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
@ -335,7 +332,7 @@ class BatchNorm(QuantizeHandler):
self.bn[1].activation_post_process = activation_post_process
else:
self.bn.activation_post_process = activation_post_process
qbn_cls = DEFAULT_MODULE_MAPPING[type(self.bn)]
qbn_cls = get_static_quant_module_class(type(self.bn))
quantized = qbn_cls.from_float(self.bn)
parent_name, name = _parent_name(self.bn_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
@ -371,7 +368,8 @@ class DefaultNode(QuantizeHandler):
if node.op == 'call_module':
module = quantizer.modules[node.target]
module.activation_post_process = activation_post_process
quantized_module = DEFAULT_MODULE_MAPPING[type(module)].from_float(module)
quantized_module_cls = get_static_quant_module_class(type(module))
quantized_module = quantized_module_cls.from_float(module)
parent_name, name = _parent_name(node.target)
setattr(quantizer.modules[parent_name], name, quantized_module)
return quantizer.quantized_graph.create_node(
@ -385,7 +383,7 @@ class DefaultNode(QuantizeHandler):
scale = float(scale)
zero_point = int(zero_point)
quantized_op = DEFAULT_OPERATOR_MAPPING[node.target]
quantized_op = get_quantized_operator(node.target)
args = load_arg(quantized=[0])(node.args)
kwargs = load_arg(quantized=False)(node.kwargs)
kwargs.update({'output_scale': scale, 'output_zero_point': zero_point})
@ -405,7 +403,7 @@ class ELU(QuantizeHandler):
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
quantized_op = DEFAULT_OPERATOR_MAPPING[node.target]
quantized_op = get_quantized_operator(node.target)
args = load_arg(quantized=[0])(node.args)
kwargs = load_arg(quantized=False)(node.kwargs)
kwargs.update({'output_scale': scale, 'output_zero_point': zero_point})

View File

@ -4,8 +4,8 @@ from torch.quantization import (
convert,
)
from torch.quantization.default_mappings import (
DEFAULT_QAT_MODULE_MAPPING,
from ..quantization_mappings import (
get_qat_module_mappings,
)
from torch.fx import (
@ -163,7 +163,7 @@ class Quantizer:
def _qat_swap_modules(self, root):
convert(root, mapping=DEFAULT_QAT_MODULE_MAPPING, inplace=True, remove_qconfig=False)
convert(root, mapping=get_qat_module_mappings(), inplace=True, remove_qconfig=False)
def _generate_qconfig_map(self, root, input_graph):
def get_qconfig(module):

View File

@ -0,0 +1,187 @@
import torch
from torch import nn
import torch.nn.functional as F
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.qat as nniqat
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
import torch.nn.qat as nnqat
from .stubs import QuantStub, DeQuantStub
# Map for swapping float module to quantized ones
STATIC_QUANT_MODULE_MAPPINGS = {
nn.Linear: nnq.Linear,
nn.ReLU: nnq.ReLU,
nn.ReLU6: nnq.ReLU6,
nn.Hardswish: nnq.Hardswish,
nn.ELU: nnq.ELU,
nn.Conv1d: nnq.Conv1d,
nn.Conv2d: nnq.Conv2d,
nn.Conv3d: nnq.Conv3d,
nn.BatchNorm2d: nnq.BatchNorm2d,
nn.BatchNorm3d: nnq.BatchNorm3d,
nn.LayerNorm: nnq.LayerNorm,
nn.GroupNorm: nnq.GroupNorm,
nn.InstanceNorm1d: nnq.InstanceNorm1d,
nn.InstanceNorm2d: nnq.InstanceNorm2d,
nn.InstanceNorm3d: nnq.InstanceNorm3d,
nn.Embedding: nnq.Embedding,
nn.EmbeddingBag: nnq.EmbeddingBag,
QuantStub: nnq.Quantize,
DeQuantStub: nnq.DeQuantize,
# Wrapper Modules:
nnq.FloatFunctional: nnq.QFunctional,
# Intrinsic modules:
nni.ConvReLU1d: nniq.ConvReLU1d,
nni.ConvReLU2d: nniq.ConvReLU2d,
nni.ConvReLU3d: nniq.ConvReLU3d,
nni.LinearReLU: nniq.LinearReLU,
nni.BNReLU2d: nniq.BNReLU2d,
nni.BNReLU3d: nniq.BNReLU3d,
nniqat.ConvReLU2d: nniq.ConvReLU2d,
nniqat.LinearReLU: nniq.LinearReLU,
nniqat.ConvBn2d: nnq.Conv2d,
nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
# QAT modules:
nnqat.Linear: nnq.Linear,
nnqat.Conv2d: nnq.Conv2d,
}
# Map for swapping float module to qat modules
QAT_MODULE_MAPPINGS = {
nn.Linear: nnqat.Linear,
nn.Conv2d: nnqat.Conv2d,
# Intrinsic modules:
nni.ConvBn2d: nniqat.ConvBn2d,
nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
nni.ConvReLU2d: nniqat.ConvReLU2d,
nni.LinearReLU: nniqat.LinearReLU
}
# Map for swapping dynamic modules
DYNAMIC_QUANT_MODULE_MAPPINGS = {
nn.Linear: nnqd.Linear,
nn.LSTM: nnqd.LSTM,
nn.LSTMCell: nnqd.LSTMCell,
nn.RNNCell: nnqd.RNNCell,
nn.GRUCell: nnqd.GRUCell,
}
# Whitelist for propagating the qconfig
_EXCLUDE_QCONFIG_PROPAGATE_LIST = {
DeQuantStub,
}
_INCLUDE_QCONFIG_PROPAGATE_LIST = {
nn.Sequential,
}
# mapping from floating point function or torch ops to quantized ops
FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = {
F.elu: torch._ops.ops.quantized.elu,
F.hardswish: torch._ops.ops.quantized.hardswish,
F.instance_norm: torch._ops.ops.quantized.instance_norm,
F.layer_norm: torch._ops.ops.quantized.layer_norm,
}
def register_static_quant_module_mapping(
float_source_module_class, static_quant_target_module_class):
''' Register a mapping from `float_source__module_class` to `static_quant_target_module_class`
`static_quant_target_module_class` must have from_float defined as a class method
The mapping is used in the convert step of post training static quantization to
convert a float module to a statically quantized module.
'''
assert hasattr(static_quant_target_module_class, 'from_float'), 'from_float must be defined' + \
' in quantized module class'
STATIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = static_quant_target_module_class
def get_static_quant_module_mappings():
''' Get module mapping for post training static quantization
'''
return STATIC_QUANT_MODULE_MAPPINGS
def get_static_quant_module_class(float_module_class):
''' Get the statically quantized module class corresponding to
the floating point module class
'''
static_quant_module_class = STATIC_QUANT_MODULE_MAPPINGS.get(float_module_class, None)
assert static_quant_module_class is not None, \
'Floating point module class {}'.format(float_module_class) + \
' does not have a corresponding quantized module class'
return static_quant_module_class
def register_qat_module_mapping(float_source_module_class, qat_target_module_class):
'''Register a mapping from `float_source_module_class` to `qat_target_module_class`,
`qat_target_module_class` must have from_float defined as a class method
This mapping is used in prepare step of quantization aware training to swap
a float module to a qat module.
'''
assert hasattr(qat_target_module_class, 'from_float'), 'from_float must be defined' + \
' in qat module class'
QAT_MODULE_MAPPINGS[float_source_module_class] = qat_target_module_class
def get_qat_module_mappings():
''' Get module mapping for quantization aware training
'''
return QAT_MODULE_MAPPINGS
def register_dynamic_quant_module_class(float_source_module_class, dynamic_quant_target_module_class):
''' Register a mapping from `float_source_module_class` to `dynamic_quant_target_module_class`,
`dynamic_quant_target_module_class` must have from_float defined as a class method
This mapping is used in convert step of post training dynamic
quantization to swap a float module to a dynamically quantized
module.
'''
assert hasattr(dynamic_quant_target_module_class, 'from_float'), 'from_float must be defined' + \
' in dynamically quantized module type'
DYNAMIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = dynamic_quant_target_module_class
def get_dynamic_quant_module_mappings():
''' Get module mapping for post training dynamic quantization
'''
return DYNAMIC_QUANT_MODULE_MAPPINGS
def get_qconfig_propagation_list():
''' Get the list of module types that we'll attach qconfig
attribute to in prepare
'''
QCONFIG_PROPAGATE_MODULE_CLASS_LIST = (
(set(STATIC_QUANT_MODULE_MAPPINGS.keys()) |
set(QAT_MODULE_MAPPINGS.keys()) |
set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) |
_INCLUDE_QCONFIG_PROPAGATE_LIST) -
_EXCLUDE_QCONFIG_PROPAGATE_LIST
)
return QCONFIG_PROPAGATE_MODULE_CLASS_LIST
def get_compare_output_module_list():
''' Get list of module class types that we will record output
in numeric suite
'''
NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = (
set(STATIC_QUANT_MODULE_MAPPINGS.values())
| set(QAT_MODULE_MAPPINGS.values())
| set(DYNAMIC_QUANT_MODULE_MAPPINGS.values())
| set(STATIC_QUANT_MODULE_MAPPINGS.keys())
| set(QAT_MODULE_MAPPINGS.keys())
| set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
| _INCLUDE_QCONFIG_PROPAGATE_LIST
) - _EXCLUDE_QCONFIG_PROPAGATE_LIST
return NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST
def register_quantized_operator_mapping(float_op, quantized_op):
''' Register a mapping from `floating_point_op` (torch or functional) to `quantized_op`
This is used in convert step of fx based graph mode quantization
to convert a float op to quantized op.
'''
FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS[float_op] = quantized_op
def get_quantized_operator(float_op):
''' Get the quantized operator corresponding to the float operator
'''
quantized_op = FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
assert quantized_op is not None, \
'Operator {} does not have corresponding quantized op'.format(float_op)
return quantized_op

View File

@ -10,10 +10,11 @@ import torch.nn.intrinsic as nni
import torch.nn.quantized as nnq
import torch.nn.intrinsic.qat as nniqat
from .default_mappings import (DEFAULT_DYNAMIC_MODULE_MAPPING,
DEFAULT_MODULE_MAPPING,
DEFAULT_QAT_MODULE_MAPPING,
DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST)
from .quantization_mappings import (get_dynamic_quant_module_mappings,
get_static_quant_module_mappings,
get_qat_module_mappings,
get_qconfig_propagation_list)
from .stubs import DeQuantStub, QuantWrapper
from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_dynamic_qconfig
@ -37,7 +38,7 @@ def _propagate_qconfig_helper(module, qconfig_dict, allow_list=None,
"""
# TODO: Add test
if allow_list is None:
allow_list = DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST
allow_list = get_qconfig_propagation_list()
module_qconfig = qconfig_dict.get(type(module), qconfig_parent)
module_qconfig = qconfig_dict.get(prefix, module_qconfig)
@ -100,7 +101,8 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No
None, module is modified inplace with added observer modules and forward_hooks
"""
if qconfig_propagation_list is None:
qconfig_propagation_list = DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST
qconfig_propagation_list = get_qconfig_propagation_list()
# respect device affinity when adding observers
if device is None:
devices = get_unique_devices_(module)
@ -194,9 +196,10 @@ def prepare(model, inplace=False, allow_list=None,
"""
if not inplace:
model = copy.deepcopy(model)
propagate_qconfig_list = allow_list
if propagate_qconfig_list is None:
propagate_qconfig_list = DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST
qconfig_propagation_list = allow_list
if qconfig_propagation_list is None:
qconfig_propagation_list = get_qconfig_propagation_list()
propagate_qconfig_(model, qconfig_dict=None)
# sanity check common API misusage
@ -205,7 +208,7 @@ def prepare(model, inplace=False, allow_list=None,
"passed correct configuration through `qconfig_dict` or "
"by assigning the `.qconfig` attribute directly on submodules")
add_observer_(model, propagate_qconfig_list, observer_non_leaf_module_list, prehook=prehook)
add_observer_(model, qconfig_propagation_list, observer_non_leaf_module_list, prehook=prehook)
return model
def _remove_qconfig(module):
@ -239,7 +242,7 @@ def quantize(model, run_fn, run_args, mapping=None, inplace=False):
Quantized model.
"""
if mapping is None:
mapping = DEFAULT_MODULE_MAPPING
mapping = get_static_quant_module_mappings()
if not inplace:
model = copy.deepcopy(model)
model.eval()
@ -316,7 +319,7 @@ def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8,
qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig)))
if mapping is None:
mapping = DEFAULT_DYNAMIC_MODULE_MAPPING
mapping = get_dynamic_quant_module_mappings()
if not inplace:
model = copy.deepcopy(model)
@ -341,7 +344,7 @@ def prepare_qat(model, mapping=None, inplace=False):
is mutated
"""
if mapping is None:
mapping = DEFAULT_QAT_MODULE_MAPPING
mapping = get_qat_module_mappings()
if not inplace:
model = copy.deepcopy(model)
@ -406,7 +409,7 @@ def _convert(module, mapping=None, inplace=False):
"""
if mapping is None:
mapping = DEFAULT_MODULE_MAPPING
mapping = get_static_quant_module_mappings()
if not inplace:
module = copy.deepcopy(module)
reassign = {}

View File

@ -17,10 +17,10 @@ from torch.testing._internal.common_utils import TestCase
from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \
default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \
propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_dynamic_qconfig
from torch.quantization.default_mappings import (
DEFAULT_DYNAMIC_MODULE_MAPPING,
DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST,
DEFAULT_QAT_MODULE_MAPPING,
from torch.quantization.quantization_mappings import (
get_dynamic_quant_module_mappings,
get_qconfig_propagation_list,
get_qat_module_mappings,
)
# symbolic trace
from torch.fx import symbolic_trace
@ -191,7 +191,7 @@ def run_ddp(rank, world_size, prepared):
def convert_dynamic(module):
convert(module, DEFAULT_DYNAMIC_MODULE_MAPPING, inplace=True)
convert(module, get_dynamic_quant_module_mappings(), inplace=True)
def prepare_dynamic(model, qconfig_dict=None):
propagate_qconfig_(model, qconfig_dict)
@ -347,7 +347,7 @@ class QuantizationTestCase(TestCase):
have observers in preperation for quantization
"""
if propagate_qconfig_list is None:
propagate_qconfig_list = DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST
propagate_qconfig_list = get_qconfig_propagation_list()
if hasattr(module, 'qconfig') and module.qconfig is not None and \
len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential) \
and type(module) in propagate_qconfig_list:
@ -355,7 +355,7 @@ class QuantizationTestCase(TestCase):
'module: ' + str(type(module)) + ' do not have observer')
# we don't need to check observers for child modules of the
# qat modules
if type(module) not in DEFAULT_QAT_MODULE_MAPPING.values():
if type(module) not in get_qat_module_mappings().values():
for child in module.children():
self.checkObservers(child)