mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: This PR added constant folding for quantize ops so that instead of storing fp32 weight in the quantized model, we'll get int8/int16 etc. weight Test Plan: python test/test_quantization.py TestQuantizePT2E.test_fold_quantize also will verify in executorch later Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D49399210](https://our.internmc.facebook.com/intern/diff/D49399210) Pull Request resolved: https://github.com/pytorch/pytorch/pull/109343 Approved by: https://github.com/kimishpatel, https://github.com/jgong5
124 lines
4.1 KiB
Python
124 lines
4.1 KiB
Python
import torch
|
|
from torch.fx import GraphModule
|
|
from torch.fx import Node
|
|
|
|
from .pt2e.prepare import prepare
|
|
from .pt2e.qat_utils import (
|
|
_fuse_conv_bn_qat,
|
|
_fold_conv_bn_qat,
|
|
)
|
|
from .pt2e.utils import (
|
|
_get_node_name_to_scope,
|
|
_fuse_conv_bn_,
|
|
_disallow_eval_train,
|
|
)
|
|
from .pt2e.representation import reference_representation_rewrite
|
|
from .quantize_fx import _convert_to_reference_decomposed_fx
|
|
from torch.ao.quantization.quantizer import ( # noqa: F401
|
|
Quantizer,
|
|
QuantizationSpecBase,
|
|
QuantizationSpec,
|
|
FixedQParamsQuantizationSpec,
|
|
SharedQuantizationSpec,
|
|
DerivedQuantizationSpec,
|
|
QuantizationAnnotation,
|
|
)
|
|
from torch.fx.passes.infra.pass_manager import PassManager
|
|
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
|
|
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
|
|
from torch._inductor.constant_folding import constant_fold
|
|
|
|
__all__ = [
|
|
"prepare_pt2e",
|
|
"prepare_qat_pt2e",
|
|
"convert_pt2e",
|
|
]
|
|
|
|
|
|
def prepare_pt2e(
|
|
model: GraphModule,
|
|
quantizer: Quantizer,
|
|
) -> GraphModule:
|
|
original_graph_meta = model.meta
|
|
node_name_to_scope = _get_node_name_to_scope(model)
|
|
# TODO: check qconfig_mapping to make sure conv and bn are both configured
|
|
# to be quantized before fusion
|
|
# TODO: (maybe) rewrite this with subgraph_rewriter
|
|
_fuse_conv_bn_(model)
|
|
quantizer.annotate(model)
|
|
quantizer.validate(model)
|
|
model = prepare(model, node_name_to_scope, is_qat=False)
|
|
model.meta.update(original_graph_meta)
|
|
model = _disallow_eval_train(model)
|
|
return model
|
|
|
|
def prepare_qat_pt2e(
|
|
model: GraphModule,
|
|
quantizer: Quantizer,
|
|
) -> GraphModule:
|
|
original_graph_meta = model.meta
|
|
node_name_to_scope = _get_node_name_to_scope(model)
|
|
quantizer.annotate(model)
|
|
quantizer.validate(model)
|
|
# Perform fusion after annotate to avoid quantizing ops in the new
|
|
# subgraph that don't need to be quantized
|
|
# TODO: only fuse if conv and bn are both configured to be quantized
|
|
_fuse_conv_bn_qat(model)
|
|
model = prepare(model, node_name_to_scope, is_qat=True)
|
|
model.meta.update(original_graph_meta)
|
|
model = _disallow_eval_train(model)
|
|
return model
|
|
|
|
_QUANT_OPS = [
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
|
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
|
]
|
|
def _quant_node_constraint(n: Node) -> bool:
|
|
"""If there is any pure ops between get_attr and quantize op they will be const propagated
|
|
e.g. get_attr(weight) -> transpose -> quantize -> dequantize*
|
|
(Note: dequantize op is not going to be constant propagated)
|
|
|
|
This filter is added because we don't want to constant fold the things that are not
|
|
related to quantization
|
|
"""
|
|
return n.op == "call_function" and n.target in _QUANT_OPS
|
|
|
|
def convert_pt2e(
|
|
model: GraphModule,
|
|
use_reference_representation: bool = False,
|
|
fold_quantize: bool = False,
|
|
) -> GraphModule:
|
|
"""Convert a calibrated/trained model to a quantized model
|
|
|
|
Args:
|
|
model: calibrated/trained model
|
|
use_reference_representation: boolean flag to indicate whether to produce referece representation or not
|
|
fold_quantize: boolean flag to indicate whether fold the quantize op or not
|
|
|
|
Note: please set `fold_quantize` to True whenever you can, we'll deprecate this flag and
|
|
make True the default option in the future, to make sure the change doesn't break BC for you, it's
|
|
better to set the flag to True now.
|
|
|
|
Returns:
|
|
quantized model, either in q/dq representation or reference representation
|
|
"""
|
|
original_graph_meta = model.meta
|
|
model = _convert_to_reference_decomposed_fx(model)
|
|
model = _fold_conv_bn_qat(model)
|
|
pm = PassManager([DuplicateDQPass()])
|
|
model = pm(model).graph_module
|
|
|
|
pm = PassManager([PortNodeMetaForQDQ()])
|
|
model = pm(model).graph_module
|
|
|
|
if fold_quantize:
|
|
constant_fold(model, _quant_node_constraint)
|
|
|
|
if use_reference_representation:
|
|
model = reference_representation_rewrite(model)
|
|
|
|
model.meta.update(original_graph_meta)
|
|
model = _disallow_eval_train(model)
|
|
return model
|