mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[PT2][Quant] Move add/add relu pattern via module partitioner (#102397)
This diff uses module partitioners to find add and add + relu patterns. Differential Revision: [D46095330](https://our.internmc.facebook.com/intern/diff/D46095330/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/102397 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
3d8f405022
commit
9fac5afbcc
|
|
@ -1,5 +1,6 @@
|
|||
import itertools
|
||||
from typing import Any, List, OrderedDict, Set
|
||||
import operator
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -13,6 +14,7 @@ _EQUIVALENT_TYPES: List[Set] = [
|
|||
{torch.nn.Conv2d, torch.nn.functional.conv2d},
|
||||
{torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_},
|
||||
{torch.nn.BatchNorm2d, torch.nn.functional.batch_norm},
|
||||
{torch.add, operator.add},
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ import copy
|
|||
import functools
|
||||
|
||||
import itertools
|
||||
from typing import Callable, Dict, List, Optional, Set, Any
|
||||
import operator
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
|
|
@ -13,12 +14,22 @@ import torch.nn.functional as F
|
|||
from torch.ao.quantization._pt2e.graph_utils import find_sequential_partitions
|
||||
|
||||
from torch.ao.quantization._pt2e.quantizer.utils import (
|
||||
get_act_qspec,
|
||||
get_weight_qspec,
|
||||
get_bias_qspec,
|
||||
_annotate_input_qspec_map,
|
||||
_annotate_output_qspec,
|
||||
get_act_qspec,
|
||||
get_bias_qspec,
|
||||
get_weight_qspec,
|
||||
)
|
||||
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
|
||||
from torch.ao.quantization.observer import (
|
||||
HistogramObserver,
|
||||
MinMaxObserver,
|
||||
MovingAverageMinMaxObserver,
|
||||
MovingAveragePerChannelMinMaxObserver,
|
||||
PerChannelMinMaxObserver,
|
||||
PlaceholderObserver,
|
||||
)
|
||||
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
||||
|
||||
from torch.fx import Node
|
||||
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
|
||||
|
|
@ -26,21 +37,11 @@ from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
|
|||
from .quantizer import (
|
||||
OperatorConfig,
|
||||
OperatorPatternType,
|
||||
QuantizationAnnotation,
|
||||
QuantizationConfig,
|
||||
QuantizationSpec,
|
||||
Quantizer,
|
||||
QuantizationAnnotation,
|
||||
)
|
||||
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
|
||||
from torch.ao.quantization.observer import (
|
||||
HistogramObserver,
|
||||
MinMaxObserver,
|
||||
PerChannelMinMaxObserver,
|
||||
MovingAverageMinMaxObserver,
|
||||
MovingAveragePerChannelMinMaxObserver,
|
||||
PlaceholderObserver,
|
||||
)
|
||||
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -58,6 +59,7 @@ def _mark_nodes_as_annotated(nodes: List[Node]):
|
|||
node.meta["quantization_annotation"] = QuantizationAnnotation()
|
||||
node.meta["quantization_annotation"]._annotated = True
|
||||
|
||||
|
||||
def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph:
|
||||
gm, _ = torchdynamo.export(function, *inputs, aten_graph=True)
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
|
@ -137,8 +139,9 @@ def get_symmetric_quantization_config(
|
|||
is_per_channel: bool = False,
|
||||
is_qat: bool = False,
|
||||
):
|
||||
act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \
|
||||
act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
||||
FusedMovingAvgObsFakeQuantize if is_qat else HistogramObserver
|
||||
)
|
||||
|
||||
act_quantization_spec = QuantizationSpec(
|
||||
dtype=torch.int8,
|
||||
|
|
@ -146,12 +149,16 @@ def get_symmetric_quantization_config(
|
|||
quant_max=127,
|
||||
qscheme=torch.per_tensor_affine,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12),
|
||||
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
|
||||
eps=2**-12
|
||||
),
|
||||
)
|
||||
qscheme = (
|
||||
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
|
||||
)
|
||||
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = MinMaxObserver
|
||||
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
||||
MinMaxObserver
|
||||
)
|
||||
if is_qat:
|
||||
weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
|
||||
elif is_per_channel:
|
||||
|
|
@ -170,13 +177,16 @@ def get_symmetric_quantization_config(
|
|||
qscheme=qscheme,
|
||||
ch_axis=0,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(**extra_args),
|
||||
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
|
||||
**extra_args
|
||||
),
|
||||
)
|
||||
|
||||
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PlaceholderObserver
|
||||
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
||||
PlaceholderObserver
|
||||
)
|
||||
bias_quantization_spec = QuantizationSpec(
|
||||
dtype=torch.float,
|
||||
observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr
|
||||
dtype=torch.float, observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr
|
||||
)
|
||||
quantization_config = QuantizationConfig(
|
||||
act_quantization_spec, weight_quantization_spec, bias_quantization_spec, is_qat
|
||||
|
|
@ -187,6 +197,7 @@ def get_symmetric_quantization_config(
|
|||
def get_supported_config_and_operators() -> List[OperatorConfig]:
|
||||
return get_supported_symmetric_config_and_operators()
|
||||
|
||||
|
||||
def _is_annotated(nodes: List[Node]):
|
||||
"""
|
||||
Given a list of nodes (that represents an operator pattern),
|
||||
|
|
@ -278,11 +289,11 @@ class QNNPackQuantizer(Quantizer):
|
|||
self._annotate_conv2d_relu(model, config)
|
||||
self._annotate_conv2d(model, config)
|
||||
self._annotate_maxpool2d(model, config)
|
||||
self._annotate_add_relu(model, config)
|
||||
self._annotate_add(model, config)
|
||||
for node in reversed(model.graph.nodes):
|
||||
# one improvement is to register node annotators for each
|
||||
# supported op type.
|
||||
self._annotate_add_relu(node, config)
|
||||
self._annotate_add(node, config)
|
||||
self._annotate_hardtanh(node, config)
|
||||
self._annotate_mean(node, config)
|
||||
self._annotate_adaptive_avg_pool2d(node, config)
|
||||
|
|
@ -329,13 +340,12 @@ class QNNPackQuantizer(Quantizer):
|
|||
input_qspec_map[bias] = get_bias_qspec(quantization_config)
|
||||
|
||||
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
_annotated=True
|
||||
input_qspec_map=input_qspec_map, _annotated=True
|
||||
)
|
||||
|
||||
bn_output_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
output_qspec=get_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||
_annotated=True
|
||||
_annotated=True,
|
||||
)
|
||||
nodes_to_mark_annotated = list(conv_partition.nodes)
|
||||
nodes_to_mark_annotated.extend(list(bn_partition.nodes))
|
||||
|
|
@ -385,13 +395,12 @@ class QNNPackQuantizer(Quantizer):
|
|||
input_qspec_map[bias] = get_bias_qspec(quantization_config)
|
||||
|
||||
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
_annotated=True
|
||||
input_qspec_map=input_qspec_map, _annotated=True
|
||||
)
|
||||
|
||||
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
output_qspec=get_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||
_annotated=True
|
||||
_annotated=True,
|
||||
)
|
||||
nodes_to_mark_annotated = list(conv_partition.nodes)
|
||||
nodes_to_mark_annotated.extend(list(bn_partition.nodes))
|
||||
|
|
@ -420,10 +429,10 @@ class QNNPackQuantizer(Quantizer):
|
|||
or conv_node.target != torch.ops.aten.convolution.default
|
||||
):
|
||||
raise ValueError(f"{conv_node} is not an aten conv2d operator")
|
||||
if (
|
||||
relu_node.op != "call_function"
|
||||
or relu_node.target not in [torch.ops.aten.relu.default, torch.ops.aten.relu_.default]
|
||||
):
|
||||
if relu_node.op != "call_function" or relu_node.target not in [
|
||||
torch.ops.aten.relu.default,
|
||||
torch.ops.aten.relu_.default,
|
||||
]:
|
||||
raise ValueError(f"{relu_node} is not an aten relu operator")
|
||||
|
||||
if _is_annotated([relu_node, conv_node]):
|
||||
|
|
@ -443,12 +452,11 @@ class QNNPackQuantizer(Quantizer):
|
|||
input_qspec_map[bias] = get_bias_qspec(quantization_config)
|
||||
|
||||
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
_annotated=True
|
||||
input_qspec_map=input_qspec_map, _annotated=True
|
||||
)
|
||||
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
output_qspec=get_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||
_annotated=True
|
||||
_annotated=True,
|
||||
)
|
||||
|
||||
def _annotate_conv2d(
|
||||
|
|
@ -487,7 +495,7 @@ class QNNPackQuantizer(Quantizer):
|
|||
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
output_qspec=get_act_qspec(quantization_config),
|
||||
_annotated=True
|
||||
_annotated=True,
|
||||
)
|
||||
|
||||
def _annotate_linear(
|
||||
|
|
@ -624,72 +632,67 @@ class QNNPackQuantizer(Quantizer):
|
|||
)
|
||||
|
||||
def _annotate_add_relu(
|
||||
self, node: Node, quantization_config: QuantizationConfig
|
||||
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
|
||||
) -> None:
|
||||
if node.op != "call_function" or node.target not in [
|
||||
torch.ops.aten.relu_.default,
|
||||
torch.ops.aten.relu.default,
|
||||
]:
|
||||
return
|
||||
relu_node = node
|
||||
add_node = relu_node.args[0]
|
||||
assert isinstance(add_node, Node)
|
||||
if add_node.op != "call_function" or add_node.target not in [
|
||||
torch.ops.aten.add.Tensor,
|
||||
torch.ops.aten.add_.Tensor,
|
||||
]:
|
||||
return
|
||||
if _is_annotated([relu_node, add_node]):
|
||||
return
|
||||
fused_partitions = find_sequential_partitions(gm, [torch.add, torch.nn.ReLU])
|
||||
for fused_partition in fused_partitions:
|
||||
add_partition, relu_partition = fused_partition
|
||||
if len(relu_partition.output_nodes) > 1:
|
||||
raise ValueError("Relu partition has more than one output node")
|
||||
relu_node = relu_partition.output_nodes[0]
|
||||
if len(add_partition.output_nodes) > 1:
|
||||
raise ValueError("add partition has more than one output node")
|
||||
add_node = add_partition.output_nodes[0]
|
||||
|
||||
act_qspec = get_act_qspec(quantization_config)
|
||||
if _is_annotated([relu_node, add_node]):
|
||||
continue
|
||||
|
||||
input_qspec_map = {}
|
||||
input_act0 = add_node.args[0]
|
||||
if isinstance(input_act0, Node):
|
||||
input_qspec_map[input_act0] = act_qspec
|
||||
act_qspec = get_act_qspec(quantization_config)
|
||||
|
||||
input_act1 = add_node.args[1]
|
||||
if isinstance(input_act1, Node):
|
||||
input_qspec_map[input_act1] = act_qspec
|
||||
input_qspec_map = {}
|
||||
input_act0 = add_node.args[0]
|
||||
if isinstance(input_act0, Node):
|
||||
input_qspec_map[input_act0] = act_qspec
|
||||
|
||||
add_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
_annotated=True,
|
||||
)
|
||||
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
output_qspec=act_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
input_act1 = add_node.args[1]
|
||||
if isinstance(input_act1, Node):
|
||||
input_qspec_map[input_act1] = act_qspec
|
||||
|
||||
add_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
_annotated=True,
|
||||
)
|
||||
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
output_qspec=act_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
|
||||
def _annotate_add(
|
||||
self, node: Node, quantization_config: QuantizationConfig
|
||||
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
|
||||
) -> None:
|
||||
add_node = node
|
||||
if add_node.op != "call_function" or add_node.target not in [
|
||||
torch.ops.aten.add.Tensor,
|
||||
torch.ops.aten.add_.Tensor,
|
||||
]:
|
||||
return
|
||||
if _is_annotated([add_node]):
|
||||
return
|
||||
add_partitions = get_source_partitions(gm.graph, [operator.add, torch.add])
|
||||
add_partitions = list(itertools.chain(*add_partitions.values()))
|
||||
for add_partition in add_partitions:
|
||||
add_node = add_partition.output_nodes[0]
|
||||
if _is_annotated([add_node]):
|
||||
continue
|
||||
|
||||
act_qspec = get_act_qspec(quantization_config)
|
||||
act_qspec = get_act_qspec(quantization_config)
|
||||
|
||||
input_qspec_map = {}
|
||||
input_act0 = add_node.args[0]
|
||||
if isinstance(input_act0, Node):
|
||||
input_qspec_map[input_act0] = act_qspec
|
||||
input_qspec_map = {}
|
||||
input_act0 = add_node.args[0]
|
||||
if isinstance(input_act0, Node):
|
||||
input_qspec_map[input_act0] = act_qspec
|
||||
|
||||
input_act1 = add_node.args[1]
|
||||
if isinstance(input_act1, Node):
|
||||
input_qspec_map[input_act1] = act_qspec
|
||||
input_act1 = add_node.args[1]
|
||||
if isinstance(input_act1, Node):
|
||||
input_qspec_map[input_act1] = act_qspec
|
||||
|
||||
add_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
output_qspec=act_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
add_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
output_qspec=act_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user