[quant][be] Add a test for per channel quant for groupwise conv (#115224)

Summary:
just making sure this works

Test Plan:
python test/test_quantization.py -k test_groupwise_per_channel_quant

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115224
Approved by: https://github.com/andrewor14
This commit is contained in:
Jerry Zhang 2023-12-05 17:12:29 -08:00 committed by PyTorch MergeBot
parent b7eb9b1e7e
commit a93b9ee9d8
2 changed files with 21 additions and 0 deletions

View File

@ -1752,3 +1752,13 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
self.checkGraphModuleNodes(
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
)
def test_groupwise_per_channel_quant(self):
m = TestHelperModules.GroupwiseConv2d()
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
example_inputs = m.example_inputs()
m = self._quantize(m, quantizer, example_inputs)
# make sure it runs
m(*example_inputs)

View File

@ -2667,3 +2667,14 @@ class TestHelperModules:
permute_out = torch.permute(x, (0, 2, 3, 1))
linear_out = self.linear(permute_out)
return linear_out
class GroupwiseConv2d(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(4, 4, 3, groups=2)
def forward(self, x):
return self.conv(x)
def example_inputs(self):
return (torch.randn(2, 4, 10, 10),)