mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Currently in quantizer/quantize_pt2e we import things from specific quantizers (XNNPACKQuantizer, QuantizationConfig) etc. this PR removes them so it's clearer that they are not part of the core quantization code base This PR also removed get_supported_operators from main Quantizer since we haven't seen a clear need for this API Test Plan: CIs Imported from OSS Differential Revision: D48340367 Pull Request resolved: https://github.com/pytorch/pytorch/pull/107259 Approved by: https://github.com/kimishpatel
483 lines
18 KiB
Python
483 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
import copy
|
|
import functools
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Set
|
|
|
|
import torch
|
|
import torch._dynamo as torchdynamo
|
|
import torch.nn.functional as F
|
|
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
|
|
from torch.ao.quantization.observer import (
|
|
HistogramObserver,
|
|
MinMaxObserver,
|
|
MovingAverageMinMaxObserver,
|
|
MovingAveragePerChannelMinMaxObserver,
|
|
PerChannelMinMaxObserver,
|
|
PlaceholderObserver,
|
|
)
|
|
|
|
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
|
|
|
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
|
|
|
|
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
|
|
OP_TO_ANNOTATOR,
|
|
OperatorConfig,
|
|
OperatorPatternType,
|
|
QuantizationConfig,
|
|
)
|
|
|
|
from torch.fx import Node
|
|
|
|
|
|
__all__ = [
|
|
"XNNPACKQuantizer",
|
|
"get_symmetric_quantization_config",
|
|
]
|
|
|
|
|
|
def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph:
|
|
gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs)
|
|
gm.graph.eliminate_dead_code()
|
|
return gm.graph
|
|
|
|
|
|
def _get_linear_patterns(input_size: List[int]):
|
|
in_channels = input_size[-1]
|
|
out_channels = 8 # hard coding but this should not matter
|
|
weight = torch.ones((out_channels, in_channels))
|
|
bias = torch.ones((out_channels,))
|
|
act = torch.ones(input_size)
|
|
|
|
def linear_op(act, weight, bias=None):
|
|
return F.linear(act, weight, bias)
|
|
|
|
pattern_w_bias = _get_dynamo_graph(linear_op, (act, weight, bias))
|
|
pattern_wo_bias = _get_dynamo_graph(linear_op, (act, weight))
|
|
return [pattern_w_bias, pattern_wo_bias]
|
|
|
|
|
|
def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
|
supported_operators: Dict[str, List[OperatorPatternType]] = {
|
|
# Both conv and linear should be able to handle relu + hardtanh fusion since
|
|
# those are clamp ops
|
|
"conv2d": [
|
|
[torch.nn.Conv2d, torch.nn.ReLU],
|
|
[torch.nn.Conv2d, F.relu],
|
|
[F.conv2d, torch.nn.ReLU],
|
|
[F.conv2d, F.relu],
|
|
],
|
|
"linear": [[torch.nn.Linear], [F.linear]],
|
|
"add": [[torch.add]],
|
|
"max_pool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]],
|
|
"adaptive_avg_pool2d": [
|
|
[torch.nn.AdaptiveAvgPool2d],
|
|
[F.adaptive_avg_pool2d],
|
|
],
|
|
}
|
|
return copy.deepcopy(supported_operators)
|
|
|
|
|
|
def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
|
|
supported_config_and_operators: List[OperatorConfig] = []
|
|
for quantization_config in [
|
|
get_symmetric_quantization_config(),
|
|
get_symmetric_quantization_config(is_qat=True),
|
|
get_symmetric_quantization_config(is_per_channel=True),
|
|
get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
|
|
]:
|
|
ops = _supported_symmetric_quantized_operators()
|
|
for pattern_list in ops.values():
|
|
supported_config_and_operators.append(
|
|
OperatorConfig(quantization_config, pattern_list)
|
|
)
|
|
return copy.deepcopy(supported_config_and_operators)
|
|
|
|
|
|
@functools.lru_cache
|
|
def get_symmetric_quantization_config(
|
|
is_per_channel: bool = False,
|
|
is_qat: bool = False,
|
|
is_dynamic: bool = False,
|
|
):
|
|
if is_qat:
|
|
if is_dynamic:
|
|
raise NotImplementedError(
|
|
"dynamic quantization for qat is not yet implemented."
|
|
)
|
|
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
|
|
else:
|
|
if is_dynamic:
|
|
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
|
|
else:
|
|
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
|
|
|
|
act_quantization_spec = QuantizationSpec(
|
|
dtype=torch.int8,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=is_dynamic,
|
|
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
|
|
eps=2**-12
|
|
),
|
|
)
|
|
qscheme = (
|
|
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
|
|
)
|
|
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
|
MinMaxObserver
|
|
)
|
|
if is_qat:
|
|
weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
|
|
elif is_per_channel:
|
|
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
|
|
|
|
extra_args: Dict[str, Any] = {"eps": 2**-12}
|
|
if is_qat:
|
|
if qscheme == torch.per_tensor_symmetric:
|
|
extra_args["observer"] = MovingAverageMinMaxObserver
|
|
else:
|
|
extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item]
|
|
weight_quantization_spec = QuantizationSpec(
|
|
dtype=torch.int8,
|
|
quant_min=-127,
|
|
quant_max=127,
|
|
qscheme=qscheme,
|
|
ch_axis=0,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
|
|
**extra_args
|
|
),
|
|
)
|
|
|
|
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
|
PlaceholderObserver
|
|
)
|
|
bias_quantization_spec = QuantizationSpec(
|
|
dtype=torch.float, observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr
|
|
)
|
|
if is_dynamic:
|
|
quantization_config = QuantizationConfig(
|
|
act_quantization_spec,
|
|
None,
|
|
weight_quantization_spec,
|
|
bias_quantization_spec,
|
|
is_qat,
|
|
)
|
|
else:
|
|
quantization_config = QuantizationConfig(
|
|
act_quantization_spec,
|
|
act_quantization_spec,
|
|
weight_quantization_spec,
|
|
bias_quantization_spec,
|
|
is_qat,
|
|
)
|
|
return quantization_config
|
|
|
|
|
|
def _get_supported_config_and_operators() -> List[OperatorConfig]:
|
|
return _get_supported_symmetric_config_and_operators()
|
|
|
|
|
|
def _get_module_name_filter(module_name: str):
|
|
"""Get the module_name_filter function for a given module name, the filter accepts
|
|
a node and checks if the node comes from a module that has certain module name
|
|
|
|
For example:
|
|
node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1
|
|
|
|
|
|
>> module_name_filter = _get_module_name_filter("blocks.sub")
|
|
>> print(module_name_filter(node))
|
|
True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1"
|
|
"""
|
|
|
|
def module_name_filter(n: Node) -> bool:
|
|
# example: {
|
|
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
|
|
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
|
|
# }
|
|
nn_module_stack = n.meta["nn_module_stack"]
|
|
names = [
|
|
n[len("L__self___") :].replace("_", ".") for n in nn_module_stack.keys()
|
|
]
|
|
return module_name in names
|
|
|
|
return module_name_filter
|
|
|
|
|
|
def _get_module_type_filter(tp: Callable):
|
|
"""Get the module_type_filter function for a given module type, the filter accepts
|
|
a node and checks if the node comes from a module that has certain module type
|
|
|
|
For example:
|
|
node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear
|
|
|
|
|
|
>> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule
|
|
>> print(module_type_filter(node))
|
|
True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well)
|
|
"""
|
|
|
|
def module_type_filter(n: Node) -> bool:
|
|
# example: {
|
|
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
|
|
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
|
|
# }
|
|
nn_module_stack = n.meta["nn_module_stack"]
|
|
types = [t for _, t in nn_module_stack.values()]
|
|
return tp in types
|
|
|
|
return module_type_filter
|
|
|
|
|
|
class XNNPACKQuantizer(Quantizer):
|
|
supported_config_and_operators = _get_supported_config_and_operators()
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.global_config: Optional[QuantizationConfig] = None
|
|
self.operator_type_config: Dict[str, Optional[QuantizationConfig]] = {}
|
|
self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {}
|
|
self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {}
|
|
|
|
@classmethod
|
|
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
|
|
op_configs: Set[QuantizationConfig] = set({})
|
|
for spec, _ in cls.supported_config_and_operators:
|
|
op_configs.add(spec)
|
|
return list(op_configs)
|
|
|
|
@classmethod
|
|
def get_supported_operator_for_quantization_config(
|
|
cls, quantization_config: Optional[QuantizationConfig]
|
|
) -> List[OperatorPatternType]:
|
|
if quantization_config is None:
|
|
all_ops = []
|
|
for _, ops in cls.supported_config_and_operators:
|
|
all_ops.extend(ops)
|
|
return all_ops
|
|
|
|
for config, ops in cls.supported_config_and_operators:
|
|
# note: this assumes each entry in cls.supported_spec_and_operators
|
|
# corresponds to one spec, e.g. we don't have
|
|
# [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
|
|
# where the first and second entry have the same spec but did not
|
|
# merge the op list
|
|
if config == quantization_config:
|
|
return ops
|
|
return []
|
|
|
|
def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer:
|
|
self.global_config = quantization_config
|
|
return self
|
|
|
|
def set_operator_type(
|
|
self, operator_type: str, quantization_config: QuantizationConfig
|
|
) -> XNNPACKQuantizer:
|
|
self.operator_type_config[operator_type] = quantization_config
|
|
return self
|
|
|
|
def set_module_type(
|
|
self, module_type: Callable, quantization_config: QuantizationConfig
|
|
):
|
|
"""Set quantization_config for a submodule with type: `module_type`, for example:
|
|
quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator
|
|
patterns in the submodule with this module type with the given `quantization_config`
|
|
"""
|
|
self.module_type_config[module_type] = quantization_config
|
|
return self
|
|
|
|
def set_module_name(
|
|
self, module_name: str, quantization_config: Optional[QuantizationConfig]
|
|
):
|
|
"""Set quantization_config for a submodule with name: `module_name`, for example:
|
|
quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
|
|
patterns in the submodule with this module name with the given `quantization_config`
|
|
"""
|
|
assert (
|
|
quantization_config is not None
|
|
), " quantization_config == None is not supported yet"
|
|
self.module_name_config[module_name] = quantization_config
|
|
return self
|
|
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
"""just handling global spec for now"""
|
|
# hacked for handling dynamic linear quant. will fix later.
|
|
if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr]
|
|
model = self._annotate_for_dynamic_quantization_config(model)
|
|
else:
|
|
model = self._annotate_for_static_quantization_config(model)
|
|
return model
|
|
|
|
def _annotate_all_patterns(
|
|
self,
|
|
model: torch.fx.GraphModule,
|
|
config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> torch.fx.GraphModule:
|
|
# TODO: implement the support for None to be canceling out previous annotations
|
|
if config is None:
|
|
return model
|
|
|
|
# assert config is not None
|
|
|
|
self._annotate_linear(model, config, filter_fn)
|
|
self._annotate_conv2d_patterns(model, config, filter_fn)
|
|
self._annotate_max_pool2d(model, config, filter_fn)
|
|
self._annotate_add_patterns(model, config, filter_fn)
|
|
self._annotate_adaptive_avg_pool2d(model, config, filter_fn)
|
|
self._annotate_gru_io_only(model, config, filter_fn)
|
|
return model
|
|
|
|
def _annotate_for_static_quantization_config(
|
|
self, model: torch.fx.GraphModule
|
|
) -> torch.fx.GraphModule:
|
|
for module_name, config in self.module_name_config.items():
|
|
self._annotate_all_patterns(
|
|
model, config, _get_module_name_filter(module_name)
|
|
)
|
|
|
|
for module_type, config in self.module_type_config.items():
|
|
self._annotate_all_patterns(
|
|
model, config, _get_module_type_filter(module_type)
|
|
)
|
|
|
|
self._annotate_all_patterns(model, self.global_config)
|
|
return model
|
|
|
|
def _annotate_for_dynamic_quantization_config(
|
|
self, model: torch.fx.GraphModule
|
|
) -> torch.fx.GraphModule:
|
|
for module_name, config in self.module_name_config.items():
|
|
self._annotate_linear(model, config, _get_module_name_filter(module_name))
|
|
|
|
for module_type, config in self.module_type_config.items():
|
|
self._annotate_linear(model, config, _get_module_type_filter(module_type))
|
|
|
|
self._annotate_linear(model, self.global_config)
|
|
return model
|
|
|
|
def _annotate_conv2d_patterns(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
if quantization_config is None:
|
|
return
|
|
|
|
if quantization_config.is_qat:
|
|
self._annotate_conv2d_bn_relu(gm, quantization_config, filter_fn)
|
|
self._annotate_conv2d_bn(gm, quantization_config, filter_fn)
|
|
self._annotate_conv2d_relu(gm, quantization_config, filter_fn)
|
|
self._annotate_conv2d(gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_conv2d_bn(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
"""
|
|
Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
|
|
"""
|
|
return OP_TO_ANNOTATOR["conv2d_bn"](gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_conv2d_bn_relu(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
"""
|
|
Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
|
|
"""
|
|
return OP_TO_ANNOTATOR["conv2d_bn_relu"](gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_conv2d_relu(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
return OP_TO_ANNOTATOR["conv2d_relu"](gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_conv2d(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
return OP_TO_ANNOTATOR["conv2d"](gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_linear(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
return OP_TO_ANNOTATOR["linear"](gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_adaptive_avg_pool2d(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
return OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
|
|
gm, quantization_config, filter_fn
|
|
)
|
|
|
|
# TODO: move this to BoltNNQuantizer?
|
|
def _annotate_gru_io_only(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
return OP_TO_ANNOTATOR["gru_io_only"](gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_max_pool2d(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
return OP_TO_ANNOTATOR["max_pool2d"](gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_add_patterns(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
self._annotate_add_relu(gm, quantization_config, filter_fn)
|
|
self._annotate_add(gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_add_relu(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
return OP_TO_ANNOTATOR["add_relu"](gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_add(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
return OP_TO_ANNOTATOR["add"](gm, quantization_config, filter_fn)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
@classmethod
|
|
def get_supported_operators(cls) -> List[OperatorConfig]:
|
|
return cls.supported_config_and_operators
|