mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][eagermode][refactor] Add set/get method for quantization and fusion mappings (#43990)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43990 Allow user to register custom quantization and fusion patterns Test Plan: Imported from OSS Reviewed By: z-a-f Differential Revision: D23485344 fbshipit-source-id: 4f0174ee6d8000d83de0f73cb370e9a1941d54aa
This commit is contained in:
parent
f7278473d3
commit
0c58a017bd
4
mypy.ini
4
mypy.ini
|
|
@ -56,9 +56,6 @@ ignore_errors = True
|
|||
[mypy-torch.testing._internal.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.quantization.default_mappings]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.quantization.observer]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
@ -74,6 +71,7 @@ ignore_errors = True
|
|||
[mypy-torch.quantization._numeric_suite]
|
||||
ignore_errors = True
|
||||
|
||||
|
||||
[mypy-torch.quantization.quantize_fx]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ from .stubs import *
|
|||
from .quant_type import *
|
||||
from .quantize_jit import *
|
||||
from .quantize_fx import *
|
||||
from .quantization_mappings import *
|
||||
from .fuser_method_mappings import *
|
||||
|
||||
def default_eval_fn(model, calib_data):
|
||||
r"""
|
||||
|
|
@ -28,6 +30,17 @@ _all__ = [
|
|||
'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx
|
||||
'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',
|
||||
'QuantType', # quantization type
|
||||
# custom module APIs
|
||||
'register_static_quant_module_mapping',
|
||||
'get_static_quant_module_mappings', 'get_static_quant_module_class',
|
||||
'register_dynamic_quant_module_mapping',
|
||||
'get_dynamic_quant_module_mappings',
|
||||
'register_qat_module_mapping',
|
||||
'get_qat_module_mappings',
|
||||
'get_qconfig_propagation_list',
|
||||
'get_compare_output_module_list',
|
||||
'register_quantized_operator_mapping', 'get_quantized_operator',
|
||||
'register_fuser_method', 'get_fuser_method',
|
||||
# Sub functions for `prepare` and `swap_module`
|
||||
'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module',
|
||||
'default_eval_fn', 'get_observer_dict',
|
||||
|
|
|
|||
|
|
@ -6,25 +6,10 @@ import torch.nn.quantized as nnq
|
|||
import torch.nn.quantized.dynamic as nnqd
|
||||
from torch.quantization import prepare
|
||||
|
||||
from .default_mappings import (
|
||||
_EXCLUDE_QCONFIG_PROPAGATE_LIST,
|
||||
_INCLUDE_QCONFIG_PROPAGATE_LIST,
|
||||
DEFAULT_DYNAMIC_MODULE_MAPPING,
|
||||
DEFAULT_MODULE_MAPPING,
|
||||
DEFAULT_QAT_MODULE_MAPPING,
|
||||
from .quantization_mappings import (
|
||||
get_compare_output_module_list,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_ALLOWED_LIST = (
|
||||
set(DEFAULT_MODULE_MAPPING.values())
|
||||
| set(DEFAULT_QAT_MODULE_MAPPING.values())
|
||||
| set(DEFAULT_DYNAMIC_MODULE_MAPPING.values())
|
||||
| set(DEFAULT_MODULE_MAPPING.keys())
|
||||
| set(DEFAULT_QAT_MODULE_MAPPING.keys())
|
||||
| set(DEFAULT_DYNAMIC_MODULE_MAPPING.keys())
|
||||
| _INCLUDE_QCONFIG_PROPAGATE_LIST
|
||||
) - _EXCLUDE_QCONFIG_PROPAGATE_LIST
|
||||
|
||||
NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
|
||||
nnqd.Linear,
|
||||
nnq.Linear,
|
||||
|
|
@ -409,7 +394,7 @@ def prepare_model_outputs(
|
|||
float_module,
|
||||
q_module,
|
||||
Logger=OutputLogger,
|
||||
allow_list=DEFAULT_NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_ALLOWED_LIST,
|
||||
allow_list=None
|
||||
):
|
||||
r"""Prepare the model by attaching the logger to both float module
|
||||
and quantized module if they are in the allow_list.
|
||||
|
|
@ -420,6 +405,9 @@ def prepare_model_outputs(
|
|||
Logger: type of logger to be attached to float_module and q_module
|
||||
allow_list: list of module types to attach logger
|
||||
"""
|
||||
if allow_list is None:
|
||||
allow_list = get_compare_output_module_list()
|
||||
|
||||
qconfig_debug = torch.quantization.QConfig(activation=Logger, weight=None)
|
||||
float_module.qconfig = qconfig_debug
|
||||
prepare(float_module, inplace=True, allow_list=allow_list)
|
||||
|
|
@ -437,7 +425,7 @@ def compare_model_outputs(
|
|||
q_model,
|
||||
*data,
|
||||
Logger=OutputLogger,
|
||||
allow_list=DEFAULT_NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_ALLOWED_LIST,
|
||||
allow_list=None
|
||||
):
|
||||
r"""Compare output activations between float and quantized models at
|
||||
corresponding locations for the same input. Return a dict with key corresponding
|
||||
|
|
@ -463,6 +451,8 @@ def compare_model_outputs(
|
|||
and each entry being a dictionary with two keys 'float' and 'quantized',
|
||||
containing the matching float and quantized activations
|
||||
"""
|
||||
if allow_list is None:
|
||||
allow_list = get_compare_output_module_list()
|
||||
prepare_model_outputs(float_model, q_model, Logger, allow_list)
|
||||
float_model(*data)
|
||||
q_model(*data)
|
||||
|
|
|
|||
|
|
@ -2,95 +2,14 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||
|
||||
import copy
|
||||
|
||||
import torch.nn.intrinsic.modules.fused as torch_fused
|
||||
import torch.nn as nn
|
||||
import torch.nn.intrinsic as nni
|
||||
|
||||
from typing import Type, List, Optional, Union, Callable, Tuple, Dict
|
||||
from .fuser_method_mappings import get_fuser_method
|
||||
# for backward compatiblity
|
||||
from .fuser_method_mappings import fuse_conv_bn # noqa: F401
|
||||
from .fuser_method_mappings import fuse_conv_bn_relu # noqa: F40
|
||||
|
||||
def fuse_conv_bn(conv, bn):
|
||||
r"""Given the conv and bn modules, fuses them and returns the fused module
|
||||
|
||||
Args:
|
||||
conv: Module instance of type conv2d/conv3d
|
||||
bn: Spatial BN instance that needs to be fused with the conv
|
||||
|
||||
Examples::
|
||||
|
||||
>>> m1 = nn.Conv2d(10, 20, 3)
|
||||
>>> b1 = nn.BatchNorm2d(20)
|
||||
>>> m2 = fuse_conv_bn(m1, b1)
|
||||
"""
|
||||
assert(conv.training == bn.training),\
|
||||
"Conv and BN both must be in the same mode (train or eval)."
|
||||
|
||||
is_3d = isinstance(conv, nn.Conv3d)
|
||||
|
||||
if conv.training:
|
||||
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
|
||||
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
|
||||
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
|
||||
return nni.ConvBn3d(conv, bn) if is_3d \
|
||||
else nni.ConvBn2d(conv, bn)
|
||||
else:
|
||||
return nn.utils.fuse_conv_bn_eval(conv, bn)
|
||||
|
||||
def fuse_conv_bn_relu(conv, bn, relu):
|
||||
r"""Given the conv and bn modules, fuses them and returns the fused module
|
||||
|
||||
Args:
|
||||
conv: Module instance of type conv2d/conv3d
|
||||
bn: Spatial BN instance that needs to be fused with the conv
|
||||
|
||||
Examples::
|
||||
|
||||
>>> m1 = nn.Conv2d(10, 20, 3)
|
||||
>>> b1 = nn.BatchNorm2d(20)
|
||||
>>> m2 = fuse_conv_bn(m1, b1)
|
||||
"""
|
||||
assert(conv.training == bn.training == relu.training),\
|
||||
"Conv and BN both must be in the same mode (train or eval)."
|
||||
fused_module : Optional[Type[nn.Sequential]] = None
|
||||
if conv.training:
|
||||
map_to_fused_module_train = {
|
||||
nn.Conv2d: torch_fused.ConvBnReLU2d,
|
||||
nn.Conv3d: torch_fused.ConvBnReLU3d,
|
||||
}
|
||||
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
|
||||
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
|
||||
assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
|
||||
fused_module = map_to_fused_module_train.get(type(conv))
|
||||
if fused_module is not None:
|
||||
return fused_module(conv, bn, relu)
|
||||
else:
|
||||
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
|
||||
else:
|
||||
map_to_fused_module_eval = {
|
||||
nn.Conv1d: torch_fused.ConvReLU1d,
|
||||
nn.Conv2d: torch_fused.ConvReLU2d,
|
||||
nn.Conv3d: torch_fused.ConvReLU3d,
|
||||
}
|
||||
fused_module = map_to_fused_module_eval[type(conv)]
|
||||
if fused_module is not None:
|
||||
fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
|
||||
return fused_module(fused_conv, relu)
|
||||
else:
|
||||
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
|
||||
|
||||
OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
|
||||
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
|
||||
(nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
|
||||
(nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
|
||||
(nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
|
||||
(nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
|
||||
(nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
|
||||
(nn.Conv1d, nn.ReLU): nni.ConvReLU1d,
|
||||
(nn.Conv2d, nn.ReLU): nni.ConvReLU2d,
|
||||
(nn.Conv3d, nn.ReLU): nni.ConvReLU3d,
|
||||
(nn.Linear, nn.ReLU): nni.LinearReLU,
|
||||
(nn.BatchNorm2d, nn.ReLU): nni.BNReLU2d,
|
||||
(nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d,
|
||||
}
|
||||
from typing import List, Optional
|
||||
|
||||
# Generalization of getattr
|
||||
def _get_module(model, submodule_key):
|
||||
|
|
@ -123,7 +42,7 @@ def fuse_known_modules(mod_list):
|
|||
the fused operation. The rest of the elements are set to nn.Identity()
|
||||
"""
|
||||
types = tuple(type(m) for m in mod_list)
|
||||
fuser_method = OP_LIST_TO_FUSER_METHOD.get(types)
|
||||
fuser_method = get_fuser_method(types)
|
||||
if fuser_method is None:
|
||||
raise NotImplementedError("Cannot fuse modules: {}".format(types))
|
||||
new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
|
||||
|
|
|
|||
101
torch/quantization/fuser_method_mappings.py
Normal file
101
torch/quantization/fuser_method_mappings.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.intrinsic as nni
|
||||
|
||||
from typing import Union, Callable, Tuple, Dict, Optional, Type
|
||||
|
||||
def fuse_conv_bn(conv, bn):
|
||||
r"""Given the conv and bn modules, fuses them and returns the fused module
|
||||
|
||||
Args:
|
||||
conv: Module instance of type conv2d/conv3d
|
||||
bn: Spatial BN instance that needs to be fused with the conv
|
||||
|
||||
Examples::
|
||||
|
||||
>>> m1 = nn.Conv2d(10, 20, 3)
|
||||
>>> b1 = nn.BatchNorm2d(20)
|
||||
>>> m2 = fuse_conv_bn(m1, b1)
|
||||
"""
|
||||
assert(conv.training == bn.training),\
|
||||
"Conv and BN both must be in the same mode (train or eval)."
|
||||
|
||||
is_3d = isinstance(conv, nn.Conv3d)
|
||||
|
||||
if conv.training:
|
||||
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
|
||||
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
|
||||
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
|
||||
return nni.ConvBn3d(conv, bn) if is_3d \
|
||||
else nni.ConvBn2d(conv, bn)
|
||||
else:
|
||||
return nn.utils.fuse_conv_bn_eval(conv, bn)
|
||||
|
||||
def fuse_conv_bn_relu(conv, bn, relu):
|
||||
r"""Given the conv and bn modules, fuses them and returns the fused module
|
||||
|
||||
Args:
|
||||
conv: Module instance of type conv2d/conv3d
|
||||
bn: Spatial BN instance that needs to be fused with the conv
|
||||
|
||||
Examples::
|
||||
|
||||
>>> m1 = nn.Conv2d(10, 20, 3)
|
||||
>>> b1 = nn.BatchNorm2d(20)
|
||||
>>> m2 = fuse_conv_bn(m1, b1)
|
||||
"""
|
||||
assert(conv.training == bn.training == relu.training),\
|
||||
"Conv and BN both must be in the same mode (train or eval)."
|
||||
fused_module : Optional[Type[nn.Sequential]] = None
|
||||
if conv.training:
|
||||
map_to_fused_module_train = {
|
||||
nn.Conv2d: nni.ConvBnReLU2d,
|
||||
nn.Conv3d: nni.ConvBnReLU3d,
|
||||
}
|
||||
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
|
||||
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
|
||||
assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
|
||||
fused_module = map_to_fused_module_train.get(type(conv))
|
||||
if fused_module is not None:
|
||||
return fused_module(conv, bn, relu)
|
||||
else:
|
||||
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
|
||||
else:
|
||||
map_to_fused_module_eval = {
|
||||
nn.Conv1d: nni.ConvReLU1d,
|
||||
nn.Conv2d: nni.ConvReLU2d,
|
||||
nn.Conv3d: nni.ConvReLU3d,
|
||||
}
|
||||
fused_module = map_to_fused_module_eval[type(conv)]
|
||||
if fused_module is not None:
|
||||
fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
|
||||
return fused_module(fused_conv, relu)
|
||||
else:
|
||||
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
|
||||
|
||||
OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
|
||||
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
|
||||
(nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
|
||||
(nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
|
||||
(nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
|
||||
(nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
|
||||
(nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
|
||||
(nn.Conv1d, nn.ReLU): nni.ConvReLU1d,
|
||||
(nn.Conv2d, nn.ReLU): nni.ConvReLU2d,
|
||||
(nn.Conv3d, nn.ReLU): nni.ConvReLU3d,
|
||||
(nn.Linear, nn.ReLU): nni.LinearReLU,
|
||||
(nn.BatchNorm2d, nn.ReLU): nni.BNReLU2d,
|
||||
(nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d,
|
||||
}
|
||||
|
||||
def register_fuser_method(op_list, fuser_method):
|
||||
''' Register a fuser method for a tuple of ops, will be called
|
||||
during fusion step
|
||||
'''
|
||||
assert isinstance(op_list, tuple), 'op list must be a tuple'
|
||||
OP_LIST_TO_FUSER_METHOD[op_list] = fuser_method
|
||||
|
||||
def get_fuser_method(op_list):
|
||||
''' Get fuser method for the given list of module types,
|
||||
return None if fuser method does not exist
|
||||
'''
|
||||
return OP_LIST_TO_FUSER_METHOD.get(op_list, None)
|
||||
|
|
@ -3,7 +3,7 @@ from .pattern_utils import (
|
|||
register_fusion_pattern,
|
||||
)
|
||||
from .utils import _parent_name
|
||||
from ..fuse_modules import OP_LIST_TO_FUSER_METHOD
|
||||
from ..fuser_method_mappings import get_fuser_method
|
||||
|
||||
# ---------------------
|
||||
# Fusion Patterns
|
||||
|
|
@ -60,7 +60,7 @@ class ConvBNReLUFusion():
|
|||
op_list.reverse()
|
||||
op_type_list = tuple(type(m) for m in op_list)
|
||||
conv_parent_name, conv_name = _parent_name(self.conv_node.target)
|
||||
fuser_method = OP_LIST_TO_FUSER_METHOD.get(op_type_list, None)
|
||||
fuser_method = get_fuser_method(op_type_list)
|
||||
if fuser_method is None:
|
||||
raise NotImplementedError("Cannot fuse modules: {}".format(types))
|
||||
setattr(quantizer.modules[conv_parent_name], conv_name, fuser_method(*op_list))
|
||||
|
|
@ -104,8 +104,6 @@ class ModuleReLUFusion():
|
|||
op_list.reverse()
|
||||
op_type_list = tuple(type(m) for m in op_list)
|
||||
module_parent_name, module_name = _parent_name(self.module_node.target)
|
||||
fuser_method = OP_LIST_TO_FUSER_METHOD.get(op_type_list, None)
|
||||
if fuser_method is None:
|
||||
raise NotImplementedError("Cannot fuse modules: {}".format(types))
|
||||
fuser_method = get_fuser_method(op_type_list)
|
||||
setattr(quantizer.modules[module_parent_name], module_name, fuser_method(*op_list))
|
||||
return quantizer.fused_graph.node_copy(self.module_node, load_arg)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import torch
|
||||
from torch.quantization.default_mappings import (
|
||||
DEFAULT_MODULE_MAPPING,
|
||||
DEFAULT_OPERATOR_MAPPING,
|
||||
)
|
||||
from torch.fx.graph import (
|
||||
Node,
|
||||
)
|
||||
from ..quantization_mappings import (
|
||||
get_static_quant_module_class,
|
||||
get_quantized_operator,
|
||||
)
|
||||
from .pattern_utils import (
|
||||
register_quant_pattern,
|
||||
register_dynamic_quant_pattern,
|
||||
|
|
@ -181,10 +181,7 @@ class ConvRelu(QuantizeHandler):
|
|||
else:
|
||||
self.conv.activation_post_process = quantizer.activation_post_process_map[node.name]
|
||||
# 2. select quantized class
|
||||
# TODO: make the mapping configurable?
|
||||
assert type(self.conv) in DEFAULT_MODULE_MAPPING, \
|
||||
'unhandled conv type:{}'.format(type(self.conv))
|
||||
qconv_cls = DEFAULT_MODULE_MAPPING[type(self.conv)]
|
||||
qconv_cls = get_static_quant_module_class(type(self.conv))
|
||||
quantized = qconv_cls.from_float(self.conv)
|
||||
parent_name, name = _parent_name(self.conv_node.target)
|
||||
setattr(quantizer.modules[parent_name], name, quantized)
|
||||
|
|
@ -335,7 +332,7 @@ class BatchNorm(QuantizeHandler):
|
|||
self.bn[1].activation_post_process = activation_post_process
|
||||
else:
|
||||
self.bn.activation_post_process = activation_post_process
|
||||
qbn_cls = DEFAULT_MODULE_MAPPING[type(self.bn)]
|
||||
qbn_cls = get_static_quant_module_class(type(self.bn))
|
||||
quantized = qbn_cls.from_float(self.bn)
|
||||
parent_name, name = _parent_name(self.bn_node.target)
|
||||
setattr(quantizer.modules[parent_name], name, quantized)
|
||||
|
|
@ -371,7 +368,8 @@ class DefaultNode(QuantizeHandler):
|
|||
if node.op == 'call_module':
|
||||
module = quantizer.modules[node.target]
|
||||
module.activation_post_process = activation_post_process
|
||||
quantized_module = DEFAULT_MODULE_MAPPING[type(module)].from_float(module)
|
||||
quantized_module_cls = get_static_quant_module_class(type(module))
|
||||
quantized_module = quantized_module_cls.from_float(module)
|
||||
parent_name, name = _parent_name(node.target)
|
||||
setattr(quantizer.modules[parent_name], name, quantized_module)
|
||||
return quantizer.quantized_graph.create_node(
|
||||
|
|
@ -385,7 +383,7 @@ class DefaultNode(QuantizeHandler):
|
|||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
|
||||
quantized_op = DEFAULT_OPERATOR_MAPPING[node.target]
|
||||
quantized_op = get_quantized_operator(node.target)
|
||||
args = load_arg(quantized=[0])(node.args)
|
||||
kwargs = load_arg(quantized=False)(node.kwargs)
|
||||
kwargs.update({'output_scale': scale, 'output_zero_point': zero_point})
|
||||
|
|
@ -405,7 +403,7 @@ class ELU(QuantizeHandler):
|
|||
scale, zero_point = activation_post_process.calculate_qparams()
|
||||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
quantized_op = DEFAULT_OPERATOR_MAPPING[node.target]
|
||||
quantized_op = get_quantized_operator(node.target)
|
||||
args = load_arg(quantized=[0])(node.args)
|
||||
kwargs = load_arg(quantized=False)(node.kwargs)
|
||||
kwargs.update({'output_scale': scale, 'output_zero_point': zero_point})
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ from torch.quantization import (
|
|||
convert,
|
||||
)
|
||||
|
||||
from torch.quantization.default_mappings import (
|
||||
DEFAULT_QAT_MODULE_MAPPING,
|
||||
from ..quantization_mappings import (
|
||||
get_qat_module_mappings,
|
||||
)
|
||||
|
||||
from torch.fx import (
|
||||
|
|
@ -163,7 +163,7 @@ class Quantizer:
|
|||
|
||||
|
||||
def _qat_swap_modules(self, root):
|
||||
convert(root, mapping=DEFAULT_QAT_MODULE_MAPPING, inplace=True, remove_qconfig=False)
|
||||
convert(root, mapping=get_qat_module_mappings(), inplace=True, remove_qconfig=False)
|
||||
|
||||
def _generate_qconfig_map(self, root, input_graph):
|
||||
def get_qconfig(module):
|
||||
|
|
|
|||
187
torch/quantization/quantization_mappings.py
Normal file
187
torch/quantization/quantization_mappings.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.intrinsic as nni
|
||||
import torch.nn.intrinsic.quantized as nniq
|
||||
import torch.nn.intrinsic.qat as nniqat
|
||||
import torch.nn.quantized as nnq
|
||||
import torch.nn.quantized.dynamic as nnqd
|
||||
import torch.nn.qat as nnqat
|
||||
|
||||
from .stubs import QuantStub, DeQuantStub
|
||||
|
||||
# Map for swapping float module to quantized ones
|
||||
STATIC_QUANT_MODULE_MAPPINGS = {
|
||||
nn.Linear: nnq.Linear,
|
||||
nn.ReLU: nnq.ReLU,
|
||||
nn.ReLU6: nnq.ReLU6,
|
||||
nn.Hardswish: nnq.Hardswish,
|
||||
nn.ELU: nnq.ELU,
|
||||
nn.Conv1d: nnq.Conv1d,
|
||||
nn.Conv2d: nnq.Conv2d,
|
||||
nn.Conv3d: nnq.Conv3d,
|
||||
nn.BatchNorm2d: nnq.BatchNorm2d,
|
||||
nn.BatchNorm3d: nnq.BatchNorm3d,
|
||||
nn.LayerNorm: nnq.LayerNorm,
|
||||
nn.GroupNorm: nnq.GroupNorm,
|
||||
nn.InstanceNorm1d: nnq.InstanceNorm1d,
|
||||
nn.InstanceNorm2d: nnq.InstanceNorm2d,
|
||||
nn.InstanceNorm3d: nnq.InstanceNorm3d,
|
||||
nn.Embedding: nnq.Embedding,
|
||||
nn.EmbeddingBag: nnq.EmbeddingBag,
|
||||
QuantStub: nnq.Quantize,
|
||||
DeQuantStub: nnq.DeQuantize,
|
||||
# Wrapper Modules:
|
||||
nnq.FloatFunctional: nnq.QFunctional,
|
||||
# Intrinsic modules:
|
||||
nni.ConvReLU1d: nniq.ConvReLU1d,
|
||||
nni.ConvReLU2d: nniq.ConvReLU2d,
|
||||
nni.ConvReLU3d: nniq.ConvReLU3d,
|
||||
nni.LinearReLU: nniq.LinearReLU,
|
||||
nni.BNReLU2d: nniq.BNReLU2d,
|
||||
nni.BNReLU3d: nniq.BNReLU3d,
|
||||
nniqat.ConvReLU2d: nniq.ConvReLU2d,
|
||||
nniqat.LinearReLU: nniq.LinearReLU,
|
||||
nniqat.ConvBn2d: nnq.Conv2d,
|
||||
nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
|
||||
# QAT modules:
|
||||
nnqat.Linear: nnq.Linear,
|
||||
nnqat.Conv2d: nnq.Conv2d,
|
||||
}
|
||||
|
||||
# Map for swapping float module to qat modules
|
||||
QAT_MODULE_MAPPINGS = {
|
||||
nn.Linear: nnqat.Linear,
|
||||
nn.Conv2d: nnqat.Conv2d,
|
||||
# Intrinsic modules:
|
||||
nni.ConvBn2d: nniqat.ConvBn2d,
|
||||
nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
|
||||
nni.ConvReLU2d: nniqat.ConvReLU2d,
|
||||
nni.LinearReLU: nniqat.LinearReLU
|
||||
}
|
||||
|
||||
# Map for swapping dynamic modules
|
||||
DYNAMIC_QUANT_MODULE_MAPPINGS = {
|
||||
nn.Linear: nnqd.Linear,
|
||||
nn.LSTM: nnqd.LSTM,
|
||||
nn.LSTMCell: nnqd.LSTMCell,
|
||||
nn.RNNCell: nnqd.RNNCell,
|
||||
nn.GRUCell: nnqd.GRUCell,
|
||||
}
|
||||
|
||||
# Whitelist for propagating the qconfig
|
||||
_EXCLUDE_QCONFIG_PROPAGATE_LIST = {
|
||||
DeQuantStub,
|
||||
}
|
||||
_INCLUDE_QCONFIG_PROPAGATE_LIST = {
|
||||
nn.Sequential,
|
||||
}
|
||||
|
||||
# mapping from floating point function or torch ops to quantized ops
|
||||
FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = {
|
||||
F.elu: torch._ops.ops.quantized.elu,
|
||||
F.hardswish: torch._ops.ops.quantized.hardswish,
|
||||
F.instance_norm: torch._ops.ops.quantized.instance_norm,
|
||||
F.layer_norm: torch._ops.ops.quantized.layer_norm,
|
||||
}
|
||||
|
||||
def register_static_quant_module_mapping(
|
||||
float_source_module_class, static_quant_target_module_class):
|
||||
''' Register a mapping from `float_source__module_class` to `static_quant_target_module_class`
|
||||
`static_quant_target_module_class` must have from_float defined as a class method
|
||||
The mapping is used in the convert step of post training static quantization to
|
||||
convert a float module to a statically quantized module.
|
||||
'''
|
||||
assert hasattr(static_quant_target_module_class, 'from_float'), 'from_float must be defined' + \
|
||||
' in quantized module class'
|
||||
STATIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = static_quant_target_module_class
|
||||
|
||||
def get_static_quant_module_mappings():
|
||||
''' Get module mapping for post training static quantization
|
||||
'''
|
||||
return STATIC_QUANT_MODULE_MAPPINGS
|
||||
|
||||
def get_static_quant_module_class(float_module_class):
|
||||
''' Get the statically quantized module class corresponding to
|
||||
the floating point module class
|
||||
'''
|
||||
static_quant_module_class = STATIC_QUANT_MODULE_MAPPINGS.get(float_module_class, None)
|
||||
assert static_quant_module_class is not None, \
|
||||
'Floating point module class {}'.format(float_module_class) + \
|
||||
' does not have a corresponding quantized module class'
|
||||
return static_quant_module_class
|
||||
|
||||
def register_qat_module_mapping(float_source_module_class, qat_target_module_class):
|
||||
'''Register a mapping from `float_source_module_class` to `qat_target_module_class`,
|
||||
`qat_target_module_class` must have from_float defined as a class method
|
||||
This mapping is used in prepare step of quantization aware training to swap
|
||||
a float module to a qat module.
|
||||
'''
|
||||
assert hasattr(qat_target_module_class, 'from_float'), 'from_float must be defined' + \
|
||||
' in qat module class'
|
||||
QAT_MODULE_MAPPINGS[float_source_module_class] = qat_target_module_class
|
||||
|
||||
def get_qat_module_mappings():
|
||||
''' Get module mapping for quantization aware training
|
||||
'''
|
||||
return QAT_MODULE_MAPPINGS
|
||||
|
||||
def register_dynamic_quant_module_class(float_source_module_class, dynamic_quant_target_module_class):
|
||||
''' Register a mapping from `float_source_module_class` to `dynamic_quant_target_module_class`,
|
||||
`dynamic_quant_target_module_class` must have from_float defined as a class method
|
||||
This mapping is used in convert step of post training dynamic
|
||||
quantization to swap a float module to a dynamically quantized
|
||||
module.
|
||||
'''
|
||||
assert hasattr(dynamic_quant_target_module_class, 'from_float'), 'from_float must be defined' + \
|
||||
' in dynamically quantized module type'
|
||||
DYNAMIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = dynamic_quant_target_module_class
|
||||
|
||||
def get_dynamic_quant_module_mappings():
|
||||
''' Get module mapping for post training dynamic quantization
|
||||
'''
|
||||
return DYNAMIC_QUANT_MODULE_MAPPINGS
|
||||
|
||||
def get_qconfig_propagation_list():
|
||||
''' Get the list of module types that we'll attach qconfig
|
||||
attribute to in prepare
|
||||
'''
|
||||
QCONFIG_PROPAGATE_MODULE_CLASS_LIST = (
|
||||
(set(STATIC_QUANT_MODULE_MAPPINGS.keys()) |
|
||||
set(QAT_MODULE_MAPPINGS.keys()) |
|
||||
set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) |
|
||||
_INCLUDE_QCONFIG_PROPAGATE_LIST) -
|
||||
_EXCLUDE_QCONFIG_PROPAGATE_LIST
|
||||
)
|
||||
return QCONFIG_PROPAGATE_MODULE_CLASS_LIST
|
||||
|
||||
def get_compare_output_module_list():
|
||||
''' Get list of module class types that we will record output
|
||||
in numeric suite
|
||||
'''
|
||||
NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = (
|
||||
set(STATIC_QUANT_MODULE_MAPPINGS.values())
|
||||
| set(QAT_MODULE_MAPPINGS.values())
|
||||
| set(DYNAMIC_QUANT_MODULE_MAPPINGS.values())
|
||||
| set(STATIC_QUANT_MODULE_MAPPINGS.keys())
|
||||
| set(QAT_MODULE_MAPPINGS.keys())
|
||||
| set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
|
||||
| _INCLUDE_QCONFIG_PROPAGATE_LIST
|
||||
) - _EXCLUDE_QCONFIG_PROPAGATE_LIST
|
||||
return NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST
|
||||
|
||||
def register_quantized_operator_mapping(float_op, quantized_op):
|
||||
''' Register a mapping from `floating_point_op` (torch or functional) to `quantized_op`
|
||||
This is used in convert step of fx based graph mode quantization
|
||||
to convert a float op to quantized op.
|
||||
'''
|
||||
FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS[float_op] = quantized_op
|
||||
|
||||
def get_quantized_operator(float_op):
|
||||
''' Get the quantized operator corresponding to the float operator
|
||||
'''
|
||||
quantized_op = FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
|
||||
assert quantized_op is not None, \
|
||||
'Operator {} does not have corresponding quantized op'.format(float_op)
|
||||
return quantized_op
|
||||
|
|
@ -10,10 +10,11 @@ import torch.nn.intrinsic as nni
|
|||
import torch.nn.quantized as nnq
|
||||
import torch.nn.intrinsic.qat as nniqat
|
||||
|
||||
from .default_mappings import (DEFAULT_DYNAMIC_MODULE_MAPPING,
|
||||
DEFAULT_MODULE_MAPPING,
|
||||
DEFAULT_QAT_MODULE_MAPPING,
|
||||
DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST)
|
||||
from .quantization_mappings import (get_dynamic_quant_module_mappings,
|
||||
get_static_quant_module_mappings,
|
||||
get_qat_module_mappings,
|
||||
get_qconfig_propagation_list)
|
||||
|
||||
from .stubs import DeQuantStub, QuantWrapper
|
||||
from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_dynamic_qconfig
|
||||
|
||||
|
|
@ -37,7 +38,7 @@ def _propagate_qconfig_helper(module, qconfig_dict, allow_list=None,
|
|||
"""
|
||||
# TODO: Add test
|
||||
if allow_list is None:
|
||||
allow_list = DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST
|
||||
allow_list = get_qconfig_propagation_list()
|
||||
|
||||
module_qconfig = qconfig_dict.get(type(module), qconfig_parent)
|
||||
module_qconfig = qconfig_dict.get(prefix, module_qconfig)
|
||||
|
|
@ -100,7 +101,8 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No
|
|||
None, module is modified inplace with added observer modules and forward_hooks
|
||||
"""
|
||||
if qconfig_propagation_list is None:
|
||||
qconfig_propagation_list = DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST
|
||||
qconfig_propagation_list = get_qconfig_propagation_list()
|
||||
|
||||
# respect device affinity when adding observers
|
||||
if device is None:
|
||||
devices = get_unique_devices_(module)
|
||||
|
|
@ -194,9 +196,10 @@ def prepare(model, inplace=False, allow_list=None,
|
|||
"""
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
propagate_qconfig_list = allow_list
|
||||
if propagate_qconfig_list is None:
|
||||
propagate_qconfig_list = DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST
|
||||
|
||||
qconfig_propagation_list = allow_list
|
||||
if qconfig_propagation_list is None:
|
||||
qconfig_propagation_list = get_qconfig_propagation_list()
|
||||
propagate_qconfig_(model, qconfig_dict=None)
|
||||
|
||||
# sanity check common API misusage
|
||||
|
|
@ -205,7 +208,7 @@ def prepare(model, inplace=False, allow_list=None,
|
|||
"passed correct configuration through `qconfig_dict` or "
|
||||
"by assigning the `.qconfig` attribute directly on submodules")
|
||||
|
||||
add_observer_(model, propagate_qconfig_list, observer_non_leaf_module_list, prehook=prehook)
|
||||
add_observer_(model, qconfig_propagation_list, observer_non_leaf_module_list, prehook=prehook)
|
||||
return model
|
||||
|
||||
def _remove_qconfig(module):
|
||||
|
|
@ -239,7 +242,7 @@ def quantize(model, run_fn, run_args, mapping=None, inplace=False):
|
|||
Quantized model.
|
||||
"""
|
||||
if mapping is None:
|
||||
mapping = DEFAULT_MODULE_MAPPING
|
||||
mapping = get_static_quant_module_mappings()
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
model.eval()
|
||||
|
|
@ -316,7 +319,7 @@ def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8,
|
|||
qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig)))
|
||||
|
||||
if mapping is None:
|
||||
mapping = DEFAULT_DYNAMIC_MODULE_MAPPING
|
||||
mapping = get_dynamic_quant_module_mappings()
|
||||
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
|
|
@ -341,7 +344,7 @@ def prepare_qat(model, mapping=None, inplace=False):
|
|||
is mutated
|
||||
"""
|
||||
if mapping is None:
|
||||
mapping = DEFAULT_QAT_MODULE_MAPPING
|
||||
mapping = get_qat_module_mappings()
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
|
||||
|
|
@ -406,7 +409,7 @@ def _convert(module, mapping=None, inplace=False):
|
|||
|
||||
"""
|
||||
if mapping is None:
|
||||
mapping = DEFAULT_MODULE_MAPPING
|
||||
mapping = get_static_quant_module_mappings()
|
||||
if not inplace:
|
||||
module = copy.deepcopy(module)
|
||||
reassign = {}
|
||||
|
|
|
|||
|
|
@ -17,10 +17,10 @@ from torch.testing._internal.common_utils import TestCase
|
|||
from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \
|
||||
default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \
|
||||
propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_dynamic_qconfig
|
||||
from torch.quantization.default_mappings import (
|
||||
DEFAULT_DYNAMIC_MODULE_MAPPING,
|
||||
DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST,
|
||||
DEFAULT_QAT_MODULE_MAPPING,
|
||||
from torch.quantization.quantization_mappings import (
|
||||
get_dynamic_quant_module_mappings,
|
||||
get_qconfig_propagation_list,
|
||||
get_qat_module_mappings,
|
||||
)
|
||||
# symbolic trace
|
||||
from torch.fx import symbolic_trace
|
||||
|
|
@ -191,7 +191,7 @@ def run_ddp(rank, world_size, prepared):
|
|||
|
||||
|
||||
def convert_dynamic(module):
|
||||
convert(module, DEFAULT_DYNAMIC_MODULE_MAPPING, inplace=True)
|
||||
convert(module, get_dynamic_quant_module_mappings(), inplace=True)
|
||||
|
||||
def prepare_dynamic(model, qconfig_dict=None):
|
||||
propagate_qconfig_(model, qconfig_dict)
|
||||
|
|
@ -347,7 +347,7 @@ class QuantizationTestCase(TestCase):
|
|||
have observers in preperation for quantization
|
||||
"""
|
||||
if propagate_qconfig_list is None:
|
||||
propagate_qconfig_list = DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST
|
||||
propagate_qconfig_list = get_qconfig_propagation_list()
|
||||
if hasattr(module, 'qconfig') and module.qconfig is not None and \
|
||||
len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential) \
|
||||
and type(module) in propagate_qconfig_list:
|
||||
|
|
@ -355,7 +355,7 @@ class QuantizationTestCase(TestCase):
|
|||
'module: ' + str(type(module)) + ' do not have observer')
|
||||
# we don't need to check observers for child modules of the
|
||||
# qat modules
|
||||
if type(module) not in DEFAULT_QAT_MODULE_MAPPING.values():
|
||||
if type(module) not in get_qat_module_mappings().values():
|
||||
for child in module.children():
|
||||
self.checkObservers(child)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user