mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36701 Add module output comparison API. ghstack-source-id: 103368194 Test Plan: buck test mode/dev caffe2/test:quantization -- 'test_compare_model_outputs' Differential Revision: D21053197 fbshipit-source-id: cabcafbeeac1b604db069833a0f17ebce506ba65
88 lines
2.5 KiB
Python
88 lines
2.5 KiB
Python
|
|
from torch import nn
|
|
|
|
import torch.nn.intrinsic as nni
|
|
import torch.nn.intrinsic.quantized as nniq
|
|
import torch.nn.intrinsic.qat as nniqat
|
|
import torch.nn.quantized as nnq
|
|
import torch.nn.quantized.dynamic as nnqd
|
|
import torch.nn.qat as nnqat
|
|
|
|
from .stubs import QuantStub, DeQuantStub
|
|
|
|
# Map for swapping float module to quantized ones
|
|
DEFAULT_MODULE_MAPPING = {
|
|
nn.Linear: nnq.Linear,
|
|
nn.ReLU: nnq.ReLU,
|
|
nn.ReLU6: nnq.ReLU6,
|
|
nn.Hardswish: nnq.Hardswish,
|
|
nn.Conv1d: nnq.Conv1d,
|
|
nn.Conv2d: nnq.Conv2d,
|
|
nn.Conv3d: nnq.Conv3d,
|
|
nn.BatchNorm2d: nnq.BatchNorm2d,
|
|
nn.BatchNorm3d: nnq.BatchNorm3d,
|
|
nn.LayerNorm: nnq.LayerNorm,
|
|
QuantStub: nnq.Quantize,
|
|
DeQuantStub: nnq.DeQuantize,
|
|
# Wrapper Modules:
|
|
nnq.FloatFunctional: nnq.QFunctional,
|
|
# Intrinsic modules:
|
|
nni.ConvReLU2d: nniq.ConvReLU2d,
|
|
nni.ConvReLU3d: nniq.ConvReLU3d,
|
|
nni.LinearReLU: nniq.LinearReLU,
|
|
nni.BNReLU2d: nniq.BNReLU2d,
|
|
nni.BNReLU3d: nniq.BNReLU3d,
|
|
nniqat.ConvReLU2d: nniq.ConvReLU2d,
|
|
nniqat.LinearReLU: nniq.LinearReLU,
|
|
nniqat.ConvBn2d: nnq.Conv2d,
|
|
nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
|
|
# QAT modules:
|
|
nnqat.Linear: nnq.Linear,
|
|
nnqat.Conv2d: nnq.Conv2d,
|
|
nnqat.Hardswish: nnq.Hardswish,
|
|
}
|
|
|
|
# Map for swapping float module to qat modules
|
|
DEFAULT_QAT_MODULE_MAPPING = {
|
|
nn.Linear: nnqat.Linear,
|
|
nn.Conv2d: nnqat.Conv2d,
|
|
nn.Hardswish: nnqat.Hardswish,
|
|
# Intrinsic modules:
|
|
nni.ConvBn2d: nniqat.ConvBn2d,
|
|
nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
|
|
nni.ConvReLU2d: nniqat.ConvReLU2d,
|
|
nni.LinearReLU: nniqat.LinearReLU
|
|
}
|
|
|
|
# Map for swapping dynamic modules
|
|
DEFAULT_DYNAMIC_MODULE_MAPPING = {
|
|
nn.Linear: nnqd.Linear,
|
|
nn.LSTM: nnqd.LSTM,
|
|
}
|
|
|
|
# Whitelist for propagating the qconfig
|
|
_EXCLUDE_QCONFIG_PROPAGATE_LIST = {
|
|
DeQuantStub,
|
|
}
|
|
_INCLUDE_QCONFIG_PROPAGATE_LIST = {
|
|
nn.Sequential,
|
|
}
|
|
|
|
DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST = (
|
|
(set(DEFAULT_MODULE_MAPPING.keys()) |
|
|
set(DEFAULT_QAT_MODULE_MAPPING.keys()) |
|
|
set(DEFAULT_DYNAMIC_MODULE_MAPPING.keys()) |
|
|
_INCLUDE_QCONFIG_PROPAGATE_LIST) -
|
|
_EXCLUDE_QCONFIG_PROPAGATE_LIST
|
|
)
|
|
|
|
DEFAULT_NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_WHITE_LIST = (
|
|
set(DEFAULT_MODULE_MAPPING.values())
|
|
| set(DEFAULT_QAT_MODULE_MAPPING.values())
|
|
| set(DEFAULT_DYNAMIC_MODULE_MAPPING.values())
|
|
| set(DEFAULT_MODULE_MAPPING.keys())
|
|
| set(DEFAULT_QAT_MODULE_MAPPING.keys())
|
|
| set(DEFAULT_DYNAMIC_MODULE_MAPPING.keys())
|
|
| _INCLUDE_QCONFIG_PROPAGATE_LIST
|
|
) - _EXCLUDE_QCONFIG_PROPAGATE_LIST
|