mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76637 The previous naming convention `default_affine_fixed_qparams_observer` and `default_symmetric_fixed_qparams_observer` were uninformative, and users had to read the definition in order to understand what these observers are. The new naming convention reveals information about the range of the observers The analogous changes were also made for `default_symmetric_fixed_qparams_fake_quant` and `default_affine_fixed_qparams_fake_quant` Test Plan: ``` python test/test_quantization.py ``` ``` python test/test_quantization.py ``` Differential Revision: D36054169 D36054169 Reviewed By: vkuzo Pulled By: dzdang fbshipit-source-id: 215f7786a4b7abda7327f17cc61735697ec5cca9 (cherry picked from commit 21a4e6eda4467c8adca7fd534a506a14e975f9cf)
315 lines
12 KiB
Python
315 lines
12 KiB
Python
import copy
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
import torch.nn.functional as F
|
|
import torch.nn.intrinsic as nni
|
|
import torch.nn.intrinsic.quantized as nniq
|
|
import torch.nn.intrinsic.quantized.dynamic as nniqd
|
|
import torch.nn.intrinsic.qat as nniqat
|
|
import torch.nn.quantized as nnq
|
|
import torch.nn.quantized._reference as nnqr
|
|
import torch.nn.quantized.dynamic as nnqd
|
|
import torch.nn.qat as nnqat
|
|
import torch.nn.qat.dynamic as nnqatd
|
|
|
|
from typing import Optional, Union, Dict, Set, Callable, Any
|
|
|
|
import torch.ao.nn as ao_nn
|
|
from torch.ao.quantization.stubs import QuantStub, DeQuantStub
|
|
from torch.ao.quantization.fake_quantize import (
|
|
default_fixed_qparams_range_0to1_fake_quant,
|
|
default_fixed_qparams_range_neg1to1_fake_quant,
|
|
)
|
|
from torch.ao.quantization.utils import get_combined_dict
|
|
from torch.nn.utils.parametrize import type_before_parametrizations
|
|
|
|
# Default map for swapping float module to reference quantized modules
|
|
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
|
QuantStub: nnq.Quantize,
|
|
DeQuantStub: nnq.DeQuantize,
|
|
nn.Linear: nnqr.Linear,
|
|
nn.Conv1d: nnqr.Conv1d,
|
|
nn.Conv2d: nnqr.Conv2d,
|
|
nn.Conv3d: nnqr.Conv3d,
|
|
nn.ConvTranspose1d: nnqr.ConvTranspose1d,
|
|
nn.ConvTranspose2d: nnqr.ConvTranspose2d,
|
|
nn.ConvTranspose3d: nnqr.ConvTranspose3d,
|
|
nn.Embedding: nnqr.Embedding,
|
|
nn.EmbeddingBag: nnqr.EmbeddingBag,
|
|
nn.GRUCell: nnqr.GRUCell,
|
|
nn.LSTMCell: nnqr.LSTMCell,
|
|
nn.RNNCell: nnqr.RNNCell,
|
|
nn.LSTM: nnqr.LSTM,
|
|
}
|
|
|
|
# Default map for swapping float module to quantized ones
|
|
DEFAULT_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
|
QuantStub: nnq.Quantize,
|
|
DeQuantStub: nnq.DeQuantize,
|
|
nn.BatchNorm2d: nnq.BatchNorm2d,
|
|
nn.BatchNorm3d: nnq.BatchNorm3d,
|
|
nn.Dropout: nnq.Dropout,
|
|
nn.Conv1d: nnq.Conv1d,
|
|
nn.Conv2d: nnq.Conv2d,
|
|
nn.Conv3d: nnq.Conv3d,
|
|
nn.ConvTranspose1d: nnq.ConvTranspose1d,
|
|
nn.ConvTranspose2d: nnq.ConvTranspose2d,
|
|
nn.ConvTranspose3d: nnq.ConvTranspose3d,
|
|
nn.ELU: nnq.ELU,
|
|
nn.Embedding: nnq.Embedding,
|
|
nn.EmbeddingBag: nnq.EmbeddingBag,
|
|
nn.GroupNorm: nnq.GroupNorm,
|
|
nn.Hardswish: nnq.Hardswish,
|
|
nn.InstanceNorm1d: nnq.InstanceNorm1d,
|
|
nn.InstanceNorm2d: nnq.InstanceNorm2d,
|
|
nn.InstanceNorm3d: nnq.InstanceNorm3d,
|
|
nn.LayerNorm: nnq.LayerNorm,
|
|
nn.LeakyReLU: nnq.LeakyReLU,
|
|
nn.modules.linear.NonDynamicallyQuantizableLinear: nnq.Linear,
|
|
nn.Linear: nnq.Linear,
|
|
nn.ReLU6: nnq.ReLU6,
|
|
nn.Dropout: nnq.Dropout,
|
|
# Wrapper Modules:
|
|
nnq.FloatFunctional: nnq.QFunctional,
|
|
# Intrinsic modules:
|
|
nni.BNReLU2d: nniq.BNReLU2d,
|
|
nni.BNReLU3d: nniq.BNReLU3d,
|
|
nni.ConvReLU1d: nniq.ConvReLU1d,
|
|
nni.ConvReLU2d: nniq.ConvReLU2d,
|
|
nni.ConvReLU3d: nniq.ConvReLU3d,
|
|
nni.LinearReLU: nniq.LinearReLU,
|
|
nniqat.ConvBn1d: nnq.Conv1d,
|
|
nniqat.ConvBn2d: nnq.Conv2d,
|
|
nniqat.ConvBn3d: nnq.Conv3d,
|
|
nniqat.ConvBnReLU1d: nniq.ConvReLU1d,
|
|
nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
|
|
nniqat.ConvBnReLU3d: nniq.ConvReLU3d,
|
|
nniqat.ConvReLU2d: nniq.ConvReLU2d,
|
|
nniqat.ConvReLU3d: nniq.ConvReLU3d,
|
|
nniqat.LinearReLU: nniq.LinearReLU,
|
|
nniqat.LinearBn1d: nnq.Linear,
|
|
# QAT modules:
|
|
nnqat.Linear: nnq.Linear,
|
|
nnqat.Conv2d: nnq.Conv2d,
|
|
nnqat.Conv3d: nnq.Conv3d,
|
|
}
|
|
|
|
# Default map for swapping float module to qat modules
|
|
DEFAULT_QAT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
|
nn.Conv2d: nnqat.Conv2d,
|
|
nn.Conv3d: nnqat.Conv3d,
|
|
nn.Linear: nnqat.Linear,
|
|
nn.modules.linear.NonDynamicallyQuantizableLinear: nnqat.Linear,
|
|
# Intrinsic modules:
|
|
nni.ConvBn1d: nniqat.ConvBn1d,
|
|
nni.ConvBn2d: nniqat.ConvBn2d,
|
|
nni.ConvBn3d: nniqat.ConvBn3d,
|
|
nni.ConvBnReLU1d: nniqat.ConvBnReLU1d,
|
|
nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
|
|
nni.ConvBnReLU3d: nniqat.ConvBnReLU3d,
|
|
nni.ConvReLU2d: nniqat.ConvReLU2d,
|
|
nni.ConvReLU3d: nniqat.ConvReLU3d,
|
|
nni.LinearReLU: nniqat.LinearReLU,
|
|
nni.LinearBn1d: nniqat.LinearBn1d,
|
|
}
|
|
|
|
# Default map for swapping dynamic modules
|
|
DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
|
nn.GRUCell: nnqd.GRUCell,
|
|
nn.Linear: nnqd.Linear,
|
|
nnqatd.Linear: nnqd.Linear,
|
|
nn.modules.linear.NonDynamicallyQuantizableLinear: nnqd.Linear,
|
|
nn.LSTM: nnqd.LSTM,
|
|
nn.GRU: nnqd.GRU,
|
|
nn.LSTMCell: nnqd.LSTMCell,
|
|
nn.RNNCell: nnqd.RNNCell,
|
|
nni.LinearReLU: nniqd.LinearReLU,
|
|
nn.EmbeddingBag: nnq.EmbeddingBag,
|
|
nn.Embedding: nnq.Embedding,
|
|
# Don't want to enable these by default because the numerical
|
|
# accuracy is poor compared to other dynamic ops
|
|
# nn.Conv1d: nnqd.Conv1d,
|
|
# nn.Conv2d: nnqd.Conv2d,
|
|
# nn.Conv3d: nnqd.Conv3d,
|
|
# nn.ConvTranspose1d: nnqd.ConvTranspose1d,
|
|
# nn.ConvTranspose2d: nnqd.ConvTranspose2d,
|
|
# nn.ConvTranspose3d: nnqd.ConvTranspose3d,
|
|
}
|
|
|
|
# Allowlist for propagating the qconfig
|
|
_INCLUDE_QCONFIG_PROPAGATE_LIST : Set[Callable] = {
|
|
nn.Sequential,
|
|
}
|
|
|
|
# Default mapping from floating point function or torch ops to quantized ops
|
|
# TODO: merge with default static mapping
|
|
DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS : Dict[Union[Callable, str], Callable] = {
|
|
F.elu: torch.ops.quantized.elu,
|
|
F.hardswish: torch.ops.quantized.hardswish,
|
|
F.instance_norm: torch.ops.quantized.instance_norm,
|
|
F.layer_norm: torch.ops.quantized.layer_norm,
|
|
F.leaky_relu: torch.ops.quantized.leaky_relu,
|
|
F.dropout: torch.ops.quantized.dropout,
|
|
}
|
|
|
|
# mapping from module to output activation post process class
|
|
DEFAULT_MODULE_TO_ACT_POST_PROCESS : Dict[Callable, Callable] = {
|
|
nn.Hardsigmoid: default_fixed_qparams_range_0to1_fake_quant,
|
|
nn.Sigmoid: default_fixed_qparams_range_0to1_fake_quant,
|
|
nn.Softmax: default_fixed_qparams_range_0to1_fake_quant,
|
|
nn.Tanh: default_fixed_qparams_range_neg1to1_fake_quant,
|
|
}
|
|
|
|
# Default map for swapping float module to static sparse quantized ones
|
|
DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
|
nn.Linear: ao_nn.sparse.quantized.Linear
|
|
}
|
|
|
|
# Default map for swapping float module to dynamic sparse quantized ones
|
|
DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
|
nn.Linear: ao_nn.sparse.quantized.dynamic.Linear
|
|
}
|
|
|
|
def no_observer_set() -> Set[Any]:
|
|
r"""These modules cannot have observers inserted by default."""
|
|
no_observers = set([
|
|
nn.quantizable.LSTM,
|
|
nn.quantizable.MultiheadAttention
|
|
])
|
|
return no_observers
|
|
|
|
def get_default_static_quant_module_mappings() -> Dict[Callable, Any]:
|
|
''' Get module mapping for post training static quantization
|
|
'''
|
|
return copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS)
|
|
|
|
def get_default_static_quant_reference_module_mappings() -> Dict[Callable, Any]:
|
|
''' Get reference module mapping for post training static quantization
|
|
'''
|
|
return copy.deepcopy(DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS)
|
|
|
|
def get_embedding_static_quant_module_mappings() -> Dict[Callable, Any]:
|
|
''' Get module mapping, including mapping for embedding QAT
|
|
'''
|
|
mapping = copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS)
|
|
mapping[nnqat.EmbeddingBag] = nnq.EmbeddingBag
|
|
mapping[nnqat.Embedding] = nnq.Embedding
|
|
return mapping
|
|
|
|
def get_default_static_sparse_quant_module_mappings() -> Dict[Callable, Any]:
|
|
''' Get module mapping for post training static sparse quantization
|
|
'''
|
|
return copy.deepcopy(DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS)
|
|
|
|
def get_static_quant_module_class(
|
|
float_module_class: Callable,
|
|
additional_static_quant_mapping: Optional[Dict[Callable, Any]] = None,
|
|
is_reference: bool = False) -> Any:
|
|
r"""n Get the statically quantized module class corresponding to
|
|
the floating point module class
|
|
"""
|
|
if additional_static_quant_mapping is None:
|
|
additional_static_quant_mapping = {}
|
|
all_mappings = get_combined_dict(
|
|
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS if is_reference
|
|
else DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, additional_static_quant_mapping)
|
|
static_quant_module_class = all_mappings.get(float_module_class, None)
|
|
assert static_quant_module_class is not None, \
|
|
"Floating point module class {}".format(str(float_module_class)) + \
|
|
" does not have a corresponding quantized module class"
|
|
return copy.deepcopy(static_quant_module_class)
|
|
|
|
def get_dynamic_quant_module_class(
|
|
float_module_class: Callable,
|
|
additional_dynamic_quant_mapping: Optional[Dict[Callable, Any]] = None) -> Any:
|
|
r"""n Get the dynamically quantized module class corresponding to
|
|
the floating point module class
|
|
"""
|
|
if additional_dynamic_quant_mapping is None:
|
|
additional_dynamic_quant_mapping = {}
|
|
all_mappings = get_combined_dict(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping)
|
|
dynamic_quant_module_class = all_mappings.get(float_module_class, None)
|
|
assert dynamic_quant_module_class is not None, \
|
|
"Floating point module class {}".format(str(float_module_class)) + \
|
|
" does not have a corresponding quantized module class"
|
|
return copy.deepcopy(dynamic_quant_module_class)
|
|
|
|
def get_default_qat_module_mappings() -> Dict[Callable, Any]:
|
|
''' Get default module mapping for quantization aware training
|
|
'''
|
|
return copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS)
|
|
|
|
def get_embedding_qat_module_mappings() -> Dict[Callable, Any]:
|
|
''' Get module mapping for quantization aware training
|
|
This is includes default values in addition to
|
|
enabling qat for embeddings.
|
|
'''
|
|
mapping = copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS)
|
|
mapping[nn.EmbeddingBag] = nnqat.EmbeddingBag
|
|
mapping[nn.Embedding] = nnqat.Embedding
|
|
return mapping
|
|
|
|
def get_default_dynamic_quant_module_mappings() -> Dict[Callable, Any]:
|
|
''' Get module mapping for post training dynamic quantization
|
|
'''
|
|
return DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS
|
|
|
|
def get_default_dynamic_sparse_quant_module_mappings() -> Dict[Callable, Any]:
|
|
''' Get module mapping for post training dynamic sparse quantization
|
|
'''
|
|
return DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS
|
|
|
|
def get_default_qconfig_propagation_list() -> Set[Callable]:
|
|
''' Get the default list of module types that we'll attach qconfig
|
|
attribute to in prepare
|
|
'''
|
|
QCONFIG_PROPAGATE_MODULE_CLASS_LIST = (
|
|
(set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) |
|
|
set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) |
|
|
set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) |
|
|
_INCLUDE_QCONFIG_PROPAGATE_LIST)
|
|
)
|
|
return copy.deepcopy(QCONFIG_PROPAGATE_MODULE_CLASS_LIST)
|
|
|
|
def get_default_compare_output_module_list() -> Set[Callable]:
|
|
''' Get list of module class types that we will record output
|
|
in numeric suite
|
|
'''
|
|
NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = (
|
|
set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.values())
|
|
| set(DEFAULT_QAT_MODULE_MAPPINGS.values())
|
|
| set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.values())
|
|
| set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys())
|
|
| set(DEFAULT_QAT_MODULE_MAPPINGS.keys())
|
|
| set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
|
|
| _INCLUDE_QCONFIG_PROPAGATE_LIST
|
|
)
|
|
return copy.deepcopy(NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST)
|
|
|
|
def get_default_float_to_quantized_operator_mappings(
|
|
) -> Dict[Union[Callable, str], Callable]:
|
|
return copy.deepcopy(DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS)
|
|
|
|
# TODO: merge with get_static_quant_module_class
|
|
def get_quantized_operator(float_op: Union[Callable, str]) -> Callable:
|
|
''' Get the quantized operator corresponding to the float operator
|
|
'''
|
|
quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
|
|
assert quantized_op is not None, \
|
|
'Operator {} does not have corresponding quantized op'.format(str(float_op))
|
|
return quantized_op
|
|
|
|
def _get_special_act_post_process(module: torch.nn.Module) -> Optional[Callable]:
|
|
r""" Get the special activation post process for `module`, this has
|
|
higher priority than the activation post process in `qconfig`
|
|
e.g.
|
|
input: torch.nn.Sigmoid
|
|
output: default_affine_fixed_qparam_fake_quant
|
|
"""
|
|
return DEFAULT_MODULE_TO_ACT_POST_PROCESS.get(type_before_parametrizations(module), None)
|
|
|
|
def _has_special_act_post_process(module: torch.nn.Module) -> bool:
|
|
return module.training and type(module) in DEFAULT_MODULE_TO_ACT_POST_PROCESS
|