[quant][pt2e] Fix annotation for conv no bias case (#107971)

Summary: This fixes the no bias case for conv annotations.
Previously this would result in an index out of bounds, since
the new aten.conv2d op may not have the bias arg (unlike the
old aten.convolution op). This was not caught because of a lack
of test cases, which are added in this commit.

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_qat_conv_no_bias
python test/test_quantization.py TestQuantizePT2E.test_qat_conv_bn_relu_fusion_no_conv_bias

Reviewers: jerryzh168, kimishpatel

Subscribers: jerryzh168, kimishpatel

Differential Revision: [D48696874](https://our.internmc.facebook.com/intern/diff/D48696874)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107971
Approved by: https://github.com/jerryzh168
This commit is contained in:
andrewor14 2023-08-25 13:05:03 -07:00 committed by PyTorch MergeBot
parent 25d98a3e3b
commit 240bdbea61
2 changed files with 50 additions and 3 deletions

View File

@ -1830,6 +1830,34 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
qconfig_mapping,
)
def test_qat_conv_no_bias(self):
class M(torch.nn.Module):
def __init__(self, has_relu: bool):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3, bias=False)
self.relu = torch.nn.ReLU() if has_relu else torch.nn.Identity()
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x
example_inputs = (torch.randn(1, 3, 5, 5),)
# simple conv
self._verify_symmetric_qnnpack_qat_numerics(
M(has_relu=False), example_inputs, is_per_channel=False, verify_convert=True,
)
self._verify_symmetric_qnnpack_qat_numerics(
M(has_relu=False), example_inputs, is_per_channel=True, verify_convert=True,
)
# conv + relu
self._verify_symmetric_qnnpack_qat_numerics(
M(has_relu=True), example_inputs, is_per_channel=False, verify_convert=True,
)
self._verify_symmetric_qnnpack_qat_numerics(
M(has_relu=True), example_inputs, is_per_channel=True, verify_convert=True,
)
def test_prepare_qat_conv_bn_fusion(self):
example_inputs = (torch.randn(1, 3, 5, 5),)
m = TestHelperModules.ConvWithBNRelu(relu=False)
@ -1931,6 +1959,25 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m1, example_inputs, is_per_channel=True, has_relu=True
)
def test_qat_conv_bn_relu_fusion_no_conv_bias(self):
m1 = TestHelperModules.ConvWithBNRelu(relu=True, bias=False)
example_inputs = (torch.randn(3, 3, 5, 5),)
self._verify_symmetric_qnnpack_qat_graph(
m1, example_inputs, is_per_channel=False, has_relu=True, has_bias=False,
)
m1 = TestHelperModules.ConvWithBNRelu(relu=True, bias=False)
self._verify_symmetric_qnnpack_qat_numerics(
m1, example_inputs, is_per_channel=False, verify_convert=True,
)
m1 = TestHelperModules.ConvWithBNRelu(relu=True, bias=False)
self._verify_symmetric_qnnpack_qat_graph(
m1, example_inputs, is_per_channel=True, has_relu=True, has_bias=False,
)
m1 = TestHelperModules.ConvWithBNRelu(relu=True, bias=False)
self._verify_symmetric_qnnpack_qat_numerics(
m1, example_inputs, is_per_channel=True, verify_convert=True,
)
def test_qat_inplace_add_relu(self):
class M(torch.nn.Module):
def __init__(self):

View File

@ -246,7 +246,7 @@ def _annotate_conv2d(
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
bias = conv_node.args[2]
bias = conv_node.args[2] if len(conv_node.args) > 2 else None
if isinstance(bias, Node):
input_qspec_map[bias] = get_bias_qspec(quantization_config)
@ -299,7 +299,7 @@ def _annotate_conv2d_relu(
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
bias = conv_node.args[2]
bias = conv_node.args[2] if len(conv_node.args) > 2 else None
if isinstance(bias, Node):
input_qspec_map[bias] = get_bias_qspec(quantization_config)
@ -408,7 +408,7 @@ def _annotate_conv2d_bn_relu(
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
bias = conv_node.args[2]
bias = conv_node.args[2] if len(conv_node.args) > 2 else None
if isinstance(bias, Node):
input_qspec_map[bias] = get_bias_qspec(quantization_config)