mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][pt2e] Fix annotation for conv no bias case (#107971)
Summary: This fixes the no bias case for conv annotations. Previously this would result in an index out of bounds, since the new aten.conv2d op may not have the bias arg (unlike the old aten.convolution op). This was not caught because of a lack of test cases, which are added in this commit. Test Plan: python test/test_quantization.py TestQuantizePT2E.test_qat_conv_no_bias python test/test_quantization.py TestQuantizePT2E.test_qat_conv_bn_relu_fusion_no_conv_bias Reviewers: jerryzh168, kimishpatel Subscribers: jerryzh168, kimishpatel Differential Revision: [D48696874](https://our.internmc.facebook.com/intern/diff/D48696874) Pull Request resolved: https://github.com/pytorch/pytorch/pull/107971 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
25d98a3e3b
commit
240bdbea61
|
|
@ -1830,6 +1830,34 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
qconfig_mapping,
|
||||
)
|
||||
|
||||
def test_qat_conv_no_bias(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, has_relu: bool):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3, bias=False)
|
||||
self.relu = torch.nn.ReLU() if has_relu else torch.nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
# simple conv
|
||||
self._verify_symmetric_qnnpack_qat_numerics(
|
||||
M(has_relu=False), example_inputs, is_per_channel=False, verify_convert=True,
|
||||
)
|
||||
self._verify_symmetric_qnnpack_qat_numerics(
|
||||
M(has_relu=False), example_inputs, is_per_channel=True, verify_convert=True,
|
||||
)
|
||||
# conv + relu
|
||||
self._verify_symmetric_qnnpack_qat_numerics(
|
||||
M(has_relu=True), example_inputs, is_per_channel=False, verify_convert=True,
|
||||
)
|
||||
self._verify_symmetric_qnnpack_qat_numerics(
|
||||
M(has_relu=True), example_inputs, is_per_channel=True, verify_convert=True,
|
||||
)
|
||||
|
||||
def test_prepare_qat_conv_bn_fusion(self):
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
m = TestHelperModules.ConvWithBNRelu(relu=False)
|
||||
|
|
@ -1931,6 +1959,25 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
m1, example_inputs, is_per_channel=True, has_relu=True
|
||||
)
|
||||
|
||||
def test_qat_conv_bn_relu_fusion_no_conv_bias(self):
|
||||
m1 = TestHelperModules.ConvWithBNRelu(relu=True, bias=False)
|
||||
example_inputs = (torch.randn(3, 3, 5, 5),)
|
||||
self._verify_symmetric_qnnpack_qat_graph(
|
||||
m1, example_inputs, is_per_channel=False, has_relu=True, has_bias=False,
|
||||
)
|
||||
m1 = TestHelperModules.ConvWithBNRelu(relu=True, bias=False)
|
||||
self._verify_symmetric_qnnpack_qat_numerics(
|
||||
m1, example_inputs, is_per_channel=False, verify_convert=True,
|
||||
)
|
||||
m1 = TestHelperModules.ConvWithBNRelu(relu=True, bias=False)
|
||||
self._verify_symmetric_qnnpack_qat_graph(
|
||||
m1, example_inputs, is_per_channel=True, has_relu=True, has_bias=False,
|
||||
)
|
||||
m1 = TestHelperModules.ConvWithBNRelu(relu=True, bias=False)
|
||||
self._verify_symmetric_qnnpack_qat_numerics(
|
||||
m1, example_inputs, is_per_channel=True, verify_convert=True,
|
||||
)
|
||||
|
||||
def test_qat_inplace_add_relu(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -246,7 +246,7 @@ def _annotate_conv2d(
|
|||
assert isinstance(weight, Node)
|
||||
input_qspec_map[weight] = get_weight_qspec(quantization_config)
|
||||
|
||||
bias = conv_node.args[2]
|
||||
bias = conv_node.args[2] if len(conv_node.args) > 2 else None
|
||||
if isinstance(bias, Node):
|
||||
input_qspec_map[bias] = get_bias_qspec(quantization_config)
|
||||
|
||||
|
|
@ -299,7 +299,7 @@ def _annotate_conv2d_relu(
|
|||
assert isinstance(weight, Node)
|
||||
input_qspec_map[weight] = get_weight_qspec(quantization_config)
|
||||
|
||||
bias = conv_node.args[2]
|
||||
bias = conv_node.args[2] if len(conv_node.args) > 2 else None
|
||||
if isinstance(bias, Node):
|
||||
input_qspec_map[bias] = get_bias_qspec(quantization_config)
|
||||
|
||||
|
|
@ -408,7 +408,7 @@ def _annotate_conv2d_bn_relu(
|
|||
assert isinstance(weight, Node)
|
||||
input_qspec_map[weight] = get_weight_qspec(quantization_config)
|
||||
|
||||
bias = conv_node.args[2]
|
||||
bias = conv_node.args[2] if len(conv_node.args) > 2 else None
|
||||
if isinstance(bias, Node):
|
||||
input_qspec_map[bias] = get_bias_qspec(quantization_config)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user