mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Quant] [PT2] Fix an issue in Conv Binary Quantization Annotation (#114540)
**Summary** To annotate a conv-binary pattern, should skip the pattern if the conv node has more than one user. **Test Plan** ``` python -m pytest test_x86inductor_quantizer.py -k test_conv2d_binary2 python -m pytest test_x86inductor_quantizer.py -k test_qat_conv2d_binary2 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114540 Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
parent
b1fb591272
commit
e592b9a469
|
|
@ -258,6 +258,30 @@ class TestHelperModules:
|
|||
def forward(self, x):
|
||||
return self.postop(self.linear(x))
|
||||
|
||||
class Conv2dAddModule2(torch.nn.Module):
|
||||
def __init__(self,
|
||||
inplace_add: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
self.inplace_add = inplace_add
|
||||
self.bn = torch.nn.BatchNorm2d(3)
|
||||
self.bn2 = torch.nn.BatchNorm2d(3)
|
||||
|
||||
def forward(self, x):
|
||||
if self.inplace_add:
|
||||
tmp = self.bn(self.conv(x))
|
||||
tmp += self.bn2(self.conv2(tmp))
|
||||
return tmp
|
||||
else:
|
||||
tmp = self.bn(self.conv(x))
|
||||
return tmp + self.bn2(self.conv2(tmp))
|
||||
|
||||
class X86InductorQuantTestCase(QuantizationTestCase):
|
||||
def _test_quantizer(
|
||||
self,
|
||||
|
|
@ -418,6 +442,46 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
node_list,
|
||||
)
|
||||
|
||||
|
||||
@skipIfNoX86
|
||||
def test_conv2d_binary2(self):
|
||||
"""
|
||||
Test Pattern:
|
||||
tmp = conv2d_1(x)
|
||||
tmp2 = conv2d_2(tmp)
|
||||
return tmp + tmp2
|
||||
Since conv2d_1 has 2 users, we should annotate conv2d_2 for binary fusion instead of conv2d_1
|
||||
"""
|
||||
example_inputs = (torch.randn(2, 3, 6, 6),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
inplace_add_list = [True, False]
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
for inplace_add in inplace_add_list:
|
||||
m = TestHelperModules.Conv2dAddModule2(inplace_add=inplace_add).eval()
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.conv2d.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor,
|
||||
]
|
||||
self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_conv2d_binary_unary(self):
|
||||
"""
|
||||
|
|
@ -1006,6 +1070,48 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
is_qat=True,
|
||||
)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_qat_conv2d_binary2(self):
|
||||
"""
|
||||
Test qat Pattern:
|
||||
tmp = bn1(conv2d_1(x))
|
||||
tmp2 = bn2(conv2d_2(tmp))
|
||||
return tmp + tmp2
|
||||
Since conv2d_1 has 2 users, we should annotate conv2d_2 for binary fusion instead of conv2d_1
|
||||
"""
|
||||
example_inputs = (torch.randn(2, 3, 6, 6),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config(is_qat=True)
|
||||
)
|
||||
inplace_add_list = [True, False]
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
for inplace_add in inplace_add_list:
|
||||
m = TestHelperModules.Conv2dAddModule2(inplace_add=inplace_add)
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
# BN should be folded into Conv
|
||||
torch.ops.aten._native_batch_norm_legit.default: 0,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.conv2d.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor,
|
||||
]
|
||||
self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
is_qat=True,
|
||||
)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_qat_conv2d_binary_unary(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -445,7 +445,9 @@ class X86InductorQuantizer(Quantizer):
|
|||
) = self._get_output_nodes_of_partitions(
|
||||
[conv_partition, bn_partition, binary_partition, unary_partition]
|
||||
)
|
||||
|
||||
if len(bn_output_node.users) != 1:
|
||||
# Conv BN pattern should only has 1 user.
|
||||
continue
|
||||
(
|
||||
bn_output_node_idx,
|
||||
extra_input_node_idx,
|
||||
|
|
@ -502,7 +504,9 @@ class X86InductorQuantizer(Quantizer):
|
|||
) = self._get_output_nodes_of_partitions(
|
||||
[conv_partition, bn_partition, binary_partition]
|
||||
)
|
||||
|
||||
if len(bn_output_node.users) != 1:
|
||||
# Conv BN pattern should only has 1 user.
|
||||
continue
|
||||
(
|
||||
bn_output_node_idx,
|
||||
extra_input_node_idx,
|
||||
|
|
@ -634,6 +638,9 @@ class X86InductorQuantizer(Quantizer):
|
|||
conv_node, binary_node, unary_node = self._get_output_nodes_of_partitions(
|
||||
[conv_partition, binary_partition, unary_partition]
|
||||
)
|
||||
if len(conv_node.users) != 1:
|
||||
# Conv Node should only has 1 user node
|
||||
continue
|
||||
conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(
|
||||
conv_node, binary_node
|
||||
)
|
||||
|
|
@ -676,6 +683,9 @@ class X86InductorQuantizer(Quantizer):
|
|||
conv_node, binary_node = self._get_output_nodes_of_partitions(
|
||||
[conv_partition, binary_partition]
|
||||
)
|
||||
if len(conv_node.users) != 1:
|
||||
# Conv Node should only has 1 user node
|
||||
continue
|
||||
conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(
|
||||
conv_node, binary_node
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user