mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: This commit adds a public facing `torch.ao.quantization.move_model_to_eval` util function for QAT users. Instead of calling model.eval() on an exported model (which doesn't work, see https://github.com/pytorch/pytorch/issues/103681), the user would call this new util function instead. This ensures special ops such as dropout and batchnorm (not supported yet) will have the right behavior when the graph is later used for inference. Note: Support for an equivalent `move_model_to_train` will be added in the future. This is difficult to do for dropout currently because the eval pattern of dropout is simply a clone op, which we cannot just match and replace with a dropout op. Test Plan: python test/test_quantization.py TestQuantizePT2E.test_move_model_to_eval Reviewers: jerryzh168, kimishpatel Subscribers: jerryzh168, kimishpatel, supriyar Differential Revision: [D48814735](https://our.internmc.facebook.com/intern/diff/D48814735) Pull Request resolved: https://github.com/pytorch/pytorch/pull/108184 Approved by: https://github.com/jerryzh168
104 lines
3.1 KiB
Python
104 lines
3.1 KiB
Python
from torch.fx import GraphModule
|
|
|
|
from .pt2e.prepare import prepare
|
|
from .pt2e._propagate_annotation import propagate_annotation
|
|
from .pt2e.qat_utils import (
|
|
_fuse_conv_bn_qat,
|
|
_fold_conv_bn_qat,
|
|
)
|
|
from .pt2e.utils import (
|
|
_get_node_name_to_scope,
|
|
_fuse_conv_bn_,
|
|
)
|
|
from .pt2e.representation import reference_representation_rewrite
|
|
from .fx.prepare import prepare as fx_prepare
|
|
from .quantize_fx import _convert_to_reference_decomposed_fx
|
|
from torch.ao.quantization import QConfigMapping
|
|
from torch.ao.quantization.quantizer import ( # noqa: F401
|
|
Quantizer,
|
|
QuantizationSpecBase,
|
|
QuantizationSpec,
|
|
FixedQParamsQuantizationSpec,
|
|
SharedQuantizationSpec,
|
|
DerivedQuantizationSpec,
|
|
QuantizationAnnotation,
|
|
)
|
|
from torch.ao.quantization.backend_config import BackendConfig
|
|
|
|
from typing import Any, Tuple
|
|
|
|
__all__ = [
|
|
"prepare_pt2e",
|
|
"prepare_qat_pt2e",
|
|
"convert_pt2e",
|
|
]
|
|
|
|
def _prepare_pt2e_deprecated(
|
|
model: GraphModule,
|
|
qconfig_mapping: QConfigMapping,
|
|
example_inputs: Tuple[Any, ...],
|
|
backend_config: BackendConfig,
|
|
) -> GraphModule:
|
|
node_name_to_scope = _get_node_name_to_scope(model)
|
|
|
|
# TODO: check qconfig_mapping to make sure conv and bn are both configured
|
|
# to be quantized before fusion
|
|
# TODO: (maybe) rewrite this with subgraph_rewriter
|
|
_fuse_conv_bn_(model)
|
|
model = fx_prepare(
|
|
model,
|
|
qconfig_mapping,
|
|
False, # is_qat
|
|
node_name_to_scope,
|
|
example_inputs,
|
|
backend_config=backend_config
|
|
)
|
|
return model
|
|
|
|
def prepare_pt2e(
|
|
model: GraphModule,
|
|
quantizer: Quantizer,
|
|
) -> GraphModule:
|
|
original_graph_meta = model.meta
|
|
node_name_to_scope = _get_node_name_to_scope(model)
|
|
# TODO: check qconfig_mapping to make sure conv and bn are both configured
|
|
# to be quantized before fusion
|
|
# TODO: (maybe) rewrite this with subgraph_rewriter
|
|
_fuse_conv_bn_(model)
|
|
quantizer.annotate(model)
|
|
quantizer.validate(model)
|
|
propagate_annotation(model)
|
|
model = prepare(model, node_name_to_scope, is_qat=False)
|
|
model.meta.update(original_graph_meta)
|
|
return model
|
|
|
|
def prepare_qat_pt2e(
|
|
model: GraphModule,
|
|
quantizer: Quantizer,
|
|
) -> GraphModule:
|
|
original_graph_meta = model.meta
|
|
node_name_to_scope = _get_node_name_to_scope(model)
|
|
quantizer.annotate(model)
|
|
quantizer.validate(model)
|
|
propagate_annotation(model)
|
|
# Perform fusion after annotate to avoid quantizing ops in the new
|
|
# subgraph that don't need to be quantized
|
|
# TODO: only fuse if conv and bn are both configured to be quantized
|
|
_fuse_conv_bn_qat(model)
|
|
model = prepare(model, node_name_to_scope, is_qat=True)
|
|
model.meta.update(original_graph_meta)
|
|
return model
|
|
|
|
def convert_pt2e(
|
|
model: GraphModule,
|
|
use_reference_representation: bool = False,
|
|
) -> GraphModule:
|
|
original_graph_meta = model.meta
|
|
model = _convert_to_reference_decomposed_fx(model)
|
|
model = _fold_conv_bn_qat(model)
|
|
if use_reference_representation:
|
|
model = reference_representation_rewrite(model)
|
|
|
|
model.meta.update(original_graph_meta)
|
|
return model
|