pytorch/torch/ao/quantization/quantizer/xnnpack_quantizer.py
Jerry Zhang 3a77f9aaaf [quant][api] Move torch.ao.quantization.pt2e.quantizer to torch.ao.quantization.quantizer (#105885)
Summary: moving quantizer to torch.ao.quantization to make it a public api, since pt2e is a folder for implementations

Test Plan:
CIs

sanity check: "buck test //executorch/backends/xnnpack/test:test_xnnpack_quantized_models -- test_resnet18"

Differential Revision: D47727838

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105885
Approved by: https://github.com/andrewor14
2023-07-26 18:20:09 +00:00

801 lines
32 KiB
Python

from __future__ import annotations
import copy
import functools
import itertools
import operator
from typing import Any, Callable, Dict, List, Optional, Set
import torch
import torch._dynamo as torchdynamo
import torch.nn.functional as F
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
_is_sym_size_node,
_node_only_used_for_sym_size,
get_bias_qspec,
get_input_act_qspec,
get_output_act_qspec,
get_weight_qspec,
)
from torch.fx import Node
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
from .quantizer import (
OperatorConfig,
OperatorPatternType,
QuantizationAnnotation,
QuantizationConfig,
QuantizationSpec,
QuantizationSpecBase,
Quantizer,
SharedQuantizationSpec,
)
__all__ = [
"XNNPACKQuantizer",
"get_symmetric_quantization_config",
]
def _mark_nodes_as_annotated(nodes: List[Node]):
for node in nodes:
if node is not None:
if "quantization_annotation" not in node.meta:
node.meta["quantization_annotation"] = QuantizationAnnotation()
node.meta["quantization_annotation"]._annotated = True
def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph:
gm, _ = torchdynamo.export(function, *inputs, aten_graph=True)
gm.graph.eliminate_dead_code()
return gm.graph
def _get_linear_patterns(input_size: List[int]):
in_channels = input_size[-1]
out_channels = 8 # hard coding but this should not matter
weight = torch.ones((out_channels, in_channels))
bias = torch.ones((out_channels,))
act = torch.ones(input_size)
def linear_op(act, weight, bias=None):
return F.linear(act, weight, bias)
pattern_w_bias = _get_dynamo_graph(linear_op, (act, weight, bias))
pattern_wo_bias = _get_dynamo_graph(linear_op, (act, weight))
return [pattern_w_bias, pattern_wo_bias]
def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
supported_operators: Dict[str, List[OperatorPatternType]] = {
# Both conv and linear should be able to handle relu + hardtanh fusion since
# those are clamp ops
"conv2d": [
[torch.nn.Conv2d, torch.nn.ReLU],
[torch.nn.Conv2d, F.relu],
[F.conv2d, torch.nn.ReLU],
[F.conv2d, F.relu],
],
"linear": [[torch.nn.Linear], [F.linear]],
"add": [[torch.add]],
"maxpool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]],
"hardtanh": [[torch.nn.Hardtanh], [F.hardtanh]],
"mean": [[torch.mean]],
"adaptive_avgpool2d": [
[torch.nn.AdaptiveAvgPool2d],
[F.adaptive_avg_pool2d],
],
}
return copy.deepcopy(supported_operators)
def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
supported_config_and_operators: List[OperatorConfig] = []
for quantization_config in [
get_symmetric_quantization_config(),
get_symmetric_quantization_config(is_qat=True),
get_symmetric_quantization_config(is_per_channel=True),
get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
]:
ops = _supported_symmetric_quantized_operators()
for pattern_list in ops.values():
supported_config_and_operators.append(
OperatorConfig(quantization_config, pattern_list)
)
return copy.deepcopy(supported_config_and_operators)
@functools.lru_cache
def get_symmetric_quantization_config(
is_per_channel: bool = False,
is_qat: bool = False,
is_dynamic: bool = False,
):
if is_qat:
if is_dynamic:
raise NotImplementedError(
"dynamic quantization for qat is not yet implemented."
)
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
else:
if is_dynamic:
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
else:
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
act_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=is_dynamic,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
eps=2**-12
),
)
qscheme = (
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
MinMaxObserver
)
if is_qat:
weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
elif is_per_channel:
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
extra_args: Dict[str, Any] = {"eps": 2**-12}
if is_qat:
if qscheme == torch.per_tensor_symmetric:
extra_args["observer"] = MovingAverageMinMaxObserver
else:
extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item]
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
qscheme=qscheme,
ch_axis=0,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
**extra_args
),
)
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
PlaceholderObserver
)
bias_quantization_spec = QuantizationSpec(
dtype=torch.float, observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr
)
if is_dynamic:
quantization_config = QuantizationConfig(
act_quantization_spec,
None,
weight_quantization_spec,
bias_quantization_spec,
is_qat,
)
else:
quantization_config = QuantizationConfig(
act_quantization_spec,
act_quantization_spec,
weight_quantization_spec,
bias_quantization_spec,
is_qat,
)
return quantization_config
def _get_supported_config_and_operators() -> List[OperatorConfig]:
return _get_supported_symmetric_config_and_operators()
def _is_annotated(nodes: List[Node]):
"""
Given a list of nodes (that represents an operator pattern),
check if any of the node is annotated, return True if any of the node
is annotated, otherwise return False
"""
annotated = False
for node in nodes:
annotated = annotated or (
"quantization_annotation" in node.meta
and node.meta["quantization_annotation"]._annotated
)
return annotated
class XNNPACKQuantizer(Quantizer):
supported_config_and_operators = _get_supported_config_and_operators()
def __init__(self):
super().__init__()
self.global_config: QuantizationConfig = None # type: ignore[assignment]
self.operator_type_config: Dict[str, Optional[QuantizationConfig]] = {}
@classmethod
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
op_configs: Set[QuantizationConfig] = set({})
for spec, _ in cls.supported_config_and_operators:
op_configs.add(spec)
return list(op_configs)
@classmethod
def get_supported_operator_for_quantization_config(
cls, quantization_config: Optional[QuantizationConfig]
) -> List[OperatorPatternType]:
if quantization_config is None:
all_ops = []
for _, ops in cls.supported_config_and_operators:
all_ops.extend(ops)
return all_ops
for config, ops in cls.supported_config_and_operators:
# note: this assumes each entry in cls.supported_spec_and_operators
# corresponds to one spec, e.g. we don't have
# [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
# where the first and second entry have the same spec but did not
# merge the op list
if config == quantization_config:
return ops
return []
def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer:
self.global_config = quantization_config
return self
def set_config_for_operator_type(
self, operator_type: str, quantization_config: QuantizationConfig
) -> XNNPACKQuantizer:
self.operator_type_config[operator_type] = quantization_config
return self
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""just handling global spec for now"""
# hacked for handling dynamic linear quant. will fix later.
if self.global_config.input_activation.is_dynamic: # type: ignore[union-attr]
model = self._annotate_for_dynamic_quantization_config(model)
else:
model = self._annotate_for_static_quantization_config(model)
return model
def _annotate_for_static_quantization_config(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
config = self.global_config
self._annotate_linear(model, config)
self._annotate_conv2d_patterns(model, config)
self._annotate_maxpool2d(model, config)
self._annotate_add_patterns(model, config)
self._annotate_hardtanh(model, config)
self._annotate_mean(model, config)
self._annotate_adaptive_avg_pool2d(model, config)
self._annotate_gru(model, config)
return model
def _annotate_for_dynamic_quantization_config(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
config = self.global_config
self._annotate_linear(model, config)
return model
def _annotate_conv2d_patterns(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
if quantization_config.is_qat:
self._annotate_conv2d_bn_relu(gm, quantization_config)
self._annotate_conv2d_bn(gm, quantization_config)
self._annotate_conv2d_relu(gm, quantization_config)
self._annotate_conv2d(gm, quantization_config)
def _annotate_conv2d_bn(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
"""
Find Conv2d + batchnorm parititions
Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
"""
fused_partitions = find_sequential_partitions(
gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d]
)
for fused_partition in fused_partitions:
conv_partition, bn_partition = fused_partition
if len(conv_partition.output_nodes) > 1:
raise ValueError("conv partition has more than one output node")
conv_node = conv_partition.output_nodes[0]
conv_node_users = list(conv_node.users.keys())
if len(conv_node_users) > 1:
raise ValueError(
"Conv node must be consumed by BN only for it to be fusable."
)
if len(bn_partition.output_nodes) > 1:
raise ValueError("BatchNorm partition has more than one output node")
bn_output_node = bn_partition.output_nodes[0]
if _is_annotated([bn_output_node, conv_node]):
continue
input_qspec_map = {}
input_act = conv_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
bias = conv_node.args[2]
if isinstance(bias, Node):
input_qspec_map[bias] = get_bias_qspec(quantization_config)
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map, _annotated=True
)
bn_output_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
)
nodes_to_mark_annotated = list(conv_partition.nodes)
nodes_to_mark_annotated.extend(list(bn_partition.nodes))
_mark_nodes_as_annotated(nodes_to_mark_annotated)
def _annotate_conv2d_bn_relu(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
"""
Find Conv2d + batchnorm + relu parititions
Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
"""
fused_partitions = find_sequential_partitions(
gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU]
)
for fused_partition in fused_partitions:
conv_partition, bn_partition, relu_partition = fused_partition
if len(relu_partition.output_nodes) > 1:
raise ValueError("Relu partition has more than one output node")
relu_node = relu_partition.output_nodes[0]
if len(conv_partition.output_nodes) > 1:
raise ValueError("conv partition has more than one output node")
conv_node = conv_partition.output_nodes[0]
conv_node_users = list(conv_node.users.keys())
if len(conv_node_users) > 1:
raise ValueError(
"Conv node must be consumed by BN only for it to be fusable."
)
if len(bn_partition.output_nodes) > 1:
raise ValueError("BatchNorm partition has more than one output node")
bn_output_node = bn_partition.output_nodes[0]
if _is_annotated([relu_node, bn_output_node, conv_node]):
continue
input_qspec_map = {}
input_act = conv_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
bias = conv_node.args[2]
if isinstance(bias, Node):
input_qspec_map[bias] = get_bias_qspec(quantization_config)
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map, _annotated=True
)
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
)
nodes_to_mark_annotated = list(conv_partition.nodes)
nodes_to_mark_annotated.extend(list(bn_partition.nodes))
nodes_to_mark_annotated.extend(list(relu_partition.nodes))
_mark_nodes_as_annotated(nodes_to_mark_annotated)
def _annotate_conv2d_relu(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
fused_partitions = find_sequential_partitions(
gm, [torch.nn.Conv2d, torch.nn.ReLU]
)
for fused_partition in fused_partitions:
conv_partition, relu_partition = fused_partition
if len(relu_partition.output_nodes) > 1:
raise ValueError("Relu partition has more than one output node")
relu_node = relu_partition.output_nodes[0]
if len(conv_partition.output_nodes) > 1:
raise ValueError("conv partition has more than one output node")
conv_node = conv_partition.output_nodes[0]
if not isinstance(conv_node, Node):
raise ValueError(f"{conv_node} is not a Node")
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.convolution.default
):
raise ValueError(f"{conv_node} is not an aten conv2d operator")
if relu_node.op != "call_function" or relu_node.target not in [
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
]:
raise ValueError(f"{relu_node} is not an aten relu operator")
if _is_annotated([relu_node, conv_node]):
continue
input_qspec_map = {}
input_act = conv_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
bias = conv_node.args[2]
if isinstance(bias, Node):
input_qspec_map[bias] = get_bias_qspec(quantization_config)
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map, _annotated=True
)
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
)
def _annotate_conv2d(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
conv_partitions = get_source_partitions(
gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d]
)
conv_partitions = list(itertools.chain(*conv_partitions.values()))
for conv_partition in conv_partitions:
if len(conv_partition.output_nodes) > 1:
raise ValueError("conv partition has more than one output node")
conv_node = conv_partition.output_nodes[0]
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.convolution.default
):
raise ValueError(f"{conv_node} is not an aten conv2d operator")
# skip annotation if it is already annotated
if _is_annotated([conv_node]):
continue
input_qspec_map = {}
input_act = conv_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
bias = conv_node.args[2]
if isinstance(bias, Node):
input_qspec_map[bias] = get_bias_qspec(quantization_config)
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=get_output_act_qspec(quantization_config),
_annotated=True,
)
def _annotate_linear(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
module_partitions = get_source_partitions(
gm.graph, [torch.nn.Linear, torch.nn.functional.linear]
)
input_act_qspec = get_input_act_qspec(quantization_config)
output_act_qspec = get_output_act_qspec(quantization_config)
weight_qspec = get_weight_qspec(quantization_config)
bias_qspec = get_bias_qspec(quantization_config)
for partitions in module_partitions.values():
for p in partitions:
act_nodes = [
n
for n in p.input_nodes
if not _node_only_used_for_sym_size(n, p.nodes)
]
if len(act_nodes) > 1:
raise ValueError(
f"Multiple activation nodes found for partition {p} {act_nodes}"
)
if len(act_nodes) == 0:
raise ValueError(f"No activation node found for partition {p}")
act_node = act_nodes[0]
output_node = p.output_nodes[0]
weight_node = None
bias_node = None
for node in p.params:
weight_or_bias = getattr(gm, node.target) # type: ignore[arg-type]
if weight_or_bias.ndim == 2: # type: ignore[attr-defined]
weight_node = node
if weight_or_bias.ndim == 1: # type: ignore[attr-defined]
bias_node = node
if weight_node is None:
raise ValueError("No weight found in Linear pattern")
# find use of act node within the matched pattern
act_use_node = None
# When doing tracing with dynamic shape, we end up with sym_size nodes
# This nodes do not need quantization, so skip those.
# We can also have quant workflow throw exception when sym_size nodes
# are annotated.
# This is not specific to linear, so in future diffs we should streamline
# this.
act_node_users = list(
filter((lambda x: (_is_sym_size_node(x) is False)), act_node.users)
)
act_use_node_in_p = set(act_node_users).intersection(set(p.nodes))
if len(act_use_node_in_p) != 1:
raise ValueError(
f"Could not find a valid use of act node. All uses {act_use_node_in_p}"
)
act_use_node = act_use_node_in_p.pop()
if _is_annotated([act_use_node]) is False: # type: ignore[list-item]
_annotate_input_qspec_map(
act_use_node,
act_node,
input_act_qspec,
)
if bias_node and _is_annotated([bias_node]) is False:
_annotate_output_qspec(bias_node, bias_qspec)
if _is_annotated([weight_node]) is False: # type: ignore[list-item]
_annotate_output_qspec(weight_node, weight_qspec)
if _is_annotated([output_node]) is False:
_annotate_output_qspec(output_node, output_act_qspec)
nodes_to_mark_annotated = list(p.nodes)
_mark_nodes_as_annotated(nodes_to_mark_annotated)
# TODO: move this to BoltNNQuantizer?
def _annotate_gru(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
gru_partitions = get_source_partitions(gm.graph, [torch.nn.GRU])
gru_partitions = list(itertools.chain(*gru_partitions.values()))
for gru_partition in gru_partitions:
output_nodes = gru_partition.output_nodes
input_nodes = gru_partition.input_nodes
# skip annotation if it is already annotated
if _is_annotated(input_nodes + output_nodes):
continue
# inside each GRU partition, we should be able to annotate each linear
# subgraph
input_qspec_map: Dict[Node, QuantizationSpecBase] = {}
input_act = input_nodes[0]
input_act_user = list(input_act.users.keys())[0]
assert isinstance(input_act, Node)
assert isinstance(input_act_user, Node)
input_act_user.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
input_act: get_input_act_qspec(quantization_config),
},
_annotated=True,
)
hidden_state = input_nodes[1]
hidden_state_user = list(hidden_state.users.keys())[0]
assert isinstance(hidden_state, Node)
assert isinstance(hidden_state_user, Node)
hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
hidden_state: get_input_act_qspec(quantization_config),
},
_annotated=True,
)
assert len(output_nodes) == 2, "expecting GRU to have two outputs"
for output in output_nodes:
output.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=get_output_act_qspec(quantization_config),
_annotated=True,
)
nodes_to_mark_annotated = list(gru_partition.nodes)
_mark_nodes_as_annotated(nodes_to_mark_annotated)
def _annotate_maxpool2d(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
module_partitions = get_source_partitions(
gm.graph, [torch.nn.MaxPool2d, torch.nn.functional.max_pool2d]
)
maxpool_partitions = list(itertools.chain(*module_partitions.values()))
for maxpool_partition in maxpool_partitions:
output_node = maxpool_partition.output_nodes[0]
maxpool_node = None
for n in maxpool_partition.nodes:
if n.target == torch.ops.aten.max_pool2d_with_indices.default:
maxpool_node = n
if _is_annotated([output_node, maxpool_node]): # type: ignore[list-item]
continue
input_act = maxpool_node.args[0] # type: ignore[union-attr]
assert isinstance(input_act, Node)
# only annotate maxpool when the output of the input node is annotated
if (
"quantization_annotation" not in input_act.meta
or not input_act.meta["quantization_annotation"]._annotated
or input_act.meta["quantization_annotation"].output_qspec is None
):
continue
# input and output of maxpool will share quantization parameter with input of maxpool
act_qspec = SharedQuantizationSpec(input_act)
# act_qspec = get_act_qspec(quantization_config)
maxpool_node.meta["quantization_annotation"] = QuantizationAnnotation( # type: ignore[union-attr]
input_qspec_map={
input_act: act_qspec,
},
_annotated=True,
)
output_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=act_qspec,
_annotated=True,
)
def _annotate_input_out_obs_sharing_op(
self,
op: Callable,
gm: torch.fx.GraphModule,
quantization_config: QuantizationConfig,
) -> None:
module_partitions = get_source_partitions(
gm.graph,
[op],
)
partitions = list(itertools.chain(*module_partitions.values()))
for partition in partitions:
io_obs_sharing_node = partition.output_nodes[0]
if _is_annotated([io_obs_sharing_node]):
continue
input_act = io_obs_sharing_node.args[0]
assert isinstance(input_act, Node)
# only annotate input output sharing operator
# when the output of the input node is annotated
if (
"quantization_annotation" not in input_act.meta
or not input_act.meta["quantization_annotation"]._annotated
or input_act.meta["quantization_annotation"].output_qspec is None
):
continue
act_qspec = SharedQuantizationSpec(input_act)
io_obs_sharing_node.meta[
"quantization_annotation"
] = QuantizationAnnotation(
input_qspec_map={
input_act: act_qspec,
},
output_qspec=act_qspec,
_annotated=True,
)
def _annotate_hardtanh(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
self._annotate_input_out_obs_sharing_op(
torch.nn.modules.Hardtanh, gm, quantization_config
)
self._annotate_input_out_obs_sharing_op(
torch.nn.modules.ReLU6, gm, quantization_config
)
def _annotate_mean(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
self._annotate_input_out_obs_sharing_op(torch.mean, gm, quantization_config)
def _annotate_adaptive_avg_pool2d(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
self._annotate_input_out_obs_sharing_op(
torch.nn.AdaptiveAvgPool2d, gm, quantization_config
)
def _annotate_add_patterns(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
self._annotate_add_relu(gm, quantization_config)
self._annotate_add(gm, quantization_config)
def _annotate_add_relu(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
fused_partitions = find_sequential_partitions(gm, [torch.add, torch.nn.ReLU])
for fused_partition in fused_partitions:
add_partition, relu_partition = fused_partition
if len(relu_partition.output_nodes) > 1:
raise ValueError("Relu partition has more than one output node")
relu_node = relu_partition.output_nodes[0]
if len(add_partition.output_nodes) > 1:
raise ValueError("add partition has more than one output node")
add_node = add_partition.output_nodes[0]
if _is_annotated([relu_node, add_node]):
continue
input_act_qspec = get_input_act_qspec(quantization_config)
output_act_qspec = get_output_act_qspec(quantization_config)
input_qspec_map = {}
input_act0 = add_node.args[0]
if isinstance(input_act0, Node):
input_qspec_map[input_act0] = input_act_qspec
input_act1 = add_node.args[1]
if isinstance(input_act1, Node):
input_qspec_map[input_act1] = input_act_qspec
add_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=output_act_qspec,
_annotated=True,
)
def _annotate_add(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
add_partitions = get_source_partitions(gm.graph, [operator.add, torch.add])
add_partitions = list(itertools.chain(*add_partitions.values()))
for add_partition in add_partitions:
add_node = add_partition.output_nodes[0]
if _is_annotated([add_node]):
continue
input_act_qspec = get_input_act_qspec(quantization_config)
output_act_qspec = get_output_act_qspec(quantization_config)
input_qspec_map = {}
input_act0 = add_node.args[0]
if isinstance(input_act0, Node):
input_qspec_map[input_act0] = input_act_qspec
input_act1 = add_node.args[1]
if isinstance(input_act1, Node):
input_qspec_map[input_act1] = input_act_qspec
add_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_act_qspec,
_annotated=True,
)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
@classmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
return cls.supported_config_and_operators