mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Qunat] Refactor reference module mapping (#72755)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72755 Add is_refernece flag in convert function Test Plan: python3 test/test_quantization.py TestQuantizeEagerOps.test_conv_transpose_2d Imported from OSS Reviewed By: mruberry Differential Revision: D34188856 fbshipit-source-id: 291014a7b3b4d4b40ca0ca76a80711097dcc4b58 (cherry picked from commit cfba3b8dc0373708712c0d847d590f0d587df002)
This commit is contained in:
parent
5993f48711
commit
4e6aefaf72
|
|
@ -3,7 +3,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.quantized as nnq
|
||||
import torch.nn.quantized._reference as nnqr
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
from torch.ao.quantization import (
|
||||
quantize,
|
||||
|
|
@ -140,17 +139,7 @@ class TestQuantizeEagerOps(QuantizationTestCase):
|
|||
|
||||
ref_m = prepare(original_ref_m)
|
||||
ref_m(data)
|
||||
reference_module_mapping = {
|
||||
QuantStub: nnq.Quantize,
|
||||
DeQuantStub: nnq.DeQuantize,
|
||||
nn.Conv1d: nnqr.Conv1d,
|
||||
nn.Conv2d: nnqr.Conv2d,
|
||||
nn.Conv3d: nnqr.Conv3d,
|
||||
nn.ConvTranspose1d: nnqr.ConvTranspose1d,
|
||||
nn.ConvTranspose2d: nnqr.ConvTranspose2d,
|
||||
nn.ConvTranspose3d: nnqr.ConvTranspose3d,
|
||||
}
|
||||
ref_m = convert(ref_m, mapping=reference_module_mapping)
|
||||
ref_m = convert(ref_m, is_reference=True)
|
||||
ref_res = ref_m(data)
|
||||
self.assertEqual(res, ref_res)
|
||||
|
||||
|
|
@ -202,6 +191,14 @@ class TestQuantizeEagerOps(QuantizationTestCase):
|
|||
(16, 1, 10, 10, 10)
|
||||
)
|
||||
|
||||
def test_linear(self):
|
||||
self._test_reference_module_impl(
|
||||
nn.Linear,
|
||||
nnq.Linear,
|
||||
{'in_features': 5, 'out_features': 10},
|
||||
(16, 5)
|
||||
)
|
||||
|
||||
def _test_activation_op_impl(
|
||||
self, float_module_class, quantized_module_class, extra_module_kwargs):
|
||||
""" Implementation for testing common activation ops like leaky relu
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ from torch.ao.quantization.utils import get_combined_dict
|
|||
|
||||
# Default map for swapping float module to reference quantized modules
|
||||
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
||||
QuantStub: nnq.Quantize,
|
||||
DeQuantStub: nnq.DeQuantize,
|
||||
nn.Linear: nnqr.Linear,
|
||||
nn.Conv1d: nnqr.Conv1d,
|
||||
nn.Conv2d: nnqr.Conv2d,
|
||||
|
|
@ -175,6 +177,11 @@ def get_default_static_quant_module_mappings() -> Dict[Callable, Any]:
|
|||
'''
|
||||
return copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS)
|
||||
|
||||
def get_default_static_quant_reference_module_mappings() -> Dict[Callable, Any]:
|
||||
''' Get reference module mapping for post training static quantization
|
||||
'''
|
||||
return copy.deepcopy(DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS)
|
||||
|
||||
def get_embedding_static_quant_module_mappings() -> Dict[Callable, Any]:
|
||||
''' Get module mapping, including mapping for embedding QAT
|
||||
'''
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from torch.nn.intrinsic import _FusedModule
|
|||
from torch.ao.quantization.quantization_mappings import (
|
||||
get_default_dynamic_quant_module_mappings,
|
||||
get_default_static_quant_module_mappings,
|
||||
get_default_static_quant_reference_module_mappings,
|
||||
get_default_qat_module_mappings,
|
||||
get_default_qconfig_propagation_list,
|
||||
no_observer_set,
|
||||
|
|
@ -472,7 +473,7 @@ def quantize_qat(model, run_fn, run_args, inplace=False):
|
|||
|
||||
def convert(
|
||||
module, mapping=None, inplace=False, remove_qconfig=True,
|
||||
convert_custom_config_dict=None):
|
||||
is_reference=False, convert_custom_config_dict=None):
|
||||
r"""Converts submodules in input module to a different module according to `mapping`
|
||||
by calling `from_float` method on the target module class. And remove qconfig at the
|
||||
end if remove_qconfig is set to True.
|
||||
|
|
@ -503,7 +504,7 @@ def convert(
|
|||
if not inplace:
|
||||
module = copy.deepcopy(module)
|
||||
_convert(
|
||||
module, mapping, inplace=True,
|
||||
module, mapping, inplace=True, is_reference=is_reference,
|
||||
convert_custom_config_dict=convert_custom_config_dict)
|
||||
if remove_qconfig:
|
||||
_remove_qconfig(module)
|
||||
|
|
@ -511,7 +512,7 @@ def convert(
|
|||
|
||||
def _convert(
|
||||
module, mapping=None, inplace=False,
|
||||
convert_custom_config_dict=None):
|
||||
is_reference=False, convert_custom_config_dict=None):
|
||||
r"""Converts submodules in input module to a different module according to `mapping`
|
||||
by calling `from_float` method on the target module class
|
||||
|
||||
|
|
@ -522,10 +523,12 @@ def _convert(
|
|||
Modules
|
||||
inplace: carry out model transformations in-place, the original module
|
||||
is mutated
|
||||
is_reference: a flag to enable quantized reference module
|
||||
|
||||
"""
|
||||
if mapping is None:
|
||||
mapping = get_default_static_quant_module_mappings()
|
||||
mapping = get_default_static_quant_reference_module_mappings() if is_reference \
|
||||
else get_default_static_quant_module_mappings()
|
||||
if convert_custom_config_dict is None:
|
||||
convert_custom_config_dict = {}
|
||||
custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})
|
||||
|
|
@ -539,7 +542,7 @@ def _convert(
|
|||
if not isinstance(mod, _FusedModule) and \
|
||||
type(mod) not in custom_module_class_mapping:
|
||||
_convert(mod, mapping, True, # inplace
|
||||
convert_custom_config_dict)
|
||||
is_reference, convert_custom_config_dict)
|
||||
reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
|
||||
|
||||
for key, value in reassign.items():
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ class Linear(nn.Linear, ReferenceQuantizedModule):
|
|||
and dequantize the weight before running the floating point functional
|
||||
linear operator.
|
||||
"""
|
||||
_IS_REFERENCE = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user