pytorch/torch/ao/quantization/quantizer/xnnpack_quantizer.py
Jerry Zhang 1474dad28c [quant][pt2e][xnnpack] Add support for QAT dynamic quantization for linear in XNNPACKQuantizer (#113288)
Summary:
FX graph mode quant workflow and also pt2e flow relies on the `is_dynamic` flag in observer/quantizationspec to
convert an observer to dynamic quantization patterns (choose_qparams -> q -> dq), this PR added is_dynamic flag
for all observers so that it's possible to convert these observers to the pattern.

However, this dynamic quantization pattern (choose_qparams -> q -> dq) is actually only valid for MovingAverageObserver(averaging_constant=1)
for the computation before convert and after convert to match in the context of QAT. So we'll have some sanity
checks in other observers to make sure the is_dynamic is False.

Test Plan:
python test/test_quantization.py TestXNNPACKQuantizer.test_qat_dynamic_linear

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D51124725](https://our.internmc.facebook.com/intern/diff/D51124725)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113288
Approved by: https://github.com/kimishpatel
2023-12-04 23:06:38 +00:00

450 lines
16 KiB
Python

from __future__ import annotations
import copy
import functools
from typing import Any, Callable, Dict, List, Optional, Set
import torch
import torch._dynamo as torchdynamo
import torch.nn.functional as F
from torch.ao.quantization.fake_quantize import (
FakeQuantize,
FusedMovingAvgObsFakeQuantize,
)
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
_convert_scalars_to_attrs,
OP_TO_ANNOTATOR,
OperatorConfig,
OperatorPatternType,
propagate_annotation,
QuantizationConfig,
)
from torch.fx import Node
__all__ = [
"XNNPACKQuantizer",
"get_symmetric_quantization_config",
]
def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph:
gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs)
gm.graph.eliminate_dead_code()
return gm.graph
def _get_linear_patterns(input_size: List[int]):
in_channels = input_size[-1]
out_channels = 8 # hard coding but this should not matter
weight = torch.ones((out_channels, in_channels))
bias = torch.ones((out_channels,))
act = torch.ones(input_size)
def linear_op(act, weight, bias=None):
return F.linear(act, weight, bias)
pattern_w_bias = _get_dynamo_graph(linear_op, (act, weight, bias))
pattern_wo_bias = _get_dynamo_graph(linear_op, (act, weight))
return [pattern_w_bias, pattern_wo_bias]
def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
supported_operators: Dict[str, List[OperatorPatternType]] = {
# Both conv and linear should be able to handle relu + hardtanh fusion since
# those are clamp ops
"conv2d": [
[torch.nn.Conv2d, torch.nn.ReLU],
[torch.nn.Conv2d, F.relu],
[F.conv2d, torch.nn.ReLU],
[F.conv2d, F.relu],
],
"linear": [[torch.nn.Linear], [F.linear]],
"add": [[torch.add]],
"max_pool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]],
"adaptive_avg_pool2d": [
[torch.nn.AdaptiveAvgPool2d],
[F.adaptive_avg_pool2d],
],
}
return copy.deepcopy(supported_operators)
def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
supported_config_and_operators: List[OperatorConfig] = []
for quantization_config in [
get_symmetric_quantization_config(),
get_symmetric_quantization_config(is_qat=True),
get_symmetric_quantization_config(is_per_channel=True),
get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
]:
ops = _supported_symmetric_quantized_operators()
for pattern_list in ops.values():
supported_config_and_operators.append(
OperatorConfig(quantization_config, pattern_list)
)
return copy.deepcopy(supported_config_and_operators)
@functools.lru_cache
def get_symmetric_quantization_config(
is_per_channel: bool = False,
is_qat: bool = False,
is_dynamic: bool = False,
):
extra_args: Dict[str, Any] = {"eps": 2**-12}
if is_qat:
if is_dynamic:
act_observer_or_fake_quant_ctr = FakeQuantize
dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
averaging_constant=1
)
extra_args["observer"] = dynamic_quant_observer
else:
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
else:
if is_dynamic:
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
else:
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
act_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=is_dynamic,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
**extra_args,
),
)
weight_qscheme = (
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
MinMaxObserver
)
if is_qat:
# TODO: qat + per channel?
weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
elif is_per_channel:
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
extra_args: Dict[str, Any] = {"eps": 2**-12}
if is_qat:
if weight_qscheme == torch.per_tensor_symmetric:
extra_args["observer"] = MovingAverageMinMaxObserver
else:
extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item]
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
qscheme=weight_qscheme,
ch_axis=0,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
**extra_args
),
)
bias_quantization_spec = None
if is_dynamic:
quantization_config = QuantizationConfig(
act_quantization_spec,
None,
weight_quantization_spec,
bias_quantization_spec,
is_qat,
)
else:
quantization_config = QuantizationConfig(
act_quantization_spec,
act_quantization_spec,
weight_quantization_spec,
bias_quantization_spec,
is_qat,
)
return quantization_config
def _get_supported_config_and_operators() -> List[OperatorConfig]:
return _get_supported_symmetric_config_and_operators()
def _get_module_name_filter(module_name: str):
"""Get the module_name_filter function for a given module name, the filter accepts
a node and checks if the node comes from a module that has certain module name
For example:
node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1
>> module_name_filter = _get_module_name_filter("blocks.sub")
>> print(module_name_filter(node))
True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1"
"""
def module_name_filter(n: Node) -> bool:
# example: {
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
# }
# get_attr nodes doesn't have nn_module_stack?
nn_module_stack = n.meta.get("nn_module_stack", {})
names = [
n[len("L__self___") :].replace("_", ".") for n in nn_module_stack.keys()
]
return module_name in names
return module_name_filter
def _get_module_type_filter(tp: Callable):
"""Get the module_type_filter function for a given module type, the filter accepts
a node and checks if the node comes from a module that has certain module type
For example:
node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear
>> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule
>> print(module_type_filter(node))
True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well)
"""
def module_type_filter(n: Node) -> bool:
# example: {
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
# }
nn_module_stack = n.meta.get("nn_module_stack", {})
types = [t for _, t in nn_module_stack.values()]
return tp in types
return module_type_filter
def _get_not_module_type_or_name_filter(
tp_list: List[Callable], module_name_list: List[str]
) -> Callable[[Node], bool]:
module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
def not_module_type_or_name_filter(n: Node) -> bool:
return not any(f(n) for f in module_type_filters + module_name_list_filters)
return not_module_type_or_name_filter
class XNNPACKQuantizer(Quantizer):
supported_config_and_operators = _get_supported_config_and_operators()
STATIC_QAT_ONLY_OPS = [
"conv_bn_relu",
"conv_bn",
]
# static quantization ops (both PTQ and QAT)
STATIC_OPS = [
"linear",
"conv_relu",
"conv",
"adaptive_avg_pool2d",
# TODO: move this to BoltNNQuantizer?
"gru_io_only",
"max_pool2d",
"add_relu",
"add",
"mul_relu",
"mul",
"cat",
]
DYNAMIC_OPS = [
"linear",
]
def __init__(self):
super().__init__()
self.global_config: Optional[QuantizationConfig] = None
self.operator_type_config: Dict[
torch._ops.OpOverloadPacket, Optional[QuantizationConfig]
] = {}
self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {}
self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {}
@classmethod
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
op_configs: Set[QuantizationConfig] = set({})
for spec, _ in cls.supported_config_and_operators:
op_configs.add(spec)
return list(op_configs)
@classmethod
def get_supported_operator_for_quantization_config(
cls, quantization_config: Optional[QuantizationConfig]
) -> List[OperatorPatternType]:
if quantization_config is None:
all_ops = []
for _, ops in cls.supported_config_and_operators:
all_ops.extend(ops)
return all_ops
for config, ops in cls.supported_config_and_operators:
# note: this assumes each entry in cls.supported_spec_and_operators
# corresponds to one spec, e.g. we don't have
# [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
# where the first and second entry have the same spec but did not
# merge the op list
if config == quantization_config:
return ops
return []
def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer:
self.global_config = quantization_config
return self
def set_operator_type(
self,
operator_type: torch._ops.OpOverloadPacket,
quantization_config: QuantizationConfig,
) -> XNNPACKQuantizer:
self.operator_type_config[operator_type] = quantization_config
return self
def set_module_type(
self, module_type: Callable, quantization_config: QuantizationConfig
):
"""Set quantization_config for a submodule with type: `module_type`, for example:
quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator
patterns in the submodule with this module type with the given `quantization_config`
"""
self.module_type_config[module_type] = quantization_config
return self
def set_module_name(
self, module_name: str, quantization_config: Optional[QuantizationConfig]
):
"""Set quantization_config for a submodule with name: `module_name`, for example:
quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
patterns in the submodule with this module name with the given `quantization_config`
"""
assert (
quantization_config is not None
), " quantization_config == None is not supported yet"
self.module_name_config[module_name] = quantization_config
return self
def transform_for_annotation(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
"""Transforms scalar values to tensor attributes"""
return _convert_scalars_to_attrs(model)
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""just handling global spec for now"""
# hacked for handling dynamic linear quant. will fix later.
if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr]
model = self._annotate_for_dynamic_quantization_config(model)
else:
model = self._annotate_for_static_quantization_config(model)
propagate_annotation(model)
return model
def _annotate_all_static_patterns(
self,
model: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> torch.fx.GraphModule:
# TODO: implement the support for None to be canceling out previous annotations
if quantization_config is None:
return model
if quantization_config.is_qat:
for op in self.STATIC_QAT_ONLY_OPS:
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
for op in self.STATIC_OPS:
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
return model
def _annotate_all_dynamic_patterns(
self,
model: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> torch.fx.GraphModule:
# TODO: implement the support for None to be canceling out previous annotations
if quantization_config is None:
return model
for op in self.DYNAMIC_OPS:
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
return model
def _annotate_for_static_quantization_config(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
module_name_list = list(self.module_name_config.keys())
for module_name, config in self.module_name_config.items():
self._annotate_all_static_patterns(
model, config, _get_module_name_filter(module_name)
)
tp_list = list(self.module_type_config.keys())
for module_type, config in self.module_type_config.items():
self._annotate_all_static_patterns(
model, config, _get_module_type_filter(module_type)
)
self._annotate_all_static_patterns(
model,
self.global_config,
_get_not_module_type_or_name_filter(tp_list, module_name_list),
)
return model
def _annotate_for_dynamic_quantization_config(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
module_name_list = list(self.module_name_config.keys())
for module_name, config in self.module_name_config.items():
self._annotate_all_dynamic_patterns(
model, config, _get_module_name_filter(module_name)
)
tp_list = list(self.module_type_config.keys())
for module_type, config in self.module_type_config.items():
self._annotate_all_dynamic_patterns(
model, config, _get_module_type_filter(module_type)
)
self._annotate_all_dynamic_patterns(
model,
self.global_config,
_get_not_module_type_or_name_filter(tp_list, module_name_list),
)
return model
def validate(self, model: torch.fx.GraphModule) -> None:
pass
@classmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
return cls.supported_config_and_operators