mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47415 nn.ReLU works for both float and quantized input, we don't want to define an nn.quantized.ReLU that does the same thing as nn.ReLU, similarly for nn.quantized.functional.relu this also removes the numerical inconsistency for models quantizes nn.ReLU independently in qat mode Test Plan: Imported from OSS Reviewed By: z-a-f Differential Revision: D24747035 fbshipit-source-id: b8fdf13e513a0d5f0c4c6c9835635bdf9fdc2769
185 lines
6.5 KiB
Python
185 lines
6.5 KiB
Python
import torch
|
|
from torch import nn
|
|
|
|
import torch.nn.functional as F
|
|
import torch.nn.intrinsic as nni
|
|
import torch.nn.intrinsic.quantized as nniq
|
|
import torch.nn.intrinsic.qat as nniqat
|
|
import torch.nn.quantized as nnq
|
|
import torch.nn.quantized.dynamic as nnqd
|
|
import torch.nn.qat as nnqat
|
|
|
|
from .stubs import QuantStub, DeQuantStub
|
|
from .fake_quantize import (
|
|
default_affine_fixed_qparams_fake_quant,
|
|
default_symmetric_fixed_qparams_fake_quant,
|
|
)
|
|
from .utils import get_combined_dict
|
|
|
|
# Default map for swapping float module to quantized ones
|
|
DEFAULT_STATIC_QUANT_MODULE_MAPPINGS = {
|
|
QuantStub: nnq.Quantize,
|
|
DeQuantStub: nnq.DeQuantize,
|
|
nn.BatchNorm2d: nnq.BatchNorm2d,
|
|
nn.BatchNorm3d: nnq.BatchNorm3d,
|
|
nn.Conv1d: nnq.Conv1d,
|
|
nn.Conv2d: nnq.Conv2d,
|
|
nn.Conv3d: nnq.Conv3d,
|
|
nn.ConvTranspose1d: nnq.ConvTranspose1d,
|
|
nn.ConvTranspose2d: nnq.ConvTranspose2d,
|
|
nn.ELU: nnq.ELU,
|
|
nn.Embedding: nnq.Embedding,
|
|
nn.EmbeddingBag: nnq.EmbeddingBag,
|
|
nn.GroupNorm: nnq.GroupNorm,
|
|
nn.Hardswish: nnq.Hardswish,
|
|
nn.InstanceNorm1d: nnq.InstanceNorm1d,
|
|
nn.InstanceNorm2d: nnq.InstanceNorm2d,
|
|
nn.InstanceNorm3d: nnq.InstanceNorm3d,
|
|
nn.LayerNorm: nnq.LayerNorm,
|
|
nn.LeakyReLU: nnq.LeakyReLU,
|
|
nn.Linear: nnq.Linear,
|
|
nn.ReLU6: nnq.ReLU6,
|
|
# Wrapper Modules:
|
|
nnq.FloatFunctional: nnq.QFunctional,
|
|
# Intrinsic modules:
|
|
nni.BNReLU2d: nniq.BNReLU2d,
|
|
nni.BNReLU3d: nniq.BNReLU3d,
|
|
nni.ConvReLU1d: nniq.ConvReLU1d,
|
|
nni.ConvReLU2d: nniq.ConvReLU2d,
|
|
nni.ConvReLU3d: nniq.ConvReLU3d,
|
|
nni.LinearReLU: nniq.LinearReLU,
|
|
nniqat.ConvBn1d: nnq.Conv1d,
|
|
nniqat.ConvBn2d: nnq.Conv2d,
|
|
nniqat.ConvBnReLU1d: nniq.ConvReLU1d,
|
|
nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
|
|
nniqat.ConvReLU2d: nniq.ConvReLU2d,
|
|
nniqat.LinearReLU: nniq.LinearReLU,
|
|
# QAT modules:
|
|
nnqat.Linear: nnq.Linear,
|
|
nnqat.Conv2d: nnq.Conv2d,
|
|
}
|
|
|
|
# Default map for swapping float module to qat modules
|
|
DEFAULT_QAT_MODULE_MAPPINGS = {
|
|
nn.Conv2d: nnqat.Conv2d,
|
|
nn.Linear: nnqat.Linear,
|
|
# Intrinsic modules:
|
|
nni.ConvBn1d: nniqat.ConvBn1d,
|
|
nni.ConvBn2d: nniqat.ConvBn2d,
|
|
nni.ConvBnReLU1d: nniqat.ConvBnReLU1d,
|
|
nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
|
|
nni.ConvReLU2d: nniqat.ConvReLU2d,
|
|
nni.LinearReLU: nniqat.LinearReLU
|
|
}
|
|
|
|
# Default map for swapping dynamic modules
|
|
DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS = {
|
|
nn.GRUCell: nnqd.GRUCell,
|
|
nn.Linear: nnqd.Linear,
|
|
nn.LSTM: nnqd.LSTM,
|
|
nn.LSTMCell: nnqd.LSTMCell,
|
|
nn.RNNCell: nnqd.RNNCell,
|
|
}
|
|
|
|
# Whitelist for propagating the qconfig
|
|
_EXCLUDE_QCONFIG_PROPAGATE_LIST = {
|
|
DeQuantStub,
|
|
}
|
|
_INCLUDE_QCONFIG_PROPAGATE_LIST = {
|
|
nn.Sequential,
|
|
}
|
|
|
|
# Default mapping from floating point function or torch ops to quantized ops
|
|
# TODO: merge with default static mapping
|
|
DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = {
|
|
F.elu: torch._ops.ops.quantized.elu,
|
|
F.hardswish: torch._ops.ops.quantized.hardswish,
|
|
F.instance_norm: torch._ops.ops.quantized.instance_norm,
|
|
F.layer_norm: torch._ops.ops.quantized.layer_norm,
|
|
F.leaky_relu: torch._ops.ops.quantized.leaky_relu,
|
|
}
|
|
|
|
# mapping from module to output activation post process class
|
|
DEFAULT_MODULE_TO_ACT_POST_PROCESS = {
|
|
nn.Hardsigmoid: default_affine_fixed_qparams_fake_quant,
|
|
nn.Sigmoid: default_affine_fixed_qparams_fake_quant,
|
|
nn.Tanh: default_symmetric_fixed_qparams_fake_quant,
|
|
}
|
|
|
|
def get_default_static_quant_module_mappings():
|
|
''' Get module mapping for post training static quantization
|
|
'''
|
|
return DEFAULT_STATIC_QUANT_MODULE_MAPPINGS
|
|
|
|
def get_static_quant_module_class(float_module_class, additional_static_quant_mapping=None):
|
|
r"""n Get the statically quantized module class corresponding to
|
|
the floating point module class
|
|
"""
|
|
if additional_static_quant_mapping is None:
|
|
additional_static_quant_mapping = {}
|
|
all_mappings = get_combined_dict(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, additional_static_quant_mapping)
|
|
static_quant_module_class = all_mappings.get(float_module_class, None)
|
|
assert static_quant_module_class is not None, \
|
|
"Floating point module class {}".format(str(float_module_class)) + \
|
|
" does not have a corresponding quantized module class"
|
|
return static_quant_module_class
|
|
|
|
def get_default_qat_module_mappings():
|
|
''' Get default module mapping for quantization aware training
|
|
'''
|
|
return DEFAULT_QAT_MODULE_MAPPINGS
|
|
|
|
def get_default_dynamic_quant_module_mappings():
|
|
''' Get module mapping for post training dynamic quantization
|
|
'''
|
|
return DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS
|
|
|
|
def get_default_qconfig_propagation_list():
|
|
''' Get the default list of module types that we'll attach qconfig
|
|
attribute to in prepare
|
|
'''
|
|
QCONFIG_PROPAGATE_MODULE_CLASS_LIST = (
|
|
(set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) |
|
|
set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) |
|
|
set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) |
|
|
_INCLUDE_QCONFIG_PROPAGATE_LIST) -
|
|
_EXCLUDE_QCONFIG_PROPAGATE_LIST
|
|
)
|
|
return QCONFIG_PROPAGATE_MODULE_CLASS_LIST
|
|
|
|
def get_default_compare_output_module_list():
|
|
''' Get list of module class types that we will record output
|
|
in numeric suite
|
|
'''
|
|
NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = (
|
|
set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.values())
|
|
| set(DEFAULT_QAT_MODULE_MAPPINGS.values())
|
|
| set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.values())
|
|
| set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys())
|
|
| set(DEFAULT_QAT_MODULE_MAPPINGS.keys())
|
|
| set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
|
|
| _INCLUDE_QCONFIG_PROPAGATE_LIST
|
|
) - _EXCLUDE_QCONFIG_PROPAGATE_LIST
|
|
return NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST
|
|
|
|
# TODO: merge with get_static_quant_module_class
|
|
def get_quantized_operator(float_op):
|
|
''' Get the quantized operator corresponding to the float operator
|
|
'''
|
|
quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
|
|
assert quantized_op is not None, \
|
|
'Operator {} does not have corresponding quantized op'.format(str(float_op))
|
|
return quantized_op
|
|
|
|
def _get_special_act_post_process(module):
|
|
r""" Get the special activation post process for `module`, this has
|
|
higher priority than the activation post process in `qconfig`
|
|
e.g.
|
|
input: torch.nn.Sigmoid
|
|
output: default_affine_fixed_qparam_fake_quant
|
|
"""
|
|
return DEFAULT_MODULE_TO_ACT_POST_PROCESS.get(type(module), None)
|
|
|
|
def _has_special_act_post_process(module):
|
|
return module.training and type(module) in DEFAULT_MODULE_TO_ACT_POST_PROCESS
|