mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27164 Test Plan: Imported from OSS Differential Revision: D17694475 Pulled By: zafartahirov fbshipit-source-id: df8df5f7d66062ed35da957064a31344e1d3c961
60 lines
1.5 KiB
Python
60 lines
1.5 KiB
Python
|
|
from torch import nn
|
|
|
|
import torch.nn.intrinsic as nni
|
|
import torch.nn.intrinsic.quantized as nniq
|
|
import torch.nn.intrinsic.qat as nniqat
|
|
import torch.nn.quantized as nnq
|
|
import torch.nn.quantized.dynamic as nnqd
|
|
import torch.nn.qat as nnqat
|
|
|
|
from .stubs import QuantStub, DeQuantStub
|
|
|
|
# Map for swapping float module to quantized ones
|
|
DEFAULT_MODULE_MAPPING = {
|
|
nn.Linear: nnq.Linear,
|
|
nn.ReLU: nnq.ReLU,
|
|
nn.Conv2d: nnq.Conv2d,
|
|
QuantStub: nnq.Quantize,
|
|
DeQuantStub: nnq.DeQuantize,
|
|
# Wrapper Modules:
|
|
nnq.FloatFunctional: nnq.QFunctional,
|
|
# Intrinsic modules:
|
|
nni.ConvReLU2d: nniq.ConvReLU2d,
|
|
nni.LinearReLU: nniq.LinearReLU,
|
|
nniqat.ConvReLU2d: nniq.ConvReLU2d,
|
|
nniqat.LinearReLU: nniq.LinearReLU,
|
|
nniqat.ConvBn2d: nnq.Conv2d,
|
|
nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
|
|
# QAT modules:
|
|
nnqat.Linear: nnq.Linear,
|
|
nnqat.Conv2d: nnq.Conv2d,
|
|
}
|
|
|
|
# Map for swapping float module to qat modules
|
|
DEFAULT_QAT_MODULE_MAPPING = {
|
|
nn.Linear: nnqat.Linear,
|
|
nn.Conv2d: nnqat.Conv2d,
|
|
# Intrinsic modules:
|
|
nni.ConvBn2d: nniqat.ConvBn2d,
|
|
nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
|
|
nni.ConvReLU2d: nniqat.ConvReLU2d,
|
|
nni.LinearReLU: nniqat.LinearReLU
|
|
}
|
|
|
|
# Map for swapping dynamic modules
|
|
DEFAULT_DYNAMIC_MODULE_MAPPING = {
|
|
nn.Linear: nnqd.Linear,
|
|
nn.LSTM: nnqd.LSTM,
|
|
}
|
|
|
|
# List of modules to skip the qconfig propagation
|
|
DEFAULT_SKIP_LIST = (
|
|
nn.Dropout,
|
|
nn.Identity,
|
|
nn.MaxPool2d,
|
|
nn.AvgPool2d,
|
|
nn.AdaptiveAvgPool2d,
|
|
DeQuantStub
|
|
)
|