mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][pt2][be] Remove add/relu from conv-bn QAT pattern (#113006)
Summary: This commit significantly simplifies the QAT fusion code for the `conv-bn` pattern by removing add and relu nodes from the match and replacement patterns. This does not reduce functionality; patterns like `conv-bn-relu`, `conv-bn-add`, and `conv-bn-add-relu` are still supported. We simply do not match these extra nodes, since there is actually no need to replace them. This has the additional benefit of reducing the number of patterns being matched by 16x, since for each add and relu variant of the `conv-bn` pattern there is also an in-place variant. This also enables more flexible `conv-bn` pattern matching in the future and keeps the number of patterns more scalable. One important change needed in this commit was to remove the match filter that requires the input and output activations to be quantized. This was necessary because otherwise we would always expect q-dq nodes immediately after the getitem node, instead of after the add or relu nodes for example. This has another side benefit of keeping QAT fusion flexible enough to support weight only quantization. Test Plan: python test/test_quantization.py TestQuantizePT2EQAT Reviewers: jerryzh168, kimishpatel Subscribers: jerryzh168, kimishpatel Pull Request resolved: https://github.com/pytorch/pytorch/pull/113006 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
a7b75f586a
commit
14eb92cb43
|
|
@ -1,7 +1,7 @@
|
|||
import dataclasses
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
|
|
@ -52,7 +52,6 @@ def _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(
|
|||
is_per_channel: bool,
|
||||
has_bias: bool,
|
||||
is_cuda: bool,
|
||||
has_add: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Optional example inputs for both `_quantized_qat_conv2d_bn_pattern`
|
||||
|
|
@ -67,9 +66,6 @@ def _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(
|
|||
kwargs["zero_point"] = torch.tensor([0], dtype=torch.int)
|
||||
if has_bias:
|
||||
kwargs["conv_bias"] = torch.randn(1)
|
||||
if has_add:
|
||||
# extra_input_for_add, use same shape as x here since conv weight is torch.randn(1, 1, 1, 1)
|
||||
kwargs["extra_input"] = torch.randn(1, 1, 3, 3)
|
||||
if is_cuda:
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
|
|
@ -147,53 +143,6 @@ def _qat_conv2d_bn_pattern_no_conv_bias(
|
|||
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
|
||||
return x
|
||||
|
||||
def _get_input_output_quantized_filter():
|
||||
def _input_output_quantized_filter(
|
||||
match: "InternalMatch", # type: ignore[name-defined]
|
||||
original_graph: Graph,
|
||||
pattern_graph: Graph,
|
||||
) -> bool:
|
||||
"""
|
||||
Make sure that the matched pattern's input is coming from dq node
|
||||
and the output is from q node. This is used to filter out the nodes for
|
||||
conv-bn pattern.
|
||||
We need to replace qat's conv-bn pattern with just conv-bn nodes.
|
||||
QAT's conv-bn pattern has q-dq node inserted after convert step.
|
||||
In order to replace QAT pattern, see _get_quantized_qat_conv2d_bn_pattern,
|
||||
with a simpler pattern, see _get_folded_quantized_qat_conv2d_bn_pattern,
|
||||
we need to port the quantization parameters from q/dq nodes of weight.
|
||||
This porting becomes easier if there is only one q/dq node because we dont have to
|
||||
reason about about finding the right q/dq node from original graph.
|
||||
In order to facilitate that matched pattern and replacement pattern cannot have q for
|
||||
input activation and dq for output of the fusion. Thus those nodes are removed from
|
||||
pattern to be matched, however we still want to make sure that input activation of
|
||||
the pattern is actually quantized and output is dequantized. Hence this filter.
|
||||
"""
|
||||
input_dq_node = None
|
||||
output_q_node = None
|
||||
for pattern_node, original_node in match.nodes_map.items():
|
||||
if pattern_node.op == "placeholder":
|
||||
if (
|
||||
original_node.target
|
||||
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
):
|
||||
input_dq_node = original_node
|
||||
# output node is not a separate node in the list of nodes seen in the matçh
|
||||
# it is a node in the node.users list of the last node.
|
||||
if (
|
||||
len(pattern_node.users) == 1
|
||||
and next(iter(pattern_node.users.keys())).op == "output"
|
||||
):
|
||||
output_node = next(iter(original_node.users.keys()))
|
||||
if (
|
||||
output_node.target
|
||||
== torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
):
|
||||
output_q_node = original_node
|
||||
return (input_dq_node is not None) and (output_q_node is not None)
|
||||
|
||||
return _input_output_quantized_filter
|
||||
|
||||
def _append_qdq(x, is_per_channel, kwargs):
|
||||
"""
|
||||
Helper function to append q-dq ops after `x`, using dummy values for the qparams
|
||||
|
|
@ -221,11 +170,7 @@ def _append_qdq(x, is_per_channel, kwargs):
|
|||
|
||||
def _get_quantized_qat_conv2d_bn_pattern(
|
||||
is_per_channel: bool,
|
||||
has_relu: bool,
|
||||
has_bias: bool,
|
||||
relu_is_inplace: bool,
|
||||
has_add: bool,
|
||||
add_is_inplace: bool,
|
||||
bias_is_quantized: bool,
|
||||
) -> Callable:
|
||||
"""
|
||||
|
|
@ -266,28 +211,12 @@ def _get_quantized_qat_conv2d_bn_pattern(
|
|||
if has_bias:
|
||||
x = x + kwargs["conv_bias"].reshape(bias_shape)
|
||||
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
|
||||
|
||||
if has_add:
|
||||
if add_is_inplace:
|
||||
x += kwargs["extra_input"]
|
||||
else:
|
||||
x = x + kwargs["extra_input"]
|
||||
|
||||
if has_relu:
|
||||
if relu_is_inplace:
|
||||
x = F.relu_(x)
|
||||
else:
|
||||
x = F.relu(x)
|
||||
return x
|
||||
return _quantized_qat_conv2d_bn_pattern
|
||||
|
||||
def _get_folded_quantized_qat_conv2d_bn_pattern(
|
||||
is_per_channel: bool,
|
||||
has_relu: bool,
|
||||
has_bias: bool,
|
||||
relu_is_inplace: bool,
|
||||
has_add: bool,
|
||||
add_is_inplace: bool,
|
||||
bias_is_quantized: bool,
|
||||
) -> Callable:
|
||||
"""
|
||||
|
|
@ -314,18 +243,6 @@ def _get_folded_quantized_qat_conv2d_bn_pattern(
|
|||
bias = None
|
||||
x = F.conv2d(x, conv_weight, bias)
|
||||
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
|
||||
|
||||
if has_add:
|
||||
if add_is_inplace:
|
||||
x += kwargs["extra_input"]
|
||||
else:
|
||||
x = x + kwargs["extra_input"]
|
||||
|
||||
if has_relu:
|
||||
if relu_is_inplace:
|
||||
x = F.relu_(x)
|
||||
else:
|
||||
x = F.relu(x)
|
||||
return x
|
||||
return _folded_quantized_qat_conv2d_bn_pattern
|
||||
|
||||
|
|
@ -382,15 +299,14 @@ def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Nod
|
|||
The following names may exist in the map:
|
||||
|
||||
"conv_weight_q", "conv_weight_dq", "conv_bias",
|
||||
"conv_bias_q", "conv_bias_dq", "add", "relu"
|
||||
"conv_bias_q", "conv_bias_dq"
|
||||
"""
|
||||
def _get_nodes(nodes: List[Node]) -> Tuple[Node, Node, Node, Optional[Node], Optional[Node]]:
|
||||
def _get_nodes(nodes: List[Node]) -> Tuple[Node, Node, Node]:
|
||||
"""
|
||||
Return a 5-tuple of (conv_node, bn_node, getitem_node, add_node, relu_node).
|
||||
This asserts that the match contains exactly one conv, bn, and getitem,
|
||||
and at most one add and one relu.
|
||||
Return a 3-tuple of (conv_node, bn_node, getitem_node).
|
||||
This asserts that the match contains exactly one of each node.
|
||||
"""
|
||||
conv_node, bn_node, getitem_node, add_node, relu_node = None, None, None, None, None
|
||||
conv_node, bn_node, getitem_node = None, None, None
|
||||
for n in nodes:
|
||||
if n.op != "call_function":
|
||||
continue
|
||||
|
|
@ -403,21 +319,10 @@ def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Nod
|
|||
if n.target == operator.getitem:
|
||||
assert getitem_node is None
|
||||
getitem_node = n
|
||||
if n.target == torch.ops.aten.relu.default:
|
||||
assert relu_node is None
|
||||
relu_node = n
|
||||
if (n.target in [torch.ops.aten.add_.Tensor, torch.ops.aten.add.Tensor]) and (
|
||||
(isinstance(n.args[0], torch.fx.Node) and n.args[0].target == operator.getitem)
|
||||
or (isinstance(n.args[1], torch.fx.Node) and n.args[1].target == operator.getitem)
|
||||
):
|
||||
# One of Add's input should be BN's getitem node
|
||||
assert add_node is None
|
||||
add_node = n
|
||||
|
||||
assert conv_node is not None
|
||||
assert bn_node is not None
|
||||
assert getitem_node is not None
|
||||
return (conv_node, bn_node, getitem_node, add_node, relu_node)
|
||||
return (conv_node, bn_node, getitem_node)
|
||||
|
||||
def _get_q_dq_nodes(n: Node) -> Tuple[Node, Node, Node]:
|
||||
"""
|
||||
|
|
@ -432,8 +337,8 @@ def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Nod
|
|||
return (orig_node, q_node, n)
|
||||
|
||||
original_nodes = list(_filter_nodes_map(r.nodes_map).values())
|
||||
o_conv, o_bn, o_getitem, o_add, o_relu = _get_nodes(original_nodes)
|
||||
r_conv, r_bn, r_getitem, r_add, r_relu = _get_nodes(r.replacements)
|
||||
o_conv, o_bn, o_getitem = _get_nodes(original_nodes)
|
||||
r_conv, r_bn, r_getitem = _get_nodes(r.replacements)
|
||||
|
||||
# Create the mapping from original node to replacement node
|
||||
mapping = {
|
||||
|
|
@ -441,17 +346,11 @@ def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Nod
|
|||
"bn": (o_bn, r_bn),
|
||||
"getitem": (o_getitem, r_getitem),
|
||||
}
|
||||
if o_relu is not None:
|
||||
assert r_relu is not None
|
||||
mapping["relu"] = (o_relu, r_relu)
|
||||
if o_add is not None:
|
||||
assert r_add is not None
|
||||
mapping["add"] = (o_add, r_add)
|
||||
|
||||
# Extract conv input and weight
|
||||
# Note: here we extract the original nodes indirectly through the pattern nodes
|
||||
# because the args of the original nodes are no longer available after replacement
|
||||
(p_conv, _, _, _, _) = _get_nodes(list(r.nodes_map.keys()))
|
||||
(p_conv, _, _) = _get_nodes(list(r.nodes_map.keys()))
|
||||
(p_conv_input, p_conv_weight, *_) = p_conv.args
|
||||
(r_conv_input, r_conv_weight, *_) = r_conv.args
|
||||
assert isinstance(p_conv_input, Node)
|
||||
|
|
@ -681,7 +580,7 @@ def _fuse_conv_bn_qat_helper(m: GraphModule, is_cuda: bool) -> GraphModule:
|
|||
all_original_to_replacement_nodes = {}
|
||||
for r in replacements_with_conv_bias + replacements_no_conv_bias:
|
||||
for original_node, replacement_node in _get_conv_bn_pattern_nodes(r).values():
|
||||
# Step (3a): Copy over metadata for all nodes in [conv - bn - getitem (- add) (- relu)]
|
||||
# Step (3a): Copy over metadata for all nodes in [conv - bn - getitem]
|
||||
replacement_node.meta = original_node.meta
|
||||
if original_node.target == torch.ops.aten.conv2d.default:
|
||||
# Step (3b): Copy over conv literal args
|
||||
|
|
@ -789,34 +688,20 @@ def _fold_conv_bn_qat_helper(m: GraphModule, is_cuda: bool) -> GraphModule:
|
|||
replacements = []
|
||||
replacement_options = itertools.product(
|
||||
[True, False], # is_per_channel
|
||||
[True, False], # has_relu
|
||||
[True, False], # has_bias
|
||||
[True, False], # relu_is_inplace
|
||||
[True, False], # has_add
|
||||
[True, False], # add_is_inplace
|
||||
[True, False], # bias_is_quantized
|
||||
)
|
||||
for (
|
||||
is_per_channel, has_relu, has_bias, relu_is_inplace,
|
||||
has_add, add_is_inplace, bias_is_quantized,
|
||||
) in replacement_options:
|
||||
# For the cases without relu, `relu_is_inplace` is irrelevant, so here we arbitrarily
|
||||
for is_per_channel, has_bias, bias_is_quantized in replacement_options:
|
||||
# For the cases without bias, `bias_is_quantized` is irrelevant, so here we arbitrarily
|
||||
# filter out one of the values for this flag to avoid having duplicate patterns
|
||||
# Same for `add_is_inplace` and `bias_is_quantized`
|
||||
if (
|
||||
(not has_relu and relu_is_inplace) or
|
||||
(not has_add and add_is_inplace) or
|
||||
(not has_bias and bias_is_quantized)
|
||||
):
|
||||
if not has_bias and bias_is_quantized:
|
||||
continue
|
||||
example_inputs = _quantized_conv2d_bn_pattern_example_inputs
|
||||
kwargs = _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(is_per_channel, has_bias, is_cuda, has_add)
|
||||
match_pattern = _get_quantized_qat_conv2d_bn_pattern(
|
||||
is_per_channel, has_relu, has_bias, relu_is_inplace, has_add, add_is_inplace, bias_is_quantized
|
||||
)
|
||||
kwargs = _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(is_per_channel, has_bias, is_cuda)
|
||||
match_pattern = _get_quantized_qat_conv2d_bn_pattern(is_per_channel, has_bias, bias_is_quantized)
|
||||
match_pattern = get_aten_graph_module(match_pattern, example_inputs, is_cuda, **kwargs)
|
||||
replacement_pattern = _get_folded_quantized_qat_conv2d_bn_pattern(
|
||||
is_per_channel, has_relu, has_bias, relu_is_inplace, has_add, add_is_inplace, bias_is_quantized
|
||||
is_per_channel, has_bias, bias_is_quantized,
|
||||
)
|
||||
replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, is_cuda, **kwargs)
|
||||
replacements.extend(
|
||||
|
|
@ -824,7 +709,6 @@ def _fold_conv_bn_qat_helper(m: GraphModule, is_cuda: bool) -> GraphModule:
|
|||
m,
|
||||
match_pattern,
|
||||
replacement_pattern,
|
||||
match_filters=[_get_input_output_quantized_filter()],
|
||||
ignore_literals=True,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user