mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][pt2] Handle constant conv args in prepare QAT fusion (#100525)
Summary: Previously, we would only match and replace conv + BN patterns with default constant args for conv (stride, padding, dilation etc.). If the user sets one of these args to values that are different from the default, we would simply not fuse the pattern. This is due to a limitation in the subgraph rewriter: see https://github.com/pytorch/pytorch/issues/100419. This commit works around the above limitation by first configuring the subgraph rewriter to ignore literals when matching, and then manually copy over the constant args to the new subgraph after `replace_pattern`. Test Plan: python test/test_quantization.py TestQuantizePT2E.test_prepare_qat_conv_bn_fusion_constant_args Reviewers: jerryzh168, kimishpatel Differential Revision: [D45515437](https://our.internmc.facebook.com/intern/diff/D45515437) Pull Request resolved: https://github.com/pytorch/pytorch/pull/100525 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
3f734c584e
commit
4434b9af6a
|
|
@ -2,7 +2,7 @@
|
|||
import copy
|
||||
import operator
|
||||
import unittest
|
||||
from typing import Any, List, Tuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
|
|
@ -535,6 +535,30 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
self._verify_symmetric_qnnpack_qat_graph(M(), example_inputs, is_per_channel=False, has_relu=False)
|
||||
self._verify_symmetric_qnnpack_qat_graph(M(), example_inputs, is_per_channel=True, has_relu=False)
|
||||
|
||||
def test_prepare_qat_conv_bn_fusion_constant_args(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3, stride=(2, 2), padding=(4, 4))
|
||||
self.bn = torch.nn.BatchNorm2d(3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
# stride, padding, dilation, transposed, output_padding, groups
|
||||
conv_args = ((2, 2), (4, 4), (1, 1), False, (0, 0), 1)
|
||||
self._verify_symmetric_qnnpack_qat_graph(
|
||||
M(), example_inputs, is_per_channel=False, has_relu=False, expected_conv_constant_args=conv_args
|
||||
)
|
||||
self._verify_symmetric_qnnpack_qat_graph(
|
||||
M(), example_inputs, is_per_channel=True, has_relu=False, expected_conv_constant_args=conv_args
|
||||
)
|
||||
self._verify_symmetric_qnnpack_qat_numerics(M(), example_inputs, is_per_channel=False)
|
||||
self._verify_symmetric_qnnpack_qat_numerics(M(), example_inputs, is_per_channel=True)
|
||||
|
||||
def test_prepare_qat_conv_bn_relu_fusion(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -559,6 +583,7 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
example_inputs: Tuple[Any, ...],
|
||||
is_per_channel: bool,
|
||||
has_relu: bool,
|
||||
expected_conv_constant_args: Optional[Tuple[Any, ...]] = None,
|
||||
):
|
||||
"""
|
||||
Verify that the graph module matches the fused QAT [conv - bn (- relu)] pattern
|
||||
|
|
@ -610,6 +635,12 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
self.assertEqual(conv_node.target, torch.ops.aten.convolution.default)
|
||||
self.assertEqual(scale_factor_reshape_node.target, torch.ops.aten.view.default)
|
||||
|
||||
# Verify: conv constant args
|
||||
if expected_conv_constant_args is not None:
|
||||
assert len(expected_conv_constant_args) == 6, "wrong num conv args, bad test setup"
|
||||
for i in range(6):
|
||||
self.assertEqual(conv_node.args[i + 3], expected_conv_constant_args[i])
|
||||
|
||||
# Verify: conv input activation fake quantize
|
||||
conv_input_fq_node = conv_node.args[0]
|
||||
conv_input_node = conv_input_fq_node.args[0]
|
||||
|
|
@ -657,6 +688,7 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
self.assertTrue("tensor_constant" in bn_running_var_node.target)
|
||||
self.assertEqual(eps, 1e-5)
|
||||
|
||||
# TODO: merge these numerics tests with the graph tests above
|
||||
def test_prepare_qat_conv_bn_numerics(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -204,12 +204,22 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
|||
match_pattern = _get_aten_graph_module(_conv2d_bn_pattern, example_inputs)
|
||||
replacement_pattern = _get_aten_graph_module(_qat_conv2d_bn_pattern, example_inputs)
|
||||
# TODO: use the public replace_pattern API once it also returns replacement nodes
|
||||
match_and_replacement = _replace_pattern(m, match_pattern, replacement_pattern)
|
||||
match_and_replacement = _replace_pattern(m, match_pattern, replacement_pattern, ignore_literals=True)
|
||||
m.recompile()
|
||||
|
||||
# Copy over metadata from original subgraph
|
||||
# This ensures the stack traces and annotations are preserved in the new subgraph
|
||||
# TODO: handle this in replace_pattern
|
||||
# Due to limited functionality in the subgraph rewriter, here we manually
|
||||
# update the replacement graph as follows:
|
||||
#
|
||||
# (1) Copy over metadata from original subgraph. This ensures the stack traces
|
||||
# and annotations are preserved in the new subgraph
|
||||
#
|
||||
# (2) Copy over constant args for conv from the original subgraph
|
||||
# TODO: do this for constant args for batchnorm as well
|
||||
#
|
||||
# In the future, we should try to push as much of this functionality into the
|
||||
# subgraph rewriter as possible, so we don't have to manually copy anything over.
|
||||
# For more detail, see https://github.com/pytorch/pytorch/issues/100419.
|
||||
|
||||
for mr in match_and_replacement:
|
||||
# Find replacement conv and bn nodes by climbing upwards from anchor node
|
||||
assert len(mr.replacements) == 1, "expected only one replacement node"
|
||||
|
|
@ -227,9 +237,14 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
|||
n = n.args[0]
|
||||
|
||||
# Copy over metadata for all three nodes in [conv - bn - getitem]
|
||||
# Also copy over constant args for conv
|
||||
for match_pattern_node, original_node in mr.nodes_map.items():
|
||||
if original_node.target == torch.ops.aten.convolution.default:
|
||||
replacement_conv_node.meta = original_node.meta
|
||||
# Note: Unlike other tensor args like conv weights and biases, literal args are
|
||||
# preserved in the original nodes after replacement, so we can access them here
|
||||
# x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups]
|
||||
replacement_conv_node.args = replacement_conv_node.args[:3] + original_node.args[3:]
|
||||
if original_node.target == torch.ops.aten._native_batch_norm_legit.default:
|
||||
replacement_bn_node.meta = original_node.meta
|
||||
if original_node.target == operator.getitem:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user