[quant][pt2][be] Remove add/relu from conv-bn QAT pattern (#113006)

Summary: This commit significantly simplifies the QAT fusion
code for the `conv-bn` pattern by removing add and relu nodes
from the match and replacement patterns. This does not reduce
functionality; patterns like `conv-bn-relu`, `conv-bn-add`,
and `conv-bn-add-relu` are still supported. We simply do not
match these extra nodes, since there is actually no need to
replace them.

This has the additional benefit of reducing the number of
patterns being matched by 16x, since for each add and relu
variant of the `conv-bn` pattern there is also an in-place
variant. This also enables more flexible `conv-bn` pattern
matching in the future and keeps the number of patterns
more scalable.

One important change needed in this commit was to remove
the match filter that requires the input and output
activations to be quantized. This was necessary because
otherwise we would always expect q-dq nodes immediately
after the getitem node, instead of after the add or relu
nodes for example. This has another side benefit of
keeping QAT fusion flexible enough to support weight
only quantization.

Test Plan:
python test/test_quantization.py TestQuantizePT2EQAT

Reviewers: jerryzh168, kimishpatel

Subscribers: jerryzh168, kimishpatel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113006
Approved by: https://github.com/jerryzh168
This commit is contained in:
andrewor14 2023-11-13 17:01:31 -08:00 committed by PyTorch MergeBot
parent a7b75f586a
commit 14eb92cb43

View File

@ -1,7 +1,7 @@
import dataclasses
import itertools
import operator
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Tuple
import torch
from torch.fx import Graph, GraphModule, Node
@ -52,7 +52,6 @@ def _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(
is_per_channel: bool,
has_bias: bool,
is_cuda: bool,
has_add: bool = False,
) -> Dict[str, Any]:
"""
Optional example inputs for both `_quantized_qat_conv2d_bn_pattern`
@ -67,9 +66,6 @@ def _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(
kwargs["zero_point"] = torch.tensor([0], dtype=torch.int)
if has_bias:
kwargs["conv_bias"] = torch.randn(1)
if has_add:
# extra_input_for_add, use same shape as x here since conv weight is torch.randn(1, 1, 1, 1)
kwargs["extra_input"] = torch.randn(1, 1, 3, 3)
if is_cuda:
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
@ -147,53 +143,6 @@ def _qat_conv2d_bn_pattern_no_conv_bias(
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
return x
def _get_input_output_quantized_filter():
def _input_output_quantized_filter(
match: "InternalMatch", # type: ignore[name-defined]
original_graph: Graph,
pattern_graph: Graph,
) -> bool:
"""
Make sure that the matched pattern's input is coming from dq node
and the output is from q node. This is used to filter out the nodes for
conv-bn pattern.
We need to replace qat's conv-bn pattern with just conv-bn nodes.
QAT's conv-bn pattern has q-dq node inserted after convert step.
In order to replace QAT pattern, see _get_quantized_qat_conv2d_bn_pattern,
with a simpler pattern, see _get_folded_quantized_qat_conv2d_bn_pattern,
we need to port the quantization parameters from q/dq nodes of weight.
This porting becomes easier if there is only one q/dq node because we dont have to
reason about about finding the right q/dq node from original graph.
In order to facilitate that matched pattern and replacement pattern cannot have q for
input activation and dq for output of the fusion. Thus those nodes are removed from
pattern to be matched, however we still want to make sure that input activation of
the pattern is actually quantized and output is dequantized. Hence this filter.
"""
input_dq_node = None
output_q_node = None
for pattern_node, original_node in match.nodes_map.items():
if pattern_node.op == "placeholder":
if (
original_node.target
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
):
input_dq_node = original_node
# output node is not a separate node in the list of nodes seen in the matçh
# it is a node in the node.users list of the last node.
if (
len(pattern_node.users) == 1
and next(iter(pattern_node.users.keys())).op == "output"
):
output_node = next(iter(original_node.users.keys()))
if (
output_node.target
== torch.ops.quantized_decomposed.quantize_per_tensor.default
):
output_q_node = original_node
return (input_dq_node is not None) and (output_q_node is not None)
return _input_output_quantized_filter
def _append_qdq(x, is_per_channel, kwargs):
"""
Helper function to append q-dq ops after `x`, using dummy values for the qparams
@ -221,11 +170,7 @@ def _append_qdq(x, is_per_channel, kwargs):
def _get_quantized_qat_conv2d_bn_pattern(
is_per_channel: bool,
has_relu: bool,
has_bias: bool,
relu_is_inplace: bool,
has_add: bool,
add_is_inplace: bool,
bias_is_quantized: bool,
) -> Callable:
"""
@ -266,28 +211,12 @@ def _get_quantized_qat_conv2d_bn_pattern(
if has_bias:
x = x + kwargs["conv_bias"].reshape(bias_shape)
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
if has_add:
if add_is_inplace:
x += kwargs["extra_input"]
else:
x = x + kwargs["extra_input"]
if has_relu:
if relu_is_inplace:
x = F.relu_(x)
else:
x = F.relu(x)
return x
return _quantized_qat_conv2d_bn_pattern
def _get_folded_quantized_qat_conv2d_bn_pattern(
is_per_channel: bool,
has_relu: bool,
has_bias: bool,
relu_is_inplace: bool,
has_add: bool,
add_is_inplace: bool,
bias_is_quantized: bool,
) -> Callable:
"""
@ -314,18 +243,6 @@ def _get_folded_quantized_qat_conv2d_bn_pattern(
bias = None
x = F.conv2d(x, conv_weight, bias)
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
if has_add:
if add_is_inplace:
x += kwargs["extra_input"]
else:
x = x + kwargs["extra_input"]
if has_relu:
if relu_is_inplace:
x = F.relu_(x)
else:
x = F.relu(x)
return x
return _folded_quantized_qat_conv2d_bn_pattern
@ -382,15 +299,14 @@ def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Nod
The following names may exist in the map:
"conv_weight_q", "conv_weight_dq", "conv_bias",
"conv_bias_q", "conv_bias_dq", "add", "relu"
"conv_bias_q", "conv_bias_dq"
"""
def _get_nodes(nodes: List[Node]) -> Tuple[Node, Node, Node, Optional[Node], Optional[Node]]:
def _get_nodes(nodes: List[Node]) -> Tuple[Node, Node, Node]:
"""
Return a 5-tuple of (conv_node, bn_node, getitem_node, add_node, relu_node).
This asserts that the match contains exactly one conv, bn, and getitem,
and at most one add and one relu.
Return a 3-tuple of (conv_node, bn_node, getitem_node).
This asserts that the match contains exactly one of each node.
"""
conv_node, bn_node, getitem_node, add_node, relu_node = None, None, None, None, None
conv_node, bn_node, getitem_node = None, None, None
for n in nodes:
if n.op != "call_function":
continue
@ -403,21 +319,10 @@ def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Nod
if n.target == operator.getitem:
assert getitem_node is None
getitem_node = n
if n.target == torch.ops.aten.relu.default:
assert relu_node is None
relu_node = n
if (n.target in [torch.ops.aten.add_.Tensor, torch.ops.aten.add.Tensor]) and (
(isinstance(n.args[0], torch.fx.Node) and n.args[0].target == operator.getitem)
or (isinstance(n.args[1], torch.fx.Node) and n.args[1].target == operator.getitem)
):
# One of Add's input should be BN's getitem node
assert add_node is None
add_node = n
assert conv_node is not None
assert bn_node is not None
assert getitem_node is not None
return (conv_node, bn_node, getitem_node, add_node, relu_node)
return (conv_node, bn_node, getitem_node)
def _get_q_dq_nodes(n: Node) -> Tuple[Node, Node, Node]:
"""
@ -432,8 +337,8 @@ def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Nod
return (orig_node, q_node, n)
original_nodes = list(_filter_nodes_map(r.nodes_map).values())
o_conv, o_bn, o_getitem, o_add, o_relu = _get_nodes(original_nodes)
r_conv, r_bn, r_getitem, r_add, r_relu = _get_nodes(r.replacements)
o_conv, o_bn, o_getitem = _get_nodes(original_nodes)
r_conv, r_bn, r_getitem = _get_nodes(r.replacements)
# Create the mapping from original node to replacement node
mapping = {
@ -441,17 +346,11 @@ def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Nod
"bn": (o_bn, r_bn),
"getitem": (o_getitem, r_getitem),
}
if o_relu is not None:
assert r_relu is not None
mapping["relu"] = (o_relu, r_relu)
if o_add is not None:
assert r_add is not None
mapping["add"] = (o_add, r_add)
# Extract conv input and weight
# Note: here we extract the original nodes indirectly through the pattern nodes
# because the args of the original nodes are no longer available after replacement
(p_conv, _, _, _, _) = _get_nodes(list(r.nodes_map.keys()))
(p_conv, _, _) = _get_nodes(list(r.nodes_map.keys()))
(p_conv_input, p_conv_weight, *_) = p_conv.args
(r_conv_input, r_conv_weight, *_) = r_conv.args
assert isinstance(p_conv_input, Node)
@ -681,7 +580,7 @@ def _fuse_conv_bn_qat_helper(m: GraphModule, is_cuda: bool) -> GraphModule:
all_original_to_replacement_nodes = {}
for r in replacements_with_conv_bias + replacements_no_conv_bias:
for original_node, replacement_node in _get_conv_bn_pattern_nodes(r).values():
# Step (3a): Copy over metadata for all nodes in [conv - bn - getitem (- add) (- relu)]
# Step (3a): Copy over metadata for all nodes in [conv - bn - getitem]
replacement_node.meta = original_node.meta
if original_node.target == torch.ops.aten.conv2d.default:
# Step (3b): Copy over conv literal args
@ -789,34 +688,20 @@ def _fold_conv_bn_qat_helper(m: GraphModule, is_cuda: bool) -> GraphModule:
replacements = []
replacement_options = itertools.product(
[True, False], # is_per_channel
[True, False], # has_relu
[True, False], # has_bias
[True, False], # relu_is_inplace
[True, False], # has_add
[True, False], # add_is_inplace
[True, False], # bias_is_quantized
)
for (
is_per_channel, has_relu, has_bias, relu_is_inplace,
has_add, add_is_inplace, bias_is_quantized,
) in replacement_options:
# For the cases without relu, `relu_is_inplace` is irrelevant, so here we arbitrarily
for is_per_channel, has_bias, bias_is_quantized in replacement_options:
# For the cases without bias, `bias_is_quantized` is irrelevant, so here we arbitrarily
# filter out one of the values for this flag to avoid having duplicate patterns
# Same for `add_is_inplace` and `bias_is_quantized`
if (
(not has_relu and relu_is_inplace) or
(not has_add and add_is_inplace) or
(not has_bias and bias_is_quantized)
):
if not has_bias and bias_is_quantized:
continue
example_inputs = _quantized_conv2d_bn_pattern_example_inputs
kwargs = _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(is_per_channel, has_bias, is_cuda, has_add)
match_pattern = _get_quantized_qat_conv2d_bn_pattern(
is_per_channel, has_relu, has_bias, relu_is_inplace, has_add, add_is_inplace, bias_is_quantized
)
kwargs = _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(is_per_channel, has_bias, is_cuda)
match_pattern = _get_quantized_qat_conv2d_bn_pattern(is_per_channel, has_bias, bias_is_quantized)
match_pattern = get_aten_graph_module(match_pattern, example_inputs, is_cuda, **kwargs)
replacement_pattern = _get_folded_quantized_qat_conv2d_bn_pattern(
is_per_channel, has_relu, has_bias, relu_is_inplace, has_add, add_is_inplace, bias_is_quantized
is_per_channel, has_bias, bias_is_quantized,
)
replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, is_cuda, **kwargs)
replacements.extend(
@ -824,7 +709,6 @@ def _fold_conv_bn_qat_helper(m: GraphModule, is_cuda: bool) -> GraphModule:
m,
match_pattern,
replacement_pattern,
match_filters=[_get_input_output_quantized_filter()],
ignore_literals=True,
)
)