mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
During convert step observers are first replaced by Q-DQ pair. In some
scenarios like following output DQ has a fan out.
---> OP2 -> Q -> DQ
/
OP -> Q -> DQ -
\
---> OP3 -> Q -> DQ
If either op OP2 or OP3 are configured to be quantized, then the input
is expected to quantized. In this case quantized equivalent of some
pattern, that quantizer asked to be quantized, should look like:
[DQ -> {pattern} -> Q]. However, in scenario like above where DQ node
is shared between multiple "quantized" patterns, boundary of "quantized"
pattern is not clear because DQ now belongs to multiple quantized
patterns.
This poses challenge for:
- Porting metadata: which "quantized" partition this DQ node belongs
- Quantized representation, equivalently, needs to identify
self-contained quantized pattern that is replaced by its equivalent pattern
that captures compute in the quantized precision.
Test Plan:
test_duplicate_dq_pass
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: [D48663147](https://our.internmc.facebook.com/intern/diff/D48663147)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107900
Approved by: https://github.com/jerryzh168, https://github.com/andrewor14, https://github.com/leslie-fang-intel
ghstack dependencies: #107105, #107106, #107899
107 lines
3.2 KiB
Python
107 lines
3.2 KiB
Python
from torch.fx import GraphModule
|
|
|
|
from .pt2e.prepare import prepare
|
|
from .pt2e.qat_utils import (
|
|
_fuse_conv_bn_qat,
|
|
_fold_conv_bn_qat,
|
|
)
|
|
from .pt2e.utils import (
|
|
_get_node_name_to_scope,
|
|
_fuse_conv_bn_,
|
|
)
|
|
from .pt2e.representation import reference_representation_rewrite
|
|
from .fx.prepare import prepare as fx_prepare
|
|
from .quantize_fx import _convert_to_reference_decomposed_fx
|
|
from torch.ao.quantization import QConfigMapping
|
|
from torch.ao.quantization.quantizer import ( # noqa: F401
|
|
Quantizer,
|
|
QuantizationSpecBase,
|
|
QuantizationSpec,
|
|
FixedQParamsQuantizationSpec,
|
|
SharedQuantizationSpec,
|
|
DerivedQuantizationSpec,
|
|
QuantizationAnnotation,
|
|
)
|
|
from torch.ao.quantization.backend_config import BackendConfig
|
|
|
|
from typing import Any, Tuple
|
|
|
|
from torch.fx.passes.infra.pass_manager import PassManager
|
|
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
|
|
|
|
__all__ = [
|
|
"prepare_pt2e",
|
|
"prepare_qat_pt2e",
|
|
"convert_pt2e",
|
|
]
|
|
|
|
def _prepare_pt2e_deprecated(
|
|
model: GraphModule,
|
|
qconfig_mapping: QConfigMapping,
|
|
example_inputs: Tuple[Any, ...],
|
|
backend_config: BackendConfig,
|
|
) -> GraphModule:
|
|
node_name_to_scope = _get_node_name_to_scope(model)
|
|
|
|
# TODO: check qconfig_mapping to make sure conv and bn are both configured
|
|
# to be quantized before fusion
|
|
# TODO: (maybe) rewrite this with subgraph_rewriter
|
|
_fuse_conv_bn_(model)
|
|
model = fx_prepare(
|
|
model,
|
|
qconfig_mapping,
|
|
False, # is_qat
|
|
node_name_to_scope,
|
|
example_inputs,
|
|
backend_config=backend_config
|
|
)
|
|
return model
|
|
|
|
def prepare_pt2e(
|
|
model: GraphModule,
|
|
quantizer: Quantizer,
|
|
) -> GraphModule:
|
|
original_graph_meta = model.meta
|
|
node_name_to_scope = _get_node_name_to_scope(model)
|
|
# TODO: check qconfig_mapping to make sure conv and bn are both configured
|
|
# to be quantized before fusion
|
|
# TODO: (maybe) rewrite this with subgraph_rewriter
|
|
_fuse_conv_bn_(model)
|
|
quantizer.annotate(model)
|
|
quantizer.validate(model)
|
|
model = prepare(model, node_name_to_scope, is_qat=False)
|
|
model.meta.update(original_graph_meta)
|
|
return model
|
|
|
|
def prepare_qat_pt2e(
|
|
model: GraphModule,
|
|
quantizer: Quantizer,
|
|
) -> GraphModule:
|
|
original_graph_meta = model.meta
|
|
node_name_to_scope = _get_node_name_to_scope(model)
|
|
quantizer.annotate(model)
|
|
quantizer.validate(model)
|
|
# Perform fusion after annotate to avoid quantizing ops in the new
|
|
# subgraph that don't need to be quantized
|
|
# TODO: only fuse if conv and bn are both configured to be quantized
|
|
_fuse_conv_bn_qat(model)
|
|
model = prepare(model, node_name_to_scope, is_qat=True)
|
|
model.meta.update(original_graph_meta)
|
|
return model
|
|
|
|
def convert_pt2e(
|
|
model: GraphModule,
|
|
use_reference_representation: bool = False,
|
|
) -> GraphModule:
|
|
original_graph_meta = model.meta
|
|
model = _convert_to_reference_decomposed_fx(model)
|
|
model = _fold_conv_bn_qat(model)
|
|
pm = PassManager([DuplicateDQPass()])
|
|
model = pm(model).graph_module
|
|
|
|
if use_reference_representation:
|
|
model = reference_representation_rewrite(model)
|
|
|
|
model.meta.update(original_graph_meta)
|
|
return model
|