pytorch/torch/ao/quantization/quantize_pt2e.py
andrewor14 057b807178 [quant] Move dropout replacement to move_model_to_eval (#108184)
Summary: This commit adds a public facing
`torch.ao.quantization.move_model_to_eval` util function
for QAT users. Instead of calling model.eval() on an exported
model (which doesn't work, see
https://github.com/pytorch/pytorch/issues/103681), the user
would call this new util function instead. This ensures special
ops such as dropout and batchnorm (not supported yet) will have
the right behavior when the graph is later used for inference.

Note: Support for an equivalent `move_model_to_train` will be
added in the future. This is difficult to do for dropout
currently because the eval pattern of dropout is simply a clone
op, which we cannot just match and replace with a dropout op.

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_move_model_to_eval

Reviewers: jerryzh168, kimishpatel

Subscribers: jerryzh168, kimishpatel, supriyar

Differential Revision: [D48814735](https://our.internmc.facebook.com/intern/diff/D48814735)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108184
Approved by: https://github.com/jerryzh168
2023-08-30 16:33:17 +00:00

104 lines
3.1 KiB
Python

from torch.fx import GraphModule
from .pt2e.prepare import prepare
from .pt2e._propagate_annotation import propagate_annotation
from .pt2e.qat_utils import (
_fuse_conv_bn_qat,
_fold_conv_bn_qat,
)
from .pt2e.utils import (
_get_node_name_to_scope,
_fuse_conv_bn_,
)
from .pt2e.representation import reference_representation_rewrite
from .fx.prepare import prepare as fx_prepare
from .quantize_fx import _convert_to_reference_decomposed_fx
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.quantizer import ( # noqa: F401
Quantizer,
QuantizationSpecBase,
QuantizationSpec,
FixedQParamsQuantizationSpec,
SharedQuantizationSpec,
DerivedQuantizationSpec,
QuantizationAnnotation,
)
from torch.ao.quantization.backend_config import BackendConfig
from typing import Any, Tuple
__all__ = [
"prepare_pt2e",
"prepare_qat_pt2e",
"convert_pt2e",
]
def _prepare_pt2e_deprecated(
model: GraphModule,
qconfig_mapping: QConfigMapping,
example_inputs: Tuple[Any, ...],
backend_config: BackendConfig,
) -> GraphModule:
node_name_to_scope = _get_node_name_to_scope(model)
# TODO: check qconfig_mapping to make sure conv and bn are both configured
# to be quantized before fusion
# TODO: (maybe) rewrite this with subgraph_rewriter
_fuse_conv_bn_(model)
model = fx_prepare(
model,
qconfig_mapping,
False, # is_qat
node_name_to_scope,
example_inputs,
backend_config=backend_config
)
return model
def prepare_pt2e(
model: GraphModule,
quantizer: Quantizer,
) -> GraphModule:
original_graph_meta = model.meta
node_name_to_scope = _get_node_name_to_scope(model)
# TODO: check qconfig_mapping to make sure conv and bn are both configured
# to be quantized before fusion
# TODO: (maybe) rewrite this with subgraph_rewriter
_fuse_conv_bn_(model)
quantizer.annotate(model)
quantizer.validate(model)
propagate_annotation(model)
model = prepare(model, node_name_to_scope, is_qat=False)
model.meta.update(original_graph_meta)
return model
def prepare_qat_pt2e(
model: GraphModule,
quantizer: Quantizer,
) -> GraphModule:
original_graph_meta = model.meta
node_name_to_scope = _get_node_name_to_scope(model)
quantizer.annotate(model)
quantizer.validate(model)
propagate_annotation(model)
# Perform fusion after annotate to avoid quantizing ops in the new
# subgraph that don't need to be quantized
# TODO: only fuse if conv and bn are both configured to be quantized
_fuse_conv_bn_qat(model)
model = prepare(model, node_name_to_scope, is_qat=True)
model.meta.update(original_graph_meta)
return model
def convert_pt2e(
model: GraphModule,
use_reference_representation: bool = False,
) -> GraphModule:
original_graph_meta = model.meta
model = _convert_to_reference_decomposed_fx(model)
model = _fold_conv_bn_qat(model)
if use_reference_representation:
model = reference_representation_rewrite(model)
model.meta.update(original_graph_meta)
return model