mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
During convert step observers are first replaced by Q-DQ pair. In some
scenarios like following output DQ has a fan out.
---> OP2 -> Q -> DQ
/
OP -> Q -> DQ -
\
---> OP3 -> Q -> DQ
If either op OP2 or OP3 are configured to be quantized, then the input
is expected to quantized. In this case quantized equivalent of some
pattern, that quantizer asked to be quantized, should look like:
[DQ -> {pattern} -> Q]. However, in scenario like above where DQ node
is shared between multiple "quantized" patterns, boundary of "quantized"
pattern is not clear because DQ now belongs to multiple quantized
patterns.
This poses challenge for:
- Porting metadata: which "quantized" partition this DQ node belongs
- Quantized representation, equivalently, needs to identify
self-contained quantized pattern that is replaced by its equivalent pattern
that captures compute in the quantized precision.
Test Plan:
test_duplicate_dq_pass
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: [D48663147](https://our.internmc.facebook.com/intern/diff/D48663147)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107900
Approved by: https://github.com/jerryzh168, https://github.com/andrewor14, https://github.com/leslie-fang-intel
ghstack dependencies: #107105, #107106, #107899
182 lines
5.9 KiB
Python
182 lines
5.9 KiB
Python
# flake8: noqa: F403
|
|
|
|
from .fake_quantize import * # noqa: F403
|
|
from .fuse_modules import fuse_modules # noqa: F403
|
|
from .fuse_modules import fuse_modules_qat # noqa: F403
|
|
from .fuser_method_mappings import * # noqa: F403
|
|
from .observer import * # noqa: F403
|
|
from .qconfig import * # noqa: F403
|
|
from .qconfig_mapping import * # noqa: F403
|
|
from .quant_type import * # noqa: F403
|
|
from .quantization_mappings import * # type: ignore[no-redef]
|
|
from .quantize import * # noqa: F403
|
|
from .quantize_jit import * # noqa: F403
|
|
from .stubs import * # noqa: F403
|
|
from .pt2e.eval_utils import _move_model_to_eval as move_model_to_eval
|
|
from typing import Union, List, Callable, Tuple, Optional
|
|
from torch import Tensor
|
|
import torch
|
|
|
|
ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase]
|
|
ObserverOrFakeQuantize.__module__ = "torch.ao.quantization"
|
|
|
|
__all__ = [
|
|
"DeQuantStub",
|
|
"FakeQuantize",
|
|
"FakeQuantizeBase",
|
|
"FixedQParamsFakeQuantize",
|
|
"FixedQParamsObserver",
|
|
"FusedMovingAvgObsFakeQuantize",
|
|
"HistogramObserver",
|
|
"MatchAllNode",
|
|
"MinMaxObserver",
|
|
"MovingAverageMinMaxObserver",
|
|
"MovingAveragePerChannelMinMaxObserver",
|
|
"NoopObserver",
|
|
"ObserverBase",
|
|
"ObserverOrFakeQuantize",
|
|
"Pattern",
|
|
"PerChannelMinMaxObserver",
|
|
"PlaceholderObserver",
|
|
"QConfig",
|
|
"QConfigAny",
|
|
"QConfigDynamic",
|
|
"QConfigMapping",
|
|
"QuantStub",
|
|
"QuantType",
|
|
"QuantWrapper",
|
|
"RecordingObserver",
|
|
"ReuseInputObserver",
|
|
"UniformQuantizationObserverBase",
|
|
"add_quant_dequant",
|
|
"convert",
|
|
"convert_dynamic_jit",
|
|
"convert_jit",
|
|
"default_affine_fixed_qparams_fake_quant",
|
|
"default_affine_fixed_qparams_observer",
|
|
"default_debug_observer",
|
|
"default_dynamic_fake_quant",
|
|
"default_dynamic_quant_observer",
|
|
"default_embedding_fake_quant",
|
|
"default_embedding_fake_quant_4bit",
|
|
"default_eval_fn",
|
|
"default_fake_quant",
|
|
"default_fixed_qparams_range_0to1_fake_quant",
|
|
"default_fixed_qparams_range_0to1_observer",
|
|
"default_fixed_qparams_range_neg1to1_fake_quant",
|
|
"default_fixed_qparams_range_neg1to1_observer",
|
|
"default_float_qparams_observer",
|
|
"default_float_qparams_observer_4bit",
|
|
"default_fused_act_fake_quant",
|
|
"default_fused_per_channel_wt_fake_quant",
|
|
"default_fused_wt_fake_quant",
|
|
"default_histogram_fake_quant",
|
|
"default_histogram_observer",
|
|
"default_observer",
|
|
"default_per_channel_weight_fake_quant",
|
|
"default_per_channel_weight_observer",
|
|
"default_placeholder_observer",
|
|
"default_reuse_input_observer",
|
|
"default_symmetric_fixed_qparams_fake_quant",
|
|
"default_symmetric_fixed_qparams_observer",
|
|
"default_weight_fake_quant",
|
|
"default_weight_observer",
|
|
"disable_fake_quant",
|
|
"disable_observer",
|
|
"enable_fake_quant",
|
|
"enable_observer",
|
|
"fuse_conv_bn",
|
|
"fuse_conv_bn_jit",
|
|
"fuse_conv_bn_relu",
|
|
"fuse_convtranspose_bn",
|
|
"fuse_linear_bn",
|
|
"fuse_modules",
|
|
"fuse_modules_qat",
|
|
"fused_per_channel_wt_fake_quant_range_neg_127_to_127",
|
|
"fused_wt_fake_quant_range_neg_127_to_127",
|
|
"get_combined_dict",
|
|
"get_default_compare_output_module_list",
|
|
"get_default_custom_config_dict",
|
|
"get_default_dynamic_quant_module_mappings",
|
|
"get_default_dynamic_sparse_quant_module_mappings",
|
|
"get_default_float_to_quantized_operator_mappings",
|
|
"get_default_qat_module_mappings",
|
|
"get_default_qat_qconfig",
|
|
"get_default_qat_qconfig_dict",
|
|
"get_default_qat_qconfig_mapping",
|
|
"get_default_qconfig",
|
|
"get_default_qconfig_dict",
|
|
"get_default_qconfig_mapping",
|
|
"get_default_qconfig_propagation_list",
|
|
"get_default_static_quant_module_mappings",
|
|
"get_default_static_quant_reference_module_mappings",
|
|
"get_default_static_sparse_quant_module_mappings",
|
|
"get_dynamic_quant_module_class",
|
|
"get_embedding_qat_module_mappings",
|
|
"get_embedding_static_quant_module_mappings",
|
|
"get_fuser_method",
|
|
"get_fuser_method_new",
|
|
"get_observer_state_dict",
|
|
"get_quantized_operator",
|
|
"get_static_quant_module_class",
|
|
"load_observer_state_dict",
|
|
"move_model_to_eval",
|
|
"no_observer_set",
|
|
"per_channel_weight_observer_range_neg_127_to_127",
|
|
"prepare",
|
|
"prepare_dynamic_jit",
|
|
"prepare_jit",
|
|
"prepare_qat",
|
|
"propagate_qconfig_",
|
|
"qconfig_equals",
|
|
"quantize",
|
|
"quantize_dynamic",
|
|
"quantize_dynamic_jit",
|
|
"quantize_jit",
|
|
"quantize_qat",
|
|
"script_qconfig",
|
|
"script_qconfig_dict",
|
|
"swap_module",
|
|
"weight_observer_range_neg_127_to_127",
|
|
]
|
|
|
|
def default_eval_fn(model, calib_data):
|
|
r"""
|
|
Default evaluation function takes a torch.utils.data.Dataset or a list of
|
|
input Tensors and run the model on the dataset
|
|
"""
|
|
for data, target in calib_data:
|
|
model(data)
|
|
|
|
class _DerivedObserverOrFakeQuantize(ObserverBase):
|
|
r""" This observer is used to describe an observer whose quantization parameters
|
|
are derived from other observers
|
|
"""
|
|
def __init__(
|
|
self,
|
|
dtype: torch.dtype,
|
|
obs_or_fqs: List[ObserverOrFakeQuantize],
|
|
derive_qparams_fn: Callable[[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]],
|
|
quant_min: Optional[int]=None,
|
|
quant_max: Optional[int]=None,
|
|
qscheme: Optional[torch.qscheme]=None,
|
|
ch_axis: Optional[int] = None
|
|
):
|
|
super().__init__(dtype)
|
|
self.obs_or_fqs = obs_or_fqs
|
|
self.derive_qparams_fn = derive_qparams_fn
|
|
self.quant_min = quant_min
|
|
self.quant_max = quant_max
|
|
self.qscheme = qscheme
|
|
self.ch_axis = ch_axis
|
|
|
|
from .utils import is_per_channel
|
|
if is_per_channel(self.qscheme):
|
|
assert self.ch_axis is not None, "Must provide a valid ch_axis if qscheme is per channel"
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return x
|
|
|
|
def calculate_qparams(self):
|
|
return self.derive_qparams_fn(self.obs_or_fqs)
|