pytorch/torch/ao/quantization/__init__.py
Kimish Patel eb67c452c8 [Quant] Add DQ duplication pass (#107900)
Summary:
During convert step observers are first replaced by Q-DQ pair. In some
scenarios like following output DQ has a fan out.

                 ---> OP2 -> Q -> DQ
                /
OP -> Q -> DQ -
                \
                 ---> OP3 -> Q -> DQ

If either op OP2 or OP3 are configured to be quantized, then the input
is expected to quantized. In this case quantized equivalent of some
pattern, that quantizer asked to be quantized, should look like:
[DQ -> {pattern} -> Q]. However, in scenario like above where DQ node
is shared between multiple "quantized" patterns, boundary of "quantized"
pattern is not clear because DQ now belongs to multiple quantized
patterns.

This poses challenge for:
- Porting metadata: which "quantized" partition this DQ node belongs
- Quantized representation, equivalently, needs to identify
self-contained quantized pattern that is replaced by its equivalent pattern
that captures compute in the quantized precision.

Test Plan:
test_duplicate_dq_pass

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D48663147](https://our.internmc.facebook.com/intern/diff/D48663147)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107900
Approved by: https://github.com/jerryzh168, https://github.com/andrewor14, https://github.com/leslie-fang-intel
ghstack dependencies: #107105, #107106, #107899
2023-09-02 06:20:03 +00:00

182 lines
5.9 KiB
Python

# flake8: noqa: F403
from .fake_quantize import * # noqa: F403
from .fuse_modules import fuse_modules # noqa: F403
from .fuse_modules import fuse_modules_qat # noqa: F403
from .fuser_method_mappings import * # noqa: F403
from .observer import * # noqa: F403
from .qconfig import * # noqa: F403
from .qconfig_mapping import * # noqa: F403
from .quant_type import * # noqa: F403
from .quantization_mappings import * # type: ignore[no-redef]
from .quantize import * # noqa: F403
from .quantize_jit import * # noqa: F403
from .stubs import * # noqa: F403
from .pt2e.eval_utils import _move_model_to_eval as move_model_to_eval
from typing import Union, List, Callable, Tuple, Optional
from torch import Tensor
import torch
ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase]
ObserverOrFakeQuantize.__module__ = "torch.ao.quantization"
__all__ = [
"DeQuantStub",
"FakeQuantize",
"FakeQuantizeBase",
"FixedQParamsFakeQuantize",
"FixedQParamsObserver",
"FusedMovingAvgObsFakeQuantize",
"HistogramObserver",
"MatchAllNode",
"MinMaxObserver",
"MovingAverageMinMaxObserver",
"MovingAveragePerChannelMinMaxObserver",
"NoopObserver",
"ObserverBase",
"ObserverOrFakeQuantize",
"Pattern",
"PerChannelMinMaxObserver",
"PlaceholderObserver",
"QConfig",
"QConfigAny",
"QConfigDynamic",
"QConfigMapping",
"QuantStub",
"QuantType",
"QuantWrapper",
"RecordingObserver",
"ReuseInputObserver",
"UniformQuantizationObserverBase",
"add_quant_dequant",
"convert",
"convert_dynamic_jit",
"convert_jit",
"default_affine_fixed_qparams_fake_quant",
"default_affine_fixed_qparams_observer",
"default_debug_observer",
"default_dynamic_fake_quant",
"default_dynamic_quant_observer",
"default_embedding_fake_quant",
"default_embedding_fake_quant_4bit",
"default_eval_fn",
"default_fake_quant",
"default_fixed_qparams_range_0to1_fake_quant",
"default_fixed_qparams_range_0to1_observer",
"default_fixed_qparams_range_neg1to1_fake_quant",
"default_fixed_qparams_range_neg1to1_observer",
"default_float_qparams_observer",
"default_float_qparams_observer_4bit",
"default_fused_act_fake_quant",
"default_fused_per_channel_wt_fake_quant",
"default_fused_wt_fake_quant",
"default_histogram_fake_quant",
"default_histogram_observer",
"default_observer",
"default_per_channel_weight_fake_quant",
"default_per_channel_weight_observer",
"default_placeholder_observer",
"default_reuse_input_observer",
"default_symmetric_fixed_qparams_fake_quant",
"default_symmetric_fixed_qparams_observer",
"default_weight_fake_quant",
"default_weight_observer",
"disable_fake_quant",
"disable_observer",
"enable_fake_quant",
"enable_observer",
"fuse_conv_bn",
"fuse_conv_bn_jit",
"fuse_conv_bn_relu",
"fuse_convtranspose_bn",
"fuse_linear_bn",
"fuse_modules",
"fuse_modules_qat",
"fused_per_channel_wt_fake_quant_range_neg_127_to_127",
"fused_wt_fake_quant_range_neg_127_to_127",
"get_combined_dict",
"get_default_compare_output_module_list",
"get_default_custom_config_dict",
"get_default_dynamic_quant_module_mappings",
"get_default_dynamic_sparse_quant_module_mappings",
"get_default_float_to_quantized_operator_mappings",
"get_default_qat_module_mappings",
"get_default_qat_qconfig",
"get_default_qat_qconfig_dict",
"get_default_qat_qconfig_mapping",
"get_default_qconfig",
"get_default_qconfig_dict",
"get_default_qconfig_mapping",
"get_default_qconfig_propagation_list",
"get_default_static_quant_module_mappings",
"get_default_static_quant_reference_module_mappings",
"get_default_static_sparse_quant_module_mappings",
"get_dynamic_quant_module_class",
"get_embedding_qat_module_mappings",
"get_embedding_static_quant_module_mappings",
"get_fuser_method",
"get_fuser_method_new",
"get_observer_state_dict",
"get_quantized_operator",
"get_static_quant_module_class",
"load_observer_state_dict",
"move_model_to_eval",
"no_observer_set",
"per_channel_weight_observer_range_neg_127_to_127",
"prepare",
"prepare_dynamic_jit",
"prepare_jit",
"prepare_qat",
"propagate_qconfig_",
"qconfig_equals",
"quantize",
"quantize_dynamic",
"quantize_dynamic_jit",
"quantize_jit",
"quantize_qat",
"script_qconfig",
"script_qconfig_dict",
"swap_module",
"weight_observer_range_neg_127_to_127",
]
def default_eval_fn(model, calib_data):
r"""
Default evaluation function takes a torch.utils.data.Dataset or a list of
input Tensors and run the model on the dataset
"""
for data, target in calib_data:
model(data)
class _DerivedObserverOrFakeQuantize(ObserverBase):
r""" This observer is used to describe an observer whose quantization parameters
are derived from other observers
"""
def __init__(
self,
dtype: torch.dtype,
obs_or_fqs: List[ObserverOrFakeQuantize],
derive_qparams_fn: Callable[[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]],
quant_min: Optional[int]=None,
quant_max: Optional[int]=None,
qscheme: Optional[torch.qscheme]=None,
ch_axis: Optional[int] = None
):
super().__init__(dtype)
self.obs_or_fqs = obs_or_fqs
self.derive_qparams_fn = derive_qparams_fn
self.quant_min = quant_min
self.quant_max = quant_max
self.qscheme = qscheme
self.ch_axis = ch_axis
from .utils import is_per_channel
if is_per_channel(self.qscheme):
assert self.ch_axis is not None, "Must provide a valid ch_axis if qscheme is per channel"
def forward(self, x: Tensor) -> Tensor:
return x
def calculate_qparams(self):
return self.derive_qparams_fn(self.obs_or_fqs)