mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: For a Node: node1 and edge: (node1, node2), since they are observing the same Tensor, we may want to implicitly share observers, this flag allows people to turn off this behavior for the output of the node See the test_allow_implicit_sharing test for use case Test Plan: python test/test_quantization.py TestQuantizePT2E.test_allow_implicit_sharing Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/112929 Approved by: https://github.com/kimishpatel
170 lines
5.5 KiB
Python
170 lines
5.5 KiB
Python
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.ao.quantization import ObserverOrFakeQuantize
|
|
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
|
from torch.fx import Node
|
|
|
|
__all__ = [
|
|
"Quantizer",
|
|
"QuantizationSpecBase",
|
|
"QuantizationSpec",
|
|
"FixedQParamsQuantizationSpec",
|
|
"EdgeOrNode",
|
|
"SharedQuantizationSpec",
|
|
"DerivedQuantizationSpec",
|
|
"QuantizationAnnotation",
|
|
]
|
|
|
|
# TODO: maybe remove torch.float32
|
|
SUPPORTED_DTYPES = [
|
|
torch.uint8,
|
|
torch.int8,
|
|
torch.int16,
|
|
torch.int32,
|
|
torch.float16,
|
|
torch.float32,
|
|
]
|
|
SUPPORTED_QSCHEMES = [
|
|
torch.per_tensor_affine,
|
|
torch.per_tensor_symmetric,
|
|
torch.per_channel_affine,
|
|
torch.per_channel_symmetric,
|
|
torch.per_channel_affine_float_qparams,
|
|
]
|
|
|
|
|
|
class QuantizationSpecBase(ABC): # noqa: B024
|
|
"""Base class for different types of quantization specs that allows users to
|
|
specify how to quantize a Tensor (input/output of a Node) in the model
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
@dataclass(eq=True, frozen=True)
|
|
class QuantizationSpec(QuantizationSpecBase):
|
|
"""Quantization spec for common operators that allows user to specify how to
|
|
quantize a Tensor, this includes dtype, quant_min, quant_max etc.
|
|
"""
|
|
|
|
dtype: torch.dtype
|
|
# observer or fake_quantize constructor such as
|
|
# MinMaxObserver, PerChannelHistogramObserver etc.
|
|
# or we can attach some custom args to them
|
|
# e.g. MinMaxObserver.with_args(eps=eps)
|
|
observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor
|
|
quant_min: Optional[int] = None
|
|
quant_max: Optional[int] = None
|
|
qscheme: Optional[torch.qscheme] = None
|
|
ch_axis: Optional[int] = None
|
|
is_dynamic: bool = False
|
|
|
|
def __post_init__(self):
|
|
# check dtype is one of the supported types
|
|
if self.dtype not in SUPPORTED_DTYPES:
|
|
raise TypeError(f"Unsupported dtype {self.dtype}.")
|
|
|
|
# quant_min must be less than quant_max
|
|
if (
|
|
self.quant_min is not None
|
|
and self.quant_max is not None
|
|
and self.quant_min > self.quant_max
|
|
):
|
|
raise ValueError(
|
|
f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}."
|
|
)
|
|
|
|
# check qscheme is on of the supported ones
|
|
if self.qscheme is not None and self.qscheme not in SUPPORTED_QSCHEMES:
|
|
raise ValueError(f"Unsupported qscheme {self.qscheme}.")
|
|
|
|
# ch_axis must be less than the number of channels
|
|
# but no way to check here. Just check that it is not < 0.
|
|
if self.ch_axis is not None and self.ch_axis < 0:
|
|
raise ValueError("Ch_axis is < 0.")
|
|
|
|
|
|
@dataclass(eq=True, frozen=True)
|
|
class FixedQParamsQuantizationSpec(QuantizationSpecBase):
|
|
dtype: torch.dtype
|
|
scale: float
|
|
zero_point: int
|
|
quant_min: Optional[int] = None
|
|
quant_max: Optional[int] = None
|
|
qscheme: Optional[torch.qscheme] = None
|
|
|
|
|
|
"""
|
|
The way we refer to other points of quantization in the graph will be either
|
|
an input edge or an output value
|
|
input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node]
|
|
output value is an fx Node
|
|
"""
|
|
EdgeOrNode = Union[Tuple[Node, Node], Node]
|
|
EdgeOrNode.__module__ = "torch.ao.quantization.quantizer.quantizer"
|
|
|
|
|
|
@dataclass(eq=True, frozen=True)
|
|
class SharedQuantizationSpec(QuantizationSpecBase):
|
|
"""
|
|
Quantization spec for the Tensors whose quantization parameters are shared with other Tensors
|
|
"""
|
|
|
|
# the edge or node to share observer or fake quant instances with
|
|
edge_or_node: EdgeOrNode
|
|
|
|
|
|
@dataclass(eq=True, frozen=True)
|
|
class DerivedQuantizationSpec(QuantizationSpecBase):
|
|
"""Quantization spec for the Tensors whose quantization parameters are derived from other Tensors"""
|
|
|
|
derived_from: List[EdgeOrNode]
|
|
derive_qparams_fn: Callable[[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]]
|
|
dtype: torch.dtype
|
|
quant_min: Optional[int] = None
|
|
quant_max: Optional[int] = None
|
|
qscheme: Optional[torch.qscheme] = None
|
|
ch_axis: Optional[int] = None
|
|
|
|
|
|
@dataclass
|
|
class QuantizationAnnotation:
|
|
"""How are input arguemnt or output should be quantized,
|
|
expressed as QuantizationSpec, this corresponds to how a Tensor in the
|
|
operator Graph is observed (PTQ) or fake quantized (QAT)
|
|
"""
|
|
|
|
# a map from torch.fx.Node to a type of QuantizationSpecBase
|
|
input_qspec_map: Dict[Node, Optional[QuantizationSpecBase]] = field(
|
|
default_factory=dict
|
|
)
|
|
|
|
# How the output of this node is quantized, expressed as QuantizationSpec
|
|
# TODO: change the value to QuantizationSpec in a separate PR
|
|
output_qspec: Optional[QuantizationSpecBase] = None
|
|
|
|
# For a Node: node1 and edge: (node1, node2), since they are observing the same
|
|
# Tensor, we may want to implicitly share observers, this flag allows people to
|
|
# turn off this behavior for the output of the node
|
|
allow_implicit_sharing: bool = True
|
|
|
|
# whether the node is annotated or not
|
|
_annotated: bool = False
|
|
|
|
|
|
class Quantizer(ABC):
|
|
# annotate nodes in the graph with observer or fake quant constructors
|
|
# to convey the desired way of quantization
|
|
@abstractmethod
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
pass
|
|
|
|
# validate the annotated graph is supported by the backend
|
|
@abstractmethod
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|