mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][be] Add a test for per channel quant for groupwise conv (#115224)
Summary: just making sure this works Test Plan: python test/test_quantization.py -k test_groupwise_per_channel_quant Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/115224 Approved by: https://github.com/andrewor14
This commit is contained in:
parent
b7eb9b1e7e
commit
a93b9ee9d8
|
|
@ -1752,3 +1752,13 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
self.checkGraphModuleNodes(
|
||||
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
|
||||
)
|
||||
|
||||
def test_groupwise_per_channel_quant(self):
|
||||
m = TestHelperModules.GroupwiseConv2d()
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
quantizer.set_global(operator_config)
|
||||
example_inputs = m.example_inputs()
|
||||
m = self._quantize(m, quantizer, example_inputs)
|
||||
# make sure it runs
|
||||
m(*example_inputs)
|
||||
|
|
|
|||
|
|
@ -2667,3 +2667,14 @@ class TestHelperModules:
|
|||
permute_out = torch.permute(x, (0, 2, 3, 1))
|
||||
linear_out = self.linear(permute_out)
|
||||
return linear_out
|
||||
|
||||
class GroupwiseConv2d(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(4, 4, 3, groups=2)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
def example_inputs(self):
|
||||
return (torch.randn(2, 4, 10, 10),)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user