[quant][pt2] Handle constant conv args in prepare QAT fusion (#100525)

Summary: Previously, we would only match and replace conv + BN
patterns with default constant args for conv (stride, padding,
dilation etc.). If the user sets one of these args to values
that are different from the default, we would simply not fuse
the pattern. This is due to a limitation in the subgraph
rewriter: see https://github.com/pytorch/pytorch/issues/100419.

This commit works around the above limitation by first
configuring the subgraph rewriter to ignore literals when
matching, and then manually copy over the constant args to the
new subgraph after `replace_pattern`.

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_prepare_qat_conv_bn_fusion_constant_args

Reviewers: jerryzh168, kimishpatel

Differential Revision: [D45515437](https://our.internmc.facebook.com/intern/diff/D45515437)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100525
Approved by: https://github.com/jerryzh168
This commit is contained in:
andrewor14 2023-05-11 10:45:46 -07:00 committed by PyTorch MergeBot
parent 3f734c584e
commit 4434b9af6a
2 changed files with 52 additions and 5 deletions

View File

@ -2,7 +2,7 @@
import copy
import operator
import unittest
from typing import Any, List, Tuple
from typing import Any, List, Optional, Tuple
import torch
import torch._dynamo as torchdynamo
@ -535,6 +535,30 @@ class TestQuantizePT2E(QuantizationTestCase):
self._verify_symmetric_qnnpack_qat_graph(M(), example_inputs, is_per_channel=False, has_relu=False)
self._verify_symmetric_qnnpack_qat_graph(M(), example_inputs, is_per_channel=True, has_relu=False)
def test_prepare_qat_conv_bn_fusion_constant_args(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3, stride=(2, 2), padding=(4, 4))
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
example_inputs = (torch.randn(1, 3, 5, 5),)
# stride, padding, dilation, transposed, output_padding, groups
conv_args = ((2, 2), (4, 4), (1, 1), False, (0, 0), 1)
self._verify_symmetric_qnnpack_qat_graph(
M(), example_inputs, is_per_channel=False, has_relu=False, expected_conv_constant_args=conv_args
)
self._verify_symmetric_qnnpack_qat_graph(
M(), example_inputs, is_per_channel=True, has_relu=False, expected_conv_constant_args=conv_args
)
self._verify_symmetric_qnnpack_qat_numerics(M(), example_inputs, is_per_channel=False)
self._verify_symmetric_qnnpack_qat_numerics(M(), example_inputs, is_per_channel=True)
def test_prepare_qat_conv_bn_relu_fusion(self):
class M(torch.nn.Module):
def __init__(self):
@ -559,6 +583,7 @@ class TestQuantizePT2E(QuantizationTestCase):
example_inputs: Tuple[Any, ...],
is_per_channel: bool,
has_relu: bool,
expected_conv_constant_args: Optional[Tuple[Any, ...]] = None,
):
"""
Verify that the graph module matches the fused QAT [conv - bn (- relu)] pattern
@ -610,6 +635,12 @@ class TestQuantizePT2E(QuantizationTestCase):
self.assertEqual(conv_node.target, torch.ops.aten.convolution.default)
self.assertEqual(scale_factor_reshape_node.target, torch.ops.aten.view.default)
# Verify: conv constant args
if expected_conv_constant_args is not None:
assert len(expected_conv_constant_args) == 6, "wrong num conv args, bad test setup"
for i in range(6):
self.assertEqual(conv_node.args[i + 3], expected_conv_constant_args[i])
# Verify: conv input activation fake quantize
conv_input_fq_node = conv_node.args[0]
conv_input_node = conv_input_fq_node.args[0]
@ -657,6 +688,7 @@ class TestQuantizePT2E(QuantizationTestCase):
self.assertTrue("tensor_constant" in bn_running_var_node.target)
self.assertEqual(eps, 1e-5)
# TODO: merge these numerics tests with the graph tests above
def test_prepare_qat_conv_bn_numerics(self):
class M(torch.nn.Module):
def __init__(self):

View File

@ -204,12 +204,22 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
match_pattern = _get_aten_graph_module(_conv2d_bn_pattern, example_inputs)
replacement_pattern = _get_aten_graph_module(_qat_conv2d_bn_pattern, example_inputs)
# TODO: use the public replace_pattern API once it also returns replacement nodes
match_and_replacement = _replace_pattern(m, match_pattern, replacement_pattern)
match_and_replacement = _replace_pattern(m, match_pattern, replacement_pattern, ignore_literals=True)
m.recompile()
# Copy over metadata from original subgraph
# This ensures the stack traces and annotations are preserved in the new subgraph
# TODO: handle this in replace_pattern
# Due to limited functionality in the subgraph rewriter, here we manually
# update the replacement graph as follows:
#
# (1) Copy over metadata from original subgraph. This ensures the stack traces
# and annotations are preserved in the new subgraph
#
# (2) Copy over constant args for conv from the original subgraph
# TODO: do this for constant args for batchnorm as well
#
# In the future, we should try to push as much of this functionality into the
# subgraph rewriter as possible, so we don't have to manually copy anything over.
# For more detail, see https://github.com/pytorch/pytorch/issues/100419.
for mr in match_and_replacement:
# Find replacement conv and bn nodes by climbing upwards from anchor node
assert len(mr.replacements) == 1, "expected only one replacement node"
@ -227,9 +237,14 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
n = n.args[0]
# Copy over metadata for all three nodes in [conv - bn - getitem]
# Also copy over constant args for conv
for match_pattern_node, original_node in mr.nodes_map.items():
if original_node.target == torch.ops.aten.convolution.default:
replacement_conv_node.meta = original_node.meta
# Note: Unlike other tensor args like conv weights and biases, literal args are
# preserved in the original nodes after replacement, so we can access them here
# x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups]
replacement_conv_node.args = replacement_conv_node.args[:3] + original_node.args[3:]
if original_node.target == torch.ops.aten._native_batch_norm_legit.default:
replacement_bn_node.meta = original_node.meta
if original_node.target == operator.getitem: