pytorch/torch/quantization/quantization_mappings.py
Alban Desmaison 25db74bf5e Revert D24486972: [quant][graphmode][fx] Support sigmoid/hardsigmoid/tanh in qat
Test Plan: revert-hammer

Differential Revision:
D24486972 (e927b62e73)

Original commit changeset: c9f139bfdd54

fbshipit-source-id: 2a75f5ec93d55a62b40d1cdd49adcf65436058f7
2020-10-26 12:47:05 -07:00

161 lines
5.6 KiB
Python

import torch
from torch import nn
import torch.nn.functional as F
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.qat as nniqat
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
import torch.nn.qat as nnqat
from .stubs import QuantStub, DeQuantStub
# Default map for swapping float module to quantized ones
DEFAULT_STATIC_QUANT_MODULE_MAPPINGS = {
QuantStub: nnq.Quantize,
DeQuantStub: nnq.DeQuantize,
nn.BatchNorm2d: nnq.BatchNorm2d,
nn.BatchNorm3d: nnq.BatchNorm3d,
nn.Conv1d: nnq.Conv1d,
nn.Conv2d: nnq.Conv2d,
nn.Conv3d: nnq.Conv3d,
nn.ConvTranspose1d: nnq.ConvTranspose1d,
nn.ConvTranspose2d: nnq.ConvTranspose2d,
nn.ELU: nnq.ELU,
nn.Embedding: nnq.Embedding,
nn.EmbeddingBag: nnq.EmbeddingBag,
nn.GroupNorm: nnq.GroupNorm,
nn.Hardswish: nnq.Hardswish,
nn.InstanceNorm1d: nnq.InstanceNorm1d,
nn.InstanceNorm2d: nnq.InstanceNorm2d,
nn.InstanceNorm3d: nnq.InstanceNorm3d,
nn.LayerNorm: nnq.LayerNorm,
nn.LeakyReLU: nnq.LeakyReLU,
nn.Linear: nnq.Linear,
nn.ReLU6: nnq.ReLU6,
nn.ReLU: nnq.ReLU,
# Wrapper Modules:
nnq.FloatFunctional: nnq.QFunctional,
# Intrinsic modules:
nni.BNReLU2d: nniq.BNReLU2d,
nni.BNReLU3d: nniq.BNReLU3d,
nni.ConvReLU1d: nniq.ConvReLU1d,
nni.ConvReLU2d: nniq.ConvReLU2d,
nni.ConvReLU3d: nniq.ConvReLU3d,
nni.LinearReLU: nniq.LinearReLU,
nniqat.ConvBn2d: nnq.Conv2d,
nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
nniqat.ConvReLU2d: nniq.ConvReLU2d,
nniqat.LinearReLU: nniq.LinearReLU,
# QAT modules:
nnqat.Linear: nnq.Linear,
nnqat.Conv2d: nnq.Conv2d,
}
# Default map for swapping float module to qat modules
DEFAULT_QAT_MODULE_MAPPINGS = {
nn.Conv2d: nnqat.Conv2d,
nn.Linear: nnqat.Linear,
# Intrinsic modules:
nni.ConvBn2d: nniqat.ConvBn2d,
nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
nni.ConvReLU2d: nniqat.ConvReLU2d,
nni.LinearReLU: nniqat.LinearReLU
}
# Default map for swapping dynamic modules
DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS = {
nn.GRUCell: nnqd.GRUCell,
nn.Linear: nnqd.Linear,
nn.LSTM: nnqd.LSTM,
nn.LSTMCell: nnqd.LSTMCell,
nn.RNNCell: nnqd.RNNCell,
}
# Whitelist for propagating the qconfig
_EXCLUDE_QCONFIG_PROPAGATE_LIST = {
DeQuantStub,
}
_INCLUDE_QCONFIG_PROPAGATE_LIST = {
nn.Sequential,
}
# Default mapping from floating point function or torch ops to quantized ops
# TODO: merge with default static mapping
DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = {
F.elu: torch._ops.ops.quantized.elu,
F.hardswish: torch._ops.ops.quantized.hardswish,
F.instance_norm: torch._ops.ops.quantized.instance_norm,
F.layer_norm: torch._ops.ops.quantized.layer_norm,
F.leaky_relu: torch._ops.ops.quantized.leaky_relu,
}
def get_default_static_quant_module_mappings():
''' Get module mapping for post training static quantization
'''
return DEFAULT_STATIC_QUANT_MODULE_MAPPINGS
def get_static_quant_module_class(float_module_class, additional_static_quant_mapping=None):
r"""n Get the statically quantized module class corresponding to
the floating point module class
"""
if additional_static_quant_mapping is None:
additional_static_quant_mapping = {}
all_mappings = DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.copy()
for k, v in additional_static_quant_mapping.items():
all_mappings[k] = v
static_quant_module_class = all_mappings.get(float_module_class, None)
assert static_quant_module_class is not None, \
"Floating point module class {}".format(str(float_module_class)) + \
" does not have a corresponding quantized module class"
return static_quant_module_class
def get_default_qat_module_mappings():
''' Get default module mapping for quantization aware training
'''
return DEFAULT_QAT_MODULE_MAPPINGS
def get_default_dynamic_quant_module_mappings():
''' Get module mapping for post training dynamic quantization
'''
return DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS
def get_default_qconfig_propagation_list():
''' Get the default list of module types that we'll attach qconfig
attribute to in prepare
'''
QCONFIG_PROPAGATE_MODULE_CLASS_LIST = (
(set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) |
set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) |
set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) |
_INCLUDE_QCONFIG_PROPAGATE_LIST) -
_EXCLUDE_QCONFIG_PROPAGATE_LIST
)
return QCONFIG_PROPAGATE_MODULE_CLASS_LIST
def get_default_compare_output_module_list():
''' Get list of module class types that we will record output
in numeric suite
'''
NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = (
set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.values())
| set(DEFAULT_QAT_MODULE_MAPPINGS.values())
| set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.values())
| set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys())
| set(DEFAULT_QAT_MODULE_MAPPINGS.keys())
| set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
| _INCLUDE_QCONFIG_PROPAGATE_LIST
) - _EXCLUDE_QCONFIG_PROPAGATE_LIST
return NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST
# TODO: merge with get_static_quant_module_class
def get_quantized_operator(float_op):
''' Get the quantized operator corresponding to the float operator
'''
quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
assert quantized_op is not None, \
'Operator {} does not have corresponding quantized op'.format(str(float_op))
return quantized_op