mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45712 Eager mode will still be able to use functional leaky relu, but it will be less accurate than LeakyReLU module. FX graph mode will support both leaky relu functional and module Test Plan: Imported from OSS Reviewed By: z-a-f Differential Revision: D24069961 fbshipit-source-id: 8d91c3c50c0bcd068ba3072378ebb4da9549be3b
192 lines
7.2 KiB
Python
192 lines
7.2 KiB
Python
import torch
|
|
from torch import nn
|
|
|
|
import torch.nn.functional as F
|
|
import torch.nn.intrinsic as nni
|
|
import torch.nn.intrinsic.quantized as nniq
|
|
import torch.nn.intrinsic.qat as nniqat
|
|
import torch.nn.quantized as nnq
|
|
import torch.nn.quantized.dynamic as nnqd
|
|
import torch.nn.qat as nnqat
|
|
|
|
from .stubs import QuantStub, DeQuantStub
|
|
|
|
# Map for swapping float module to quantized ones
|
|
STATIC_QUANT_MODULE_MAPPINGS = {
|
|
nn.Linear: nnq.Linear,
|
|
nn.ReLU: nnq.ReLU,
|
|
nn.ReLU6: nnq.ReLU6,
|
|
nn.Hardswish: nnq.Hardswish,
|
|
nn.ELU: nnq.ELU,
|
|
nn.LeakyReLU: nnq.LeakyReLU,
|
|
nn.Conv1d: nnq.Conv1d,
|
|
nn.Conv2d: nnq.Conv2d,
|
|
nn.Conv3d: nnq.Conv3d,
|
|
nn.ConvTranspose1d: nnq.ConvTranspose1d,
|
|
nn.ConvTranspose2d: nnq.ConvTranspose2d,
|
|
nn.BatchNorm2d: nnq.BatchNorm2d,
|
|
nn.BatchNorm3d: nnq.BatchNorm3d,
|
|
nn.LayerNorm: nnq.LayerNorm,
|
|
nn.GroupNorm: nnq.GroupNorm,
|
|
nn.InstanceNorm1d: nnq.InstanceNorm1d,
|
|
nn.InstanceNorm2d: nnq.InstanceNorm2d,
|
|
nn.InstanceNorm3d: nnq.InstanceNorm3d,
|
|
nn.Embedding: nnq.Embedding,
|
|
nn.EmbeddingBag: nnq.EmbeddingBag,
|
|
QuantStub: nnq.Quantize,
|
|
DeQuantStub: nnq.DeQuantize,
|
|
# Wrapper Modules:
|
|
nnq.FloatFunctional: nnq.QFunctional,
|
|
# Intrinsic modules:
|
|
nni.ConvReLU1d: nniq.ConvReLU1d,
|
|
nni.ConvReLU2d: nniq.ConvReLU2d,
|
|
nni.ConvReLU3d: nniq.ConvReLU3d,
|
|
nni.LinearReLU: nniq.LinearReLU,
|
|
nni.BNReLU2d: nniq.BNReLU2d,
|
|
nni.BNReLU3d: nniq.BNReLU3d,
|
|
nniqat.ConvReLU2d: nniq.ConvReLU2d,
|
|
nniqat.LinearReLU: nniq.LinearReLU,
|
|
nniqat.ConvBn2d: nnq.Conv2d,
|
|
nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
|
|
# QAT modules:
|
|
nnqat.Linear: nnq.Linear,
|
|
nnqat.Conv2d: nnq.Conv2d,
|
|
}
|
|
|
|
# Map for swapping float module to qat modules
|
|
QAT_MODULE_MAPPINGS = {
|
|
nn.Linear: nnqat.Linear,
|
|
nn.Conv2d: nnqat.Conv2d,
|
|
# Intrinsic modules:
|
|
nni.ConvBn2d: nniqat.ConvBn2d,
|
|
nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
|
|
nni.ConvReLU2d: nniqat.ConvReLU2d,
|
|
nni.LinearReLU: nniqat.LinearReLU
|
|
}
|
|
|
|
# Map for swapping dynamic modules
|
|
DYNAMIC_QUANT_MODULE_MAPPINGS = {
|
|
nn.Linear: nnqd.Linear,
|
|
nn.LSTM: nnqd.LSTM,
|
|
nn.LSTMCell: nnqd.LSTMCell,
|
|
nn.RNNCell: nnqd.RNNCell,
|
|
nn.GRUCell: nnqd.GRUCell,
|
|
}
|
|
|
|
# Whitelist for propagating the qconfig
|
|
_EXCLUDE_QCONFIG_PROPAGATE_LIST = {
|
|
DeQuantStub,
|
|
}
|
|
_INCLUDE_QCONFIG_PROPAGATE_LIST = {
|
|
nn.Sequential,
|
|
}
|
|
|
|
# mapping from floating point function or torch ops to quantized ops
|
|
FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = {
|
|
F.elu: torch._ops.ops.quantized.elu,
|
|
F.leaky_relu: torch._ops.ops.quantized.leaky_relu,
|
|
F.hardswish: torch._ops.ops.quantized.hardswish,
|
|
F.instance_norm: torch._ops.ops.quantized.instance_norm,
|
|
F.layer_norm: torch._ops.ops.quantized.layer_norm,
|
|
}
|
|
|
|
def register_static_quant_module_mapping(
|
|
float_source_module_class, static_quant_target_module_class):
|
|
''' Register a mapping from `float_source__module_class` to `static_quant_target_module_class`
|
|
`static_quant_target_module_class` must have from_float defined as a class method
|
|
The mapping is used in the convert step of post training static quantization to
|
|
convert a float module to a statically quantized module.
|
|
'''
|
|
assert hasattr(static_quant_target_module_class, 'from_float'), 'from_float must be defined' + \
|
|
' in quantized module class'
|
|
STATIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = static_quant_target_module_class
|
|
|
|
def get_static_quant_module_mappings():
|
|
''' Get module mapping for post training static quantization
|
|
'''
|
|
return STATIC_QUANT_MODULE_MAPPINGS
|
|
|
|
def get_static_quant_module_class(float_module_class):
|
|
''' Get the statically quantized module class corresponding to
|
|
the floating point module class
|
|
'''
|
|
static_quant_module_class = STATIC_QUANT_MODULE_MAPPINGS.get(float_module_class, None)
|
|
assert static_quant_module_class is not None, \
|
|
'Floating point module class {}'.format(float_module_class) + \
|
|
' does not have a corresponding quantized module class'
|
|
return static_quant_module_class
|
|
|
|
def register_qat_module_mapping(float_source_module_class, qat_target_module_class):
|
|
'''Register a mapping from `float_source_module_class` to `qat_target_module_class`,
|
|
`qat_target_module_class` must have from_float defined as a class method
|
|
This mapping is used in prepare step of quantization aware training to swap
|
|
a float module to a qat module.
|
|
'''
|
|
assert hasattr(qat_target_module_class, 'from_float'), 'from_float must be defined' + \
|
|
' in qat module class'
|
|
QAT_MODULE_MAPPINGS[float_source_module_class] = qat_target_module_class
|
|
|
|
def get_qat_module_mappings():
|
|
''' Get module mapping for quantization aware training
|
|
'''
|
|
return QAT_MODULE_MAPPINGS
|
|
|
|
def register_dynamic_quant_module_class(float_source_module_class, dynamic_quant_target_module_class):
|
|
''' Register a mapping from `float_source_module_class` to `dynamic_quant_target_module_class`,
|
|
`dynamic_quant_target_module_class` must have from_float defined as a class method
|
|
This mapping is used in convert step of post training dynamic
|
|
quantization to swap a float module to a dynamically quantized
|
|
module.
|
|
'''
|
|
assert hasattr(dynamic_quant_target_module_class, 'from_float'), 'from_float must be defined' + \
|
|
' in dynamically quantized module type'
|
|
DYNAMIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = dynamic_quant_target_module_class
|
|
|
|
def get_dynamic_quant_module_mappings():
|
|
''' Get module mapping for post training dynamic quantization
|
|
'''
|
|
return DYNAMIC_QUANT_MODULE_MAPPINGS
|
|
|
|
def get_qconfig_propagation_list():
|
|
''' Get the list of module types that we'll attach qconfig
|
|
attribute to in prepare
|
|
'''
|
|
QCONFIG_PROPAGATE_MODULE_CLASS_LIST = (
|
|
(set(STATIC_QUANT_MODULE_MAPPINGS.keys()) |
|
|
set(QAT_MODULE_MAPPINGS.keys()) |
|
|
set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) |
|
|
_INCLUDE_QCONFIG_PROPAGATE_LIST) -
|
|
_EXCLUDE_QCONFIG_PROPAGATE_LIST
|
|
)
|
|
return QCONFIG_PROPAGATE_MODULE_CLASS_LIST
|
|
|
|
def get_compare_output_module_list():
|
|
''' Get list of module class types that we will record output
|
|
in numeric suite
|
|
'''
|
|
NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = (
|
|
set(STATIC_QUANT_MODULE_MAPPINGS.values())
|
|
| set(QAT_MODULE_MAPPINGS.values())
|
|
| set(DYNAMIC_QUANT_MODULE_MAPPINGS.values())
|
|
| set(STATIC_QUANT_MODULE_MAPPINGS.keys())
|
|
| set(QAT_MODULE_MAPPINGS.keys())
|
|
| set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
|
|
| _INCLUDE_QCONFIG_PROPAGATE_LIST
|
|
) - _EXCLUDE_QCONFIG_PROPAGATE_LIST
|
|
return NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST
|
|
|
|
def register_quantized_operator_mapping(float_op, quantized_op):
|
|
''' Register a mapping from `floating_point_op` (torch or functional) to `quantized_op`
|
|
This is used in convert step of fx based graph mode quantization
|
|
to convert a float op to quantized op.
|
|
'''
|
|
FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS[float_op] = quantized_op
|
|
|
|
def get_quantized_operator(float_op):
|
|
''' Get the quantized operator corresponding to the float operator
|
|
'''
|
|
quantized_op = FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
|
|
assert quantized_op is not None, \
|
|
'Operator {} does not have corresponding quantized op'.format(float_op)
|
|
return quantized_op
|