pytorch/torch/ao/quantization/quantizer/composable_quantizer.py

84 lines
2.9 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING
from .quantizer import QuantizationAnnotation, Quantizer
if TYPE_CHECKING:
import torch
from torch.fx import Node
__all__ = [
"ComposableQuantizer",
]
class ComposableQuantizer(Quantizer):
"""
ComposableQuantizer allows users to combine more than one quantizer into a single quantizer.
This allows users to quantize a model with multiple quantizers. E.g., embedding quantization
maybe supported by one quantizer while linear layers and other ops might be supported by another
quantizer.
ComposableQuantizer is initialized with a list of `Quantizer` instances.
The order of the composition matters since that is the order in which the quantizers will be
applies.
Example:
```
embedding_quantizer = EmbeddingQuantizer()
linear_quantizer = MyLinearQuantizer()
xnnpack_quantizer = (
XNNPackQuantizer()
) # to handle ops not quantized by previous two quantizers
composed_quantizer = ComposableQuantizer(
[embedding_quantizer, linear_quantizer, xnnpack_quantizer]
)
prepared_m = prepare_pt2e(model, composed_quantizer)
```
"""
def __init__(self, quantizers: list[Quantizer]):
super().__init__()
self.quantizers = quantizers
self._graph_annotations: dict[Node, QuantizationAnnotation] = {}
def _record_and_validate_annotations(
self, gm: torch.fx.GraphModule, quantizer: Quantizer
) -> None:
for n in gm.graph.nodes:
if "quantization_annotation" in n.meta:
# check if the annotation has been changed by
# comparing QuantizationAnnotation object id
if n in self._graph_annotations and (
id(self._graph_annotations[n])
!= id(n.meta["quantization_annotation"])
):
raise RuntimeError(
f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}"
)
else:
self._graph_annotations[n] = n.meta["quantization_annotation"]
else:
if n in self._graph_annotations:
raise RuntimeError(
f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}"
)
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""just handling global spec for now"""
for quantizer in self.quantizers:
quantizer.annotate(model)
self._record_and_validate_annotations(model, quantizer)
return model
def transform_for_annotation(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
for quantizer in self.quantizers:
model = quantizer.transform_for_annotation(model)
return model
def validate(self, model: torch.fx.GraphModule) -> None:
pass