pytorch/torch/ao/quantization/quantizer/xnnpack_quantizer.py
Jerry Zhang 28be2c674a [quant][pt2e] Move specific quantizer related things outside of main quant code base (#106806) (#107259)
Summary:

Currently in quantizer/quantize_pt2e we import things from specific quantizers (XNNPACKQuantizer, QuantizationConfig) etc.
this PR removes them so it's clearer that they are not part of the core quantization code base

This PR also removed get_supported_operators from main Quantizer since we haven't seen a clear need for this API

Test Plan:
CIs

Imported from OSS

Differential Revision: D48340367

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107259
Approved by: https://github.com/kimishpatel
2023-08-18 21:29:09 +00:00

483 lines
18 KiB
Python

from __future__ import annotations
import copy
import functools
from typing import Any, Callable, Dict, List, Optional, Set
import torch
import torch._dynamo as torchdynamo
import torch.nn.functional as F
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
OP_TO_ANNOTATOR,
OperatorConfig,
OperatorPatternType,
QuantizationConfig,
)
from torch.fx import Node
__all__ = [
"XNNPACKQuantizer",
"get_symmetric_quantization_config",
]
def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph:
gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs)
gm.graph.eliminate_dead_code()
return gm.graph
def _get_linear_patterns(input_size: List[int]):
in_channels = input_size[-1]
out_channels = 8 # hard coding but this should not matter
weight = torch.ones((out_channels, in_channels))
bias = torch.ones((out_channels,))
act = torch.ones(input_size)
def linear_op(act, weight, bias=None):
return F.linear(act, weight, bias)
pattern_w_bias = _get_dynamo_graph(linear_op, (act, weight, bias))
pattern_wo_bias = _get_dynamo_graph(linear_op, (act, weight))
return [pattern_w_bias, pattern_wo_bias]
def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
supported_operators: Dict[str, List[OperatorPatternType]] = {
# Both conv and linear should be able to handle relu + hardtanh fusion since
# those are clamp ops
"conv2d": [
[torch.nn.Conv2d, torch.nn.ReLU],
[torch.nn.Conv2d, F.relu],
[F.conv2d, torch.nn.ReLU],
[F.conv2d, F.relu],
],
"linear": [[torch.nn.Linear], [F.linear]],
"add": [[torch.add]],
"max_pool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]],
"adaptive_avg_pool2d": [
[torch.nn.AdaptiveAvgPool2d],
[F.adaptive_avg_pool2d],
],
}
return copy.deepcopy(supported_operators)
def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
supported_config_and_operators: List[OperatorConfig] = []
for quantization_config in [
get_symmetric_quantization_config(),
get_symmetric_quantization_config(is_qat=True),
get_symmetric_quantization_config(is_per_channel=True),
get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
]:
ops = _supported_symmetric_quantized_operators()
for pattern_list in ops.values():
supported_config_and_operators.append(
OperatorConfig(quantization_config, pattern_list)
)
return copy.deepcopy(supported_config_and_operators)
@functools.lru_cache
def get_symmetric_quantization_config(
is_per_channel: bool = False,
is_qat: bool = False,
is_dynamic: bool = False,
):
if is_qat:
if is_dynamic:
raise NotImplementedError(
"dynamic quantization for qat is not yet implemented."
)
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
else:
if is_dynamic:
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
else:
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
act_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=is_dynamic,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
eps=2**-12
),
)
qscheme = (
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
MinMaxObserver
)
if is_qat:
weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
elif is_per_channel:
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
extra_args: Dict[str, Any] = {"eps": 2**-12}
if is_qat:
if qscheme == torch.per_tensor_symmetric:
extra_args["observer"] = MovingAverageMinMaxObserver
else:
extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item]
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
qscheme=qscheme,
ch_axis=0,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
**extra_args
),
)
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
PlaceholderObserver
)
bias_quantization_spec = QuantizationSpec(
dtype=torch.float, observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr
)
if is_dynamic:
quantization_config = QuantizationConfig(
act_quantization_spec,
None,
weight_quantization_spec,
bias_quantization_spec,
is_qat,
)
else:
quantization_config = QuantizationConfig(
act_quantization_spec,
act_quantization_spec,
weight_quantization_spec,
bias_quantization_spec,
is_qat,
)
return quantization_config
def _get_supported_config_and_operators() -> List[OperatorConfig]:
return _get_supported_symmetric_config_and_operators()
def _get_module_name_filter(module_name: str):
"""Get the module_name_filter function for a given module name, the filter accepts
a node and checks if the node comes from a module that has certain module name
For example:
node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1
>> module_name_filter = _get_module_name_filter("blocks.sub")
>> print(module_name_filter(node))
True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1"
"""
def module_name_filter(n: Node) -> bool:
# example: {
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
# }
nn_module_stack = n.meta["nn_module_stack"]
names = [
n[len("L__self___") :].replace("_", ".") for n in nn_module_stack.keys()
]
return module_name in names
return module_name_filter
def _get_module_type_filter(tp: Callable):
"""Get the module_type_filter function for a given module type, the filter accepts
a node and checks if the node comes from a module that has certain module type
For example:
node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear
>> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule
>> print(module_type_filter(node))
True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well)
"""
def module_type_filter(n: Node) -> bool:
# example: {
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
# }
nn_module_stack = n.meta["nn_module_stack"]
types = [t for _, t in nn_module_stack.values()]
return tp in types
return module_type_filter
class XNNPACKQuantizer(Quantizer):
supported_config_and_operators = _get_supported_config_and_operators()
def __init__(self):
super().__init__()
self.global_config: Optional[QuantizationConfig] = None
self.operator_type_config: Dict[str, Optional[QuantizationConfig]] = {}
self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {}
self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {}
@classmethod
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
op_configs: Set[QuantizationConfig] = set({})
for spec, _ in cls.supported_config_and_operators:
op_configs.add(spec)
return list(op_configs)
@classmethod
def get_supported_operator_for_quantization_config(
cls, quantization_config: Optional[QuantizationConfig]
) -> List[OperatorPatternType]:
if quantization_config is None:
all_ops = []
for _, ops in cls.supported_config_and_operators:
all_ops.extend(ops)
return all_ops
for config, ops in cls.supported_config_and_operators:
# note: this assumes each entry in cls.supported_spec_and_operators
# corresponds to one spec, e.g. we don't have
# [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
# where the first and second entry have the same spec but did not
# merge the op list
if config == quantization_config:
return ops
return []
def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer:
self.global_config = quantization_config
return self
def set_operator_type(
self, operator_type: str, quantization_config: QuantizationConfig
) -> XNNPACKQuantizer:
self.operator_type_config[operator_type] = quantization_config
return self
def set_module_type(
self, module_type: Callable, quantization_config: QuantizationConfig
):
"""Set quantization_config for a submodule with type: `module_type`, for example:
quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator
patterns in the submodule with this module type with the given `quantization_config`
"""
self.module_type_config[module_type] = quantization_config
return self
def set_module_name(
self, module_name: str, quantization_config: Optional[QuantizationConfig]
):
"""Set quantization_config for a submodule with name: `module_name`, for example:
quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
patterns in the submodule with this module name with the given `quantization_config`
"""
assert (
quantization_config is not None
), " quantization_config == None is not supported yet"
self.module_name_config[module_name] = quantization_config
return self
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""just handling global spec for now"""
# hacked for handling dynamic linear quant. will fix later.
if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr]
model = self._annotate_for_dynamic_quantization_config(model)
else:
model = self._annotate_for_static_quantization_config(model)
return model
def _annotate_all_patterns(
self,
model: torch.fx.GraphModule,
config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> torch.fx.GraphModule:
# TODO: implement the support for None to be canceling out previous annotations
if config is None:
return model
# assert config is not None
self._annotate_linear(model, config, filter_fn)
self._annotate_conv2d_patterns(model, config, filter_fn)
self._annotate_max_pool2d(model, config, filter_fn)
self._annotate_add_patterns(model, config, filter_fn)
self._annotate_adaptive_avg_pool2d(model, config, filter_fn)
self._annotate_gru_io_only(model, config, filter_fn)
return model
def _annotate_for_static_quantization_config(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
for module_name, config in self.module_name_config.items():
self._annotate_all_patterns(
model, config, _get_module_name_filter(module_name)
)
for module_type, config in self.module_type_config.items():
self._annotate_all_patterns(
model, config, _get_module_type_filter(module_type)
)
self._annotate_all_patterns(model, self.global_config)
return model
def _annotate_for_dynamic_quantization_config(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
for module_name, config in self.module_name_config.items():
self._annotate_linear(model, config, _get_module_name_filter(module_name))
for module_type, config in self.module_type_config.items():
self._annotate_linear(model, config, _get_module_type_filter(module_type))
self._annotate_linear(model, self.global_config)
return model
def _annotate_conv2d_patterns(
self,
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
if quantization_config is None:
return
if quantization_config.is_qat:
self._annotate_conv2d_bn_relu(gm, quantization_config, filter_fn)
self._annotate_conv2d_bn(gm, quantization_config, filter_fn)
self._annotate_conv2d_relu(gm, quantization_config, filter_fn)
self._annotate_conv2d(gm, quantization_config, filter_fn)
def _annotate_conv2d_bn(
self,
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
"""
Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
"""
return OP_TO_ANNOTATOR["conv2d_bn"](gm, quantization_config, filter_fn)
def _annotate_conv2d_bn_relu(
self,
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
"""
Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
"""
return OP_TO_ANNOTATOR["conv2d_bn_relu"](gm, quantization_config, filter_fn)
def _annotate_conv2d_relu(
self,
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
return OP_TO_ANNOTATOR["conv2d_relu"](gm, quantization_config, filter_fn)
def _annotate_conv2d(
self,
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
return OP_TO_ANNOTATOR["conv2d"](gm, quantization_config, filter_fn)
def _annotate_linear(
self,
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
return OP_TO_ANNOTATOR["linear"](gm, quantization_config, filter_fn)
def _annotate_adaptive_avg_pool2d(
self,
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
return OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
gm, quantization_config, filter_fn
)
# TODO: move this to BoltNNQuantizer?
def _annotate_gru_io_only(
self,
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
return OP_TO_ANNOTATOR["gru_io_only"](gm, quantization_config, filter_fn)
def _annotate_max_pool2d(
self,
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
return OP_TO_ANNOTATOR["max_pool2d"](gm, quantization_config, filter_fn)
def _annotate_add_patterns(
self,
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
self._annotate_add_relu(gm, quantization_config, filter_fn)
self._annotate_add(gm, quantization_config, filter_fn)
def _annotate_add_relu(
self,
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
return OP_TO_ANNOTATOR["add_relu"](gm, quantization_config, filter_fn)
def _annotate_add(
self,
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> None:
return OP_TO_ANNOTATOR["add"](gm, quantization_config, filter_fn)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
@classmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
return cls.supported_config_and_operators