pytorch/torch/ao/quantization/quantize_pt2e.py
Jerry Zhang 3943afc94e [quant][be] Remove unused APIs (#109342)
Summary:
att

Test Plan:
python test/test_quantization.py TestQuantizePT2E

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109342
Approved by: https://github.com/kimishpatel, https://github.com/andrewor14
2023-09-15 16:07:01 +00:00

87 lines
2.7 KiB
Python

from torch.fx import GraphModule
from .pt2e.prepare import prepare
from .pt2e.qat_utils import (
_fuse_conv_bn_qat,
_fold_conv_bn_qat,
)
from .pt2e.utils import (
_get_node_name_to_scope,
_fuse_conv_bn_,
_disallow_eval_train,
)
from .pt2e.representation import reference_representation_rewrite
from .quantize_fx import _convert_to_reference_decomposed_fx
from torch.ao.quantization.quantizer import ( # noqa: F401
Quantizer,
QuantizationSpecBase,
QuantizationSpec,
FixedQParamsQuantizationSpec,
SharedQuantizationSpec,
DerivedQuantizationSpec,
QuantizationAnnotation,
)
from torch.fx.passes.infra.pass_manager import PassManager
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
__all__ = [
"prepare_pt2e",
"prepare_qat_pt2e",
"convert_pt2e",
]
def prepare_pt2e(
model: GraphModule,
quantizer: Quantizer,
) -> GraphModule:
original_graph_meta = model.meta
node_name_to_scope = _get_node_name_to_scope(model)
# TODO: check qconfig_mapping to make sure conv and bn are both configured
# to be quantized before fusion
# TODO: (maybe) rewrite this with subgraph_rewriter
_fuse_conv_bn_(model)
quantizer.annotate(model)
quantizer.validate(model)
model = prepare(model, node_name_to_scope, is_qat=False)
model.meta.update(original_graph_meta)
model = _disallow_eval_train(model)
return model
def prepare_qat_pt2e(
model: GraphModule,
quantizer: Quantizer,
) -> GraphModule:
original_graph_meta = model.meta
node_name_to_scope = _get_node_name_to_scope(model)
quantizer.annotate(model)
quantizer.validate(model)
# Perform fusion after annotate to avoid quantizing ops in the new
# subgraph that don't need to be quantized
# TODO: only fuse if conv and bn are both configured to be quantized
_fuse_conv_bn_qat(model)
model = prepare(model, node_name_to_scope, is_qat=True)
model.meta.update(original_graph_meta)
model = _disallow_eval_train(model)
return model
def convert_pt2e(
model: GraphModule,
use_reference_representation: bool = False,
) -> GraphModule:
original_graph_meta = model.meta
model = _convert_to_reference_decomposed_fx(model)
model = _fold_conv_bn_qat(model)
pm = PassManager([DuplicateDQPass()])
model = pm(model).graph_module
pm = PassManager([PortNodeMetaForQDQ()])
model = pm(model).graph_module
if use_reference_representation:
model = reference_representation_rewrite(model)
model.meta.update(original_graph_meta)
model = _disallow_eval_train(model)
return model