mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148185 Approved by: https://github.com/ezyang
84 lines
2.9 KiB
Python
84 lines
2.9 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
from .quantizer import QuantizationAnnotation, Quantizer
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
import torch
|
|
from torch.fx import Node
|
|
|
|
__all__ = [
|
|
"ComposableQuantizer",
|
|
]
|
|
|
|
|
|
class ComposableQuantizer(Quantizer):
|
|
"""
|
|
ComposableQuantizer allows users to combine more than one quantizer into a single quantizer.
|
|
This allows users to quantize a model with multiple quantizers. E.g., embedding quantization
|
|
maybe supported by one quantizer while linear layers and other ops might be supported by another
|
|
quantizer.
|
|
|
|
ComposableQuantizer is initialized with a list of `Quantizer` instances.
|
|
The order of the composition matters since that is the order in which the quantizers will be
|
|
applies.
|
|
Example:
|
|
```
|
|
embedding_quantizer = EmbeddingQuantizer()
|
|
linear_quantizer = MyLinearQuantizer()
|
|
xnnpack_quantizer = (
|
|
XNNPackQuantizer()
|
|
) # to handle ops not quantized by previous two quantizers
|
|
composed_quantizer = ComposableQuantizer(
|
|
[embedding_quantizer, linear_quantizer, xnnpack_quantizer]
|
|
)
|
|
prepared_m = prepare_pt2e(model, composed_quantizer)
|
|
```
|
|
"""
|
|
|
|
def __init__(self, quantizers: list[Quantizer]):
|
|
super().__init__()
|
|
self.quantizers = quantizers
|
|
self._graph_annotations: dict[Node, QuantizationAnnotation] = {}
|
|
|
|
def _record_and_validate_annotations(
|
|
self, gm: torch.fx.GraphModule, quantizer: Quantizer
|
|
) -> None:
|
|
for n in gm.graph.nodes:
|
|
if "quantization_annotation" in n.meta:
|
|
# check if the annotation has been changed by
|
|
# comparing QuantizationAnnotation object id
|
|
if n in self._graph_annotations and (
|
|
id(self._graph_annotations[n])
|
|
!= id(n.meta["quantization_annotation"])
|
|
):
|
|
raise RuntimeError(
|
|
f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}"
|
|
)
|
|
else:
|
|
self._graph_annotations[n] = n.meta["quantization_annotation"]
|
|
else:
|
|
if n in self._graph_annotations:
|
|
raise RuntimeError(
|
|
f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}"
|
|
)
|
|
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
"""just handling global spec for now"""
|
|
for quantizer in self.quantizers:
|
|
quantizer.annotate(model)
|
|
self._record_and_validate_annotations(model, quantizer)
|
|
return model
|
|
|
|
def transform_for_annotation(
|
|
self, model: torch.fx.GraphModule
|
|
) -> torch.fx.GraphModule:
|
|
for quantizer in self.quantizers:
|
|
model = quantizer.transform_for_annotation(model)
|
|
return model
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|