mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Currently in quantizer/quantize_pt2e we import things from specific quantizers (XNNPACKQuantizer, QuantizationConfig) etc. this PR removes them so it's clearer that they are not part of the core quantization code base This PR also removed get_supported_operators from main Quantizer since we haven't seen a clear need for this API Test Plan: CIs Imported from OSS Differential Revision: D48340367 Pull Request resolved: https://github.com/pytorch/pytorch/pull/107259 Approved by: https://github.com/kimishpatel
72 lines
2.7 KiB
Python
72 lines
2.7 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Dict, List
|
|
|
|
import torch
|
|
|
|
from torch.fx import Node
|
|
|
|
from .quantizer import QuantizationAnnotation, Quantizer
|
|
|
|
__all__ = [
|
|
"ComposableQuantizer",
|
|
]
|
|
|
|
|
|
class ComposableQuantizer(Quantizer):
|
|
"""
|
|
ComposableQuantizer allows users to combine more than one quantizer into a single quantizer.
|
|
This allows users to quantize a model with multiple quantizers. E.g., embedding quantization
|
|
maybe supported by one quantizer while linear layers and other ops might be supported by another
|
|
quantizer.
|
|
|
|
ComposableQuantizer is initialized with a list of `Quantizer` instances.
|
|
The order of the composition matters since that is the order in which the quantizers will be
|
|
applies.
|
|
Example:
|
|
```
|
|
embedding_quantizer = EmbeddingQuantizer()
|
|
linear_quantizer = MyLinearQuantizer()
|
|
xnnpack_quantizer = XNNPackQuantizer() # to handle ops not quantized by previous two quantizers
|
|
composed_quantizer = ComposableQuantizer([embedding_quantizer, linear_quantizer, xnnpack_quantizer])
|
|
prepared_m = prepare_pt2e(model, composed_quantizer)
|
|
```
|
|
"""
|
|
|
|
def __init__(self, quantizers: List[Quantizer]):
|
|
super().__init__()
|
|
self.quantizers = quantizers
|
|
self._graph_annotations: Dict[Node, QuantizationAnnotation] = {}
|
|
|
|
def _record_and_validate_annotations(
|
|
self, gm: torch.fx.GraphModule, quantizer: Quantizer
|
|
) -> None:
|
|
for n in gm.graph.nodes:
|
|
if "quantization_annotation" in n.meta:
|
|
# check if the annotation has been changed by
|
|
# comparing QuantizationAnnotation object id
|
|
if n in self._graph_annotations and (
|
|
id(self._graph_annotations[n])
|
|
!= id(n.meta["quantization_annotation"])
|
|
):
|
|
raise RuntimeError(
|
|
f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}"
|
|
)
|
|
else:
|
|
self._graph_annotations[n] = n.meta["quantization_annotation"]
|
|
else:
|
|
if n in self._graph_annotations:
|
|
raise RuntimeError(
|
|
f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}"
|
|
)
|
|
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
"""just handling global spec for now"""
|
|
for quantizer in self.quantizers:
|
|
quantizer.annotate(model)
|
|
self._record_and_validate_annotations(model, quantizer)
|
|
return model
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|