pytorch/torch/ao/quantization/quantizer/composable_quantizer.py
Jerry Zhang 28be2c674a [quant][pt2e] Move specific quantizer related things outside of main quant code base (#106806) (#107259)
Summary:

Currently in quantizer/quantize_pt2e we import things from specific quantizers (XNNPACKQuantizer, QuantizationConfig) etc.
this PR removes them so it's clearer that they are not part of the core quantization code base

This PR also removed get_supported_operators from main Quantizer since we haven't seen a clear need for this API

Test Plan:
CIs

Imported from OSS

Differential Revision: D48340367

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107259
Approved by: https://github.com/kimishpatel
2023-08-18 21:29:09 +00:00

72 lines
2.7 KiB
Python

from __future__ import annotations
from typing import Dict, List
import torch
from torch.fx import Node
from .quantizer import QuantizationAnnotation, Quantizer
__all__ = [
"ComposableQuantizer",
]
class ComposableQuantizer(Quantizer):
"""
ComposableQuantizer allows users to combine more than one quantizer into a single quantizer.
This allows users to quantize a model with multiple quantizers. E.g., embedding quantization
maybe supported by one quantizer while linear layers and other ops might be supported by another
quantizer.
ComposableQuantizer is initialized with a list of `Quantizer` instances.
The order of the composition matters since that is the order in which the quantizers will be
applies.
Example:
```
embedding_quantizer = EmbeddingQuantizer()
linear_quantizer = MyLinearQuantizer()
xnnpack_quantizer = XNNPackQuantizer() # to handle ops not quantized by previous two quantizers
composed_quantizer = ComposableQuantizer([embedding_quantizer, linear_quantizer, xnnpack_quantizer])
prepared_m = prepare_pt2e(model, composed_quantizer)
```
"""
def __init__(self, quantizers: List[Quantizer]):
super().__init__()
self.quantizers = quantizers
self._graph_annotations: Dict[Node, QuantizationAnnotation] = {}
def _record_and_validate_annotations(
self, gm: torch.fx.GraphModule, quantizer: Quantizer
) -> None:
for n in gm.graph.nodes:
if "quantization_annotation" in n.meta:
# check if the annotation has been changed by
# comparing QuantizationAnnotation object id
if n in self._graph_annotations and (
id(self._graph_annotations[n])
!= id(n.meta["quantization_annotation"])
):
raise RuntimeError(
f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}"
)
else:
self._graph_annotations[n] = n.meta["quantization_annotation"]
else:
if n in self._graph_annotations:
raise RuntimeError(
f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}"
)
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""just handling global spec for now"""
for quantizer in self.quantizers:
quantizer.annotate(model)
self._record_and_validate_annotations(model, quantizer)
return model
def validate(self, model: torch.fx.GraphModule) -> None:
pass