[quant][pt2] Handle literal conv args in convert QAT (#103731)

Summary:
Similar to the prepare case, we need to manually copy
over literal conv args such as padding and stride to the new,
replaced conv nodes, since these args are not captured by the
subgraph rewriter.

Test Plan: python test/test_quantization.py TestQuantizePT2E.test_qat_conv_bn_fusion_literal_args

Reviewed By: jerryzh168

Differential Revision: D46383130

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103731
Approved by: https://github.com/jerryzh168
This commit is contained in:
Andrew Or 2023-06-16 17:15:37 +00:00 committed by PyTorch MergeBot
parent 08a054649c
commit 2bc56bec07
2 changed files with 57 additions and 33 deletions

View File

@ -340,7 +340,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
is_per_channel: bool,
has_relu: bool,
has_bias: bool = True,
expected_conv_constant_args: Optional[Tuple[Any, ...]] = None,
expected_conv_literal_args: Optional[Tuple[Any, ...]] = None,
):
"""
Verify that the graph module matches the fused QAT [conv - bn (- relu)] pattern
@ -400,13 +400,13 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
self.assertEqual(conv_node.target, torch.ops.aten.convolution.default)
self.assertEqual(scale_factor_reshape_node.target, torch.ops.aten.view.default)
# Verify: conv constant args
if expected_conv_constant_args is not None:
# Verify: conv literal args
if expected_conv_literal_args is not None:
assert (
len(expected_conv_constant_args) == 6
len(expected_conv_literal_args) == 6
), "wrong num conv args, bad test setup"
for i in range(6):
self.assertEqual(conv_node.args[i + 3], expected_conv_constant_args[i])
self.assertEqual(conv_node.args[i + 3], expected_conv_literal_args[i])
# Verify: conv input activation fake quantize
conv_input_fq_node = conv_node.args[0]
@ -1364,7 +1364,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m, example_inputs, is_per_channel=True, has_relu=False
)
def test_prepare_qat_conv_bn_fusion_constant_args(self):
def test_qat_conv_bn_fusion_literal_args(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
@ -1384,20 +1384,20 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
example_inputs,
is_per_channel=False,
has_relu=False,
expected_conv_constant_args=conv_args,
expected_conv_literal_args=conv_args,
)
self._verify_symmetric_qnnpack_qat_graph(
M(),
example_inputs,
is_per_channel=True,
has_relu=False,
expected_conv_constant_args=conv_args,
expected_conv_literal_args=conv_args,
)
self._verify_symmetric_qnnpack_qat_numerics(
M(), example_inputs, is_per_channel=False
M(), example_inputs, is_per_channel=False, verify_convert=True,
)
self._verify_symmetric_qnnpack_qat_numerics(
M(), example_inputs, is_per_channel=True
M(), example_inputs, is_per_channel=True, verify_convert=True,
)
def test_qat_conv_bn_fusion_no_conv_bias(self):

View File

@ -374,6 +374,43 @@ def _get_conv_bn_getitem_nodes(nodes: List[Node]) -> Tuple[Node, Node, Node]:
assert getitem_node is not None
return (conv_node, bn_node, getitem_node)
def _filter_nodes_map(nodes_map: Dict[Node, Node]) -> Dict[Node, Node]:
"""
Return a filtered `nodes_map` returned from the subgraph rewriter.
The filtered `nodes_map` will contain only nodes that are actually
matched in the pattern, excluding None or placeholder nodes.
"""
new_nodes_map: Dict[Node, Node] = {}
for pattern_node, graph_node in nodes_map.items():
# bias can be None
if graph_node is None:
continue
# skip pattern placeholder nodes
if pattern_node.op == "placeholder":
continue
new_nodes_map[pattern_node] = graph_node
return new_nodes_map
def _copy_over_literal_conv_args(original_node: Node, new_node: Node):
"""
Copy over literal args in conv, such as stride and padding, from the matched node
in the original graph to its replacement in the new graph.
This is needed due to the following limitation in the subgraph rewriter when used
with dynamo export: literal (non-tensor) args are not supported in the match and
replacement patterns. This is because dynamo export automatically inlines these
literal args, making them dead placeholder nodes. In the future, we should check
if dynamo export can optionally disable this inlining, or if subgraph rewriter
can do the copying for us. See https://github.com/pytorch/pytorch/issues/100419.
Note: Unlike other tensor args like conv weights and biases, literal args are
preserved in the original nodes after replacement, so we can access them here.
"""
assert original_node.target == torch.ops.aten.convolution.default
assert new_node.target == torch.ops.aten.convolution.default
# x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups]
new_node.args = new_node.args[:3] + original_node.args[3:]
def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
"""
Given a graph of decomposed aten ops, replace the (conv + bn) pattern with
@ -430,8 +467,8 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
# (1) Copy over metadata from original subgraph. This ensures the stack traces
# and annotations are preserved in the new subgraph
#
# (2) Copy over constant args for conv from the original subgraph
# TODO: do this for constant args for batchnorm as well
# (2) Copy over literal args for conv from the original subgraph
# TODO: do this for literal args for batchnorm as well
#
# In the future, we should try to push as much of this functionality into the
# subgraph rewriter as possible, so we don't have to manually copy anything over.
@ -442,34 +479,17 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
_get_conv_bn_getitem_nodes(r.replacements)
# Copy over metadata for all three nodes in [conv - bn - getitem]
# Also copy over constant args for conv
for match_pattern_node, original_node in r.nodes_map.items():
# bias can be None
if original_node is None:
continue
# We expect the subgraph rewriter to erase the non-literal args of the matched nodes.
# However, this is not done for placeholder nodes, since these nodes do not need to
# be replaced. Here we filter out these placeholder nodes since they do not need
# metadata copying. E.g. we want to filter out `getitem_placeholder` in this pattern:
#
# getitem_placeholder -> conv -> bn -> getitem
#
if any(isinstance(a, Node) for a in original_node.args):
continue
# Also copy over literal args for conv
for match_pattern_node, original_node in _filter_nodes_map(r.nodes_map).items():
if original_node.target == torch.ops.aten.convolution.default:
_copy_over_literal_conv_args(original_node, replacement_conv_node)
replacement_conv_node.meta = original_node.meta
# Note: Unlike other tensor args like conv weights and biases, literal args are
# preserved in the original nodes after replacement, so we can access them here
# x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups]
replacement_conv_node.args = replacement_conv_node.args[:3] + original_node.args[3:]
# original annotation is referring to the node object in the graph
# after rewrite we'll need to update this mapping (input_qspec_map)
# update quantization_annotation
original_input_qspec_map = original_node.meta["quantization_annotation"].input_qspec_map
if "quantization_annotation" not in original_node.meta:
continue
input_qspec_map = {}
# get the list of configs, it should be ordered as input, weight, bias
# note: this is really hacky, we need a better solution, hopefully
@ -482,7 +502,6 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
# bias
if len(replacement_conv_node.args) > 2 and len(all_configs) > 2:
input_qspec_map[replacement_conv_node.args[2]] = all_configs[2][1]
replacement_conv_node.meta["quantization_annotation"].input_qspec_map = input_qspec_map
if original_node.target == torch.ops.aten._native_batch_norm_legit.default:
replacement_bn_node.meta = original_node.meta
@ -575,6 +594,11 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
# fold bn weights into conv
_fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)
# Copy over literal args for conv
for _, original_node in _filter_nodes_map(r.nodes_map).items():
if original_node.target == torch.ops.aten.convolution.default:
_copy_over_literal_conv_args(original_node, conv_node)
m.graph.eliminate_dead_code()
m.recompile()
return m