pytorch/torch/quantization/default_mappings.py
Zafar Takhirov 111da77912 Factored out the default mappings
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27164

Test Plan: Imported from OSS

Differential Revision: D17694475

Pulled By: zafartahirov

fbshipit-source-id: df8df5f7d66062ed35da957064a31344e1d3c961
2019-10-03 11:52:21 -07:00

60 lines
1.5 KiB
Python

from torch import nn
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.qat as nniqat
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
import torch.nn.qat as nnqat
from .stubs import QuantStub, DeQuantStub
# Map for swapping float module to quantized ones
DEFAULT_MODULE_MAPPING = {
nn.Linear: nnq.Linear,
nn.ReLU: nnq.ReLU,
nn.Conv2d: nnq.Conv2d,
QuantStub: nnq.Quantize,
DeQuantStub: nnq.DeQuantize,
# Wrapper Modules:
nnq.FloatFunctional: nnq.QFunctional,
# Intrinsic modules:
nni.ConvReLU2d: nniq.ConvReLU2d,
nni.LinearReLU: nniq.LinearReLU,
nniqat.ConvReLU2d: nniq.ConvReLU2d,
nniqat.LinearReLU: nniq.LinearReLU,
nniqat.ConvBn2d: nnq.Conv2d,
nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
# QAT modules:
nnqat.Linear: nnq.Linear,
nnqat.Conv2d: nnq.Conv2d,
}
# Map for swapping float module to qat modules
DEFAULT_QAT_MODULE_MAPPING = {
nn.Linear: nnqat.Linear,
nn.Conv2d: nnqat.Conv2d,
# Intrinsic modules:
nni.ConvBn2d: nniqat.ConvBn2d,
nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
nni.ConvReLU2d: nniqat.ConvReLU2d,
nni.LinearReLU: nniqat.LinearReLU
}
# Map for swapping dynamic modules
DEFAULT_DYNAMIC_MODULE_MAPPING = {
nn.Linear: nnqd.Linear,
nn.LSTM: nnqd.LSTM,
}
# List of modules to skip the qconfig propagation
DEFAULT_SKIP_LIST = (
nn.Dropout,
nn.Identity,
nn.MaxPool2d,
nn.AvgPool2d,
nn.AdaptiveAvgPool2d,
DeQuantStub
)