pytorch/torch/ao/quantization/quantize_pt2e.py
Kimish Patel eb67c452c8 [Quant] Add DQ duplication pass (#107900)
Summary:
During convert step observers are first replaced by Q-DQ pair. In some
scenarios like following output DQ has a fan out.

                 ---> OP2 -> Q -> DQ
                /
OP -> Q -> DQ -
                \
                 ---> OP3 -> Q -> DQ

If either op OP2 or OP3 are configured to be quantized, then the input
is expected to quantized. In this case quantized equivalent of some
pattern, that quantizer asked to be quantized, should look like:
[DQ -> {pattern} -> Q]. However, in scenario like above where DQ node
is shared between multiple "quantized" patterns, boundary of "quantized"
pattern is not clear because DQ now belongs to multiple quantized
patterns.

This poses challenge for:
- Porting metadata: which "quantized" partition this DQ node belongs
- Quantized representation, equivalently, needs to identify
self-contained quantized pattern that is replaced by its equivalent pattern
that captures compute in the quantized precision.

Test Plan:
test_duplicate_dq_pass

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D48663147](https://our.internmc.facebook.com/intern/diff/D48663147)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107900
Approved by: https://github.com/jerryzh168, https://github.com/andrewor14, https://github.com/leslie-fang-intel
ghstack dependencies: #107105, #107106, #107899
2023-09-02 06:20:03 +00:00

107 lines
3.2 KiB
Python

from torch.fx import GraphModule
from .pt2e.prepare import prepare
from .pt2e.qat_utils import (
_fuse_conv_bn_qat,
_fold_conv_bn_qat,
)
from .pt2e.utils import (
_get_node_name_to_scope,
_fuse_conv_bn_,
)
from .pt2e.representation import reference_representation_rewrite
from .fx.prepare import prepare as fx_prepare
from .quantize_fx import _convert_to_reference_decomposed_fx
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.quantizer import ( # noqa: F401
Quantizer,
QuantizationSpecBase,
QuantizationSpec,
FixedQParamsQuantizationSpec,
SharedQuantizationSpec,
DerivedQuantizationSpec,
QuantizationAnnotation,
)
from torch.ao.quantization.backend_config import BackendConfig
from typing import Any, Tuple
from torch.fx.passes.infra.pass_manager import PassManager
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
__all__ = [
"prepare_pt2e",
"prepare_qat_pt2e",
"convert_pt2e",
]
def _prepare_pt2e_deprecated(
model: GraphModule,
qconfig_mapping: QConfigMapping,
example_inputs: Tuple[Any, ...],
backend_config: BackendConfig,
) -> GraphModule:
node_name_to_scope = _get_node_name_to_scope(model)
# TODO: check qconfig_mapping to make sure conv and bn are both configured
# to be quantized before fusion
# TODO: (maybe) rewrite this with subgraph_rewriter
_fuse_conv_bn_(model)
model = fx_prepare(
model,
qconfig_mapping,
False, # is_qat
node_name_to_scope,
example_inputs,
backend_config=backend_config
)
return model
def prepare_pt2e(
model: GraphModule,
quantizer: Quantizer,
) -> GraphModule:
original_graph_meta = model.meta
node_name_to_scope = _get_node_name_to_scope(model)
# TODO: check qconfig_mapping to make sure conv and bn are both configured
# to be quantized before fusion
# TODO: (maybe) rewrite this with subgraph_rewriter
_fuse_conv_bn_(model)
quantizer.annotate(model)
quantizer.validate(model)
model = prepare(model, node_name_to_scope, is_qat=False)
model.meta.update(original_graph_meta)
return model
def prepare_qat_pt2e(
model: GraphModule,
quantizer: Quantizer,
) -> GraphModule:
original_graph_meta = model.meta
node_name_to_scope = _get_node_name_to_scope(model)
quantizer.annotate(model)
quantizer.validate(model)
# Perform fusion after annotate to avoid quantizing ops in the new
# subgraph that don't need to be quantized
# TODO: only fuse if conv and bn are both configured to be quantized
_fuse_conv_bn_qat(model)
model = prepare(model, node_name_to_scope, is_qat=True)
model.meta.update(original_graph_meta)
return model
def convert_pt2e(
model: GraphModule,
use_reference_representation: bool = False,
) -> GraphModule:
original_graph_meta = model.meta
model = _convert_to_reference_decomposed_fx(model)
model = _fold_conv_bn_qat(model)
pm = PassManager([DuplicateDQPass()])
model = pm(model).graph_module
if use_reference_representation:
model = reference_representation_rewrite(model)
model.meta.update(original_graph_meta)
return model