pytorch/torch/ao/quantization/quantizer/composable_quantizer.py
Jerry Zhang 3a77f9aaaf [quant][api] Move torch.ao.quantization.pt2e.quantizer to torch.ao.quantization.quantizer (#105885)
Summary: moving quantizer to torch.ao.quantization to make it a public api, since pt2e is a folder for implementations

Test Plan:
CIs

sanity check: "buck test //executorch/backends/xnnpack/test:test_xnnpack_quantized_models -- test_resnet18"

Differential Revision: D47727838

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105885
Approved by: https://github.com/andrewor14
2023-07-26 18:20:09 +00:00

76 lines
2.8 KiB
Python

from __future__ import annotations
from typing import Dict, List
import torch
from torch.fx import Node
from .quantizer import OperatorConfig, QuantizationAnnotation, Quantizer
__all__ = [
"ComposableQuantizer",
]
class ComposableQuantizer(Quantizer):
"""
ComposableQuantizer allows users to combine more than one quantizer into a single quantizer.
This allows users to quantize a model with multiple quantizers. E.g., embedding quantization
maybe supported by one quantizer while linear layers and other ops might be supported by another
quantizer.
ComposableQuantizer is initialized with a list of `Quantizer` instances.
The order of the composition matters since that is the order in which the quantizers will be
applies.
Example:
```
embedding_quantizer = EmbeddingQuantizer()
linear_quantizer = MyLinearQuantizer()
xnnpack_quantizer = XNNPackQuantizer() # to handle ops not quantized by previous two quantizers
composed_quantizer = ComposableQuantizer([embedding_quantizer, linear_quantizer, xnnpack_quantizer])
prepared_m = prepare_pt2e(model, composed_quantizer)
```
"""
def __init__(self, quantizers: List[Quantizer]):
super().__init__()
self.quantizers = quantizers
self._graph_annotations: Dict[Node, QuantizationAnnotation] = {}
def _record_and_validate_annotations(
self, gm: torch.fx.GraphModule, quantizer: Quantizer
) -> None:
for n in gm.graph.nodes:
if "quantization_annotation" in n.meta:
# check if the annotation has been changed by
# comparing QuantizationAnnotation object id
if n in self._graph_annotations and (
id(self._graph_annotations[n])
!= id(n.meta["quantization_annotation"])
):
raise RuntimeError(
f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}"
)
else:
self._graph_annotations[n] = n.meta["quantization_annotation"]
else:
if n in self._graph_annotations:
raise RuntimeError(
f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}"
)
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""just handling global spec for now"""
for quantizer in self.quantizers:
quantizer.annotate(model)
self._record_and_validate_annotations(model, quantizer)
return model
def validate(self, model: torch.fx.GraphModule) -> None:
pass
@classmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
return []