import torch from torch.fx import GraphModule from torch.fx import Node from .pt2e.prepare import prepare from .pt2e.qat_utils import ( _fuse_conv_bn_qat, _fold_conv_bn_qat, ) from .pt2e.utils import ( _get_node_name_to_scope, _fuse_conv_bn_, _disallow_eval_train, ) from .pt2e.representation import reference_representation_rewrite from .quantize_fx import _convert_to_reference_decomposed_fx from torch.ao.quantization.quantizer import ( # noqa: F401 Quantizer, QuantizationSpecBase, QuantizationSpec, FixedQParamsQuantizationSpec, SharedQuantizationSpec, DerivedQuantizationSpec, QuantizationAnnotation, ) from torch.fx.passes.infra.pass_manager import PassManager from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ from torch._inductor.constant_folding import constant_fold __all__ = [ "prepare_pt2e", "prepare_qat_pt2e", "convert_pt2e", ] def prepare_pt2e( model: GraphModule, quantizer: Quantizer, ) -> GraphModule: original_graph_meta = model.meta node_name_to_scope = _get_node_name_to_scope(model) # TODO: check qconfig_mapping to make sure conv and bn are both configured # to be quantized before fusion # TODO: (maybe) rewrite this with subgraph_rewriter _fuse_conv_bn_(model) quantizer.annotate(model) quantizer.validate(model) model = prepare(model, node_name_to_scope, is_qat=False) model.meta.update(original_graph_meta) model = _disallow_eval_train(model) return model def prepare_qat_pt2e( model: GraphModule, quantizer: Quantizer, ) -> GraphModule: original_graph_meta = model.meta node_name_to_scope = _get_node_name_to_scope(model) quantizer.annotate(model) quantizer.validate(model) # Perform fusion after annotate to avoid quantizing ops in the new # subgraph that don't need to be quantized # TODO: only fuse if conv and bn are both configured to be quantized _fuse_conv_bn_qat(model) model = prepare(model, node_name_to_scope, is_qat=True) model.meta.update(original_graph_meta) model = _disallow_eval_train(model) return model _QUANT_OPS = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_tensor.tensor, torch.ops.quantized_decomposed.quantize_per_channel.default, ] def _quant_node_constraint(n: Node) -> bool: """If there is any pure ops between get_attr and quantize op they will be const propagated e.g. get_attr(weight) -> transpose -> quantize -> dequantize* (Note: dequantize op is not going to be constant propagated) This filter is added because we don't want to constant fold the things that are not related to quantization """ return n.op == "call_function" and n.target in _QUANT_OPS def convert_pt2e( model: GraphModule, use_reference_representation: bool = False, fold_quantize: bool = False, ) -> GraphModule: """Convert a calibrated/trained model to a quantized model Args: model: calibrated/trained model use_reference_representation: boolean flag to indicate whether to produce referece representation or not fold_quantize: boolean flag to indicate whether fold the quantize op or not Note: please set `fold_quantize` to True whenever you can, we'll deprecate this flag and make True the default option in the future, to make sure the change doesn't break BC for you, it's better to set the flag to True now. Returns: quantized model, either in q/dq representation or reference representation """ original_graph_meta = model.meta model = _convert_to_reference_decomposed_fx(model) model = _fold_conv_bn_qat(model) pm = PassManager([DuplicateDQPass()]) model = pm(model).graph_module pm = PassManager([PortNodeMetaForQDQ()]) model = pm(model).graph_module if fold_quantize: constant_fold(model, _quant_node_constraint) if use_reference_representation: model = reference_representation_rewrite(model) model.meta.update(original_graph_meta) model = _disallow_eval_train(model) return model