diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 17e8e36305c..cdde654b3ba 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -1,7 +1,7 @@ # Owner(s): ["oncall: quantization"] import copy import operator -from typing import Any, List, Optional, Tuple, Dict +from typing import Any, List, Optional, Tuple import torch import torch._dynamo as torchdynamo @@ -72,177 +72,14 @@ from torch.testing._internal.common_quantization import ( QuantizationTestCase, skip_if_no_torchvision, skipIfNoQNNPACK, + TestHelperModules, ) from torch.ao.quantization import ( default_dynamic_qconfig, ) from torch.testing._internal.common_quantized import override_quantized_engine -from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401 from torch._export import dynamic_dim -import unittest - -# TODO: Move to common utils or use existing quant utils to fetch model instances -class TestHelperModules: - class Conv2dPropAnnotaton(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 3) - self.linear = torch.nn.Linear(3, 3) - - def forward(self, x): - x = self.conv(x) - x = x.view(-1, 3) - x = torch.nn.functional.hardtanh(x, -0.5, 0.5) - x = self.linear(x) - return x - - class Conv2dWithObsSharingOps(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 3) - self.hardtanh = torch.nn.Hardtanh() - self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) - - def forward(self, x): - x = self.conv(x) - x = self.adaptive_avg_pool2d(x) - x = self.hardtanh(x) - x = torch.mean(x) - return x - - class Conv2dWithTwoLinearPermute(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 16, 3) - self.linear1 = torch.nn.Linear(16, 8, bias=False) - self.linear2 = torch.nn.Linear(8, 8) - - def forward(self, x): - conv_out = self.conv(x) - permute_out = torch.permute(conv_out, (0, 2, 3, 1)) - return self.linear2(self.linear1(permute_out)) - - class Conv2dWithTwoLinear(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 16, 3) - self.linear1 = torch.nn.Linear(64, 8, bias=False) - self.linear2 = torch.nn.Linear(8, 8) - - def forward(self, x): - conv_out = self.conv(x) - reshape_out = torch.reshape(conv_out, (2, 64)) - return self.linear2(self.linear1(reshape_out)) - - class ConvLinearWPermute(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 8, 3) - self.linear1 = torch.nn.Linear(8, 8) - - def forward(self, x): - conv_out = self.conv(x) - permute_out = torch.permute(conv_out, (0, 2, 3, 1)) - return self.linear1(permute_out) - - class TwoLinearModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(8, 16, bias=False) - self.linear2 = torch.nn.Linear(16, 8) - - def forward(self, x): - return self.linear2(self.linear1(x)) - - class ConvMaxPool2d(torch.nn.Module): - def __init__(self): - super(TestHelperModules.ConvMaxPool2d, self).__init__() - self.conv = torch.nn.Conv2d(2, 2, 1) - self.pool = torch.nn.MaxPool2d(1, 1) - - def forward(self, x): - x = self.conv(x) - x = self.pool(x) - return x - - class ConvWithAdaptiveAvgPool2d(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 3) - self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) - - def forward(self, x): - x = self.conv(x) - x = self.adaptive_avg_pool2d(x) - return x - - class ConvWithBNRelu(torch.nn.Module): - def __init__(self, relu, bn=True, bias=True): - super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 3, bias=bias) - if bn: - self.bn = torch.nn.BatchNorm2d(3) - else: - self.bn = torch.nn.Identity() - if relu: - self.relu = torch.nn.ReLU() - else: - self.relu = torch.nn.Identity() - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - return self.relu(x) - - class Conv2dWithCat(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv1 = torch.nn.Conv2d(3, 3, 3) - self.conv2 = torch.nn.Conv2d(3, 3, 3) - - def forward(self, x, y): - x = self.conv1(x) - y = self.conv2(y) - z = torch.cat([x, y], dim=1) - return z - - class EmbeddingModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) - - def forward(self, indices): - return self.emb(indices) - - class EmbeddingConvLinearModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=8) - self.conv = torch.nn.Conv2d(8, 16, (1, 3)) - self.linear = torch.nn.Linear(16, 8) - - def forward(self, indices): - embeddings = self.emb(indices) - embeddings = torch.unsqueeze(embeddings, dim=0) - embeddings = torch.permute(embeddings, (0, 3, 1, 2)) - conv_out = self.conv(embeddings) - conv_out = torch.permute(conv_out, (0, 2, 3, 1)) - conv_out = torch.squeeze(conv_out, dim=0) - return self.linear(conv_out) - - class AddInplaceAdd(torch.nn.Module): - def forward(self, x, y): - x = x + y - x += y - return x - - class MulInplaceMul(torch.nn.Module): - def forward(self, x, y): - x = x * y - x *= y - return x - class PT2EQuantizationTestCase(QuantizationTestCase): """ @@ -258,7 +95,6 @@ class PT2EQuantizationTestCase(QuantizationTestCase): } - def _test_quantizer( self, model, @@ -270,6 +106,8 @@ class PT2EQuantizationTestCase(QuantizationTestCase): fx_qconfig_mapping=None, export_with_dynamic_shape=False, ): + # resetting dynamo cache + torch._dynamo.reset() m_eager = model.eval() # program capture @@ -342,6 +180,8 @@ class PT2EQuantizationTestCase(QuantizationTestCase): Helper method to verify that the QAT numerics for PT2E quantization match those of FX graph mode quantization for symmetric qnnpack. """ + # resetting dynamo cache + torch._dynamo.reset() MANUAL_SEED = 100 # PT2 export @@ -427,6 +267,8 @@ class PT2EQuantizationTestCase(QuantizationTestCase): with fake quantizes inserted into the correct places. # TODO: also verify that metadata is copied over to the new nodes. """ + # resetting dynamo cache + torch._dynamo.reset() m = copy.deepcopy(m) quantizer = XNNPACKQuantizer() quantizer.set_global( @@ -549,60 +391,6 @@ class PT2EQuantizationTestCase(QuantizationTestCase): self.assertTrue("tensor_constant" in bn_running_var_node.target) self.assertEqual(eps, 1e-5) - def _test_representation( - self, - model: torch.nn.Module, - example_inputs: Tuple[Any, ...], - quantizer: Quantizer, - ref_node_occurrence: Dict[ns, int], - non_ref_node_occurrence: Dict[ns, int], - fixed_output_tol: float = None, - output_scale_idx: int = 3, - ) -> torch.nn.Module: - """ TODO: need to implement output checking based on output_scale once - torchdynamo issue is resolved - """ - # program capture - model = capture_pre_autograd_graph( - model, - example_inputs, - ) - model_copy = copy.deepcopy(model) - - model = prepare_pt2e(model, quantizer) - # Calibrate - model(*example_inputs) - model = convert_pt2e(model, use_reference_representation=True) - self.checkGraphModuleNodes(model, expected_node_occurrence=ref_node_occurrence) - # make sure it runs - pt2e_quant_output = model(*example_inputs) - - # TODO: torchdynamo times out when we do this, we can enable numerical checking - # after that is fixed - model_copy = prepare_pt2e(model_copy, quantizer) - # Calibrate - model_copy(*example_inputs) - model_copy = convert_pt2e(model_copy, use_reference_representation=False) - self.checkGraphModuleNodes(model_copy, expected_node_occurrence=non_ref_node_occurrence) - pt2e_quant_output_copy = model_copy(*example_inputs) - - - output_tol = None - if fixed_output_tol is not None: - output_tol = fixed_output_tol - else: - idx = 0 - for n in model_copy.graph.nodes: - if n.target == torch.ops.quantized_decomposed.quantize_per_tensor.default: - idx += 1 - if idx == output_scale_idx: - output_tol = n.args[1] - assert output_tol is not None - - # make sure the result is off by one at most in the quantized integer representation - self.assertTrue( - torch.max(torch.abs(pt2e_quant_output_copy - pt2e_quant_output)) <= (2 * output_tol + 1e-5) - ) @skipIfNoQNNPACK class TestQuantizePT2E(PT2EQuantizationTestCase): @@ -2195,242 +1983,6 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): example_inputs = (torch.randn(1, 3, 5, 5),) self._verify_symmetric_xnnpack_qat_numerics(M(), example_inputs) - def test_representation_linear(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(5, 5) - - def forward(self, x): - return self.linear(x) - - quantizer = XNNPACKQuantizer() - operator_config = get_symmetric_quantization_config(is_per_channel=False) - quantizer.set_global(operator_config) - example_inputs = (torch.randn(2, 5),) - - self._test_representation( - M().eval(), - example_inputs, - quantizer, - ref_node_occurrence={}, - non_ref_node_occurrence={} - ) - - def test_representation_dynamic_linear(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(5, 5) - - def forward(self, x): - return self.linear(x) - - quantizer = XNNPACKQuantizer() - operator_config = get_symmetric_quantization_config(is_per_channel=False, is_dynamic=True) - quantizer.set_global(operator_config) - example_inputs = (torch.randn(2, 5),) - - self._test_representation( - M().eval(), - example_inputs, - quantizer, - ref_node_occurrence={}, - non_ref_node_occurrence={}, - fixed_output_tol=1e-4, - ) - - def test_representation_conv2d(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv2d = torch.nn.Conv2d(3, 3, 3) - - def forward(self, x): - return self.conv2d(x) - - quantizer = XNNPACKQuantizer() - operator_config = get_symmetric_quantization_config(is_per_channel=False) - quantizer.set_global(operator_config) - example_inputs = (torch.randn(1, 3, 3, 3),) - - self._test_representation( - M().eval(), - example_inputs, - quantizer, - ref_node_occurrence={}, - non_ref_node_occurrence={} - ) - - def test_representation_add(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return x + y - - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - m_eager = M().eval() - - example_inputs = (torch.randn(1, 3, 3, 3), torch.randn(1, 3, 3, 3),) - - self._test_representation( - M().eval(), - example_inputs, - quantizer, - ref_node_occurrence={}, - non_ref_node_occurrence={} - ) - - def test_representation_add_relu(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - out = x + y - out = torch.nn.functional.relu(out) - return out - - quantizer = XNNPACKQuantizer() - operator_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(operator_config) - - example_inputs = (torch.randn(1, 3, 3, 3), torch.randn(1, 3, 3, 3),) - ref_node_occurrence = { - ns.call_function(out_dtype): 2, - } - - self._test_representation( - M().eval(), - example_inputs, - quantizer, - ref_node_occurrence=ref_node_occurrence, - non_ref_node_occurrence={} - ) - - def test_representation_maxpool2d(self): - quantizer = XNNPACKQuantizer() - operator_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(operator_config) - m_eager = TestHelperModules.ConvMaxPool2d().eval() - - example_inputs = (torch.randn(1, 2, 2, 2),) - - self._test_representation( - m_eager, - example_inputs, - quantizer, - ref_node_occurrence={}, - non_ref_node_occurrence={} - ) - - @unittest.skip("will fix later") - def test_representation_adaptive_avg_pool2d(self): - quantizer = XNNPACKQuantizer() - operator_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(operator_config) - m_eager = TestHelperModules.ConvWithAdaptiveAvgPool2d().eval() - - example_inputs = (torch.randn(1, 3, 3, 3),) - - self._test_representation( - m_eager, - example_inputs, - quantizer, - ref_node_occurrence={}, - non_ref_node_occurrence={} - ) - - def test_representation_quantize_dequantize_per_channel(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(5, 5) - - def forward(self, x): - return self.linear(x) - - quantizer = XNNPACKQuantizer() - # use per channel quantization for weight - operator_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(operator_config) - m_eager = M().eval() - - inputs = [ - (torch.randn(1, 5),), - (torch.randn(1, 3, 5),), - (torch.randn(1, 3, 3, 5),), - (torch.randn(1, 3, 3, 3, 5),), - ] - for example_inputs in inputs: - ref_node_occurrence = { - ns.call_function( - torch.ops.quantized_decomposed.quantize_per_channel.default - ): 0, - ns.call_function( - torch.ops.quantized_decomposed.dequantize_per_channel.default - ): 0, - } - non_ref_node_occurrence = { - ns.call_function( - torch.ops.quantized_decomposed.quantize_per_channel.default - ): 1, - ns.call_function( - torch.ops.quantized_decomposed.dequantize_per_channel.default - ): 1, - } - - self._test_representation( - M().eval(), - example_inputs, - quantizer, - ref_node_occurrence, - non_ref_node_occurrence, - output_scale_idx=2, - ) - - def test_representation_quantize_dequantize(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return x + y - - quantizer = XNNPACKQuantizer() - quantization_config = get_symmetric_quantization_config(is_per_channel=True) - quantizer.set_global(quantization_config) - m_eager = M().eval() - - example_inputs = (torch.randn(1, 3, 3, 3), torch.randn(1, 3, 3, 3),) - ref_node_occurrence = { - ns.call_function( - torch.ops.quantized_decomposed.quantize_per_tensor - ): 0, - ns.call_function( - torch.ops.quantized_decomposed.dequantize_per_tensor - ): 0, - } - non_ref_node_occurrence = { - ns.call_function( - torch.ops.quantized_decomposed.quantize_per_tensor.default - ): 3, - ns.call_function( - torch.ops.quantized_decomposed.dequantize_per_tensor.default - ): 3, - } - self._test_representation( - M().eval(), - example_inputs, - quantizer, - ref_node_occurrence, - non_ref_node_occurrence - ) - def test_move_exported_model_to_eval(self): class M(torch.nn.Module): def __init__(self): @@ -2611,6 +2163,7 @@ class TestQuantizePT2EOps(QuantizationTestCase): model_graph = convert_pt2e(model_graph) self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs)) + # TODO: express this using self._test_quantizer, add test for inception_v4 class TestQuantizePT2EModels(PT2EQuantizationTestCase): @skip_if_no_torchvision diff --git a/test/quantization/pt2e/test_representation.py b/test/quantization/pt2e/test_representation.py new file mode 100644 index 00000000000..ac3f01b552b --- /dev/null +++ b/test/quantization/pt2e/test_representation.py @@ -0,0 +1,328 @@ +# Owner(s): ["oncall: quantization"] +import copy +import unittest +from typing import Any, Dict, Tuple + +import torch +from torch._export import capture_pre_autograd_graph +from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401 +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer import Quantizer +from torch.ao.quantization.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, + QuantizationTestCase, + skipIfNoQNNPACK, + TestHelperModules, +) + + +@skipIfNoQNNPACK +class TestPT2ERepresentation(QuantizationTestCase): + def _test_representation( + self, + model: torch.nn.Module, + example_inputs: Tuple[Any, ...], + quantizer: Quantizer, + ref_node_occurrence: Dict[ns, int], + non_ref_node_occurrence: Dict[ns, int], + fixed_output_tol: float = None, + output_scale_idx: int = 3, + ) -> torch.nn.Module: + # resetting dynamo cache + torch._dynamo.reset() + model = capture_pre_autograd_graph( + model, + example_inputs, + ) + model_copy = copy.deepcopy(model) + + model = prepare_pt2e(model, quantizer) + # Calibrate + model(*example_inputs) + model = convert_pt2e(model, use_reference_representation=True) + self.checkGraphModuleNodes(model, expected_node_occurrence=ref_node_occurrence) + # make sure it runs + pt2e_quant_output = model(*example_inputs) + + # TODO: torchdynamo times out when we do this, we can enable numerical checking + # after that is fixed + model_copy = prepare_pt2e(model_copy, quantizer) + # Calibrate + model_copy(*example_inputs) + model_copy = convert_pt2e(model_copy, use_reference_representation=False) + self.checkGraphModuleNodes( + model_copy, expected_node_occurrence=non_ref_node_occurrence + ) + pt2e_quant_output_copy = model_copy(*example_inputs) + + output_tol = None + if fixed_output_tol is not None: + output_tol = fixed_output_tol + else: + idx = 0 + for n in model_copy.graph.nodes: + if ( + n.target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ): + idx += 1 + if idx == output_scale_idx: + output_tol = n.args[1] + assert output_tol is not None + + # make sure the result is off by one at most in the quantized integer representation + self.assertTrue( + torch.max(torch.abs(pt2e_quant_output_copy - pt2e_quant_output)) + <= (2 * output_tol + 1e-5) + ) + + def test_static_linear(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=False) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 5),) + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + ) + + def test_dynamic_linear(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=False, is_dynamic=True + ) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 5),) + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + fixed_output_tol=1e-4, + ) + + def test_conv2d(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv2d = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + return self.conv2d(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=False) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(1, 3, 3, 3),) + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + ) + + def test_add(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x + y + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m_eager = M().eval() + + example_inputs = ( + torch.randn(1, 3, 3, 3), + torch.randn(1, 3, 3, 3), + ) + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + ) + + def test_add_relu(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + out = x + y + out = torch.nn.functional.relu(out) + return out + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + + example_inputs = ( + torch.randn(1, 3, 3, 3), + torch.randn(1, 3, 3, 3), + ) + ref_node_occurrence = { + ns.call_function(out_dtype): 2, + } + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence=ref_node_occurrence, + non_ref_node_occurrence={}, + ) + + def test_maxpool2d(self): + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + m_eager = TestHelperModules.ConvMaxPool2d().eval() + + example_inputs = (torch.randn(1, 2, 2, 2),) + + self._test_representation( + m_eager, + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + ) + + @unittest.skip("will fix later") + def test_representation_adaptive_avg_pool2d(self): + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + m_eager = TestHelperModules.ConvWithAdaptiveAvgPool2d().eval() + + example_inputs = (torch.randn(1, 3, 3, 3),) + + self._test_representation( + m_eager, + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + ) + + def test_qdq_per_channel(self): + """Test representation for quantize_per_channel and dequantize_per_channel op""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + # use per channel quantization for weight + operator_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + m_eager = M().eval() + + inputs = [ + (torch.randn(1, 5),), + (torch.randn(1, 3, 5),), + (torch.randn(1, 3, 3, 5),), + (torch.randn(1, 3, 3, 3, 5),), + ] + for example_inputs in inputs: + ref_node_occurrence = { + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_channel.default + ): 0, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 0, + } + non_ref_node_occurrence = { + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_channel.default + ): 1, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 1, + } + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence, + non_ref_node_occurrence, + output_scale_idx=2, + ) + + def test_qdq(self): + """Test representation for quantize and dequantize op""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x + y + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m_eager = M().eval() + + example_inputs = ( + torch.randn(1, 3, 3, 3), + torch.randn(1, 3, 3, 3), + ) + ref_node_occurrence = { + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 0, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 0, + } + non_ref_node_occurrence = { + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 3, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 3, + } + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence, + non_ref_node_occurrence, + ) diff --git a/test/test_quantization.py b/test/test_quantization.py index 366b12c33d9..02120c31835 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -86,6 +86,7 @@ try: from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2E # noqa: F401 from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2EOps # noqa: F401 from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2EModels # noqa: F401 + from quantization.pt2e.test_representation import TestPT2ERepresentation # noqa: F401 from quantization.pt2e.test_x86inductor_quantizer import TestQuantizePT2EX86Inductor # noqa: F401 from quantization.pt2e.test_quantize_pt2e_fx import TestQuantizePT2EFX # noqa: F401 from quantization.pt2e.test_quantize_pt2e_fx import TestQuantizePT2EFXModels # noqa: F401 diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 9654fd30d70..c28ffcc64ab 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -2431,3 +2431,163 @@ class SparseNNModel(nn.Module): out = self.dense_top(sparse_feature, dense) return out + +class TestHelperModules: + class Conv2dPropAnnotaton(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + x = self.conv(x) + x = x.view(-1, 3) + x = torch.nn.functional.hardtanh(x, -0.5, 0.5) + x = self.linear(x) + return x + + class Conv2dWithObsSharingOps(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.hardtanh = torch.nn.Hardtanh() + self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x): + x = self.conv(x) + x = self.adaptive_avg_pool2d(x) + x = self.hardtanh(x) + x = torch.mean(x) + return x + + class Conv2dWithTwoLinearPermute(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3) + self.linear1 = torch.nn.Linear(16, 8, bias=False) + self.linear2 = torch.nn.Linear(8, 8) + + def forward(self, x): + conv_out = self.conv(x) + permute_out = torch.permute(conv_out, (0, 2, 3, 1)) + return self.linear2(self.linear1(permute_out)) + + class Conv2dWithTwoLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3) + self.linear1 = torch.nn.Linear(64, 8, bias=False) + self.linear2 = torch.nn.Linear(8, 8) + + def forward(self, x): + conv_out = self.conv(x) + reshape_out = torch.reshape(conv_out, (2, 64)) + return self.linear2(self.linear1(reshape_out)) + + class ConvLinearWPermute(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 8, 3) + self.linear1 = torch.nn.Linear(8, 8) + + def forward(self, x): + conv_out = self.conv(x) + permute_out = torch.permute(conv_out, (0, 2, 3, 1)) + return self.linear1(permute_out) + + class TwoLinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(8, 16, bias=False) + self.linear2 = torch.nn.Linear(16, 8) + + def forward(self, x): + return self.linear2(self.linear1(x)) + + class ConvMaxPool2d(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(2, 2, 1) + self.pool = torch.nn.MaxPool2d(1, 1) + + def forward(self, x): + x = self.conv(x) + x = self.pool(x) + return x + + class ConvWithAdaptiveAvgPool2d(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x): + x = self.conv(x) + x = self.adaptive_avg_pool2d(x) + return x + + class ConvWithBNRelu(torch.nn.Module): + def __init__(self, relu, bn=True, bias=True): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3, bias=bias) + if bn: + self.bn = torch.nn.BatchNorm2d(3) + else: + self.bn = torch.nn.Identity() + if relu: + self.relu = torch.nn.ReLU() + else: + self.relu = torch.nn.Identity() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return self.relu(x) + + class Conv2dWithCat(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x, y): + x = self.conv1(x) + y = self.conv2(y) + z = torch.cat([x, y], dim=1) + return z + + class EmbeddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) + + def forward(self, indices): + return self.emb(indices) + + class EmbeddingConvLinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=8) + self.conv = torch.nn.Conv2d(8, 16, (1, 3)) + self.linear = torch.nn.Linear(16, 8) + + def forward(self, indices): + embeddings = self.emb(indices) + embeddings = torch.unsqueeze(embeddings, dim=0) + embeddings = torch.permute(embeddings, (0, 3, 1, 2)) + conv_out = self.conv(embeddings) + conv_out = torch.permute(conv_out, (0, 2, 3, 1)) + conv_out = torch.squeeze(conv_out, dim=0) + return self.linear(conv_out) + + class AddInplaceAdd(torch.nn.Module): + def forward(self, x, y): + x = x + y + x += y + return x + + class MulInplaceMul(torch.nn.Module): + def forward(self, x, y): + x = x * y + x *= y + return x