[Quant] [PT2] Fix an issue in Conv Binary Quantization Annotation (#114540)

**Summary**
To annotate a conv-binary pattern, should skip the pattern if the conv node has more than one user.

**Test Plan**
```
python -m pytest test_x86inductor_quantizer.py -k test_conv2d_binary2
python -m pytest test_x86inductor_quantizer.py -k test_qat_conv2d_binary2
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114540
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
leslie-fang-intel 2023-11-26 09:09:59 +08:00 committed by PyTorch MergeBot
parent b1fb591272
commit e592b9a469
2 changed files with 118 additions and 2 deletions

View File

@ -258,6 +258,30 @@ class TestHelperModules:
def forward(self, x):
return self.postop(self.linear(x))
class Conv2dAddModule2(torch.nn.Module):
def __init__(self,
inplace_add: bool = False,
) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1
)
self.conv2 = torch.nn.Conv2d(
in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1
)
self.inplace_add = inplace_add
self.bn = torch.nn.BatchNorm2d(3)
self.bn2 = torch.nn.BatchNorm2d(3)
def forward(self, x):
if self.inplace_add:
tmp = self.bn(self.conv(x))
tmp += self.bn2(self.conv2(tmp))
return tmp
else:
tmp = self.bn(self.conv(x))
return tmp + self.bn2(self.conv2(tmp))
class X86InductorQuantTestCase(QuantizationTestCase):
def _test_quantizer(
self,
@ -418,6 +442,46 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
node_list,
)
@skipIfNoX86
def test_conv2d_binary2(self):
"""
Test Pattern:
tmp = conv2d_1(x)
tmp2 = conv2d_2(tmp)
return tmp + tmp2
Since conv2d_1 has 2 users, we should annotate conv2d_2 for binary fusion instead of conv2d_1
"""
example_inputs = (torch.randn(2, 3, 6, 6),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
inplace_add_list = [True, False]
with override_quantized_engine("x86"), torch.no_grad():
for inplace_add in inplace_add_list:
m = TestHelperModules.Conv2dAddModule2(inplace_add=inplace_add).eval()
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv2d.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor,
]
self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)
@skipIfNoX86
def test_conv2d_binary_unary(self):
"""
@ -1006,6 +1070,48 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
is_qat=True,
)
@skipIfNoX86
def test_qat_conv2d_binary2(self):
"""
Test qat Pattern:
tmp = bn1(conv2d_1(x))
tmp2 = bn2(conv2d_2(tmp))
return tmp + tmp2
Since conv2d_1 has 2 users, we should annotate conv2d_2 for binary fusion instead of conv2d_1
"""
example_inputs = (torch.randn(2, 3, 6, 6),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config(is_qat=True)
)
inplace_add_list = [True, False]
with override_quantized_engine("x86"), torch.no_grad():
for inplace_add in inplace_add_list:
m = TestHelperModules.Conv2dAddModule2(inplace_add=inplace_add)
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
# BN should be folded into Conv
torch.ops.aten._native_batch_norm_legit.default: 0,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv2d.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor,
]
self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
is_qat=True,
)
@skipIfNoX86
def test_qat_conv2d_binary_unary(self):
"""

View File

@ -445,7 +445,9 @@ class X86InductorQuantizer(Quantizer):
) = self._get_output_nodes_of_partitions(
[conv_partition, bn_partition, binary_partition, unary_partition]
)
if len(bn_output_node.users) != 1:
# Conv BN pattern should only has 1 user.
continue
(
bn_output_node_idx,
extra_input_node_idx,
@ -502,7 +504,9 @@ class X86InductorQuantizer(Quantizer):
) = self._get_output_nodes_of_partitions(
[conv_partition, bn_partition, binary_partition]
)
if len(bn_output_node.users) != 1:
# Conv BN pattern should only has 1 user.
continue
(
bn_output_node_idx,
extra_input_node_idx,
@ -634,6 +638,9 @@ class X86InductorQuantizer(Quantizer):
conv_node, binary_node, unary_node = self._get_output_nodes_of_partitions(
[conv_partition, binary_partition, unary_partition]
)
if len(conv_node.users) != 1:
# Conv Node should only has 1 user node
continue
conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(
conv_node, binary_node
)
@ -676,6 +683,9 @@ class X86InductorQuantizer(Quantizer):
conv_node, binary_node = self._get_output_nodes_of_partitions(
[conv_partition, binary_partition]
)
if len(conv_node.users) != 1:
# Conv Node should only has 1 user node
continue
conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(
conv_node, binary_node
)