[Qunat] Refactor reference module mapping (#72755)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72755

Add is_refernece flag in convert function

Test Plan:
python3 test/test_quantization.py TestQuantizeEagerOps.test_conv_transpose_2d

Imported from OSS

Reviewed By: mruberry

Differential Revision: D34188856

fbshipit-source-id: 291014a7b3b4d4b40ca0ca76a80711097dcc4b58
(cherry picked from commit cfba3b8dc0373708712c0d847d590f0d587df002)
This commit is contained in:
Terry Chen 2022-03-07 22:39:37 -08:00 committed by PyTorch MergeBot
parent 5993f48711
commit 4e6aefaf72
4 changed files with 26 additions and 17 deletions

View File

@ -3,7 +3,6 @@
import torch
import torch.nn as nn
import torch.nn.quantized as nnq
import torch.nn.quantized._reference as nnqr
from torch.nn.utils.rnn import PackedSequence
from torch.ao.quantization import (
quantize,
@ -140,17 +139,7 @@ class TestQuantizeEagerOps(QuantizationTestCase):
ref_m = prepare(original_ref_m)
ref_m(data)
reference_module_mapping = {
QuantStub: nnq.Quantize,
DeQuantStub: nnq.DeQuantize,
nn.Conv1d: nnqr.Conv1d,
nn.Conv2d: nnqr.Conv2d,
nn.Conv3d: nnqr.Conv3d,
nn.ConvTranspose1d: nnqr.ConvTranspose1d,
nn.ConvTranspose2d: nnqr.ConvTranspose2d,
nn.ConvTranspose3d: nnqr.ConvTranspose3d,
}
ref_m = convert(ref_m, mapping=reference_module_mapping)
ref_m = convert(ref_m, is_reference=True)
ref_res = ref_m(data)
self.assertEqual(res, ref_res)
@ -202,6 +191,14 @@ class TestQuantizeEagerOps(QuantizationTestCase):
(16, 1, 10, 10, 10)
)
def test_linear(self):
self._test_reference_module_impl(
nn.Linear,
nnq.Linear,
{'in_features': 5, 'out_features': 10},
(16, 5)
)
def _test_activation_op_impl(
self, float_module_class, quantized_module_class, extra_module_kwargs):
""" Implementation for testing common activation ops like leaky relu

View File

@ -26,6 +26,8 @@ from torch.ao.quantization.utils import get_combined_dict
# Default map for swapping float module to reference quantized modules
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
QuantStub: nnq.Quantize,
DeQuantStub: nnq.DeQuantize,
nn.Linear: nnqr.Linear,
nn.Conv1d: nnqr.Conv1d,
nn.Conv2d: nnqr.Conv2d,
@ -175,6 +177,11 @@ def get_default_static_quant_module_mappings() -> Dict[Callable, Any]:
'''
return copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS)
def get_default_static_quant_reference_module_mappings() -> Dict[Callable, Any]:
''' Get reference module mapping for post training static quantization
'''
return copy.deepcopy(DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS)
def get_embedding_static_quant_module_mappings() -> Dict[Callable, Any]:
''' Get module mapping, including mapping for embedding QAT
'''

View File

@ -10,6 +10,7 @@ from torch.nn.intrinsic import _FusedModule
from torch.ao.quantization.quantization_mappings import (
get_default_dynamic_quant_module_mappings,
get_default_static_quant_module_mappings,
get_default_static_quant_reference_module_mappings,
get_default_qat_module_mappings,
get_default_qconfig_propagation_list,
no_observer_set,
@ -472,7 +473,7 @@ def quantize_qat(model, run_fn, run_args, inplace=False):
def convert(
module, mapping=None, inplace=False, remove_qconfig=True,
convert_custom_config_dict=None):
is_reference=False, convert_custom_config_dict=None):
r"""Converts submodules in input module to a different module according to `mapping`
by calling `from_float` method on the target module class. And remove qconfig at the
end if remove_qconfig is set to True.
@ -503,7 +504,7 @@ def convert(
if not inplace:
module = copy.deepcopy(module)
_convert(
module, mapping, inplace=True,
module, mapping, inplace=True, is_reference=is_reference,
convert_custom_config_dict=convert_custom_config_dict)
if remove_qconfig:
_remove_qconfig(module)
@ -511,7 +512,7 @@ def convert(
def _convert(
module, mapping=None, inplace=False,
convert_custom_config_dict=None):
is_reference=False, convert_custom_config_dict=None):
r"""Converts submodules in input module to a different module according to `mapping`
by calling `from_float` method on the target module class
@ -522,10 +523,12 @@ def _convert(
Modules
inplace: carry out model transformations in-place, the original module
is mutated
is_reference: a flag to enable quantized reference module
"""
if mapping is None:
mapping = get_default_static_quant_module_mappings()
mapping = get_default_static_quant_reference_module_mappings() if is_reference \
else get_default_static_quant_module_mappings()
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})
@ -539,7 +542,7 @@ def _convert(
if not isinstance(mod, _FusedModule) and \
type(mod) not in custom_module_class_mapping:
_convert(mod, mapping, True, # inplace
convert_custom_config_dict)
is_reference, convert_custom_config_dict)
reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
for key, value in reassign.items():

View File

@ -12,6 +12,8 @@ class Linear(nn.Linear, ReferenceQuantizedModule):
and dequantize the weight before running the floating point functional
linear operator.
"""
_IS_REFERENCE = True
def __init__(
self,
in_features: int,