[PT2][Quant] Move add/add relu pattern via module partitioner (#102397)

This diff uses module partitioners to find add and add + relu patterns.

Differential Revision: [D46095330](https://our.internmc.facebook.com/intern/diff/D46095330/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102397
Approved by: https://github.com/jerryzh168
This commit is contained in:
Kimish Patel 2023-05-27 13:58:44 -07:00 committed by PyTorch MergeBot
parent 3d8f405022
commit 9fac5afbcc
2 changed files with 97 additions and 92 deletions

View File

@ -1,5 +1,6 @@
import itertools
from typing import Any, List, OrderedDict, Set
import operator
import torch
@ -13,6 +14,7 @@ _EQUIVALENT_TYPES: List[Set] = [
{torch.nn.Conv2d, torch.nn.functional.conv2d},
{torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_},
{torch.nn.BatchNorm2d, torch.nn.functional.batch_norm},
{torch.add, operator.add},
]

View File

@ -4,7 +4,8 @@ import copy
import functools
import itertools
from typing import Callable, Dict, List, Optional, Set, Any
import operator
from typing import Any, Callable, Dict, List, Optional, Set
import torch
import torch._dynamo as torchdynamo
@ -13,12 +14,22 @@ import torch.nn.functional as F
from torch.ao.quantization._pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization._pt2e.quantizer.utils import (
get_act_qspec,
get_weight_qspec,
get_bias_qspec,
_annotate_input_qspec_map,
_annotate_output_qspec,
get_act_qspec,
get_bias_qspec,
get_weight_qspec,
)
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.fx import Node
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
@ -26,21 +37,11 @@ from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
from .quantizer import (
OperatorConfig,
OperatorPatternType,
QuantizationAnnotation,
QuantizationConfig,
QuantizationSpec,
Quantizer,
QuantizationAnnotation,
)
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
PerChannelMinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
__all__ = [
@ -58,6 +59,7 @@ def _mark_nodes_as_annotated(nodes: List[Node]):
node.meta["quantization_annotation"] = QuantizationAnnotation()
node.meta["quantization_annotation"]._annotated = True
def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph:
gm, _ = torchdynamo.export(function, *inputs, aten_graph=True)
gm.graph.eliminate_dead_code()
@ -137,8 +139,9 @@ def get_symmetric_quantization_config(
is_per_channel: bool = False,
is_qat: bool = False,
):
act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \
act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
FusedMovingAvgObsFakeQuantize if is_qat else HistogramObserver
)
act_quantization_spec = QuantizationSpec(
dtype=torch.int8,
@ -146,12 +149,16 @@ def get_symmetric_quantization_config(
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12),
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
eps=2**-12
),
)
qscheme = (
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = MinMaxObserver
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
MinMaxObserver
)
if is_qat:
weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
elif is_per_channel:
@ -170,13 +177,16 @@ def get_symmetric_quantization_config(
qscheme=qscheme,
ch_axis=0,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(**extra_args),
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
**extra_args
),
)
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PlaceholderObserver
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
PlaceholderObserver
)
bias_quantization_spec = QuantizationSpec(
dtype=torch.float,
observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr
dtype=torch.float, observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr
)
quantization_config = QuantizationConfig(
act_quantization_spec, weight_quantization_spec, bias_quantization_spec, is_qat
@ -187,6 +197,7 @@ def get_symmetric_quantization_config(
def get_supported_config_and_operators() -> List[OperatorConfig]:
return get_supported_symmetric_config_and_operators()
def _is_annotated(nodes: List[Node]):
"""
Given a list of nodes (that represents an operator pattern),
@ -278,11 +289,11 @@ class QNNPackQuantizer(Quantizer):
self._annotate_conv2d_relu(model, config)
self._annotate_conv2d(model, config)
self._annotate_maxpool2d(model, config)
self._annotate_add_relu(model, config)
self._annotate_add(model, config)
for node in reversed(model.graph.nodes):
# one improvement is to register node annotators for each
# supported op type.
self._annotate_add_relu(node, config)
self._annotate_add(node, config)
self._annotate_hardtanh(node, config)
self._annotate_mean(node, config)
self._annotate_adaptive_avg_pool2d(node, config)
@ -329,13 +340,12 @@ class QNNPackQuantizer(Quantizer):
input_qspec_map[bias] = get_bias_qspec(quantization_config)
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True
input_qspec_map=input_qspec_map, _annotated=True
)
bn_output_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=get_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True
_annotated=True,
)
nodes_to_mark_annotated = list(conv_partition.nodes)
nodes_to_mark_annotated.extend(list(bn_partition.nodes))
@ -385,13 +395,12 @@ class QNNPackQuantizer(Quantizer):
input_qspec_map[bias] = get_bias_qspec(quantization_config)
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True
input_qspec_map=input_qspec_map, _annotated=True
)
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=get_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True
_annotated=True,
)
nodes_to_mark_annotated = list(conv_partition.nodes)
nodes_to_mark_annotated.extend(list(bn_partition.nodes))
@ -420,10 +429,10 @@ class QNNPackQuantizer(Quantizer):
or conv_node.target != torch.ops.aten.convolution.default
):
raise ValueError(f"{conv_node} is not an aten conv2d operator")
if (
relu_node.op != "call_function"
or relu_node.target not in [torch.ops.aten.relu.default, torch.ops.aten.relu_.default]
):
if relu_node.op != "call_function" or relu_node.target not in [
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
]:
raise ValueError(f"{relu_node} is not an aten relu operator")
if _is_annotated([relu_node, conv_node]):
@ -443,12 +452,11 @@ class QNNPackQuantizer(Quantizer):
input_qspec_map[bias] = get_bias_qspec(quantization_config)
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True
input_qspec_map=input_qspec_map, _annotated=True
)
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=get_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True
_annotated=True,
)
def _annotate_conv2d(
@ -487,7 +495,7 @@ class QNNPackQuantizer(Quantizer):
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=get_act_qspec(quantization_config),
_annotated=True
_annotated=True,
)
def _annotate_linear(
@ -624,72 +632,67 @@ class QNNPackQuantizer(Quantizer):
)
def _annotate_add_relu(
self, node: Node, quantization_config: QuantizationConfig
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
if node.op != "call_function" or node.target not in [
torch.ops.aten.relu_.default,
torch.ops.aten.relu.default,
]:
return
relu_node = node
add_node = relu_node.args[0]
assert isinstance(add_node, Node)
if add_node.op != "call_function" or add_node.target not in [
torch.ops.aten.add.Tensor,
torch.ops.aten.add_.Tensor,
]:
return
if _is_annotated([relu_node, add_node]):
return
fused_partitions = find_sequential_partitions(gm, [torch.add, torch.nn.ReLU])
for fused_partition in fused_partitions:
add_partition, relu_partition = fused_partition
if len(relu_partition.output_nodes) > 1:
raise ValueError("Relu partition has more than one output node")
relu_node = relu_partition.output_nodes[0]
if len(add_partition.output_nodes) > 1:
raise ValueError("add partition has more than one output node")
add_node = add_partition.output_nodes[0]
act_qspec = get_act_qspec(quantization_config)
if _is_annotated([relu_node, add_node]):
continue
input_qspec_map = {}
input_act0 = add_node.args[0]
if isinstance(input_act0, Node):
input_qspec_map[input_act0] = act_qspec
act_qspec = get_act_qspec(quantization_config)
input_act1 = add_node.args[1]
if isinstance(input_act1, Node):
input_qspec_map[input_act1] = act_qspec
input_qspec_map = {}
input_act0 = add_node.args[0]
if isinstance(input_act0, Node):
input_qspec_map[input_act0] = act_qspec
add_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=act_qspec,
_annotated=True,
)
input_act1 = add_node.args[1]
if isinstance(input_act1, Node):
input_qspec_map[input_act1] = act_qspec
add_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=act_qspec,
_annotated=True,
)
def _annotate_add(
self, node: Node, quantization_config: QuantizationConfig
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
add_node = node
if add_node.op != "call_function" or add_node.target not in [
torch.ops.aten.add.Tensor,
torch.ops.aten.add_.Tensor,
]:
return
if _is_annotated([add_node]):
return
add_partitions = get_source_partitions(gm.graph, [operator.add, torch.add])
add_partitions = list(itertools.chain(*add_partitions.values()))
for add_partition in add_partitions:
add_node = add_partition.output_nodes[0]
if _is_annotated([add_node]):
continue
act_qspec = get_act_qspec(quantization_config)
act_qspec = get_act_qspec(quantization_config)
input_qspec_map = {}
input_act0 = add_node.args[0]
if isinstance(input_act0, Node):
input_qspec_map[input_act0] = act_qspec
input_qspec_map = {}
input_act0 = add_node.args[0]
if isinstance(input_act0, Node):
input_qspec_map[input_act0] = act_qspec
input_act1 = add_node.args[1]
if isinstance(input_act1, Node):
input_qspec_map[input_act1] = act_qspec
input_act1 = add_node.args[1]
if isinstance(input_act1, Node):
input_qspec_map[input_act1] = act_qspec
add_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=act_qspec,
_annotated=True,
)
add_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=act_qspec,
_annotated=True,
)
def validate(self, model: torch.fx.GraphModule) -> None:
pass