pytorch/torch/quantization/default_mappings.py
Vasiliy Kuznetsov 65df8b3886 hardswish: make it work in static quantization (#36545)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36545

* adds a quantized nn.module for Hardswish so we can observe activation values
* modifies the hardswish op to allow specifying scale + zero_point
* makes hardswish model be properly swapped in static quantization

Test Plan:
added tests and they pass for:
* the new _out flavor of hardswish
* QNNPACK changes
* static quant e2e

Imported from OSS

Differential Revision: D21045320

fbshipit-source-id: ab7e52f0f54a7d5923ab6f58197022cc28c12354
2020-04-15 18:02:35 -07:00

75 lines
2.0 KiB
Python

from torch import nn
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.qat as nniqat
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
import torch.nn.qat as nnqat
from .stubs import QuantStub, DeQuantStub
# Map for swapping float module to quantized ones
DEFAULT_MODULE_MAPPING = {
nn.Linear: nnq.Linear,
nn.ReLU: nnq.ReLU,
nn.ReLU6: nnq.ReLU6,
nn.Hardswish: nnq.Hardswish,
nn.Conv1d: nnq.Conv1d,
nn.Conv2d: nnq.Conv2d,
nn.Conv3d: nnq.Conv3d,
nn.BatchNorm2d: nnq.BatchNorm2d,
nn.BatchNorm3d: nnq.BatchNorm3d,
QuantStub: nnq.Quantize,
DeQuantStub: nnq.DeQuantize,
# Wrapper Modules:
nnq.FloatFunctional: nnq.QFunctional,
# Intrinsic modules:
nni.ConvReLU2d: nniq.ConvReLU2d,
nni.ConvReLU3d: nniq.ConvReLU3d,
nni.LinearReLU: nniq.LinearReLU,
nni.BNReLU2d: nniq.BNReLU2d,
nni.BNReLU3d: nniq.BNReLU3d,
nniqat.ConvReLU2d: nniq.ConvReLU2d,
nniqat.LinearReLU: nniq.LinearReLU,
nniqat.ConvBn2d: nnq.Conv2d,
nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
# QAT modules:
nnqat.Linear: nnq.Linear,
nnqat.Conv2d: nnq.Conv2d,
}
# Map for swapping float module to qat modules
DEFAULT_QAT_MODULE_MAPPING = {
nn.Linear: nnqat.Linear,
nn.Conv2d: nnqat.Conv2d,
# Intrinsic modules:
nni.ConvBn2d: nniqat.ConvBn2d,
nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
nni.ConvReLU2d: nniqat.ConvReLU2d,
nni.LinearReLU: nniqat.LinearReLU
}
# Map for swapping dynamic modules
DEFAULT_DYNAMIC_MODULE_MAPPING = {
nn.Linear: nnqd.Linear,
nn.LSTM: nnqd.LSTM,
}
# Whitelist for propagating the qconfig
_EXCLUDE_QCONFIG_PROPAGATE_LIST = {
DeQuantStub,
}
_INCLUDE_QCONFIG_PROPAGATE_LIST = {
nn.Sequential,
}
DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST = (
(set(DEFAULT_MODULE_MAPPING.keys()) |
set(DEFAULT_QAT_MODULE_MAPPING.keys()) |
set(DEFAULT_DYNAMIC_MODULE_MAPPING.keys()) |
_INCLUDE_QCONFIG_PROPAGATE_LIST) -
_EXCLUDE_QCONFIG_PROPAGATE_LIST
)