pytorch/torch/ao/quantization/quantizer/embedding_quantizer.py
Jerry Zhang 28be2c674a [quant][pt2e] Move specific quantizer related things outside of main quant code base (#106806) (#107259)
Summary:

Currently in quantizer/quantize_pt2e we import things from specific quantizers (XNNPACKQuantizer, QuantizationConfig) etc.
this PR removes them so it's clearer that they are not part of the core quantization code base

This PR also removed get_supported_operators from main Quantizer since we haven't seen a clear need for this API

Test Plan:
CIs

Imported from OSS

Differential Revision: D48340367

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107259
Approved by: https://github.com/kimishpatel
2023-08-18 21:29:09 +00:00

97 lines
3.4 KiB
Python

from __future__ import annotations
import copy
from typing import List, Set
import torch
import torch.nn.functional as F
from torch.ao.quantization.observer import PerChannelMinMaxObserver
from torch.ao.quantization.quantizer.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
Quantizer,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
OperatorConfig,
OperatorPatternType,
QuantizationConfig,
)
__all__ = [
"get_embedding_operators_config",
"EmbeddingQuantizer",
]
def get_embedding_operators_config() -> OperatorConfig:
weight_quantization_spec = QuantizationSpec(
dtype=torch.uint8,
qscheme=torch.per_channel_affine_float_qparams,
ch_axis=0,
observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(eps=2**-12),
)
quantization_config = QuantizationConfig(None, None, weight_quantization_spec, None)
ops: List[OperatorPatternType] = [[torch.nn.Embedding]]
ops.append([F.embedding])
supported_config_and_operators = OperatorConfig(
config=quantization_config, operators=ops
)
return copy.deepcopy(supported_config_and_operators)
class EmbeddingQuantizer(Quantizer):
def __init__(self):
super().__init__()
@classmethod
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
op_configs: Set[QuantizationConfig] = set({})
for spec, _ in cls.get_supported_operators():
op_configs.add(spec)
return list(op_configs)
@classmethod
def get_supported_operator_for_quantization_config(
cls, quantization_config: QuantizationConfig
) -> List[OperatorPatternType]:
for config, ops in cls.get_supported_operators():
# note: this assumes each entry in cls.supported_spec_and_operators
# corresponds to one spec, e.g. we don't have
# [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
# where the first and second entry have the same spec but did not
# merge the op list
if config == quantization_config:
return ops
return []
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""just handling global spec for now"""
self._annotate_embedding_ops(model.graph)
return model
def _annotate_embedding_ops(self, graph: torch.fx.Graph) -> None:
embedding_config: OperatorConfig = get_embedding_operators_config()
for node in graph.nodes:
# Keep node parsing based annotations instead of module partitioners
# just as an example of alternate ways of annotating
if (
node.op == "call_function"
and node.target == torch.ops.aten.embedding.default
):
if embedding_config.config.weight is None:
raise ValueError(
"Embedding config must have a valid weight quantization spec."
)
node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
node.args[0]: embedding_config.config.weight,
}
)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
@classmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
return [get_embedding_operators_config()]