mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: att Test Plan: buck2 run executorch/examples/quantization:example -- -m=mv3 --verify Differential Revision: D48588121 Pull Request resolved: https://github.com/pytorch/pytorch/pull/107992 Approved by: https://github.com/digantdesai, https://github.com/mcr229
485 lines
18 KiB
Python
485 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
import copy
|
|
import functools
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Set
|
|
|
|
import torch
|
|
import torch._dynamo as torchdynamo
|
|
import torch.nn.functional as F
|
|
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
|
|
from torch.ao.quantization.observer import (
|
|
HistogramObserver,
|
|
MinMaxObserver,
|
|
MovingAverageMinMaxObserver,
|
|
MovingAveragePerChannelMinMaxObserver,
|
|
PerChannelMinMaxObserver,
|
|
PlaceholderObserver,
|
|
)
|
|
|
|
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
|
|
|
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
|
|
|
|
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
|
|
OP_TO_ANNOTATOR,
|
|
OperatorConfig,
|
|
OperatorPatternType,
|
|
QuantizationConfig,
|
|
)
|
|
|
|
from torch.fx import Node
|
|
|
|
|
|
__all__ = [
|
|
"XNNPACKQuantizer",
|
|
"get_symmetric_quantization_config",
|
|
]
|
|
|
|
|
|
def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph:
|
|
gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs)
|
|
gm.graph.eliminate_dead_code()
|
|
return gm.graph
|
|
|
|
|
|
def _get_linear_patterns(input_size: List[int]):
|
|
in_channels = input_size[-1]
|
|
out_channels = 8 # hard coding but this should not matter
|
|
weight = torch.ones((out_channels, in_channels))
|
|
bias = torch.ones((out_channels,))
|
|
act = torch.ones(input_size)
|
|
|
|
def linear_op(act, weight, bias=None):
|
|
return F.linear(act, weight, bias)
|
|
|
|
pattern_w_bias = _get_dynamo_graph(linear_op, (act, weight, bias))
|
|
pattern_wo_bias = _get_dynamo_graph(linear_op, (act, weight))
|
|
return [pattern_w_bias, pattern_wo_bias]
|
|
|
|
|
|
def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
|
supported_operators: Dict[str, List[OperatorPatternType]] = {
|
|
# Both conv and linear should be able to handle relu + hardtanh fusion since
|
|
# those are clamp ops
|
|
"conv2d": [
|
|
[torch.nn.Conv2d, torch.nn.ReLU],
|
|
[torch.nn.Conv2d, F.relu],
|
|
[F.conv2d, torch.nn.ReLU],
|
|
[F.conv2d, F.relu],
|
|
],
|
|
"linear": [[torch.nn.Linear], [F.linear]],
|
|
"add": [[torch.add]],
|
|
"max_pool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]],
|
|
"adaptive_avg_pool2d": [
|
|
[torch.nn.AdaptiveAvgPool2d],
|
|
[F.adaptive_avg_pool2d],
|
|
],
|
|
}
|
|
return copy.deepcopy(supported_operators)
|
|
|
|
|
|
def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
|
|
supported_config_and_operators: List[OperatorConfig] = []
|
|
for quantization_config in [
|
|
get_symmetric_quantization_config(),
|
|
get_symmetric_quantization_config(is_qat=True),
|
|
get_symmetric_quantization_config(is_per_channel=True),
|
|
get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
|
|
]:
|
|
ops = _supported_symmetric_quantized_operators()
|
|
for pattern_list in ops.values():
|
|
supported_config_and_operators.append(
|
|
OperatorConfig(quantization_config, pattern_list)
|
|
)
|
|
return copy.deepcopy(supported_config_and_operators)
|
|
|
|
|
|
@functools.lru_cache
|
|
def get_symmetric_quantization_config(
|
|
is_per_channel: bool = False,
|
|
is_qat: bool = False,
|
|
is_dynamic: bool = False,
|
|
):
|
|
if is_qat:
|
|
if is_dynamic:
|
|
raise NotImplementedError(
|
|
"dynamic quantization for qat is not yet implemented."
|
|
)
|
|
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
|
|
else:
|
|
if is_dynamic:
|
|
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
|
|
else:
|
|
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
|
|
|
|
act_quantization_spec = QuantizationSpec(
|
|
dtype=torch.int8,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=is_dynamic,
|
|
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
|
|
eps=2**-12
|
|
),
|
|
)
|
|
qscheme = (
|
|
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
|
|
)
|
|
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
|
MinMaxObserver
|
|
)
|
|
if is_qat:
|
|
weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
|
|
elif is_per_channel:
|
|
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
|
|
|
|
extra_args: Dict[str, Any] = {"eps": 2**-12}
|
|
if is_qat:
|
|
if qscheme == torch.per_tensor_symmetric:
|
|
extra_args["observer"] = MovingAverageMinMaxObserver
|
|
else:
|
|
extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item]
|
|
weight_quantization_spec = QuantizationSpec(
|
|
dtype=torch.int8,
|
|
quant_min=-127,
|
|
quant_max=127,
|
|
qscheme=qscheme,
|
|
ch_axis=0,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
|
|
**extra_args
|
|
),
|
|
)
|
|
|
|
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
|
PlaceholderObserver
|
|
)
|
|
bias_quantization_spec = QuantizationSpec(
|
|
dtype=torch.float, observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr
|
|
)
|
|
if is_dynamic:
|
|
quantization_config = QuantizationConfig(
|
|
act_quantization_spec,
|
|
None,
|
|
weight_quantization_spec,
|
|
bias_quantization_spec,
|
|
is_qat,
|
|
)
|
|
else:
|
|
quantization_config = QuantizationConfig(
|
|
act_quantization_spec,
|
|
act_quantization_spec,
|
|
weight_quantization_spec,
|
|
bias_quantization_spec,
|
|
is_qat,
|
|
)
|
|
return quantization_config
|
|
|
|
|
|
def _get_supported_config_and_operators() -> List[OperatorConfig]:
|
|
return _get_supported_symmetric_config_and_operators()
|
|
|
|
|
|
def _get_module_name_filter(module_name: str):
|
|
"""Get the module_name_filter function for a given module name, the filter accepts
|
|
a node and checks if the node comes from a module that has certain module name
|
|
|
|
For example:
|
|
node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1
|
|
|
|
|
|
>> module_name_filter = _get_module_name_filter("blocks.sub")
|
|
>> print(module_name_filter(node))
|
|
True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1"
|
|
"""
|
|
|
|
def module_name_filter(n: Node) -> bool:
|
|
# example: {
|
|
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
|
|
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
|
|
# }
|
|
nn_module_stack = n.meta["nn_module_stack"]
|
|
names = [
|
|
n[len("L__self___") :].replace("_", ".") for n in nn_module_stack.keys()
|
|
]
|
|
return module_name in names
|
|
|
|
return module_name_filter
|
|
|
|
|
|
def _get_module_type_filter(tp: Callable):
|
|
"""Get the module_type_filter function for a given module type, the filter accepts
|
|
a node and checks if the node comes from a module that has certain module type
|
|
|
|
For example:
|
|
node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear
|
|
|
|
|
|
>> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule
|
|
>> print(module_type_filter(node))
|
|
True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well)
|
|
"""
|
|
|
|
def module_type_filter(n: Node) -> bool:
|
|
# example: {
|
|
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
|
|
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
|
|
# }
|
|
nn_module_stack = n.meta["nn_module_stack"]
|
|
types = [t for _, t in nn_module_stack.values()]
|
|
return tp in types
|
|
|
|
return module_type_filter
|
|
|
|
|
|
class XNNPACKQuantizer(Quantizer):
|
|
supported_config_and_operators = _get_supported_config_and_operators()
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.global_config: Optional[QuantizationConfig] = None
|
|
self.operator_type_config: Dict[str, Optional[QuantizationConfig]] = {}
|
|
self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {}
|
|
self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {}
|
|
|
|
@classmethod
|
|
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
|
|
op_configs: Set[QuantizationConfig] = set({})
|
|
for spec, _ in cls.supported_config_and_operators:
|
|
op_configs.add(spec)
|
|
return list(op_configs)
|
|
|
|
@classmethod
|
|
def get_supported_operator_for_quantization_config(
|
|
cls, quantization_config: Optional[QuantizationConfig]
|
|
) -> List[OperatorPatternType]:
|
|
if quantization_config is None:
|
|
all_ops = []
|
|
for _, ops in cls.supported_config_and_operators:
|
|
all_ops.extend(ops)
|
|
return all_ops
|
|
|
|
for config, ops in cls.supported_config_and_operators:
|
|
# note: this assumes each entry in cls.supported_spec_and_operators
|
|
# corresponds to one spec, e.g. we don't have
|
|
# [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
|
|
# where the first and second entry have the same spec but did not
|
|
# merge the op list
|
|
if config == quantization_config:
|
|
return ops
|
|
return []
|
|
|
|
def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer:
|
|
self.global_config = quantization_config
|
|
return self
|
|
|
|
def set_operator_type(
|
|
self, operator_type: str, quantization_config: QuantizationConfig
|
|
) -> XNNPACKQuantizer:
|
|
self.operator_type_config[operator_type] = quantization_config
|
|
return self
|
|
|
|
def set_module_type(
|
|
self, module_type: Callable, quantization_config: QuantizationConfig
|
|
):
|
|
"""Set quantization_config for a submodule with type: `module_type`, for example:
|
|
quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator
|
|
patterns in the submodule with this module type with the given `quantization_config`
|
|
"""
|
|
self.module_type_config[module_type] = quantization_config
|
|
return self
|
|
|
|
def set_module_name(
|
|
self, module_name: str, quantization_config: Optional[QuantizationConfig]
|
|
):
|
|
"""Set quantization_config for a submodule with name: `module_name`, for example:
|
|
quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
|
|
patterns in the submodule with this module name with the given `quantization_config`
|
|
"""
|
|
assert (
|
|
quantization_config is not None
|
|
), " quantization_config == None is not supported yet"
|
|
self.module_name_config[module_name] = quantization_config
|
|
return self
|
|
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
"""just handling global spec for now"""
|
|
# hacked for handling dynamic linear quant. will fix later.
|
|
if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr]
|
|
model = self._annotate_for_dynamic_quantization_config(model)
|
|
else:
|
|
model = self._annotate_for_static_quantization_config(model)
|
|
return model
|
|
|
|
def _annotate_all_patterns(
|
|
self,
|
|
model: torch.fx.GraphModule,
|
|
config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> torch.fx.GraphModule:
|
|
# TODO: implement the support for None to be canceling out previous annotations
|
|
if config is None:
|
|
return model
|
|
|
|
# assert config is not None
|
|
|
|
self._annotate_linear(model, config, filter_fn)
|
|
self._annotate_conv2d_patterns(model, config, filter_fn)
|
|
self._annotate_max_pool2d(model, config, filter_fn)
|
|
self._annotate_add_patterns(model, config, filter_fn)
|
|
OP_TO_ANNOTATOR["mul_relu"](model, config, filter_fn)
|
|
OP_TO_ANNOTATOR["mul"](model, config, filter_fn)
|
|
self._annotate_adaptive_avg_pool2d(model, config, filter_fn)
|
|
self._annotate_gru_io_only(model, config, filter_fn)
|
|
return model
|
|
|
|
def _annotate_for_static_quantization_config(
|
|
self, model: torch.fx.GraphModule
|
|
) -> torch.fx.GraphModule:
|
|
for module_name, config in self.module_name_config.items():
|
|
self._annotate_all_patterns(
|
|
model, config, _get_module_name_filter(module_name)
|
|
)
|
|
|
|
for module_type, config in self.module_type_config.items():
|
|
self._annotate_all_patterns(
|
|
model, config, _get_module_type_filter(module_type)
|
|
)
|
|
|
|
self._annotate_all_patterns(model, self.global_config)
|
|
return model
|
|
|
|
def _annotate_for_dynamic_quantization_config(
|
|
self, model: torch.fx.GraphModule
|
|
) -> torch.fx.GraphModule:
|
|
for module_name, config in self.module_name_config.items():
|
|
self._annotate_linear(model, config, _get_module_name_filter(module_name))
|
|
|
|
for module_type, config in self.module_type_config.items():
|
|
self._annotate_linear(model, config, _get_module_type_filter(module_type))
|
|
|
|
self._annotate_linear(model, self.global_config)
|
|
return model
|
|
|
|
def _annotate_conv2d_patterns(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
if quantization_config is None:
|
|
return
|
|
|
|
if quantization_config.is_qat:
|
|
self._annotate_conv2d_bn_relu(gm, quantization_config, filter_fn)
|
|
self._annotate_conv2d_bn(gm, quantization_config, filter_fn)
|
|
self._annotate_conv2d_relu(gm, quantization_config, filter_fn)
|
|
self._annotate_conv2d(gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_conv2d_bn(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
"""
|
|
Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
|
|
"""
|
|
return OP_TO_ANNOTATOR["conv2d_bn"](gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_conv2d_bn_relu(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
"""
|
|
Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
|
|
"""
|
|
return OP_TO_ANNOTATOR["conv2d_bn_relu"](gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_conv2d_relu(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
return OP_TO_ANNOTATOR["conv2d_relu"](gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_conv2d(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
return OP_TO_ANNOTATOR["conv2d"](gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_linear(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
return OP_TO_ANNOTATOR["linear"](gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_adaptive_avg_pool2d(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
return OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
|
|
gm, quantization_config, filter_fn
|
|
)
|
|
|
|
# TODO: move this to BoltNNQuantizer?
|
|
def _annotate_gru_io_only(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
return OP_TO_ANNOTATOR["gru_io_only"](gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_max_pool2d(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
return OP_TO_ANNOTATOR["max_pool2d"](gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_add_patterns(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
self._annotate_add_relu(gm, quantization_config, filter_fn)
|
|
self._annotate_add(gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_add_relu(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
return OP_TO_ANNOTATOR["add_relu"](gm, quantization_config, filter_fn)
|
|
|
|
def _annotate_add(
|
|
self,
|
|
gm: torch.fx.GraphModule,
|
|
quantization_config: Optional[QuantizationConfig],
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
) -> None:
|
|
return OP_TO_ANNOTATOR["add"](gm, quantization_config, filter_fn)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
@classmethod
|
|
def get_supported_operators(cls) -> List[OperatorConfig]:
|
|
return cls.supported_config_and_operators
|