# Owner(s): ["oncall: quantization"] # ruff: noqa: F841 from collections import OrderedDict import contextlib import torch import torch.nn.functional as F import torch.nn as nn import torch.ao.nn.quantized as nnq import torch.ao.nn.quantized.reference as nnqr import torch.ao.nn.quantized.dynamic as nnqd import torch.ao.nn.intrinsic as nni import torch.ao.nn.intrinsic.quantized as nniq import torch.ao.nn.intrinsic.quantized.dynamic as nniqd import torch.multiprocessing as mp from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY # graph mode quantization based on fx from torch.ao.quantization.quantize_fx import ( prepare_fx, convert_fx, convert_to_reference_fx, _convert_to_reference_decomposed_fx, prepare_qat_fx, fuse_fx, ) from torch.ao.quantization.fx.quantize_handler import DefaultNodeQuantizeHandler from torch.ao.quantization.fx.match_utils import ( _is_match, MatchAllNode, ) from torch.ao.quantization import ( QuantType, ) from torch.ao.quantization.quant_type import _get_quant_type_to_str from torch.ao.quantization import ( QuantStub, DeQuantStub, QuantWrapper, default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, default_qat_qconfig, default_reuse_input_qconfig, default_symmetric_qnnpack_qconfig, default_symmetric_qnnpack_qat_qconfig, per_channel_dynamic_qconfig, float16_dynamic_qconfig, float16_static_qconfig, float_qparams_weight_only_qconfig, float_qparams_weight_only_qconfig_4bit, get_default_qconfig, get_default_qat_qconfig, get_default_qconfig_mapping, get_default_qat_qconfig_mapping, fuse_modules, fuse_modules_qat, prepare, prepare_qat, convert, quantize_dynamic, default_placeholder_observer, default_weight_observer, PerChannelMinMaxObserver, FixedQParamsFakeQuantize, FixedQParamsObserver, FusedMovingAvgObsFakeQuantize, FakeQuantize, MovingAverageMinMaxObserver, HistogramObserver, ReuseInputObserver, QConfig, default_embedding_qat_qconfig, ) from torch.ao.quantization.backend_config import ( get_fbgemm_backend_config, get_qnnpack_backend_config, BackendConfig, BackendPatternConfig, DTypeConfig, DTypeWithConstraints, ObservationType ) from torch.ao.quantization.backend_config.native import ( get_test_only_legacy_native_backend_config, ) from torch.ao.quantization.qconfig_mapping import ( _get_symmetric_qnnpack_qconfig_mapping, _get_symmetric_qnnpack_qat_qconfig_mapping, _GLOBAL_DICT_KEY, _MODULE_NAME_DICT_KEY, _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, _MODULE_NAME_REGEX_DICT_KEY, _OBJECT_TYPE_DICT_KEY, QConfigMapping, ) from torch.ao.quantization.fx.qconfig_mapping_utils import ( _get_object_type_qconfig, _get_module_name_qconfig, _get_module_name_regex_qconfig, _maybe_adjust_qconfig_for_module_name_object_type_order, ) from torch.ao.quantization.fx.pattern_utils import ( _DEFAULT_FUSION_PATTERNS, _DEFAULT_QUANTIZATION_PATTERNS, _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP, _DEFAULT_OUTPUT_OBSERVER_MAP, _register_fusion_pattern, _register_quant_pattern, get_default_output_activation_post_process_map ) from torch.ao.quantization.fx.custom_config import ( STANDALONE_MODULE_NAME_DICT_KEY, STANDALONE_MODULE_CLASS_DICT_KEY, FLOAT_TO_OBSERVED_DICT_KEY, OBSERVED_TO_QUANTIZED_DICT_KEY, NON_TRACEABLE_MODULE_NAME_DICT_KEY, NON_TRACEABLE_MODULE_CLASS_DICT_KEY, INPUT_QUANTIZED_INDEXES_DICT_KEY, OUTPUT_QUANTIZED_INDEXES_DICT_KEY, PRESERVED_ATTRIBUTES_DICT_KEY, FuseCustomConfig, ConvertCustomConfig, PrepareCustomConfig, StandaloneModuleConfigEntry, ) import torch.ao.quantization.fx.lstm_utils from torch.ao.quantization.fx.utils import ( _reroute_tuple_getitem_pattern, NodeInfo, ) from torch.ao.quantization.fake_quantize import ( default_fixed_qparams_range_0to1_fake_quant, default_fixed_qparams_range_neg1to1_fake_quant, ) from torch.ao.quantization.observer import ( default_fixed_qparams_range_0to1_observer, default_fixed_qparams_range_neg1to1_observer, MinMaxObserver, _is_activation_post_process, ) # test utils from hypothesis import given, settings from hypothesis import strategies as st from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA from torch.testing._internal.common_quantization import ( LinearReluLinearModel, LinearReluModel, LinearBnLeakyReluModel, LinearTanhModel, ConvBnAddReluModel, QuantizationTestCase, skipIfNoFBGEMM, skipIfNoQNNPACK, skip_if_no_torchvision, train_one_epoch, run_ddp, test_only_eval_fn, test_only_train_fn, ModelForConvTransposeBNFusion, get_supported_device_types, skipIfNoONEDNN, ) from torch.testing._internal.common_quantization import ( LinearModelWithSubmodule, ResNetBase, RNNDynamicModel, RNNCellDynamicModel, ) from torch.testing._internal.common_quantized import ( supported_qengines, override_qengines, override_quantized_engine, ) from torch.testing._internal.common_utils import ( TemporaryFileName, IS_ARM64, skipIfTorchDynamo, ) from torch.testing._internal.common_quantization import NodeSpec as ns from torch.testing import FileCheck import copy import itertools import operator import unittest import io from typing import Callable, Optional class BinaryOp(torch.nn.Module): def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar): """ ibinary_op means inplace binary op """ super().__init__() self.conv1 = torch.nn.Conv2d(1, 1, 1).float() self.conv2 = torch.nn.Conv2d(1, 1, 1).float() self.is_scalar = is_scalar self.op = ibinary_op if ibinary_op and is_inplace else binary_op def forward(self, x, y): x = self.conv1(x) y = 3 if self.is_scalar else self.conv2(y) # x = x + y x = self.op(x, y) # x = y + x x = self.op(y, x) return x class BinaryOpNonQuantizedInput(torch.nn.Module): def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar): """ ibinary_op means inplace binary op """ super().__init__() self.is_scalar = is_scalar self.op = ibinary_op if ibinary_op and is_inplace else binary_op def forward(self, x, y): y = 3 if self.is_scalar else y x = self.op(x, y) return x class BinaryOpRelu(torch.nn.Module): def __init__(self, binary_op, ibinary_op, is_inplace, relu_callable, is_scalar): """ ibinary_op means inplace binary op """ super().__init__() self.conv1 = torch.nn.Conv2d(1, 1, 1).float() self.conv2 = torch.nn.Conv2d(1, 1, 1).float() self.op = ibinary_op if ibinary_op and is_inplace else binary_op self.relu_callable = relu_callable self.is_scalar = is_scalar if relu_callable is torch.nn.ReLU: self.relu = torch.nn.ReLU() else: self.relu = relu_callable def forward(self, x, y): x = self.conv1(x) y = 3 if self.is_scalar else self.conv2(y) x = self.op(x, y) x = self.relu(x) x = self.op(y, x) x = self.relu(x) return x @torch.fx.wrap def _user_func_with_complex_return_type(x): return list(torch.split(x, 1, 1)) class TestFuseFx(QuantizationTestCase): def test_fuse_conv_bn_relu(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1d = nn.Conv1d(1, 1, 1) self.conv2d = nn.Conv2d(1, 1, 1) self.conv3d = nn.Conv3d(1, 1, 1) self.bn1d = nn.BatchNorm1d(1) self.bn2d = nn.BatchNorm2d(1) self.bn3d = nn.BatchNorm3d(1) self.conv1d2 = nn.Conv1d(1, 1, 1) self.conv2d2 = nn.Conv2d(1, 1, 1) self.conv3d2 = nn.Conv3d(1, 1, 1) self.bn1d2 = nn.BatchNorm1d(1) self.bn2d2 = nn.BatchNorm2d(1) self.bn3d2 = nn.BatchNorm3d(1) self.relu = nn.ReLU() def forward(self, x): x = self.conv1d(x) x = self.bn1d(x) x = self.conv2d(x) x = self.bn2d(x) x = self.conv3d(x) x = self.bn3d(x) x = self.conv1d2(x) x = self.bn1d2(x) x = self.relu(x) x = self.conv2d2(x) x = self.bn2d2(x) x = self.relu(x) x = self.conv3d2(x) x = self.bn3d2(x) x = self.relu(x) return x # test train mode m = M().train() # currently we don't check if the module are configured with qconfig before fusion # TODO: if we decide to do that in the future, this test needs to # be updated # train mode fuse_fx is called in prepare_qat_fx m = prepare_qat_fx(m, {}, example_inputs=(torch.randn(1, 1, 1, 1),)) expected_nodes = [ ns.call_module(nni.ConvBn1d), ns.call_module(nni.ConvBn2d), ns.call_module(nni.ConvBn3d), ns.call_module(nni.ConvBnReLU1d), ns.call_module(nni.ConvBnReLU2d), ns.call_module(nni.ConvBnReLU3d), ] expected_occurrence = { ns.call_module(nn.ReLU): 0 } self.checkGraphModuleNodes( m, expected_node_list=expected_nodes, expected_node_occurrence=expected_occurrence) # test eval mode m = M().eval() # fuse_fx is a top level api and only supports eval mode m = fuse_fx(m) expected_nodes = [ ns.call_module(nn.Conv1d), ns.call_module(nn.Conv2d), ns.call_module(nn.Conv3d), ns.call_module(nni.ConvReLU1d), ns.call_module(nni.ConvReLU2d), ns.call_module(nni.ConvReLU3d), ] # ConvBnRelu1d is not fused expected_occurrence = { ns.call_module(nn.ReLU): 0 } self.checkGraphModuleNodes( m, expected_node_list=expected_nodes, expected_node_occurrence=expected_occurrence) def test_fuse_linear_bn_eval(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = nn.Linear(1, 1) self.bn1d = nn.BatchNorm1d(1) def forward(self, x): x = self.linear(x) x = self.bn1d(x) return x # test eval mode m = M().eval() # fuse_fx is a top level api and only supports eval mode m = fuse_fx(m) expected_nodes = [ ns.call_module(nn.Linear), ] expected_occurrence = { ns.call_module(nn.BatchNorm1d): 0, } self.checkGraphModuleNodes( m, expected_node_list=expected_nodes, expected_node_occurrence=expected_occurrence) @skipIfNoONEDNN def test_fuse_linear_bn_leaky_relu_onednn(self): # linear - bn - leaky_relu is fused for onednn backend only from torch.ao.quantization.backend_config import get_onednn_backend_config expected_nodes = [ ns.call_module(nni.LinearLeakyReLU), ] expected_occurrence = { ns.call_module(nn.BatchNorm1d): 0, ns.call_module(nn.LeakyReLU): 0, } for with_bn in [True, False]: # test eval mode m = LinearBnLeakyReluModel(with_bn).eval() # fuse_fx is a top level api and only supports eval mode m = fuse_fx(m, backend_config=get_onednn_backend_config()) self.checkGraphModuleNodes( m, expected_node_list=expected_nodes, expected_node_occurrence=expected_occurrence) def test_linear_bn_leaky_relu_not_fused_by_default(self): # Make sure linear - bn - leaky_relu is not fused by default for with_bn in [True, False]: # test eval mode m = LinearBnLeakyReluModel(with_bn).eval() # fuse_fx is a top level api and only supports eval mode m = fuse_fx(m) expected_nodes = [ ns.call_module(nn.Linear), ns.call_module(nn.LeakyReLU), ] expected_occurrence = { ns.call_module(nni.LinearLeakyReLU): 0, } self.checkGraphModuleNodes( m, expected_node_list=expected_nodes, expected_node_occurrence=expected_occurrence) @skipIfNoONEDNN def test_fuse_linear_tanh_for_onednn_backend(self): # linear - tanh is fused for onednn backend only from torch.ao.quantization.backend_config import get_onednn_backend_config expected_nodes = [ ns.call_module(nni.LinearTanh), ] expected_occurrence = { ns.call_module(nn.Linear): 0, ns.call_module(nn.Tanh): 0, } # test eval mode m = LinearTanhModel().eval() # fuse_fx is a top level api and only supports eval mode m = fuse_fx(m, backend_config=get_onednn_backend_config()) self.checkGraphModuleNodes( m, expected_node_list=expected_nodes, expected_node_occurrence=expected_occurrence) def test_linear_tanh_not_fused_by_default(self): # Make sure linear - tanh is not fused by default # test eval mode m = LinearTanhModel().eval() # fuse_fx is a top level api and only supports eval mode m = fuse_fx(m) expected_nodes = [ ns.call_module(nn.Linear), ns.call_module(nn.Tanh), ] expected_occurrence = { ns.call_module(nni.LinearTanh): 0, } self.checkGraphModuleNodes( m, expected_node_list=expected_nodes, expected_node_occurrence=expected_occurrence) def test_fuse_conv_bn_add_relu_onednn(self): # conv - bn - add - relu is fused for onednn backend only from torch.ao.quantization.backend_config import get_onednn_backend_config options = itertools.product( [True, False], # with_bn [True, False], # with_relu [True, False], # conv in the left [True, False], # with_two_conv [True, False], # use_torch_add ) for with_bn, with_relu, left_conv, two_conv, use_torch_add in options: expected_nodes = [ ns.call_module(nni.ConvAddReLU2d if with_relu else nni.ConvAdd2d), ] expected_occurrence = { ns.call_module(nni.ConvAddReLU2d if with_relu else nni.ConvAdd2d): 1, ns.call_module(nn.BatchNorm2d): 0, } # test eval mode m = ConvBnAddReluModel( with_bn=with_bn, with_relu=with_relu, left_conv=left_conv, two_conv=two_conv, use_torch_add=use_torch_add).eval() m = fuse_fx(m, backend_config=get_onednn_backend_config()) self.checkGraphModuleNodes( m, expected_node_list=expected_nodes, expected_node_occurrence=expected_occurrence) def test_fuse_conv_bn_add_relu_by_default(self): options = itertools.product( [True, False], # with_bn [True, False], # with_relu [True, False], # conv in the left [True, False], # with_two_conv [True, False], # use_torch_add ) for with_bn, with_relu, left_conv, two_conv, use_torch_add in options: # test eval mode expected_nodes = [ ns.call_module(nn.Conv2d), ] expected_occurrence = { ns.call_module(nni.ConvAdd2d): 0, } m = ConvBnAddReluModel( with_bn=with_bn, with_relu=with_relu, left_conv=left_conv, two_conv=two_conv, use_torch_add=use_torch_add).eval() m = fuse_fx(m) self.checkGraphModuleNodes( m, expected_node_list=expected_nodes, expected_node_occurrence=expected_occurrence) @skipIfNoONEDNN def test_fuse_conv_bn_add_relu_lowering(self): """ Test fusion and lowering of Conv2d - (bn -) ReLU by FX. For onednn backedn only. """ from torch.ao.quantization.backend_config import get_onednn_backend_config qconfig_mapping = get_default_qconfig_mapping('onednn') with override_quantized_engine('onednn'): options = itertools.product( [True, False], # with_bn [True, False], # with_relu [True, False], # conv in the left [True, False], # two_conv [True, False], # use_torch_add ) for with_bn, with_relu, left_conv, two_conv, use_torch_add in options: node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1 if two_conv else 2, ns.call_method("dequantize"): 1, ns.call_module(nniq.ConvAddReLU2d if with_relu else nniq.ConvAdd2d): 1, ns.call_module(nn.Conv2d): 0, ns.call_module(nn.ReLU): 0, } node_occurrence_ref = { ns.call_function(torch.quantize_per_tensor): 3, ns.call_method("dequantize"): 3, } # test eval mode m = ConvBnAddReluModel( with_bn=with_bn, with_relu=with_relu, left_conv=left_conv, two_conv=two_conv, use_torch_add=use_torch_add).eval() example_x = m.get_example_inputs() m = prepare_fx(m, qconfig_mapping, example_inputs=example_x, backend_config=get_onednn_backend_config()) m_copy = copy.deepcopy(m) m = convert_fx(m, backend_config=get_onednn_backend_config()) m_ref = convert_to_reference_fx(m_copy) self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref) m(*example_x) def test_fuse_convtranspose_bn_eval(self): m = ModelForConvTransposeBNFusion().eval() m = fuse_fx(m) expected_nodes = [ ns.call_module(nn.ConvTranspose1d), ns.call_module(nn.ConvTranspose2d), ns.call_module(nn.ConvTranspose3d), ] expected_occurrence = { ns.call_module(nn.BatchNorm1d): 0, ns.call_module(nn.BatchNorm2d): 0, ns.call_module(nn.BatchNorm3d): 0, } self.checkGraphModuleNodes( m, expected_node_list=expected_nodes, expected_node_occurrence=expected_occurrence) def test_fuse_module_relu(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1d = nn.Conv1d(1, 1, 1) self.conv2d = nn.Conv2d(1, 1, 1) self.conv3d = nn.Conv3d(1, 1, 1) self.bn1d = nn.BatchNorm1d(1) self.bn2d = nn.BatchNorm2d(1) self.bn3d = nn.BatchNorm3d(1) self.relu = nn.ReLU() def forward(self, x): x = self.conv1d(x) x = self.relu(x) x = self.conv2d(x) x = self.relu(x) x = self.conv3d(x) x = self.relu(x) x = self.bn1d(x) x = self.relu(x) x = self.bn2d(x) x = self.relu(x) x = self.bn3d(x) x = self.relu(x) return x m = M().eval() m = fuse_fx(m) expected_nodes = [ ns.call_module(nni.ConvReLU1d), ns.call_module(nni.ConvReLU2d), ns.call_module(nni.ConvReLU3d), ns.call_module(nni.BNReLU2d), ns.call_module(nni.BNReLU3d), ] self.checkGraphModuleNodes(m, expected_node_list=expected_nodes) @skipIfNoFBGEMM def test_qconfig_fused_module(self): """ TODO: add test for all fused modules """ qconfig_dict = { "": None, "object_type": [(nn.Linear, default_qconfig), (nn.ReLU, default_qconfig), (F.relu, default_qconfig)] } linearRelu_node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.LinearReLU), ns.call_method('dequantize') ] linearReluLinear_node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.LinearReLU), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] tests = [(LinearReluModel, linearRelu_node_list), (LinearReluLinearModel, linearReluLinear_node_list)] for M, node_list in tests: m = M().eval() example_inputs = (torch.rand(5, 5),) prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) prepared(*example_inputs) quantized = convert_fx(prepared) self.checkGraphModuleNodes(quantized, expected_node_list=node_list) def test_problematic_fuse_example(self): class LinearRelu(nn.Sequential): def __init__(self) -> None: super().__init__( nn.Linear(5, 5), nn.ReLU(), ) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.lin_relu = LinearRelu() self.linear = nn.Linear(5, 5) def forward(self, x): x = self.lin_relu(x) x = self.linear(x) return x model = M().eval() # these qconfigs somehow fail equality where default_qconfig does not qconfig_dict = { "": None, "object_type": [ (torch.nn.Linear, get_default_qconfig('fbgemm')), (torch.nn.ReLU, get_default_qconfig('fbgemm')), ], } m = prepare_fx(model, qconfig_dict, example_inputs=(torch.randn(1, 5),)) self.checkGraphModuleNodes(m, expected_node=ns.call_module(torch.ao.nn.intrinsic.modules.fused.LinearReLU)) @unittest.skip("Temporarily skipping the test case, will enable after the simple" "pattern format is supported") def test_fuse_addtional_fuser_method(self): class MyConvReLU(torch.nn.Module): pass def my_conv_relu_fuser(conv, relu): return MyConvReLU() class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.relu = torch.nn.ReLU() def forward(self, x): return self.relu(self.conv(x)) m = M().eval() m = fuse_fx(m, fuse_custom_config={ "additional_fuser_method_mapping": { (torch.nn.Conv2d, torch.nn.ReLU): my_conv_relu_fuser } }) self.checkGraphModuleNodes(m, expected_node=ns.call_module(MyConvReLU)) def test_fuse_custom_pattern(self): class M(torch.nn.Module): def __init__(self, use_torch_add=True): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.bn = torch.nn.BatchNorm2d(3) self.relu = torch.nn.ReLU() self.maxpool = torch.nn.MaxPool2d(3) if use_torch_add: self.add = torch.add else: self.add = operator.add def forward(self, x): y = x y = self.maxpool(x) x = self.conv(x) x = self.bn(x) x = self.add(y, x) x = self.relu(x) return x for use_torch_add in [True, False]: m = M(use_torch_add).eval() def fuse_conv_bn_relu(is_qat, relu, add_pattern): _, _, bn_pattern = add_pattern bn, conv = bn_pattern return conv conv_bn_res_relu_config1 = BackendPatternConfig() \ ._set_pattern_complex_format((nn.ReLU, (torch.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \ .set_fuser_method(fuse_conv_bn_relu) conv_bn_res_relu_config2 = BackendPatternConfig() \ ._set_pattern_complex_format((nn.ReLU, (operator.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \ .set_fuser_method(fuse_conv_bn_relu) backend_config = BackendConfig() \ .set_backend_pattern_config(conv_bn_res_relu_config1) \ .set_backend_pattern_config(conv_bn_res_relu_config2) m = fuse_fx(m, backend_config=backend_config) self.assertEqual(type(m.conv), torch.nn.Conv2d) # check bn and relu are gone since we replaced the whole pattern to conv self.assertFalse(hasattr(m, "bn")) self.assertFalse(hasattr(m, "relu")) def test_fusion_pattern_with_multiple_inputs(self): """ This test tests two keys in backend_config: root_node_getter and extra_inputs_getter, root_node_getter is used to identify a "root" module in the node pattern, the node that we'll keep after fusion. extra_inputs_getter will return a list of node that needs to be added to the fused node as extra inputs. """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.bn = torch.nn.BatchNorm2d(3) self.relu = torch.nn.ReLU() self.maxpool = torch.nn.MaxPool2d(3) def forward(self, x): y = x y = self.maxpool(x) x = self.conv(x) x = self.bn(x) x = torch.add(x, y) x = self.relu(x) return x m = M().eval() def fuse_conv_bn_relu(is_qat, relu, add_pattern): _, bn_pattern, _ = add_pattern bn, conv = bn_pattern return conv def conv_bn_res_relu_root_node_getter(pattern): relu, add_pattern = pattern _, bn_pattern, _ = add_pattern bn, conv = bn_pattern return conv def conv_bn_res_relu_extra_inputs_getter(pattern): """ get inputs pattern for extra inputs, inputs for root node are assumed to be copied over from root node to the fused node """ relu, add_pattern = pattern _, bn_pattern, extra_input = add_pattern bn, conv = bn_pattern return [extra_input] conv_bn_res_relu_config = BackendPatternConfig() \ ._set_pattern_complex_format((nn.ReLU, (torch.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \ .set_fuser_method(fuse_conv_bn_relu) \ ._set_root_node_getter(conv_bn_res_relu_root_node_getter) \ ._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter) backend_config = BackendConfig().set_backend_pattern_config(conv_bn_res_relu_config) m = fuse_fx(m, backend_config=backend_config) self.assertEqual(type(m.conv), torch.nn.Conv2d) # check bn and relu are gone since we replaced the whole pattern to conv self.assertFalse(hasattr(m, "bn")) self.assertFalse(hasattr(m, "relu")) # check conv module has two inputs named_modules = dict(m.named_modules()) for node in m.graph.nodes: if node.op == "call_module" and type(named_modules[node.target]) is torch.nn.Conv2d: self.assertTrue(len(node.args) == 2, msg="Expecting the fused op to have two arguments") def test_fusion_pattern_with_matchallnode(self): """This test tests that the node matched by MatchAllNode will be regared as an input instead of a module to be fused. For instance, we have two patterns: (nn.ReLU, (torch.add, MatchAllNode, nn.Conv2d)) (nn.ReLU, nn.Conv2d) And we wanna fuse the following model Conv2d -> ReLU + Conv2d ------ Add -> ReLU ReLU in the first row is matched as MatchAllNode in the residual pattern. But it won't be fused as part of that pattnern. It needs to be properly fused with the upstream Conv2d. """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = torch.nn.Conv2d(3, 3, 3) self.relu1 = torch.nn.ReLU() self.conv2 = torch.nn.Conv2d(3, 3, 3) self.relu2 = torch.nn.ReLU() def forward(self, x): y = self.conv1(x) y = self.relu1(y) x = self.conv2(x) x = torch.add(x, y) x = self.relu2(x) return x m = M().eval() def fuse_conv_relu(is_qat, conv, relu): return conv def fuse_conv_res_relu(is_qat, relu, add_pattern): _, conv, _ = add_pattern return conv def conv_res_relu_root_node_getter(pattern): relu, (_, conv, _) = pattern return conv def conv_res_relu_extra_inputs_getter(pattern): relu, (_, _, extra_input) = pattern return [extra_input] conv_relu_config = BackendPatternConfig((nn.Conv2d, nn.ReLU)) \ .set_fuser_method(fuse_conv_relu) conv_res_relu_config = BackendPatternConfig() \ ._set_pattern_complex_format((nn.ReLU, (torch.add, nn.Conv2d, MatchAllNode))) \ .set_fuser_method(fuse_conv_res_relu) \ ._set_root_node_getter(conv_res_relu_root_node_getter) \ ._set_extra_inputs_getter(conv_res_relu_extra_inputs_getter) backend_config = BackendConfig() \ .set_backend_pattern_config(conv_relu_config) \ .set_backend_pattern_config(conv_res_relu_config) m = fuse_fx(m, backend_config=backend_config) self.assertEqual(type(m.conv1), torch.nn.Conv2d) self.assertEqual(type(m.conv2), torch.nn.Conv2d) # check relu are gone since we replaced both patterns to conv self.assertFalse(hasattr(m, "relu1")) self.assertFalse(hasattr(m, "relu2")) @skipIfNoFBGEMM class TestQuantizeFx(QuantizationTestCase): def test_pattern_match(self): """ test MatchAllNode with conv - bn - add - relu pattern """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(1, 1, 1) self.bn = nn.BatchNorm2d(1) self.relu = nn.ReLU() def forward(self, x, y): x = self.conv(x) x = self.bn(x) x = x + y x = self.relu(x) return x pattern = (nn.ReLU, (operator.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode)) m = torch.fx.symbolic_trace(M()) modules = dict(m.named_modules()) for n in m.graph.nodes: if n.op == 'call_module' and type(modules[n.target]) is nn.ReLU: self.assertTrue(_is_match(modules, n, pattern)) def test_pattern_match_constant(self): class M(torch.nn.Module): def forward(self, x): x, _ = torch.ops.aten.max_pool2d_with_indices.default(x) return x pattern = (operator.getitem, torch.ops.aten.max_pool2d_with_indices.default, 0) m = torch.fx.symbolic_trace(M()) # eliminate the code that get the second output of maxpool, so that the pattern # can be matched m.graph.eliminate_dead_code() modules = dict(m.named_modules()) for n in m.graph.nodes: if n.op == "call_function" and n.target == operator.getitem: self.assertTrue(_is_match(modules, n, pattern)) def test_fused_module_qat_swap(self): class Tmp(torch.nn.Module): def __init__(self) -> None: super().__init__() self.tmp = torch.nn.Linear(5, 5) self.relu = torch.nn.ReLU() def forward(self, x): x = self.tmp(x) return self.relu(x) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mods1 = torch.nn.Sequential(Tmp(), torch.nn.Linear(5, 5)) self.mods2 = torch.nn.Linear(5, 5) def forward(self, x): a = self.mods1(x) x = torch.add(x, 5) x = self.mods2(x) x = torch.add(x, 5) return a, x model = M().train() qconfig_dict = { "": None, "object_type": [ (torch.nn.Linear, default_qat_qconfig), (torch.nn.ReLU, default_qat_qconfig), ], } prepared = prepare_qat_fx(model, qconfig_dict, example_inputs=(torch.randn(1, 5),)) self.assertTrue(isinstance(getattr(prepared.mods1, "0").tmp, torch.ao.nn.intrinsic.qat.LinearReLU)) def _get_conv_linear_test_cases(self, is_reference): """ Returns a list of test cases, with format: is_dynamic, ModuleClass, module_constructor_inputs, inputs, quantized_node, weight_prepack_op """ class FunctionalConv1d(torch.nn.Module): def __init__(self, weight): super().__init__() self.weight = torch.nn.Parameter(weight) self.stride = 1 self.padding = 0 self.dilation = 1 self.groups = 1 def forward(self, x): return F.conv1d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups) class Conv1d(torch.nn.Module): def __init__(self, *args): super().__init__() self.conv = torch.nn.Conv1d(*args) def forward(self, x): return self.conv(x) conv1d_input = torch.rand(1, 3, 224) conv1d_weight = torch.rand(3, 3, 3) conv1d_module_args = (3, 3, 3) class FunctionalConv2d(torch.nn.Module): def __init__(self, weight): super().__init__() self.weight = torch.nn.Parameter(weight) self.stride = (1, 1) self.padding = (0, 0) self.dilation = (1, 1) self.groups = 1 def forward(self, x): return F.conv2d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups) class Conv2d(torch.nn.Module): def __init__(self, *args): super().__init__() self.conv = torch.nn.Conv2d(*args) def forward(self, x): return self.conv(x) conv2d_input = torch.rand(1, 3, 224, 224) conv2d_weight = torch.rand(3, 3, 3, 3) conv2d_module_args = (3, 3, 3) class FunctionalConv3d(torch.nn.Module): def __init__(self, weight): super().__init__() self.weight = torch.nn.Parameter(weight) self.stride = (1, 1, 1) self.padding = (0, 0, 0) self.dilation = (1, 1, 1) self.groups = 1 def forward(self, x): return F.conv3d( x, self.weight, None, self.stride, self.padding, self.dilation, self.groups, ) class Conv3d(torch.nn.Module): def __init__(self, *args): super().__init__() self.conv = torch.nn.Conv3d(*args) def forward(self, x): return self.conv(x) conv3d_input = torch.rand(1, 3, 32, 224, 224) conv3d_weight = torch.rand(3, 3, 3, 3, 3) conv3d_module_args = (3, 3, 3) class Linear(torch.nn.Module): def __init__(self, weight): super().__init__() self.weight = torch.nn.Parameter(weight) def forward(self, x): return F.linear(x, self.weight) linear_input = torch.rand(8, 5) linear_weight = torch.rand(10, 5) class LinearModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 10) def forward(self, x): return self.linear(x) linear_module_input = torch.rand(8, 5) # is_dynamic, ModuleClass, module_constructor_inputs, # inputs, quantized_node, weight_prepack_node tests = [ ( False, FunctionalConv1d, (conv1d_weight,), (conv1d_input,), ns.call_function(torch.nn.functional.conv1d if is_reference else torch.ops.quantized.conv1d) , ns.call_function(torch.ops.quantized.conv1d_prepack), ), ( False, FunctionalConv2d, (conv2d_weight,), (conv2d_input,), ns.call_function(torch.nn.functional.conv2d if is_reference else torch.ops.quantized.conv2d), ns.call_function(torch.ops.quantized.conv2d_prepack), ), ( False, FunctionalConv3d, (conv3d_weight,), (conv3d_input,), ns.call_function(torch.nn.functional.conv3d if is_reference else torch.ops.quantized.conv3d), ns.call_function(torch.ops.quantized.conv3d_prepack), ), ( False, Conv1d, conv1d_module_args, (conv1d_input,), ns.call_module(nnqr.Conv1d if is_reference else nnq.Conv1d), None ), ( False, Conv2d, conv2d_module_args, (conv2d_input,), ns.call_module(nnqr.Conv2d if is_reference else nnq.Conv2d), None ), ( False, Conv3d, conv3d_module_args, (conv3d_input,), ns.call_module(nnqr.Conv3d if is_reference else nnq.Conv3d), None ), ( True, Linear, (linear_weight,), (linear_input,), None if is_reference else ns.call_function(torch.ops.quantized.linear_dynamic), ns.call_function(torch.ops.quantized.linear_prepack), ), ( False, Linear, (linear_weight,), (linear_input,), ns.call_function(torch.nn.functional.linear if is_reference else torch.ops.quantized.linear), ns.call_function(torch.ops.quantized.linear_prepack), ), ( True, LinearModule, (), (linear_module_input,), ns.call_module(nnqr.Linear) if is_reference else ns.call_module(nnqd.Linear), None, ), ( False, LinearModule, (), (linear_module_input,), ns.call_module(nnqr.Linear if is_reference else nnq.Linear), None, ), ] return tests @skipIfNoFBGEMM def test_conv_linear_not_reference(self): """ Test quantizing conv and linear """ tests = self._get_conv_linear_test_cases(is_reference=False) for (is_dynamic, ModuleClass, module_constructor_inputs, inputs, quantized_node, weight_prepack_node) in tests: quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC node_occurrence = {} if weight_prepack_node: node_occurrence[weight_prepack_node] = 0 self.checkGraphModeFxOp( ModuleClass(*module_constructor_inputs), inputs, quant_type, expected_node=quantized_node, expected_node_occurrence=node_occurrence, is_reference=False) @skipIfNoFBGEMM def test_conv_linear_reference(self): """ Test quantizing functional conv and linear with reference option """ tests = self._get_conv_linear_test_cases(is_reference=True) def _get_keys(prefix, is_dynamic): all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]] if not is_dynamic: all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]]) return all_keys for (is_dynamic, ModuleClass, module_constructor_inputs, inputs, quantized_node, weight_prepack_node) in tests: quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC node_occurrence = {} if weight_prepack_node: node_occurrence[weight_prepack_node] = 0 result_dict = self.checkGraphModeFxOp( ModuleClass(*module_constructor_inputs), inputs, quant_type, expected_node=quantized_node, expected_node_occurrence=node_occurrence, is_reference=True) qr = result_dict["quantized_reference"] def checkWeightQParams(model): for module_name in ("linear", "conv"): if hasattr(model, module_name): self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme")) self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale")) self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point")) self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name()) def checkSerDeser(model, is_dynamic): for module_name in ("linear", "conv"): if hasattr(model, module_name): # make sure serialization works state_dict = copy.deepcopy(model.state_dict()) all_keys = _get_keys(module_name, is_dynamic) for key in all_keys: self.assertTrue(key in state_dict) # check load_state_dict restores states module = getattr(model, module_name) prev_scale = module.weight_scale module.weight_scale = None model.load_state_dict(state_dict) module = getattr(model, module_name) self.assertTrue(torch.equal(prev_scale, module.weight_scale)) checkWeightQParams(qr) qr = copy.deepcopy(qr) # make sure the qparams are preserved after copy checkWeightQParams(qr) checkSerDeser(qr, is_dynamic) def _get_conv_transpose_test_cases(self, use_relu, is_reference): """ Returns a list of test cases, with format: is_dynamic, ModuleClass, module_constructor_inputs, inputs, quantized_node, weight_prepack_op """ class FunctionalConvTranspose1d(torch.nn.Module): def __init__(self, weight): super().__init__() self.weight = torch.nn.Parameter(weight) self.stride = 1 self.padding = 0 self.output_padding = 0 self.dilation = 1 self.groups = 1 def forward(self, x): y = F.conv_transpose1d( x, self.weight, None, self.stride, self.padding, self.output_padding, self.groups, self.dilation ) if use_relu: y = F.relu(y) return y class ConvTranspose1d(torch.nn.Module): def __init__(self, *args): super().__init__() self.deconv = torch.nn.ConvTranspose1d(*args) self.relu = torch.nn.ReLU() def forward(self, x): y = self.deconv(x) if use_relu: y = self.relu(y) return y conv_transpose1d_input = torch.rand(1, 3, 224) conv_transpose1d_weight = torch.rand(3, 3, 3) conv_transpose1d_module_args = (3, 3, 3) class FunctionalConvTranspose2d(torch.nn.Module): def __init__(self, weight): super().__init__() self.weight = torch.nn.Parameter(weight) self.stride = (1, 1) self.padding = (0, 0) self.output_padding = (0, 0) self.dilation = (1, 1) self.groups = 1 def forward(self, x): y = F.conv_transpose2d( x, self.weight, None, self.stride, self.padding, self.output_padding, self.groups, self.dilation ) if use_relu: y = F.relu(y) return y class ConvTranspose2d(torch.nn.Module): def __init__(self, *args): super().__init__() self.deconv = torch.nn.ConvTranspose2d(*args) self.relu = torch.nn.ReLU() def forward(self, x): y = self.deconv(x) if use_relu: y = self.relu(y) return y conv_transpose2d_input = torch.rand(1, 3, 224, 224) conv_transpose2d_weight = torch.rand(3, 3, 3, 3) conv_transpose2d_module_args = (3, 3, 3) class FunctionalConvTranspose3d(torch.nn.Module): def __init__(self, weight): super().__init__() self.weight = torch.nn.Parameter(weight) self.stride = (1, 1, 1) self.padding = (0, 0, 0) self.output_padding = (0, 0, 0) self.dilation = (1, 1, 1) self.groups = 1 def forward(self, x): y = F.conv_transpose3d( x, self.weight, None, self.stride, self.padding, self.output_padding, self.groups, self.dilation ) if use_relu: y = F.relu(y) return y class ConvTranspose3d(torch.nn.Module): def __init__(self, *args): super().__init__() self.deconv = torch.nn.ConvTranspose3d(*args) self.relu = torch.nn.ReLU() def forward(self, x): y = self.deconv(x) if use_relu: y = self.relu(y) return y conv_transpose3d_input = torch.rand(1, 3, 32, 224, 224) conv_transpose3d_weight = torch.rand(3, 3, 3, 3, 3) conv_transpose3d_module_args = (3, 3, 3) # is_dynamic, ModuleClass, module_constructor_inputs, # inputs, quantized_node, weight_prepack_node tests = [ ( False, FunctionalConvTranspose1d, (conv_transpose1d_weight,), (conv_transpose1d_input,), ns.call_function( torch.nn.functional.conv_transpose1d if is_reference else torch.ops.quantized.conv_transpose1d ), ns.call_function(torch.ops.quantized.conv_transpose1d_prepack), ), ( False, FunctionalConvTranspose2d, (conv_transpose2d_weight,), (conv_transpose2d_input,), ns.call_function( torch.nn.functional.conv_transpose2d if is_reference else torch.ops.quantized.conv_transpose2d ), ns.call_function(torch.ops.quantized.conv_transpose2d_prepack), ), ( False, FunctionalConvTranspose3d, (conv_transpose3d_weight,), (conv_transpose3d_input,), ns.call_function( torch.nn.functional.conv_transpose3d if is_reference else torch.ops.quantized.conv_transpose3d), ns.call_function(torch.ops.quantized.conv_transpose3d_prepack), ), ( False, ConvTranspose1d, conv_transpose1d_module_args, (conv_transpose1d_input,), ns.call_module(nnqr.ConvTranspose1d if is_reference else nnq.ConvTranspose1d), None ), ( False, ConvTranspose2d, conv_transpose2d_module_args, (conv_transpose2d_input,), ns.call_module(nnqr.ConvTranspose2d if is_reference else nnq.ConvTranspose2d), None ), ( False, ConvTranspose3d, conv_transpose3d_module_args, (conv_transpose3d_input,), ns.call_module(nnqr.ConvTranspose3d if is_reference else nnq.ConvTranspose3d), None ), ] return tests @skipIfNoFBGEMM def test_conv_transpose_not_reference(self): """ Test quantizing transposed conv """ tests = self._get_conv_transpose_test_cases(use_relu=False, is_reference=False) for (is_dynamic, ModuleClass, module_constructor_inputs, inputs, quantized_node, weight_prepack_node) in tests: quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC node_occurrence = {} if weight_prepack_node: node_occurrence[weight_prepack_node] = 0 self.checkGraphModeFxOp( ModuleClass(*module_constructor_inputs), inputs, quant_type, expected_node=quantized_node, expected_node_occurrence=node_occurrence, is_reference=False) @skipIfNoFBGEMM def test_conv_transpose_reference(self): """ Test quantizing transposed conv with reference option """ tests = self._get_conv_transpose_test_cases(use_relu=False, is_reference=True) def _get_keys(prefix, is_dynamic): all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]] if not is_dynamic: all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]]) return all_keys for (is_dynamic, ModuleClass, module_constructor_inputs, inputs, quantized_node, weight_prepack_node) in tests: quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC node_occurrence = {} if weight_prepack_node: node_occurrence[weight_prepack_node] = 0 result_dict = self.checkGraphModeFxOp( ModuleClass(*module_constructor_inputs), inputs, quant_type, expected_node=quantized_node, expected_node_occurrence=node_occurrence, is_reference=True) qr = result_dict["quantized_reference"] def checkWeightQParams(model): module_name = "deconv" if hasattr(model, module_name): self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme")) self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale")) self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point")) self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name()) def checkSerDeser(model, is_dynamic): module_name = "deconv" if hasattr(model, module_name): # make sure serialization works state_dict = copy.deepcopy(model.state_dict()) all_keys = _get_keys(module_name, is_dynamic) for key in all_keys: self.assertTrue(key in state_dict) # check load_state_dict restores states module = getattr(model, module_name) prev_scale = module.weight_scale module.weight_scale = None model.load_state_dict(state_dict) module = getattr(model, module_name) self.assertTrue(torch.equal(prev_scale, module.weight_scale)) checkWeightQParams(qr) qr = copy.deepcopy(qr) # make sure the qparams are preserved after copy checkWeightQParams(qr) checkSerDeser(qr, is_dynamic) def test_conv_transpose_relu_not_reference(self): """ Test quantizing transposed conv + relu Fusion with relu is not supported. """ tests = self._get_conv_transpose_test_cases(use_relu=True, is_reference=False) for (is_dynamic, ModuleClass, module_constructor_inputs, inputs, quantized_node, weight_prepack_node) in tests: quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC node_occurrence = {} if weight_prepack_node: node_occurrence[weight_prepack_node] = 0 if quantized_node.op == 'call_module': node_occurrence[ns.call_module(nn.ReLU)] = 1 else: node_occurrence[ns.call_function(F.relu)] = 1 self.checkGraphModeFxOp( ModuleClass(*module_constructor_inputs), inputs, quant_type, expected_node=quantized_node, expected_node_occurrence=node_occurrence, is_reference=False) @skipIfNoFBGEMM def test_conv_transpose_relu_reference(self): """ Test quantizing transposed conv with reference option Fusion with relu is not supported. """ tests = self._get_conv_transpose_test_cases(use_relu=True, is_reference=True) def _get_keys(prefix, is_dynamic): all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]] if not is_dynamic: all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]]) return all_keys for (is_dynamic, ModuleClass, module_constructor_inputs, inputs, quantized_node, weight_prepack_node) in tests: quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC node_occurrence = {} if weight_prepack_node: node_occurrence[weight_prepack_node] = 0 if quantized_node.op == 'call_module': node_occurrence[ns.call_module(nn.ReLU)] = 1 else: node_occurrence[ns.call_function(F.relu)] = 1 result_dict = self.checkGraphModeFxOp( ModuleClass(*module_constructor_inputs), inputs, quant_type, expected_node=quantized_node, expected_node_occurrence=node_occurrence, is_reference=True) qr = result_dict["quantized_reference"] def checkWeightQParams(model): module_name = "deconv" if hasattr(model, module_name): self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme")) self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale")) self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point")) self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name()) def checkSerDeser(model, is_dynamic): module_name = "deconv" if hasattr(model, module_name): # make sure serialization works state_dict = copy.deepcopy(model.state_dict()) all_keys = _get_keys(module_name, is_dynamic) for key in all_keys: self.assertTrue(key in state_dict) # check load_state_dict restores states module = getattr(model, module_name) prev_scale = module.weight_scale module.weight_scale = None model.load_state_dict(state_dict) module = getattr(model, module_name) self.assertTrue(torch.equal(prev_scale, module.weight_scale)) checkWeightQParams(qr) qr = copy.deepcopy(qr) # make sure the qparams are preserved after copy checkWeightQParams(qr) checkSerDeser(qr, is_dynamic) @skipIfNoFBGEMM def test_dynamic_quant_weight_observer(self): ''' Test that weight observer is run in convert step ''' class M(torch.nn.Module): def __init__(self, weight): super().__init__() self.weight = torch.nn.Parameter(weight) def forward(self, x): return F.linear(x, self.weight) m = M(torch.rand(1, 1)).eval() qconfig = default_dynamic_qconfig qconfig_dict = {'': qconfig} example_inputs = (torch.rand(1, 1),) prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) quantized = convert_to_reference_fx(prepared) qparams = (quantized._scale_0, quantized._zero_point_0) weight_obs = qconfig.weight() weight_obs(quantized.weight) # Get the actual value to avoid tensor size mismatch error, torch.Size([]) vs torch.Size([1]) ref_qparams = (weight_obs.calculate_qparams()[0].item(), weight_obs.calculate_qparams()[1].item()) self.assertEqual(qparams, ref_qparams) def test_conv_bn_relu(self): """ Tests fusion and quantization for "Conv - Bn" and "Conv - Bn - ReLU" """ convs = { 1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d, } bns = { 1: nn.BatchNorm1d, 2: nn.BatchNorm2d, 3: nn.BatchNorm3d, } quantized_convs = { 1: nnq.Conv1d, 2: nnq.Conv2d, 3: nnq.Conv3d, } quantized_conv_relus = { 1: nniq.ConvReLU1d, 2: nniq.ConvReLU2d, 3: nniq.ConvReLU3d, } class M(torch.nn.Module): def __init__(self, dim, has_relu): super().__init__() self.conv = convs[dim](3, 3, 3) self.bn = bns[dim](3) self.relu = nn.ReLU() if has_relu else nn.Identity() self.has_relu = has_relu self.quant = QuantStub() self.dequant = DeQuantStub() def forward(self, x): x = self.quant(x) x = self.conv(x) x = self.bn(x) if self.has_relu: x = self.relu(x) x = self.dequant(x) return x options = itertools.product([1, 2, 3], [True, False], self.static_quant_types) for dim, has_relu, quant_type in options: expected_node = ns.call_module( quantized_conv_relus[dim] if has_relu else quantized_convs[dim]) m = M(dim, has_relu) m_eager = copy.deepcopy(m) result_dict = self.checkGraphModeFxOp( m, self.img_data_dict[dim], quant_type, expected_node=expected_node, ) result = result_dict["quantized_output"] # check numerics qengine = torch.backends.quantized.engine if quant_type == QuantType.STATIC: m_eager.eval() qconfig = get_default_qconfig(qengine) prepare_fn = prepare is_qat = False else: m_eager.train() qconfig = get_default_qat_qconfig(qengine) prepare_fn = prepare_qat is_qat = True fuse_list = ["conv", "bn"] if has_relu: fuse_list.append("relu") if is_qat: fuse_modules_qat(m_eager, fuse_list, inplace=True) else: fuse_modules(m_eager, fuse_list, inplace=True) m_eager.qconfig = qconfig m_eager = prepare_fn(m_eager) prepared_fx = result_dict["prepared"] m_eager(*self.img_data_dict[dim][0]) m_eager = convert(m_eager) result_eager = m_eager(*self.img_data_dict[dim][0]) self.assertEqual(result, result_eager) def test_linear_bn(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = nn.Linear(4, 4) self.bn = nn.BatchNorm1d(4) self.quant = QuantStub() self.dequant = DeQuantStub() def forward(self, x): x = self.quant(x) x = self.linear(x) x = self.bn(x) x = self.dequant(x) return x data = (torch.randn(4, 4),) for quant_type in self.static_quant_types: expected_node = ns.call_module(nnq.Linear) m = M() m_eager = copy.deepcopy(m) result_dict = self.checkGraphModeFxOp(m, data, quant_type, expected_node=expected_node) result = result_dict["quantized_output"] # check numerics vs eager mode fuse_list = ["linear", "bn"] qengine = torch.backends.quantized.engine if quant_type == QuantType.STATIC: m_eager.eval() qconfig = get_default_qconfig(qengine) prepare_fn = prepare fuse_modules(m_eager, fuse_list, inplace=True) else: m_eager.train() qconfig = get_default_qat_qconfig(qengine) prepare_fn = prepare_qat fuse_modules_qat(m_eager, fuse_list, inplace=True) m_eager.qconfig = qconfig m_eager = prepare_fn(m_eager) m_eager(*data) m_eager = convert(m_eager) result_eager = m_eager(*data) self.assertEqual(result, result_eager) @skipIfNoFBGEMM def test_dynamic_quant_fp16(self): with override_quantized_engine('fbgemm'): class Linear(torch.nn.Module): def __init__(self, weight): super().__init__() self.weight = torch.nn.Parameter(weight) def forward(self, x): return F.linear(x, self.weight) linear_input = torch.rand(8, 5) linear_weight = torch.rand(10, 5) class LinearModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 10) def forward(self, x): return self.linear(x) linear_module_input = torch.rand(8, 5) tests = [ (Linear, (linear_weight,), (linear_input,), ns.call_function(torch.ops.quantized.linear_dynamic_fp16), ns.call_function(torch.ops.quantized.linear_prepack_fp16)), (LinearModule, (), (linear_module_input,), ns.call_module(nnqd.Linear), None), ] for (ModuleClass, module_constructor_inputs, inputs, quantized_node, weight_prepack_node) in tests: for is_reference in [True, False]: node_occurrence = {} if weight_prepack_node: node_occurrence[weight_prepack_node] = 0 m = ModuleClass(*module_constructor_inputs).eval() qconfig_dict = {"": float16_dynamic_qconfig} m = prepare_fx(m, qconfig_dict, example_inputs=inputs) convert_fn = convert_to_reference_fx if is_reference else convert_fx m = convert_fn(m) self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @override_qengines def test_qat_prepare_device_affinity(self): """ Tests that FX QAT prepare pass respects device affinity """ class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(1, 1, 1) self.bn = nn.BatchNorm2d(1) self.relu = nn.ReLU() def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x model = Model() qengine = torch.backends.quantized.engine qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig(qengine)} device = torch.device('cuda:0') model.to(device) example_inputs = (torch.randn(4, 1, 4, 4, device=device),) # QAT prepare model = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) # ensure that running an input on CUDA works without any needed changes model(*example_inputs) # ensure all buffers and parameters are on the device we expect model_devices = {p.device for p in model.parameters()} | \ {p.device for p in model.buffers()} self.assertEqual(len(model_devices), 1) model_device = next(iter(model_devices)) self.assertEqual(model_device, device) @skipIfNoFBGEMM def test_dict_output(self): """ Make sure quantization runs for models with dictionary output """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 1, 1) def forward(self, x): return {"output": self.conv(x["input"])} example_inputs = ({"input": torch.randn(1, 1, 1, 1)},) m = M().eval() qconfig_dict = {"": default_qconfig} m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) m(*example_inputs) m = convert_fx(m) m(*example_inputs) @override_qengines def test_attention(self): """ Make sure quantization runs for a corner case in attention module """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv(x) q, k, v = x.chunk(3, dim=0) q = q.contiguous().view(-1, 1).transpose(0, 1) k = k.contiguous().view(-1, 1).transpose(0, 1) v = v.contiguous().view(-1, 1).transpose(0, 1) torch._assert( k.size(1) == 1, "key size should be equal to 1" ) r = torch.mm(k, v) return q * k + r example_inputs = (torch.randn(3, 1, 1, 1),) m = M().eval() qconfig_dict = { "": None, "object_type": [ (nn.Conv2d, default_qconfig), ] } # make sure it runs m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) m(*example_inputs) m = convert_fx(m) m(*example_inputs) def _test_standalone_module( self, interface_config, prepare_count_check, standalone_prepare_count_check, convert_count_check, standalone_convert_count_check): """ Test standalone module with different quantized input/quantized output configurations """ class StandaloneModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 1, 1) def forward(self, x): return self.conv(x) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 1, 1) self.standalone = StandaloneModule() def forward(self, x): x = self.conv(x) x = self.standalone(x) return x class RefM(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = torch.nn.Conv2d(1, 1, 1) self.conv2 = torch.nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x example_inputs = (torch.randn(1, 1, 1, 1),) # instantiate M and RefM and align the parameters original_m = M().eval() original_ref_m = RefM().eval() original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach()) original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach()) original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach()) original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach()) for is_name in [True, False]: sm_example_inputs = example_inputs if is_name: prepare_config = { "standalone_module_name": [("standalone", None, sm_example_inputs, interface_config, None)] } else: prepare_config = { "standalone_module_class": [(StandaloneModule, None, sm_example_inputs, interface_config, None)] } original_m_copy = copy.deepcopy(original_m) original_ref_m_copy = copy.deepcopy(original_ref_m) qconfig_dict = {"": default_qconfig} # check prepared model m = prepare_fx( original_m_copy, qconfig_dict, example_inputs=example_inputs, prepare_custom_config=prepare_config) # calibration m(*example_inputs) self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check) self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check) # check converted/quantized model m = convert_fx(m) self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check) self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check) res = m(*example_inputs) # quantize the reference model ref_m = prepare_fx( original_ref_m_copy, qconfig_dict, example_inputs=example_inputs, ) ref_m(*example_inputs) ref_m = convert_fx(ref_m) ref_res = ref_m(*example_inputs) self.assertEqual(res, ref_res) def test_standalone_module_float_interface(self): float_interface_config = { "input_quantized_idxs": [], # float input "output_quantized_idxs": [], # float output } interface_config = float_interface_config # input and output of first conv, observer for standalone module # will be inserted in the standalone module itself prepare_count_check = { ns.call_module(torch.ao.quantization.MinMaxObserver): 2 } # for input and output of conv in the standalone module standalone_prepare_count_check = { ns.call_module(torch.ao.quantization.MinMaxObserver): 2 } convert_count_check = { ns.call_function(torch.quantize_per_tensor) : 1, ns.call_module(nnq.Conv2d) : 1, ns.call_method("dequantize") : 1, } standalone_convert_count_check = { # standalone module will take float as input and output # so we'll see quantize and dequantize in the modoule ns.call_function(torch.quantize_per_tensor) : 1, ns.call_module(nnq.Conv2d): 1, ns.call_method("dequantize") : 1, } self._test_standalone_module( interface_config, prepare_count_check, standalone_prepare_count_check, convert_count_check, standalone_convert_count_check) def test_standalone_module_quantized_interface(self): quantized_interface_config = { "input_quantized_idxs": [0], # quantized input "output_quantized_idxs": [0], # quantized output } interface_config = quantized_interface_config # observer for input and output of first conv prepare_count_check = { ns.call_module(torch.ao.quantization.MinMaxObserver): 2 } # for output of conv in the standalone module standalone_prepare_count_check = { ns.call_module(torch.ao.quantization.MinMaxObserver): 1 } convert_count_check = { # quantizing input for conv ns.call_function(torch.quantize_per_tensor) : 1, ns.call_module(nnq.Conv2d) : 1, # dequantizing output of standalone module ns.call_method("dequantize") : 1, } standalone_convert_count_check = { # quantization of input happens in parent module # quantization of output happens in the quantized conv module ns.call_function(torch.quantize_per_tensor) : 0, ns.call_module(nnq.Conv2d): 1, # dequantization for output happens in parent module ns.call_method("dequantize") : 0, } self._test_standalone_module( interface_config, prepare_count_check, standalone_prepare_count_check, convert_count_check, standalone_convert_count_check) @skipIfNoFBGEMM def test_qconfig_none(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 1, 1) self.conv2 = nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x m = M().eval() qconfig_dict = {"": default_qconfig, "module_name": [("conv2", None)]} example_inputs = (torch.randn(1, 1, 1, 1),) m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) m(*example_inputs) m = convert_fx(m) m(*example_inputs) # first conv is quantized, second conv is not quantized node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_method("dequantize"), ns.call_module(nn.Conv2d), ] self.checkGraphModuleNodes(m, expected_node_list=node_list) def test_qconfig_module_type(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(1, 1, 1) self.linear = nn.Linear(9, 3) def forward(self, x): x = self.conv(x) x = x.reshape((1, -1)) x = self.linear(x) return x m = M().eval() qconfig_dict = {"object_type": [(torch.nn.Conv2d, default_qconfig)]} example_inputs = (torch.randn(1, 1, 3, 3),) m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) m(*example_inputs) m = convert_fx(m) m(*example_inputs) # conv is quantized, linear is not quantized node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_method("dequantize"), ns.call_module(nn.Linear), ] self.checkGraphModuleNodes(m, expected_node_list=node_list) def test_qconfig_qat_module_type(self): class LinearRelu(nn.Sequential): def __init__(self) -> None: super().__init__( nn.Linear(5, 5), nn.ReLU(), ) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.lin_relu = LinearRelu() self.linear = nn.Linear(5, 5) def forward(self, x): x = self.lin_relu(x) x = self.linear(x) return x model = M().train() qconfig_dict = { "": None, "object_type": [ (torch.nn.Linear, default_qat_qconfig), (torch.nn.ReLU, default_qat_qconfig), ], } example_inputs = (torch.rand(5, 5),) m = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) m(*example_inputs) m = convert_fx(m) m(*example_inputs) node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.LinearReLU), ns.call_module(nnq.Linear), ns.call_method("dequantize"), ] self.checkGraphModuleNodes(m, expected_node_list=node_list) def test_qconfig_function(self): class M(torch.nn.Module): def forward(self, x, y): return x + y m = M().eval() qconfig_dict = {"object_type": [(operator.add, default_qconfig)]} data = torch.randn(1, 1, 1, 1) example_inputs = (data, data) m = prepare_fx(m, qconfig_dict, example_inputs) m(*example_inputs) m = convert_fx(m) m(*example_inputs) # first conv is quantized, second conv is not quantized node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.add), ns.call_method("dequantize"), ] self.checkGraphModuleNodes(m, expected_node_list=node_list) def test_qconfig_module_name_regex(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 1, 1) self.conv2 = nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x m = M().eval() qconfig_dict = {"module_name_regex": [("conv*", default_qconfig)]} example_inputs = (torch.randn(1, 1, 1, 1),) m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) m(*example_inputs) m = convert_fx(m) m(*example_inputs) # first conv is quantized, second conv is not quantized node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_module(nnq.Conv2d), ns.call_method("dequantize"), ] self.checkGraphModuleNodes(m, expected_node_list=node_list) def test_qconfig_precedence(self): for device in get_supported_device_types(): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = nn.Linear(1, 1) self.conv = nn.Conv2d(1, 1, 1) self.module_conv1 = nn.Conv2d(1, 1, 1) self.module_conv2 = nn.Conv2d(1, 1, 1) def forward(self, x): # global x = self.linear(x) # global + object_type --> object_type x = self.conv(x) # global + object_type + module_name_regex --> module_name_regex x = self.module_conv1(x) # global + object_type + module_name_regex + module_name --> module_name x = self.module_conv2(x) return x m = M().to(device).eval() global_qconfig = default_qconfig object_type_qconfig = default_dynamic_qconfig module_name_regex_qconfig = float16_dynamic_qconfig module_name_qconfig = default_qat_qconfig qconfig_dict = { "": global_qconfig, "object_type": [(nn.Conv2d, object_type_qconfig)], "module_name_regex": [("module_conv*", module_name_regex_qconfig)], "module_name": [("module_conv2", module_name_qconfig)]} m_prep = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1),)) self.assertEqual(m_prep.linear.qconfig.activation.p.func, global_qconfig.activation.p.func) self.assertEqual(m_prep.linear.qconfig.weight.p.func, global_qconfig.weight.p.func) self.assertEqual(m_prep.conv.qconfig.activation.p.func, object_type_qconfig.activation.p.func) self.assertEqual(m_prep.conv.qconfig.weight.p.func, object_type_qconfig.weight.p.func) self.assertEqual(m_prep.module_conv1.qconfig.activation.p.func, module_name_regex_qconfig.activation.p.func) self.assertEqual(m_prep.module_conv1.qconfig.weight.p.func, module_name_regex_qconfig.weight.p.func) self.assertEqual(m_prep.module_conv2.qconfig.activation.p.func, module_name_qconfig.activation.p.func) self.assertEqual(m_prep.module_conv2.qconfig.weight.p.func, module_name_qconfig.weight.p.func) def test_qconfig_module_name_object_type_order(self): class M1(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = nn.Linear(1, 1) self.fc2 = nn.Linear(1, 1) def forward(self, x): x = self.fc1(x) x = self.fc2(x) x = torch.add(x, x) x = torch.add(x, x) return x class M2(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = nn.Linear(1, 1) self.fc2 = nn.Linear(1, 1) self.m1 = M1() def forward(self, x): x = self.fc1(x) x = self.fc2(x) x = torch.add(x, x) x = torch.add(x, x) x = self.m1(x) return x class M3(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = nn.Linear(1, 1) self.fc2 = nn.Linear(1, 1) self.m2 = M2() def forward(self, x): x = self.fc1(x) x = self.fc2(x) x = torch.add(x, x) x = torch.add(x, x) x = self.m2(x) return x m = M3().eval() qconfig_dict = { "module_name_object_type_order": [ # test various FQNs: global, single child, multiple children ("", nn.Linear, 0, torch.ao.quantization.default_qconfig), ("", torch.add, 0, torch.ao.quantization.default_qconfig), ("m2", nn.Linear, 1, torch.ao.quantization.default_qconfig), ("m2", torch.add, 1, torch.ao.quantization.default_qconfig), ("m2.m1", nn.Linear, 0, torch.ao.quantization.default_qconfig), ("m2.m1", torch.add, 0, torch.ao.quantization.default_qconfig), ], } example_inputs = (torch.randn(1, 1, 1, 1),) m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) m(*example_inputs) m = convert_fx(m) m(*example_inputs) node_list = [ # m3 ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method("dequantize"), ns.call_module(nn.Linear), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.add), ns.call_method("dequantize"), ns.call_function(torch.add), # m2 ns.call_module(nn.Linear), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method("dequantize"), ns.call_function(torch.add), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.add), # m1 ns.call_module(nnq.Linear), ns.call_method("dequantize"), ns.call_module(nn.Linear), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.add), ns.call_method("dequantize"), ns.call_function(torch.add), ] self.checkGraphModuleNodes(m, expected_node_list=node_list) # test that function order overrides global qconfig class M4(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = nn.Linear(1, 1) self.fc2 = nn.Linear(1, 1) def forward(self, x): x = self.fc1(x) x = self.fc2(x) x = torch.add(x, x) x = torch.add(x, x) return x m = M4().eval() qconfig_dict = { "": torch.ao.quantization.default_qconfig, "module_name_object_type_order": [ ("", nn.Linear, 1, None), ("", torch.add, 1, None), ], } example_inputs = (torch.randn(1, 1, 1, 1),) m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) m(*example_inputs) m = convert_fx(m) m(*example_inputs) node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method("dequantize"), ns.call_module(nn.Linear), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.add), ns.call_method("dequantize"), ns.call_function(torch.add), ] self.checkGraphModuleNodes(m, expected_node_list=node_list) @override_qengines def test_qconfig_dict_with_fused_modules(self): class LinearReLUModel(torch.nn.Module): def __init__(self, relu): super().__init__() self.linear = torch.nn.Linear(3, 3) self.relu = relu def forward(self, x): x = self.linear(x) x = self.relu(x) return x class ConvReLUModel(torch.nn.Module): def __init__(self, relu): super().__init__() self.conv = torch.nn.Conv1d(3, 3, 3) self.relu = relu def forward(self, x): x = self.conv(x) x = self.relu(x) return x class ConvBnReLUModel(torch.nn.Module): def __init__(self, relu): super().__init__() self.conv = torch.nn.Conv1d(3, 3, 3) self.bn = torch.nn.BatchNorm1d(3) self.relu = relu def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x for model in [LinearReLUModel, ConvReLUModel, ConvBnReLUModel]: for relu in [torch.nn.ReLU(), torch.nn.functional.relu, torch.relu]: m = model(relu).eval() qengine = torch.backends.quantized.engine qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping(qengine) # should not crash as in https://github.com/pytorch/pytorch/issues/75825 prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) # TODO: move QConfigMapping tests to test/quantization/core def test_qconfig_mapping_set_global(self): qconfig = get_default_qconfig() qconfig_mapping = QConfigMapping() self.assertEqual(qconfig_mapping.global_qconfig, None) qconfig_mapping.set_global(qconfig) self.assertEqual(qconfig_mapping.global_qconfig, qconfig) def test_qconfig_mapping_set_object_type(self): qconfig1 = get_default_qconfig() qconfig2 = get_default_qconfig() qconfig3 = get_default_qconfig() self.assertNotEqual(qconfig1, qconfig2) self.assertNotEqual(qconfig1, qconfig3) qconfig_mapping = QConfigMapping() self.assertEqual(len(qconfig_mapping.object_type_qconfigs), 0) # Insert some entries qconfig_mapping.set_object_type(torch.nn.Linear, qconfig1) qconfig_mapping.set_object_type(torch.nn.ReLU, qconfig2) self.assertEqual(len(qconfig_mapping.object_type_qconfigs), 2) self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.Linear], qconfig1) self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.ReLU], qconfig2) # Override existing key qconfig_mapping.set_object_type(torch.nn.Linear, qconfig3) self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.Linear], qconfig3) self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.ReLU], qconfig2) self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.Linear, None), qconfig3) self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.ReLU, None), qconfig2) self.assertEqual(_get_object_type_qconfig(qconfig_mapping, "nomatch", None), None) def test_qconfig_mapping_set_module_name_regex(self): qconfig1 = get_default_qconfig() qconfig2 = get_default_qconfig() qconfig3 = get_default_qconfig() self.assertNotEqual(qconfig1, qconfig2) self.assertNotEqual(qconfig1, qconfig3) qconfig_mapping = QConfigMapping() self.assertEqual(len(qconfig_mapping.module_name_regex_qconfigs), 0) # Insert some entries qconfig_mapping.set_module_name_regex("foo.*bar", qconfig1) qconfig_mapping.set_module_name_regex("foo.*", qconfig2) self.assertEqual(len(qconfig_mapping.module_name_regex_qconfigs), 2) self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*bar"], qconfig1) self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*"], qconfig2) # Override existing key qconfig_mapping.set_module_name_regex("foo.*bar", qconfig3) self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*bar"], qconfig3) self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*"], qconfig2) self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo123bar", None), qconfig3) self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobar", None), qconfig3) self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobaz", None), qconfig2) self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo", None), qconfig2) self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "nomatch", None), None) def test_qconfig_mapping_set_module_name(self): qconfig1 = get_default_qconfig() qconfig2 = get_default_qconfig() qconfig3 = get_default_qconfig() self.assertNotEqual(qconfig1, qconfig2) self.assertNotEqual(qconfig1, qconfig3) qconfig_mapping = QConfigMapping() self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 0) # Insert some entries qconfig_mapping.set_module_name("mod1", qconfig1) qconfig_mapping.set_module_name("mod2", qconfig2) self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 2) self.assertEqual(qconfig_mapping.module_name_qconfigs["mod1"], qconfig1) self.assertEqual(qconfig_mapping.module_name_qconfigs["mod2"], qconfig2) # Override existing key qconfig_mapping.set_module_name("mod1", qconfig3) self.assertEqual(qconfig_mapping.module_name_qconfigs["mod1"], qconfig3) self.assertEqual(qconfig_mapping.module_name_qconfigs["mod2"], qconfig2) self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod1", None), qconfig3) self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod2", None), qconfig2) self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "nomatch", None), None) def test_qconfig_mapping_set_module_name_object_type_order(self): qconfig1 = get_default_qconfig() qconfig2 = get_default_qconfig() qconfig3 = get_default_qconfig() self.assertNotEqual(qconfig1, qconfig2) self.assertNotEqual(qconfig1, qconfig3) qconfig_mapping = QConfigMapping() self.assertEqual(len(qconfig_mapping.module_name_object_type_order_qconfigs), 0) # Insert some entries qconfig_mapping.set_module_name_object_type_order("mod1", torch.nn.Linear, 0, qconfig1) qconfig_mapping.set_module_name_object_type_order("mod2", torch.nn.ReLU, 1, qconfig2) self.assertEqual(len(qconfig_mapping.module_name_object_type_order_qconfigs), 2) key1 = ("mod1", torch.nn.Linear, 0) key2 = ("mod2", torch.nn.ReLU, 1) self.assertEqual(next(iter(qconfig_mapping.module_name_object_type_order_qconfigs)), key1) self.assertEqual(list(qconfig_mapping.module_name_object_type_order_qconfigs)[1], key2) self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key1], qconfig1) self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key2], qconfig2) self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( qconfig_mapping, "mod1", torch.nn.Linear, 0, None), qconfig1) self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( qconfig_mapping, "mod2", torch.nn.ReLU, 1, None), qconfig2) # Override existing key qconfig_mapping.set_module_name_object_type_order("mod1", torch.nn.Linear, 0, qconfig3) self.assertEqual(len(qconfig_mapping.module_name_object_type_order_qconfigs), 2) self.assertEqual(next(iter(qconfig_mapping.module_name_object_type_order_qconfigs)), key1) self.assertEqual(list(qconfig_mapping.module_name_object_type_order_qconfigs)[1], key2) self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key1], qconfig3) self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key2], qconfig2) self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( qconfig_mapping, "mod1", torch.nn.Linear, 0, None), qconfig3) self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( qconfig_mapping, "mod2", torch.nn.ReLU, 1, None), qconfig2) # No match self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( qconfig_mapping, "mod123", torch.nn.Linear, 0, None), None) self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( qconfig_mapping, "mod1", torch.nn.Linear, 35, None), None) self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( qconfig_mapping, "mod2", torch.nn.Conv2d, 1, None), None) def _get_qconfig_dict_for_qconfig_mapping_test(self, global_qconfig, qconfig1, qconfig2): """ Return a dummy qconfig_dict to test QConfigMapping's to_dict and from_dict methods. """ return { _GLOBAL_DICT_KEY: global_qconfig, _OBJECT_TYPE_DICT_KEY: [ (torch.nn.Linear, qconfig1), (torch.nn.ReLU, qconfig2), ], _MODULE_NAME_REGEX_DICT_KEY: [ ("foo.*bar", qconfig1), ("foo.*", qconfig2), ], _MODULE_NAME_DICT_KEY: [ ("bazbaz", qconfig1), ("borbor", qconfig2), ], _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ ("bazbaz", torch.nn.Linear, 0, qconfig1), ("foofoo", torch.nn.ReLU, 1, qconfig2), ], } with self.assertRaises(ValueError) as context: m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) # noqa: F821 self.assertTrue( 'Expected qconfig_dict to have the following keys:' in str(context.exception) ) self.assertTrue('But found \'object_typo\' instead.' in str(context.exception)) def test_qconfig_mapping_from_dict(self): global_qconfig = QConfig(123, "global") qconfig1 = QConfig(1, "one") qconfig2 = QConfig(2, "two") qconfig_dict = self._get_qconfig_dict_for_qconfig_mapping_test(global_qconfig, qconfig1, qconfig2) qconfig_dict["undefined_dict_key"] = [(123, qconfig1), (234, qconfig2)] qconfig_mapping = QConfigMapping.from_dict(qconfig_dict) self.assertEqual(qconfig_mapping.global_qconfig, global_qconfig) self.assertEqual(qconfig_mapping.object_type_qconfigs, OrderedDict({ torch.nn.Linear: qconfig1, torch.nn.ReLU: qconfig2, })) self.assertEqual(qconfig_mapping.module_name_regex_qconfigs, OrderedDict({ "foo.*bar": qconfig1, "foo.*": qconfig2, })) self.assertEqual(qconfig_mapping.module_name_qconfigs, OrderedDict({ "bazbaz": qconfig1, "borbor": qconfig2, })) self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs, OrderedDict({ ("bazbaz", torch.nn.Linear, 0): qconfig1, ("foofoo", torch.nn.ReLU, 1): qconfig2, })) def test_qconfig_mapping_to_dict(self): global_qconfig = QConfig(123, "global") qconfig1 = QConfig(1, "one") qconfig2 = QConfig(2, "two") qconfig_mapping = QConfigMapping().set_global(global_qconfig) \ .set_object_type(torch.nn.Linear, qconfig1) \ .set_object_type(torch.nn.ReLU, qconfig2) \ .set_module_name_regex("foo.*bar", qconfig1) \ .set_module_name_regex("foo.*", qconfig2) \ .set_module_name("bazbaz", qconfig1) \ .set_module_name("borbor", qconfig2) \ .set_module_name_object_type_order("bazbaz", torch.nn.Linear, 0, qconfig1) \ .set_module_name_object_type_order("foofoo", torch.nn.ReLU, 1, qconfig2) qconfig_dict = self._get_qconfig_dict_for_qconfig_mapping_test(global_qconfig, qconfig1, qconfig2) self.assertEqual(qconfig_mapping.to_dict(), qconfig_dict) def test_qconfig_mapping_repr(self): self.assertTrue(isinstance(get_default_qconfig_mapping().__repr__(), str)) def test_default_qconfig_mapping_override_global(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 1, 1) def forward(self, x): return self.conv(x) m = M().eval() my_qconfig = QConfig(activation=MinMaxObserver, weight=default_weight_observer) qconfig_mapping = get_default_qconfig_mapping() # Override global qconfig old_global_qconfig = qconfig_mapping.global_qconfig qconfig_mapping.set_global(my_qconfig) # Verify the correct qconfig was used example_inputs = (torch.randn(1, 1, 1, 1),) m = prepare_fx(m, qconfig_mapping, example_inputs) self.assertTrue(isinstance(old_global_qconfig.activation(), HistogramObserver)) self.assertTrue(isinstance(my_qconfig.activation(), MinMaxObserver)) self.assertTrue(hasattr(m, "activation_post_process_0")) self.assertTrue(hasattr(m, "activation_post_process_1")) self.assertTrue(isinstance(m.activation_post_process_0, MinMaxObserver)) self.assertTrue(isinstance(m.activation_post_process_1, MinMaxObserver)) # Dummy classes for PrepareCustomConfig testing class _DummyStandaloneModule: pass class _DummyFloatModule: pass class _DummyObservedModule: pass class _DummyQuantizedModule: pass class _DummyNonTraceableModule1: pass class _DummyNonTraceableModule2: pass def test_prepare_custom_config_set_standalone_module_name(self): qconfig_mapping = QConfigMapping() example_inputs = (torch.randn(3),) child_prepare_custom_config = PrepareCustomConfig() backend_config = BackendConfig("my_backend") config_entry = StandaloneModuleConfigEntry( qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config) prepare_custom_config = PrepareCustomConfig() self.assertEqual(len(prepare_custom_config.standalone_module_names), 0) prepare_custom_config.set_standalone_module_name( "module1", qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config) self.assertEqual(list(prepare_custom_config.standalone_module_names.keys()), ["module1"]) self.assertEqual(prepare_custom_config.standalone_module_names["module1"], config_entry) def test_prepare_custom_config_set_standalone_module_class(self): qconfig_mapping = QConfigMapping() example_inputs = (torch.randn(3),) child_prepare_custom_config = PrepareCustomConfig() backend_config = BackendConfig("my_backend") config_entry = StandaloneModuleConfigEntry( qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config) prepare_custom_config = PrepareCustomConfig() self.assertEqual(len(prepare_custom_config.standalone_module_classes), 0) prepare_custom_config.set_standalone_module_class( self._DummyStandaloneModule, qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config) self.assertEqual(len(prepare_custom_config.standalone_module_classes), 1) self.assertTrue(self._DummyStandaloneModule in prepare_custom_config.standalone_module_classes) self.assertEqual(prepare_custom_config.standalone_module_classes[self._DummyStandaloneModule], config_entry) def test_prepare_custom_config_set_float_to_observed_mapping(self): prepare_custom_config = PrepareCustomConfig() self.assertEqual(len(prepare_custom_config.float_to_observed_mapping), 0) prepare_custom_config.set_float_to_observed_mapping(self._DummyFloatModule, self._DummyObservedModule, QuantType.STATIC) self.assertEqual(len(prepare_custom_config.float_to_observed_mapping), 1) self.assertEqual(list(prepare_custom_config.float_to_observed_mapping.keys()), [QuantType.STATIC]) self.assertEqual(len(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC]), 1) self.assertTrue(self._DummyFloatModule in prepare_custom_config.float_to_observed_mapping[QuantType.STATIC]) self.assertEqual(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC][self._DummyFloatModule], self._DummyObservedModule) def test_prepare_custom_config_set_non_traceable_module_names(self): prepare_custom_config = PrepareCustomConfig() self.assertEqual(len(prepare_custom_config.non_traceable_module_names), 0) prepare_custom_config.set_non_traceable_module_names(["module1", "module2"]) self.assertEqual(prepare_custom_config.non_traceable_module_names, ["module1", "module2"]) def test_prepare_custom_config_set_non_traceable_module_classes(self): prepare_custom_config = PrepareCustomConfig() self.assertEqual(len(prepare_custom_config.non_traceable_module_classes), 0) prepare_custom_config.set_non_traceable_module_classes([self._DummyNonTraceableModule1, self._DummyNonTraceableModule2]) self.assertEqual(prepare_custom_config.non_traceable_module_classes, [self._DummyNonTraceableModule1, self._DummyNonTraceableModule2]) def test_prepare_custom_config_set_input_quantized_indexes(self): prepare_custom_config = PrepareCustomConfig() self.assertEqual(len(prepare_custom_config.input_quantized_indexes), 0) prepare_custom_config.set_input_quantized_indexes([0, 1]) self.assertEqual(prepare_custom_config.input_quantized_indexes, [0, 1]) def test_prepare_custom_config_set_output_quantized_indexes(self): prepare_custom_config = PrepareCustomConfig() self.assertEqual(len(prepare_custom_config.output_quantized_indexes), 0) prepare_custom_config.set_output_quantized_indexes([0, 1]) self.assertEqual(prepare_custom_config.output_quantized_indexes, [0, 1]) def test_prepare_custom_config_set_preserved_attributes(self): prepare_custom_config = PrepareCustomConfig() self.assertEqual(len(prepare_custom_config.preserved_attributes), 0) prepare_custom_config.set_preserved_attributes(["attr1", "attr2"]) self.assertEqual(prepare_custom_config.preserved_attributes, ["attr1", "attr2"]) def _get_dummy_prepare_custom_config_dict(self): """ Return a dummy prepare_custom_config_dict to test PrepareCustomConfig's to_dict and from_dict methods. """ return { STANDALONE_MODULE_NAME_DICT_KEY: [( "module1", QConfigMapping(), (torch.randn(3),), PrepareCustomConfig(), BackendConfig("my_backend"), )], STANDALONE_MODULE_CLASS_DICT_KEY: [( self._DummyStandaloneModule, QConfigMapping(), (torch.randn(10),), PrepareCustomConfig(), BackendConfig("my_backend"), )], FLOAT_TO_OBSERVED_DICT_KEY: { "static": { self._DummyFloatModule: self._DummyObservedModule }, }, NON_TRACEABLE_MODULE_NAME_DICT_KEY: ["module2", "module3"], NON_TRACEABLE_MODULE_CLASS_DICT_KEY: [self._DummyNonTraceableModule1, self._DummyNonTraceableModule2], INPUT_QUANTIZED_INDEXES_DICT_KEY: [0, 1], OUTPUT_QUANTIZED_INDEXES_DICT_KEY: [0, 1], PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"] } def test_prepare_custom_config_from_dict(self): prepare_custom_config_dict = self._get_dummy_prepare_custom_config_dict() (sm_name, qm1, ei1, pcc1, bcd1) = prepare_custom_config_dict[STANDALONE_MODULE_NAME_DICT_KEY][0] (sm_class, qm2, ei2, pcc2, bcd2) = prepare_custom_config_dict[STANDALONE_MODULE_CLASS_DICT_KEY][0] sm_config_entry1 = StandaloneModuleConfigEntry(qm1, ei1, pcc1, bcd1) sm_config_entry2 = StandaloneModuleConfigEntry(qm2, ei2, pcc2, bcd2) prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config_dict) # Standalone modules self.assertEqual(len(prepare_custom_config.standalone_module_names), 1) self.assertTrue(sm_name in prepare_custom_config.standalone_module_names) self.assertEqual(prepare_custom_config.standalone_module_names[sm_name], sm_config_entry1) self.assertEqual(len(prepare_custom_config.standalone_module_classes), 1) self.assertTrue(sm_class in prepare_custom_config.standalone_module_classes) self.assertEqual(prepare_custom_config.standalone_module_classes[sm_class], sm_config_entry2) # Float to observed mapping self.assertEqual(len(prepare_custom_config.float_to_observed_mapping), 1) self.assertEqual(list(prepare_custom_config.float_to_observed_mapping.keys()), [QuantType.STATIC]) self.assertEqual(len(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC]), 1) self.assertTrue(self._DummyFloatModule in prepare_custom_config.float_to_observed_mapping[QuantType.STATIC]) self.assertEqual(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC][self._DummyFloatModule], self._DummyObservedModule) # Other self.assertEqual(prepare_custom_config.non_traceable_module_names, ["module2", "module3"]) self.assertEqual(prepare_custom_config.non_traceable_module_classes, [self._DummyNonTraceableModule1, self._DummyNonTraceableModule2]) self.assertEqual(prepare_custom_config.input_quantized_indexes, [0, 1]) self.assertEqual(prepare_custom_config.output_quantized_indexes, [0, 1]) self.assertEqual(prepare_custom_config.preserved_attributes, ["attr1", "attr2"]) def test_prepare_custom_config_to_dict(self): prepare_custom_config_dict = self._get_dummy_prepare_custom_config_dict() (sm_name, qm1, ei1, pcc1, bcd1) = prepare_custom_config_dict[STANDALONE_MODULE_NAME_DICT_KEY][0] (sm_class, qm2, ei2, pcc2, bcd2) = prepare_custom_config_dict[STANDALONE_MODULE_CLASS_DICT_KEY][0] prepare_custom_config = PrepareCustomConfig() \ .set_standalone_module_name(sm_name, qm1, ei1, pcc1, bcd1) \ .set_standalone_module_class(sm_class, qm2, ei2, pcc2, bcd2) \ .set_float_to_observed_mapping(self._DummyFloatModule, self._DummyObservedModule) \ .set_non_traceable_module_names(["module2", "module3"]) \ .set_non_traceable_module_classes([self._DummyNonTraceableModule1, self._DummyNonTraceableModule2]) \ .set_input_quantized_indexes([0, 1]) \ .set_output_quantized_indexes([0, 1]) \ .set_preserved_attributes(["attr1", "attr2"]) # PrepareCustomConfig.to_dict also converts internal QConfigMappings and PrepareCustomConfigs to dicts prepare_custom_config_dict[STANDALONE_MODULE_NAME_DICT_KEY][0] = (sm_name, qm1.to_dict(), ei1, pcc1.to_dict(), bcd1) prepare_custom_config_dict[STANDALONE_MODULE_CLASS_DICT_KEY][0] = (sm_class, qm2.to_dict(), ei2, pcc2.to_dict(), bcd2) self.assertEqual(prepare_custom_config.to_dict(), prepare_custom_config_dict) def test_convert_custom_config_set_observed_to_quantized_mapping(self): convert_custom_config = ConvertCustomConfig() self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping), 0) convert_custom_config.set_observed_to_quantized_mapping( self._DummyObservedModule, self._DummyQuantizedModule, QuantType.STATIC) self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping), 1) self.assertEqual(list(convert_custom_config.observed_to_quantized_mapping.keys()), [QuantType.STATIC]) self.assertTrue(self._DummyObservedModule in convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC]) self.assertEqual(convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC][self._DummyObservedModule], self._DummyQuantizedModule) def test_convert_custom_config_set_preserved_attributes(self): convert_custom_config = ConvertCustomConfig() self.assertEqual(len(convert_custom_config.preserved_attributes), 0) convert_custom_config.set_preserved_attributes(["attr1", "attr2"]) self.assertEqual(convert_custom_config.preserved_attributes, ["attr1", "attr2"]) def _get_dummy_convert_custom_config_dict(self): """ Return a dummy convert_custom_config_dict to test ConvertCustomConfig's to_dict and from_dict methods. """ return { OBSERVED_TO_QUANTIZED_DICT_KEY: { "static": { self._DummyObservedModule: self._DummyQuantizedModule }, }, PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"] } def test_convert_custom_config_from_dict(self): convert_custom_config_dict = self._get_dummy_convert_custom_config_dict() convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config_dict) self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping), 1) self.assertEqual(list(convert_custom_config.observed_to_quantized_mapping.keys()), [QuantType.STATIC]) self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC]), 1) self.assertTrue(self._DummyObservedModule in convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC]) self.assertEqual(convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC][self._DummyObservedModule], self._DummyQuantizedModule) self.assertEqual(convert_custom_config.preserved_attributes, ["attr1", "attr2"]) def test_convert_custom_config_to_dict(self): convert_custom_config = ConvertCustomConfig() \ .set_observed_to_quantized_mapping(self._DummyObservedModule, self._DummyQuantizedModule) \ .set_preserved_attributes(["attr1", "attr2"]) self.assertEqual(convert_custom_config.to_dict(), self._get_dummy_convert_custom_config_dict()) def test_fuse_custom_config_set_preserved_attributes(self): fuse_custom_config = FuseCustomConfig() self.assertEqual(len(fuse_custom_config.preserved_attributes), 0) fuse_custom_config.set_preserved_attributes(["attr1", "attr2"]) self.assertEqual(fuse_custom_config.preserved_attributes, ["attr1", "attr2"]) def test_fuse_custom_config_from_dict(self): fuse_custom_config_dict = {PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"]} fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config_dict) self.assertEqual(fuse_custom_config.preserved_attributes, ["attr1", "attr2"]) def test_fuse_custom_config_to_dict(self): fuse_custom_config_dict = {PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"]} fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"]) self.assertEqual(fuse_custom_config.to_dict(), fuse_custom_config_dict) def test_remove_qconfig(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.avg_pool = torch.nn.AvgPool2d(1) def forward(self, x): return self.avg_pool(x) m = M().eval() qconfig_dict = {'': default_qconfig} example_inputs = (torch.randn(1, 1, 1, 1),) m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) m(*example_inputs) m = convert_fx(m) m(*example_inputs) for name, module in m.named_modules(): self.assertFalse(hasattr(module, 'qconfig'), 'qconfig is not removed for ' + name) def test_return_none(self): class M(torch.nn.Module): def forward(self, x): pass m = M().eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1),)) m = convert_fx(m) def test_default_quant_after_none_qconfig(self): """ Make sure default quant is inserted properly""" class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = torch.nn.Conv2d(1, 1, 1) self.conv2 = torch.nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv1(x) x = x.transpose(1, 2) x = self.conv2(x) m = M().eval() qconfig_dict = { "": default_qconfig, "module_name": [ ("conv1", None) ] } m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),)) m = convert_fx(m) def test_qconfig_for_call_method(self): class Sub(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 1, 1) def forward(self, x): x = x.transpose(2, 3) x = self.conv(x) return x.transpose(2, 3) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.sub = Sub() self.conv1 = torch.nn.Conv2d(1, 1, 1) self.conv2 = torch.nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv1(x) x = self.sub(x) x = self.conv2(x) return x.transpose(2, 3) qconfig_dict1 = {"": default_qconfig, "module_name": [("sub", None)]} # since sub is configured to have qconfig None, we should dequantize the output # of self.conv1 and quantize the input of self.conv2 # dequantize after conv2 should happen after transpose since # it is configured with default_qconfig # nodes in Sub module instance is not quantized node_list1 = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_method("dequantize"), ns.call_method("transpose"), ns.call_module(nn.Conv2d), ns.call_method("transpose"), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_method("transpose"), ns.call_method("dequantize") ] qconfig_dict2 = {"": None, "module_name": [("sub", default_qconfig)]} # Only nodes in Sub module instance are quantized # the first transpose is not quantized because the input is not quantized node_list2 = [ ns.call_module(nn.Conv2d), ns.call_function(torch.quantize_per_tensor), ns.call_method("transpose"), ns.call_module(nnq.Conv2d), ns.call_method("transpose"), ns.call_method("dequantize"), ns.call_module(nn.Conv2d), ns.call_method("transpose"), ] for qconfig_dict, node_list in [ (qconfig_dict1, node_list1), (qconfig_dict2, node_list2) ]: example_inputs = (torch.randn(2, 1, 3, 3),) m = M().eval() m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) m(torch.randn(2, 1, 3, 3)) m = convert_fx(m) self.checkGraphModuleNodes(m, expected_node_list=node_list) # make sure it runs m(*example_inputs) def test_qconfig_for_call_func(self): class Linear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.w = torch.ones(5, 5) self.b = torch.zeros(5) def forward(self, x): return torch.nn.functional.linear(x, self.w, self.b) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mods1 = torch.nn.Sequential( Linear(), Linear() ) self.mods2 = Linear() def forward(self, x): x = self.mods1(x) x = self.mods2(x) return x model = M().eval() example_inputs = (torch.rand(5, 5),) qconfig_dict = {"": default_qconfig, "module_name": [("mods2", None)]} m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) m(*example_inputs) m = convert_fx(m) node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize'), ns.call_function(torch.nn.functional.linear) ] self.checkGraphModuleNodes(m, expected_node_list=node_list) m(torch.rand(5, 5)) def test_preserve_attributes(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 1, 1) def forward(self, x): return self.conv(x) m = M() m.eval() m.preserved_attr = 3 prepare_custom_config_dict = { "preserved_attributes": ["preserved_attr"] } example_inputs = (torch.randn(1, 1, 1, 1),) m = prepare_fx( m, {"": default_qconfig}, example_inputs=example_inputs, prepare_custom_config=prepare_custom_config_dict) def assertAttrPreserved(m): self.assertTrue(hasattr(m, "preserved_attr")) self.assertEqual(m.preserved_attr, 3) assertAttrPreserved(m) convert_custom_config_dict = { "preserved_attributes": ["preserved_attr"] } m = convert_fx(m, convert_custom_config=convert_custom_config_dict) assertAttrPreserved(m) @skipIfNoFBGEMM def test_qat_and_script(self): model = LinearModelWithSubmodule().train() qengine = torch.backends.quantized.engine qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig(qengine)} x = torch.randn(5, 5) example_inputs = (x,) model = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) # ensure scripting works scripted = torch.jit.script(model) # run one round to make sure model runs scripted(x) FileCheck().check_count('FakeQuantize = prim::GetAttr[name="', 4, exactly=True) \ .run(scripted.graph) # disable fake_quant and observer for epoch in range(3): if epoch == 1: scripted.apply(torch.ao.quantization.disable_observer) if epoch == 2: scripted.apply(torch.ao.quantization.disable_fake_quant) # ensure the fake_quant and observer have been disabled. matches = ['.fake_quant_enabled', '.observer_enabled'] for key, v in scripted.state_dict().items(): if any(x in key for x in matches): self.assertEqual(v, torch.tensor([0], dtype=torch.int64)) # enable them back scripted.apply(torch.ao.quantization.enable_fake_quant) scripted.apply(torch.ao.quantization.enable_observer) for key, v in scripted.state_dict().items(): if any(x in key for x in matches): self.assertEqual(v, torch.tensor([1], dtype=torch.int64)) @skipIfNoFBGEMM def test_save_observer_state_dict(self): orig = LinearModelWithSubmodule().eval() model = orig qconfig_dict = {'': torch.ao.quantization.get_default_qconfig('fbgemm')} x = torch.randn(5, 5) model = prepare_fx(model, qconfig_dict, example_inputs=(x,)) # run it through input model(x) # save state_dict of model obs_dict = torch.ao.quantization.get_observer_state_dict(model) quant = convert_fx(model) b = io.BytesIO() torch.save(obs_dict, b) # Load the stats into new model for weights_only in [True, False]: b.seek(0) model_2 = orig model_2 = prepare_fx(model_2, qconfig_dict, example_inputs=(x,)) loaded_dict = torch.load(b, weights_only=weights_only) torch.ao.quantization.load_observer_state_dict(model_2, loaded_dict) quant_2 = convert_fx(model_2) # Verify that loaded state dict produces same results. self.assertEqual(quant(x), quant_2(x)) @skipIfNoFBGEMM def test_custom_module_class(self): class CustomModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(3, 3) def forward(self, x): return self.linear(x) class ObservedCustomModule(torch.nn.Module): def __init__(self, linear): super().__init__() self.linear = linear def forward(self, x): return self.linear(x) @classmethod def from_float(cls, float_module): assert hasattr(float_module, 'qconfig') observed = cls(float_module.linear) observed.qconfig = float_module.qconfig return observed class StaticQuantCustomModule(torch.nn.Module): def __init__(self, linear): super().__init__() self.linear = linear def forward(self, x): return self.linear(x) @classmethod def from_observed(cls, observed_module): assert hasattr(observed_module, 'qconfig') assert hasattr(observed_module, 'activation_post_process') observed_module.linear.activation_post_process = \ observed_module.activation_post_process quantized = cls(nnq.Linear.from_float(observed_module.linear)) return quantized class DynamicQuantCustomModule(torch.nn.Module): def __init__(self, linear): super().__init__() self.linear = linear def forward(self, x): return self.linear(x) @classmethod def from_observed(cls, observed_module): assert hasattr(observed_module, 'qconfig') observed_module.linear.qconfig = observed_module.qconfig quantized = cls(nnqd.Linear.from_float(observed_module.linear)) return quantized class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(3, 3) self.custom = CustomModule() def forward(self, x): x = self.linear(x) x = self.custom(x) return x class RefM(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear1 = torch.nn.Linear(3, 3) self.linear2 = torch.nn.Linear(3, 3) def forward(self, x): x = self.linear1(x) x = self.linear2(x) return x # instantiate M and RefM and align the parameters original_m = M().eval() original_ref_m = RefM().eval() original_ref_m.linear1.weight = torch.nn.Parameter(original_m.linear.weight.detach()) original_ref_m.linear1.bias = torch.nn.Parameter(original_m.linear.bias.detach()) original_ref_m.linear2.weight = torch.nn.Parameter(original_m.custom.linear.weight.detach()) original_ref_m.linear2.bias = torch.nn.Parameter(original_m.custom.linear.bias.detach()) a16_qconfig = QConfig( activation=MinMaxObserver.with_args(dtype=torch.qint32, quant_min=0, quant_max=65536), weight=default_weight_observer, ) test_configs = { "static": (default_qconfig, StaticQuantCustomModule, 3), "static_a16": (a16_qconfig, StaticQuantCustomModule, 3), "dynamic": (default_dynamic_qconfig, DynamicQuantCustomModule, 0) } for quant_type in [QuantType.STATIC, QuantType.DYNAMIC]: key = _get_quant_type_to_str(quant_type) qconfig, quantized_module_class, num_observers = test_configs[key] qconfig_dict = {"": qconfig} if key == "static": prepare_custom_config_dict = { "float_to_observed_custom_module_class": { "static": { CustomModule: ObservedCustomModule } } } convert_custom_config_dict = { "observed_to_quantized_custom_module_class": { "static": { ObservedCustomModule: quantized_module_class } } } else: prepare_custom_config_dict = { "non_traceable_module_class": [ CustomModule ] } convert_custom_config_dict = { "observed_to_quantized_custom_module_class": { "dynamic": { CustomModule: quantized_module_class } } } example_inputs = (torch.randn(3, 3),) # check prepared model m = prepare_fx( copy.deepcopy(original_m), qconfig_dict, example_inputs=example_inputs, prepare_custom_config=prepare_custom_config_dict) # calibration m(*example_inputs) # all activation observers are inserted in the top level module count_check = { ns.call_module(torch.ao.quantization.MinMaxObserver): num_observers } self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) # check converted/quantized model m = convert_fx( m, convert_custom_config=convert_custom_config_dict) if quant_type == QuantType.STATIC: count_check = { ns.call_function(torch.quantize_per_tensor) : 1, ns.call_module(nnq.Linear) : 1, ns.call_method('dequantize') : 1, } self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) self.assertEqual(type(m.custom), quantized_module_class) res = m(*example_inputs) # quantize the reference model ref_m = prepare_fx( copy.deepcopy(original_ref_m), qconfig_dict, example_inputs=example_inputs) ref_m(*example_inputs) ref_m = convert_fx(ref_m) ref_res = ref_m(*example_inputs) self.assertEqual(res, ref_res) @skipIfNoFBGEMM def test_custom_module_class_input_has_multiple_users(self): """ Tests that the flow still works when the input of custom module has multiple users """ class CustomModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(3, 3) def forward(self, x): return self.linear(x) class ObservedCustomModule(torch.nn.Module): def __init__(self, linear): super().__init__() self.linear = linear def forward(self, x): return self.linear(x) @classmethod def from_float(cls, float_module): assert hasattr(float_module, 'qconfig') observed = cls(float_module.linear) observed.qconfig = float_module.qconfig return observed class StaticQuantCustomModule(torch.nn.Module): def __init__(self, linear): super().__init__() self.linear = linear def forward(self, x): return self.linear(x) @classmethod def from_observed(cls, observed_module): assert hasattr(observed_module, 'qconfig') assert hasattr(observed_module, 'activation_post_process') observed_module.linear.activation_post_process = \ observed_module.activation_post_process quantized = cls(nnq.Linear.from_float(observed_module.linear)) return quantized class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(3, 3) self.custom = CustomModule() def forward(self, x0): x1 = self.custom(x0) x2 = self.linear(x0) return x1 + x2 prepare_custom_config_dict = { "float_to_observed_custom_module_class": { "static": { CustomModule: ObservedCustomModule } } } convert_custom_config_dict = { "observed_to_quantized_custom_module_class": { "static": { ObservedCustomModule: StaticQuantCustomModule } } } m = M().eval() example_inputs = (torch.randn(3, 3),) m = prepare_fx( m, {"": default_qconfig}, example_inputs=example_inputs, prepare_custom_config=prepare_custom_config_dict) # make sure it works m = convert_fx( m, convert_custom_config=convert_custom_config_dict) # make sure it runs m(*example_inputs) @skipIfNoFBGEMM def test_custom_module_class_input_has_duplicate_nodes(self): """ Tests that the flow still works when the graph has multiple nodes with the same custom module target. """ class CustomModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(3, 3) def forward(self, x): return self.linear(x) class ObservedCustomModule(torch.nn.Module): def __init__(self, linear): super().__init__() self.linear = linear def forward(self, x): return self.linear(x) @classmethod def from_float(cls, float_module): assert hasattr(float_module, 'qconfig') observed = cls(float_module.linear) observed.qconfig = float_module.qconfig return observed class StaticQuantCustomModule(torch.nn.Module): def __init__(self, linear): super().__init__() self.linear = linear def forward(self, x): return self.linear(x) @classmethod def from_observed(cls, observed_module): assert hasattr(observed_module, 'qconfig') assert hasattr(observed_module, 'activation_post_process') observed_module.linear.activation_post_process = \ observed_module.activation_post_process quantized = cls(nnq.Linear.from_float(observed_module.linear)) return quantized class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.custom = CustomModule() def forward(self, x0): x1 = self.custom(x0) x2 = self.custom(x0) return x1 + x2 prepare_custom_config_dict = { "float_to_observed_custom_module_class": { "static": { CustomModule: ObservedCustomModule } } } convert_custom_config_dict = { "observed_to_quantized_custom_module_class": { "static": { ObservedCustomModule: StaticQuantCustomModule } } } m = M().eval() example_inputs = (torch.randn(3, 3),) m = prepare_fx( m, {"": default_qconfig}, example_inputs=example_inputs, prepare_custom_config=prepare_custom_config_dict) # make sure it works m = convert_fx( m, convert_custom_config=convert_custom_config_dict) # make sure it runs m(*example_inputs) @skipIfNoFBGEMM def test_non_traceable_module(self): class NonTraceable(torch.nn.Module): def forward(self, x): for k in x.keys(): print(x[k]) return x class NonTraceable2(torch.nn.Module): def forward(self, x): # data dependent control flow is not traceable for i in x: print(i) return x class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.m1 = NonTraceable() self.m2 = NonTraceable2() def forward(self, x): x = self.m1(x) x = self.m2(x) return x m = M().eval() qconfig_dict = {"": default_qconfig} prepare_custom_config_dict = { "non_traceable_module_name": [ "m1" ], "non_traceable_module_class": [ NonTraceable2 ] } m = prepare_fx( m, qconfig_dict, example_inputs=({"key": torch.randn(1)},), prepare_custom_config=prepare_custom_config_dict) node_occurrence = { ns.call_module(NonTraceable) : 1, ns.call_module(NonTraceable2) : 1, } # make sure these modules are not traced self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) def test_prepared_model_deepcopy(self): """Ensures that copy.deepcopy works correctly on a prepared model. """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 1, 1) self._foobar = 'foobar' self.foobar2 = 'foobar2' def forward(self, x): x = self.conv(x) return x m = M() m.eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} example_inputs = (torch.randn(4, 1, 4, 4),) prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) # calibrate prepared(*example_inputs) # copy prepared_copy = copy.deepcopy(prepared) # quantize, should run with no errors quantized = convert_fx(prepared_copy) def test_quantized_model_type(self): """ Test state_dict and deepcopy works properly in the quantized model """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 5) def forward(self, x): return self.linear(x) example_inputs = (torch.rand(8, 5),) m = M().eval() m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) m = convert_fx(m) # test deepcopy m_copy = copy.deepcopy(m) self.assertEqual(m_copy(*example_inputs), m(*example_inputs)) # test state_dict state_dict = m.state_dict() m_new = M().eval() m_new = prepare_fx(m_new, {"": default_qconfig}, example_inputs=example_inputs) m_new = convert_fx(m_new) m_new.load_state_dict(state_dict) self.assertEqual(m_new(*example_inputs), m(*example_inputs)) def test_dequantize(self): r""" Test to make sure dequantize node are placed before non-quantizable node """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 1, 1) self.act = torch.nn.GELU() def forward(self, x): x = self.conv(x) return self.act(x) data = torch.rand(5, 1, 3, 3, dtype=torch.float) for quant_type in self.static_quant_types: node_list = [ ns.call_module(nnq.Conv2d), ns.call_method("dequantize"), ns.call_module(nn.GELU), ] self.checkGraphModeFxOp( M().eval(), (data,), quant_type, expected_node_list=node_list) def test_sequential(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.convs = torch.nn.Sequential( torch.nn.Conv2d(1, 1, 1), torch.nn.Conv2d(1, 1, 1) ) def forward(self, x): x = self.convs(x) return x data = torch.rand(5, 1, 3, 3, dtype=torch.float) for quant_type in self.static_quant_types: node_list = [ ns.call_module(nnq.Conv2d), ns.call_module(nnq.Conv2d), ] self.checkGraphModeFxOp( M().eval(), (data,), quant_type, expected_node_list=node_list) def _test_quantized_inputs_outputs( self, prepare_custom_config_dict, prepare_count_check, convert_count_check): """ Test the option to have inputs and outputs of the graph quantized """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = torch.nn.Conv2d(1, 1, 1) self.conv2 = torch.nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x # quantized input, quantized output m = M() qconfig_dict = {'': torch.ao.quantization.default_qconfig} example_inputs = (torch.randn(1, 1, 4, 4),) m.eval() mp = torch.ao.quantization.quantize_fx.prepare_fx( m, qconfig_dict, example_inputs=example_inputs, prepare_custom_config=prepare_custom_config_dict) self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check) mp(*example_inputs) mq = torch.ao.quantization.quantize_fx.convert_fx(mp) self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check) def test_quantized_input_quantized_output(self): prepare_custom_config_dict = { 'input_quantized_idxs': [0], 'output_quantized_idxs': [0]} prepare_count_check = { ns.call_module(torch.ao.quantization.MinMaxObserver): 2, } convert_count_check = { ns.call_function(torch.quantize_per_tensor): 0, ns.call_method('dequantize'): 0, } self._test_quantized_inputs_outputs( prepare_custom_config_dict, prepare_count_check, convert_count_check) def test_fp32_input_quantized_output(self): prepare_custom_config_dict = { 'output_quantized_idxs': [0]} prepare_count_check = { ns.call_module(torch.ao.quantization.MinMaxObserver): 3, } convert_count_check = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method('dequantize'): 0, } self._test_quantized_inputs_outputs( prepare_custom_config_dict, prepare_count_check, convert_count_check) def test_quantized_input_fp32_output(self): prepare_custom_config_dict = { 'input_quantized_idxs': [0]} prepare_count_check = { ns.call_module(torch.ao.quantization.MinMaxObserver): 2, } convert_count_check = { ns.call_function(torch.quantize_per_tensor): 0, ns.call_method('dequantize'): 1, } self._test_quantized_inputs_outputs( prepare_custom_config_dict, prepare_count_check, convert_count_check) def test_fp32_input_fp32_output(self): prepare_custom_config_dict = {} prepare_count_check = { ns.call_module(torch.ao.quantization.MinMaxObserver): 3, } convert_count_check = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method('dequantize'): 1, } self._test_quantized_inputs_outputs( prepare_custom_config_dict, prepare_count_check, convert_count_check) @skipIfNoFBGEMM def test_convtranspose_per_channel_fails_early(self): r""" Verifies that attempting to quantize a ConvTranspose module with per-Channel weight observers fails in the prepare step, as opposed to the convert step. """ m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1)) m.eval() qconfig_dict = {'': torch.ao.quantization.get_default_qconfig('fbgemm')} with self.assertRaises(AssertionError) as context: mp = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),)) self.assertTrue( str(context.exception) == 'Per channel weight observer is not supported yet for ConvTranspose{n}d.') @skipIfNoFBGEMM def test_qparams_buffers(self): class Linear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.w = torch.ones(5, 5) self.b = torch.zeros(5) def forward(self, x): return torch.nn.functional.linear(x, self.w, self.b) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mods1 = torch.nn.Sequential( Linear(), Linear() ) self.mods2 = Linear() def forward(self, x): x = self.mods1(x) x = self.mods2(x) return x model = M().eval() qconfig_dict = {"": default_qconfig} example_inputs = (torch.rand(5, 5),) m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) m(*example_inputs) m = convert_fx(m) keys = m.state_dict().keys() quant_scale_count = quant_zero_point = scale_count = zero_point_count = 0 for k in keys: if 'input_scale' in k: quant_scale_count = quant_scale_count + 1 elif 'input_zero_point' in k: quant_zero_point = quant_zero_point + 1 elif 'scale' in k: scale_count = scale_count + 1 elif 'zero_point' in k: zero_point_count = zero_point_count + 1 # Expect each quantized linear op to have a scale and zero point self.assertTrue(scale_count == 3, "Expect each quantized linear op to have a scale in state_dict") self.assertTrue(zero_point_count == 3, "Expect each quantized linear op to have a zero_point in state_dict") m(*example_inputs) # ensure it is scriptable scripted = torch.jit.script(m) scripted_keys = scripted.state_dict().keys() scripted.mods1_0_packed_weight_0 = m.state_dict()["mods1_0_packed_weight_0"] non_packed_weight_keys = [key for key in keys if "_packed_weight" not in key] self.assertTrue( set(scripted_keys) == set(non_packed_weight_keys), "Expected the scripted model to preserve the state_dict for non-packed weight attributes") # TODO: probably don't want to hardcode the attribute names, since they are generated for attr_name in [ "mods1_0_input_scale_0", "mods1_0_input_zero_point_0", "mods1_0_scale_1", "mods1_0_zero_point_1", "mods1_1_scale_1", "mods1_1_zero_point_1", "mods2_scale_1", "mods2_zero_point_1"]: self.assertTrue(hasattr(m, attr_name), attr_name + " not found.") @skipIfNoFBGEMM def test_packed_weight_fused_op(self): class Linear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.w = torch.ones(5, 5) self.b = torch.zeros(5) def forward(self, x): return F.linear(x, self.w, self.b) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mods1 = torch.nn.Sequential( Linear(), Linear() ) self.mods2 = Linear() self.relu = F.relu def forward(self, x): x = self.mods1(x) x = self.mods2(x) x = self.relu(x) return x model = M().eval() example_inputs = (torch.rand(5, 5),) qconfig_dict = {"": default_qconfig} m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) m(*example_inputs) m = convert_fx(m) assert hasattr(m, "mods1_0_packed_weight_0") assert hasattr(m, "mods1_1_packed_weight_0") assert hasattr(m, "mods2_packed_weight_0") @skipIfNoFBGEMM def test_mul_add_fp16_config(self): with override_quantized_engine('fbgemm'): class Linear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.w = torch.ones(5, 5) self.b = torch.zeros(5) def forward(self, x): return torch.nn.functional.linear(x, self.w, self.b) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mods1 = torch.nn.Sequential( Linear(), Linear() ) self.mods2 = Linear() def forward(self, x): x = x * 5 x = x + 5 x = self.mods1(x) x = self.mods2(x) return x model = M().eval() qconfig_dict = {"": float16_dynamic_qconfig} example_inputs = (torch.rand(5, 5),) m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) m = convert_fx(m) # make sure it runs m(*example_inputs) def test_getattr_with_nontensor_result(self): """ Verifies that binary ops get quantized correctly if some of the args are nodes but not Tensors, such as an `x.ndim` pattern. """ class M1(torch.nn.Module): def forward(self, x): dims = x.ndim dims_sub = dims - 1 dims_sub2 = dims_sub - 1 x = torch.add(x, dims_sub2) return x class M2(torch.nn.Module): def forward(self, x): dims = x.ndim dims_sub = dims - 2 mul = [1] * dims_sub dims_list = [-1, x.size(1)] + mul x = x.view(dims_list) return x class M3(torch.nn.Module): def forward(self, x): shape = x.shape x = x.view(shape) return x for cls in (M1, M2, M3): m = cls().eval() example_inputs = (torch.rand(4, 4, 4, 4),) m(*example_inputs) qconfig_dict = {'': torch.ao.quantization.default_qconfig} mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mp(torch.rand(4, 4, 4, 4)) mc = convert_fx(mp) class _NonReferenceTestModel(nn.Module): def __init__(self, func, lin_in, lin_out): super().__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.lin = nn.Linear(lin_in, lin_out) self.func = func def forward(self, x, y, z): x = self.pool(F.relu(self.conv1(x))) x = torch.flatten(x, 1) x = self.func(x, y, z) x = self.lin(x) return x # This function looks at the node specified by the NodeInfo in the key of # node_info_to_non_tensor_args and checks that the args at specified indices # are not observed (since they are non tensors). If the args at those indices # are a tuple/list (which do not show up as nodes) the function checks the # individual elements of the tuple/list recursively. def _check_not_observed(self, model, node_info_to_non_tensor_args): # this is a helper function (for easier recursion) that checks whether # arg_node is observed def _check_node_not_observed(model, arg_node, node): if isinstance(arg_node, (tuple, list)): for new_node in arg_node: _check_node_not_observed(model, new_node, node) elif arg_node.op == "call_module": self.assertTrue( not _is_activation_post_process(getattr(model, arg_node.target)), f"Arg: {arg_node} of node: {node} is observed but is not a float tensor", ) for node in model.graph.nodes: indices = node_info_to_non_tensor_args.get( NodeInfo(node.op, node.target), [] ) for index in indices: if index < len(node.args): arg_node = node.args[index] _check_node_not_observed(model, arg_node, node) # This test checks that the model gets prepared correct, doesn't have observers # on specific ops (see _check_not_observed) and that the prepared model runs def _test_dtype_propagation(self, model, node_info_to_non_tensor_args, *args): model.eval() qconfig_dict = {"": torch.ao.quantization.get_default_qconfig("fbgemm")} prepared_model = prepare_fx(model, qconfig_dict, example_inputs=tuple(args)) self._check_not_observed(prepared_model, node_info_to_non_tensor_args) prepared_model(*args) def test_masked_fill_nontensor_args_not_observed(self): def func(x, y, z): return x.masked_fill(y, z) model = self._NonReferenceTestModel(func, 1176, 1) args = [torch.randn(5, 3, 32, 32), torch.randn(1176) > 0, 0.1] node_info_to_non_tensor_args = {NodeInfo("call_method", "masked_fill"): [1, 2]} self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) def test_permute_nontensor_args_not_observed(self): def func(x, y, z): return x.permute(y, z) model = self._NonReferenceTestModel(func, 1176, 1) args = [torch.randn(5, 3, 32, 32), 0, 1] node_info_to_non_tensor_args = {NodeInfo("call_method", "permute"): [1, 2]} self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) def test_repeat_nontensor_args_not_observed(self): def func(x, y, z): return x.repeat(y, z) model = self._NonReferenceTestModel(func, 1176, 1) args = [torch.randn(5, 3, 32, 32), 2, 1] node_info_to_non_tensor_args = {NodeInfo("call_method", "repeat"): [1, 2]} self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) def test_reshape_nontensor_args_not_observed(self): def func(x, y, z): return x.reshape(-1, y) model = self._NonReferenceTestModel(func, 5, 1) args = [torch.randn(5, 3, 32, 32), 5, None] node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [2]} self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) def test_size_nontensor_args_not_observed(self): def func(x, y, z): return x.reshape((-1, x.size(y))) model = self._NonReferenceTestModel(func, 5, 1) args = [torch.randn(5, 3, 32, 32), 0, None] node_info_to_non_tensor_args = {NodeInfo("call_method", "size"): [1]} self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) def test_transpose_nontensor_args_not_observed(self): def func(x, y, z): return x.transpose(y, z) model = self._NonReferenceTestModel(func, 5, 1) args = [torch.randn(5, 3, 32, 32), 0, 1] node_info_to_non_tensor_args = {NodeInfo("call_method", "transpose"): [1, 2]} self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) def test_torch_transpose_nontensor_args_not_observed(self): # TODO: make torch.transpose traceable by fx when using # variable nontensor arguments # func = lambda x, y, z: torch.transpose(x, y, z) # error def func(x, y, z): return torch.transpose(x, 0, 1) model = self._NonReferenceTestModel(func, 5, 1) node_info_to_non_tensor_args = { NodeInfo("call_method", torch.transpose): [1, 2] } args = [torch.randn(5, 3, 32, 32), 0, 1] self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) def test_unsqueeze_nontensor_args_not_observed(self): def func(x, y, z): return x.unsqueeze(y) model = self._NonReferenceTestModel(func, 1176, 1) args = [torch.randn(5, 3, 32, 32), 1, None] node_info_to_non_tensor_args = {NodeInfo("call_method", "unsqueeze"): [1]} self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) def test_unsqueeze__nontensor_args_not_observed(self): def func(x, y, z): return x.unsqueeze_(y) model = self._NonReferenceTestModel(func, 1176, 1) args = [torch.randn(5, 3, 32, 32), 1, None] node_info_to_non_tensor_args = {NodeInfo("call_method", "unsqueeze_"): [1]} self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) def test_torch_unsqueeze_nontensor_args_not_observed(self): # TODO: make torch.unsqueeze scriptable by fx when using # variable nontensor arguments # func = lambda x, y, z: torch.unsqueeze(x, y) # error def func(x, y, z): return torch.unsqueeze(x, 1) model = self._NonReferenceTestModel(func, 1176, 1) args = [torch.randn(5, 3, 32, 32), 1, None] node_info_to_non_tensor_args = {NodeInfo("call_method", torch.unsqueeze): [1]} self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) def test_view_nontensor_args_not_observed(self): def func(x, y, z): return x.view(-1, y) model = self._NonReferenceTestModel(func, 5, 1) args = [torch.randn(5, 3, 32, 32), 5, None] node_info_to_non_tensor_args = {NodeInfo("call_method", "view"): [2]} self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) def test_propagate_dtypes_for_known_nodes_list_args(self): def func(x, y, z): return x.reshape(y) model = self._NonReferenceTestModel(func, 5, 1) args = [torch.randn(5, 3, 32, 32), [-1, 5], None] node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]} self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) def test_propagate_dtypes_for_known_nodes_split_list_args(self): def func(x, y, z): return x.reshape([y, z]) model = self._NonReferenceTestModel(func, 5, 1) args = [torch.randn(5, 3, 32, 32), -1, 5] node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]} self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) def test_propagate_dtypes_for_known_nodes_tuple_args(self): def func(x, y, z): return x.reshape(y) model = self._NonReferenceTestModel(func, 5, 1) args = [torch.randn(5, 3, 32, 32), (-1, 5), None] node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]} self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) def test_propagate_dtypes_for_known_nodes_split_tuple_args(self): def func(x, y, z): return x.reshape((y, z)) model = self._NonReferenceTestModel(func, 5, 1) args = [torch.randn(5, 3, 32, 32), -1, 5] node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]} self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) def test_propagate_dtypes_for_known_nodes_dict_args(self): def func(x, y, z): return x.transpose(y["first"], y["second"]) model = self._NonReferenceTestModel(func, 5, 1) args = [torch.randn(5, 3, 32, 32), {"first": 0, "second": 1}, None] node_info_to_non_tensor_args = {NodeInfo("call_method", "transpose"): [1, 2]} self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) def test_propagate_dtypes_for_known_nodes_dict_tuple_args(self): class reshape_module(nn.Module): def forward(self, x, y, z): return x.reshape(y["shape"]) model = self._NonReferenceTestModel(reshape_module(), 5, 1) args = [torch.randn(5, 3, 32, 32), {"shape": (-1, 5)}, None] node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]} self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) def test_propagate_dtypes_for_known_nodes_dict_split_tuple_args(self): def func(x, y, z): return x.reshape((y["first"], y["second"])) model = self._NonReferenceTestModel(func, 5, 1) args = [torch.randn(5, 3, 32, 32), {"first": -1, "second": 5}, None] node_info_to_non_tensor_args = {NodeInfo("call_method", "transpose"): [1]} self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) def test_assert_on_size_after_quant_layer(self): """ Verifies that calculating a size of a quantized tensor works correctly in quantization passes. """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv1(x) torch._assert(x.size(1) == 1, 'foobar') return x m = M().eval() example_inputs = (torch.rand(4, 1, 4, 4),) m(*example_inputs) qconfig_dict = {'': torch.ao.quantization.default_qconfig} mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mp(*example_inputs) mc = convert_fx(mp) mc(*example_inputs) def test_fp32_sum(self): """ Verifies that fp32 sum works correctly if it's before or after quantized layers. """ class M1(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv1(x) x = torch.stack([x]) x = torch.sum(x) return x class M2(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 1, 1) self.conv2 = nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv1(x) x1 = torch.stack([x]) x1 = torch.sum(x1, dim=0) x2 = self.conv2(x1) return x2 for cls in (M1, M2): m = cls().eval() example_inputs = (torch.rand(4, 1, 4, 4),) m(*example_inputs) qconfig_dict = {'': torch.ao.quantization.default_qconfig} mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mp(*example_inputs) mc = convert_fx(mp) mc(*example_inputs) def test_fusion_pattern_unquantized(self): """ Ensure that leaving a possible fusion pattern of multiple nodes unquantized runs through the APIs without errors. """ class Child(torch.nn.Module): def __init__(self) -> None: super().__init__() self.relu = nn.ReLU() def forward(self, x): x = torch.add(x, 1.0) x = torch.nn.functional.relu(x) return x class Parent(torch.nn.Module): def __init__(self) -> None: super().__init__() self.child = Child() self.conv = nn.Conv2d(1, 1, 1) def forward(self, x): x = self.child(x) x = self.conv(x) return x m = Parent().eval() qconfig_dict = { '': torch.ao.quantization.default_qconfig, 'module_name': [ ('child', None), ], } example_inputs = (torch.rand(1, 1, 1, 1),) mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mp(*example_inputs) mc = convert_fx(mp) def test_state_dict(self): """ Make sure packed params appear in state_dict """ # test linear packed weight class M1(torch.nn.Module): def __init__(self) -> None: super().__init__() self.w = torch.rand(4, 30) self.b = torch.rand(4) def forward(self, x): return F.linear(x, self.w, self.b) m = M1().eval() qconfig_dict = {"": default_qconfig} m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 30),)) m = convert_fx(m) state_dict = m.state_dict() self.assertTrue("_packed_weight_0" in state_dict) # test conv packed weight class M2(torch.nn.Module): def __init__(self) -> None: super().__init__() self.w = torch.rand(3, 3, 3, 3) self.b = torch.rand(3) self.stride = (1, 1) self.padding = (0, 0) self.dilation = (1, 1) self.groups = 1 def forward(self, x): return F.conv2d(x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups) m = M2().eval() qconfig_dict = {"": default_qconfig} m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) m = convert_fx(m) state_dict = m.state_dict() self.assertTrue("_packed_weight_0" in state_dict) # test load ref_weight, ref_bias = torch.ops.quantized.conv2d_unpack(state_dict["_packed_weight_0"]) data = torch.rand(1, 3, 5, 5) ref_res = m(data) m = M2().eval() m = prepare_fx(m, qconfig_dict, (data,)) m = convert_fx(m) res = m(data) weight, bias = m._packed_weight_0.unpack() # check that random model weight/bias does not match ref weight/bias self.assertNotEqual(weight, ref_weight) self.assertNotEqual(bias, ref_bias) self.assertNotEqual(res, ref_res) m.load_state_dict(state_dict) def checkModel(m, data, ref_weight, ref_bias, ref_res): res = m(data) weight, bias = m._packed_weight_0.unpack() # check that weight/bias matches after load the state_dict self.assertEqual(weight, ref_weight) self.assertEqual(bias, ref_bias) self.assertEqual(res, ref_res) checkModel(m, data, ref_weight, ref_bias, ref_res) # Test save to disk and load back m = M2().eval() m = prepare_fx(m, qconfig_dict, example_inputs=(data,)) m = convert_fx(m) m.load_state_dict(state_dict) with TemporaryFileName() as fname: torch.save(m.state_dict(), fname) # weights_only=False as this is loading a ScriptModule m.load_state_dict(torch.load(fname, weights_only=False)) checkModel(m, data, ref_weight, ref_bias, ref_res) @skipIfNoFBGEMM def test_preserve_qconfig(self): """ Test to make sure the temporary config option to preserve qconfig attributes in the model works """ with override_quantized_engine('fbgemm'): class Linear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.w = torch.ones(5, 5) self.b = torch.zeros(5) def forward(self, x): return torch.nn.functional.linear(x, self.w, self.b) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mods1 = torch.nn.Sequential( Linear(), Linear() ) self.mods2 = torch.nn.Sigmoid() def forward(self, x): x = self.mods1(x) x = self.mods2(x) return x model = M().eval() qconfig_dict = { "object_type": [ (torch.nn.functional.linear, float16_dynamic_qconfig), ], } example_inputs = (torch.rand(5, 5),) m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) m(*example_inputs) m = convert_fx(m, _remove_qconfig=False) self.assertTrue(hasattr(m.mods2, 'qconfig')) def test_not_used(self): """ Test quantizing a not used value""" class M(torch.nn.Module): def forward(self, x): x = x + x x.sigmoid_() return x m = M().eval() qconfig_mapping = get_default_qconfig_mapping().set_global(float16_static_qconfig) # make sure quantization runs m = prepare_fx(m, qconfig_mapping, example_inputs=(torch.randn(1),)) m = convert_fx(m) def test_qparams_fqn(self): """ Test that the FQN of input_scale/zero_point is set to that of first linear use. """ class Linear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.w = torch.ones(5, 5) self.b = torch.zeros(5) def forward(self, x): return torch.nn.functional.linear(x, self.w, self.b) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mods1 = torch.nn.Sequential( Linear(), Linear() ) def forward(self, x): x = torch.cat((x,), 1) tmp = x.size() x = self.mods1(x) y = x * tmp[0] return y model = M().eval() qconfig_dict = { "": None, "object_type": [ (torch.nn.functional.linear, default_qconfig), (torch.nn.functional.relu, default_qconfig), ], } example_inputs = (torch.rand(5, 5),) m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) m(*example_inputs) m = convert_fx(m) keys = m.state_dict().keys() m(torch.randn(5, 5)) # TODO: probably don't want to hardcode the attribute names, since they are generated for attr_name in [ "mods1_0_input_scale_0", "mods1_0_input_zero_point_0", "mods1_0_scale_0", "mods1_0_zero_point_0", "mods1_1_scale_0", "mods1_1_zero_point_0"]: self.assertTrue(hasattr(m, attr_name), attr_name + " not found.") def test_no_obs_between_unmatched_node_and_copy_node(self): """ Verifies that an observer is not inserted between an unmatched node and a node matched to CopyNodeQuantizeHandler. This is done because observers require activations to be Tensors, and there is no guarantee that an output of an unmatched node is a Tensor. """ class M(nn.Module): def __init__(self) -> None: super().__init__() self.relu = nn.ReLU() def forward(self, x): x = _user_func_with_complex_return_type(x) x1 = x[0] + 1 return x1, x[1] m = M().eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} example_inputs = (torch.randn(4, 4, 4, 4),) mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) # if an observer is inserted after _user_func_with_complex_return_type, # the following call will fail mp(*example_inputs) mc = convert_fx(mp) mc(*example_inputs) def test_fold_quant_dequant(self): """ Test that the sequence of quant-dequant nodes in the graph, get folded and we erase the extra dequant nodes. """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.w = torch.ones(5, 5) self.b = torch.zeros(5) def forward(self, x): x = torch.cat((x,), 1) tmp = x.size() x = torch.nn.functional.linear(x, self.w, self.b) y = x * tmp[0] return y model = M().eval() qconfig_dict = { "": None, "object_type": [ (torch.nn.functional.linear, default_qconfig), ], } example_inputs = (torch.rand(5, 5),) m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) m(*example_inputs) m = convert_fx(m) keys = m.state_dict().keys() m(*example_inputs) dequant = 0 quant = 0 for n in m.graph.nodes: if n.op == "call_method" and n.target == "dequantize": dequant = dequant + 1 if n.op == "call_function" and n.target == torch.quantize_per_tensor: quant = quant + 1 self.assertEqual(dequant, 1) self.assertEqual(quant, 1) def test_quant_output_always_observed(self): """ If the output is hardcoded to be quantized, ensure that there is always an observer, even if the last non-output node is not quantizeable. """ qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} prepare_custom_config_dict = {'output_quantized_idxs': [0]} example_inputs = (torch.randn(4, 1, 4, 4),) # non-quantizeable node, quantized output class M1(torch.nn.Module): def __init__(self) -> None: super().__init__() self.identity = torch.nn.Identity() def forward(self, x): x = self.identity(x) return x m1 = M1() self.checkGraphModeFxOp( m1, example_inputs, QuantType.QAT, prepare_expected_node_occurrence={ ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2, }, expected_node_occurrence={ ns.call_function(torch.quantize_per_tensor): 1, }, prepare_custom_config=prepare_custom_config_dict) # quantizeable node, quantized output class M2(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv(x) return x m2 = M2() self.checkGraphModeFxOp( m2, example_inputs, QuantType.QAT, prepare_expected_node_occurrence={ # one for weights, one for activations ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2, }, expected_node_occurrence={ ns.call_function(torch.quantize_per_tensor): 1, }, prepare_custom_config=prepare_custom_config_dict) # quantizeable node, quantized dictionary output class M3(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv(x) return {"output": x} m3 = M3() self.checkGraphModeFxOp( m3, example_inputs, QuantType.QAT, prepare_expected_node_occurrence={ # one for weights, one for activations ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2, }, expected_node_occurrence={ ns.call_function(torch.quantize_per_tensor): 1, }, prepare_custom_config=prepare_custom_config_dict) def test_deepcopy_preserve_attributes(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.attr = 3 def forward(self, x): return x m = M().eval() m = prepare_fx( m, {"": default_qconfig}, example_inputs=(torch.randn(1),), prepare_custom_config={"preserved_attributes": ["attr"]}) # preserved attributes are also stored in meta so that it doesn't get lost # during deepcopy self.assertTrue(hasattr(m, "attr")) self.assertTrue("attr" in m.meta[_USER_PRESERVED_ATTRIBUTES_KEY]) m2 = copy.deepcopy(m) self.assertTrue(hasattr(m2, "attr")) self.assertTrue("attr" in m2.meta[_USER_PRESERVED_ATTRIBUTES_KEY]) m = convert_fx(m, convert_custom_config={"preserved_attributes": ["attr"]}) self.assertTrue(hasattr(m, "attr")) self.assertTrue("attr" in m.meta[_USER_PRESERVED_ATTRIBUTES_KEY]) m2 = copy.deepcopy(m) self.assertTrue(hasattr(m2, "attr")) self.assertTrue("attr" in m2.meta[_USER_PRESERVED_ATTRIBUTES_KEY]) def test_output_lists_and_dicts(self): """Verify that specifying complicated output types does not crash. """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv(x) return {'foo': [x]}, [{'foo': [[x]]}] m = M().eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} mp = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),)) mc = convert_fx(mp) def test_shape_followed_by_quantized_op(self): """ Make sure that shape does not dequantize the Tensor before the next operator """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = torch.nn.Conv2d(2, 2, 2) self.conv2 = torch.nn.Conv2d(2, 2, 2) def forward(self, x): x = self.conv1(x) s = x.shape torch._assert(s == x.shape, "") x = self.conv2(x) return x # make sure quantization runs m = M().eval() example_inputs = (torch.randn(2, 2, 4, 4),) m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) m = convert_fx(m) m(*example_inputs) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method("dequantize"): 1 } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) def test_trace_quantize_per_tensor(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv(x) return x m = M().eval() m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.randn(1, 1, 3, 3),)) m = convert_fx(m) # Make sure this runs without error m = torch.fx.Transformer(m).transform() def test_copy_node_has_shared_actpp_instance(self): """ Test the output of CopyNode to have the same observer/fake_quant instance as the input """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.avgpool2d = torch.nn.AvgPool2d(kernel_size=3) def forward(self, x): x = self.avgpool2d(x) return x for quant_type in self.static_quant_types: m = M() # Checks that we have an observer for both input and output occurrence_map = { QuantType.STATIC: { ns.call_module(torch.ao.quantization.MinMaxObserver): 2 }, QuantType.QAT: { ns.call_module(torch.ao.quantization.FakeQuantize): 2 } } if quant_type == QuantType.QAT: m.train() prepare = prepare_qat_fx qconfig = default_qat_qconfig actpp_module_class = torch.ao.quantization.FakeQuantize else: m.eval() prepare = prepare_fx qconfig = default_qconfig actpp_module_class = torch.ao.quantization.MinMaxObserver example_inputs = (torch.randn(1, 3, 3, 3),) m = prepare(m, {"": qconfig}, example_inputs=example_inputs) # check that there is a duplicated observer instance actpp_module_count = 0 for name, module in m.named_modules(remove_duplicate=False): if isinstance(module, actpp_module_class): actpp_module_count += 1 self.assertEqual(actpp_module_count, 2) actpp_module_count = 0 for name, module in m.named_modules(): if isinstance(module, actpp_module_class): actpp_module_count += 1 self.assertEqual(actpp_module_count, 1) m_copy = copy.deepcopy(m) m = convert_fx(m) m_reference = convert_to_reference_fx(m_copy) # checks for non-reference quantized model node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method("dequantize"): 1 } node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(torch.nn.AvgPool2d), ns.call_method("dequantize"), ] self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence, expected_node_list=node_list) # checks for reference quantized model, for copy nodes we'll have # dequant - copy_node - quant patterns which will be fused later # in the backend lowering step node_occurrence = { ns.call_function(torch.quantize_per_tensor): 2, ns.call_method("dequantize"): 2 } node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_method("dequantize"), ns.call_module(torch.nn.AvgPool2d), ns.call_function(torch.quantize_per_tensor), ns.call_method("dequantize"), ] self.checkGraphModuleNodes(m_reference, expected_node_occurrence=node_occurrence, expected_node_list=node_list) def test_linear_qint8_activation(self): """Test support for qint8 activation in reference pattern """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 2, 2, 2) self.linear = torch.nn.Linear(8, 5) def forward(self, x): x = self.conv(x) x = torch.flatten(x, 1) x = self.linear(x) return x m = M().eval() example_inputs = (torch.rand(2, 1, 5, 5),) m = prepare_fx( m, {"": torch.ao.quantization.QConfig( activation=torch.ao.quantization.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 ), weight=torch.ao.quantization.default_per_channel_weight_observer)}, example_inputs=example_inputs) m = convert_to_reference_fx(m) m(*example_inputs) def test_preserve_tuple(self): """ Test tuple input type is preserved """ class LSTM(nn.Module): def __init__(self) -> None: super().__init__() self.lstm = nn.LSTM(50, 50, 1) def forward(self, inputs: torch.Tensor, state: list[torch.Tensor]): h = state[0] c = state[1] return self.lstm(inputs, (h, c)) m = LSTM().eval() example_inputs = (torch.randn(5, 3, 50), torch.randn(2, 3, 50), torch.randn(2, 3, 50)) m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) # make sure the arg[1] of lstm module is a tuple for n in m.graph.nodes: if n.target == "lstm": self.assertEqual(type(n.args[1]), tuple) def _test_static_lstm_helper(self, model, prepare_node_occurrence, convert_node_occurrence): """ Helper method to validate the graph of a model with static LSTM. """ qconfig_mapping = get_default_qconfig_mapping() prepare_custom_config = PrepareCustomConfig() \ .set_float_to_observed_mapping(torch.nn.LSTM, torch.ao.nn.quantizable.LSTM) convert_custom_config = ConvertCustomConfig() \ .set_observed_to_quantized_mapping(torch.ao.nn.quantizable.LSTM, torch.ao.nn.quantized.LSTM) example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50)) model = prepare_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=prepare_custom_config) self.checkGraphModuleNodes(model, expected_node_occurrence=prepare_node_occurrence) model(*example_inputs) model = convert_fx(model, convert_custom_config=convert_custom_config) self.checkGraphModuleNodes(model, expected_node_occurrence=convert_node_occurrence) model(*example_inputs) def test_static_lstm(self): """ Test statically quantized custom module LSTM followed by ops that consume individual tensors of the output tuple. """ class MyModel(nn.Module): def __init__(self) -> None: super().__init__() self.lstm = nn.LSTM(50, 50, 1) self.linear1 = nn.Linear(50, 10) self.linear2 = nn.Linear(50, 10) self.linear3 = nn.Linear(50, 10) def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor): (out, (h0_out, c0_out)) = self.lstm(inputs, (h0, c0)) out = self.linear1(out) h0_out = self.linear2(h0_out) c0_out = self.linear3(c0_out) return (out, (h0_out, c0_out)) m = MyModel() prepare_node_occurrence = { ns.call_module(torch.ao.nn.quantizable.LSTM): 1, } convert_node_occurrence = { ns.call_module(torch.ao.nn.quantized.LSTM): 1, ns.call_function(torch.quantize_per_tensor): 3, # lstm[0].dequantize() # lstm[1][0].dequantize() # lstm[1][1].dequantize() ns.call_method("dequantize"): 3, # lstm[0], lstm[1], lstm[1][0], lstm[1][1] ns.call_function(operator.getitem): 4, # No tuples are consumed ns.call_function(tuple): 0, } self._test_static_lstm_helper(m, prepare_node_occurrence, convert_node_occurrence) def test_static_lstm_consume_tuple(self): """ Test statically quantized custom module LSTM followed by a module that consumes the output tuple, either as a whole or part of it. """ class ModuleAfterLSTM(nn.Module): def __init__(self) -> None: super().__init__() self.identity = torch.nn.Identity() def forward(self, x): return self.identity(x) class ConsumeWholeTuple(nn.Module): def __init__(self) -> None: super().__init__() self.lstm = nn.LSTM(50, 50, 1) self.module_after_lstm = ModuleAfterLSTM() def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor): x = self.lstm(inputs, (h0, c0)) x = self.module_after_lstm(x) # consume tuple (output, (hidden0, hidden1)) return x class ConsumeHiddenTuple(ConsumeWholeTuple): def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor): x = self.lstm(inputs, (h0, c0)) x = self.module_after_lstm(x[1]) # consume tuple (hidden0, hidden1) return x # Test consuming the whole tuple (output, (hidden0, hidden1)) m1 = ConsumeWholeTuple() prepare_node_occurrence = { ns.call_module(torch.ao.nn.quantizable.LSTM): 1, } convert_node_occurrence1 = { ns.call_module(torch.ao.nn.quantized.LSTM): 1, ns.call_function(torch.quantize_per_tensor): 3, # lstm[0].dequantize() # lstm[1][0].dequantize() # lstm[1][1].dequantize() ns.call_method("dequantize"): 3, # lstm[0], lstm[1], lstm[1][0], lstm[1][1] ns.call_function(operator.getitem): 4, # tuple(output_dq, tuple(hidden0_dq, hidden1_dq)) ns.call_function(tuple): 2, } self._test_static_lstm_helper(m1, prepare_node_occurrence, convert_node_occurrence1) # Test consuming just the hidden tuple (hidden0, hidden1) m2 = ConsumeHiddenTuple() convert_node_occurrence2 = { ns.call_module(torch.ao.nn.quantized.LSTM): 1, ns.call_function(torch.quantize_per_tensor): 3, # lstm[1][0].dequantize() # lstm[1][1].dequantize() ns.call_method("dequantize"): 2, # lstm[1], lstm[1][0], lstm[1][1] ns.call_function(operator.getitem): 3, # tuple(hidden0_dq, hidden1_dq) ns.call_function(tuple): 1, } self._test_static_lstm_helper(m2, prepare_node_occurrence, convert_node_occurrence2) def test_static_lstm_with_custom_fixed_qparams(self): """ Test statically quantized LSTM with custom fixed qparams assigned to each of the inner submodules. This flow requires users to extend `torch.ao.nn.quantizable.LSTM` and use the child class in the custom module mapping. """ class MyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.my_lstm = torch.nn.LSTM(50, 50, 1) def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor): x = self.my_lstm(inputs, (h0, c0)) return x # Construct a BackendConfig that supports qint32 for certain ops # TODO: build a BackendConfig from scratch instead of modifying an existing one qint32_dtype_config = DTypeConfig(input_dtype=torch.qint32, output_dtype=torch.qint32) my_backend_config = get_qnnpack_backend_config() for config in my_backend_config.configs: if config.pattern in [torch.nn.Sigmoid, torch.nn.Tanh, torch.add, torch.mul]: config.add_dtype_config(qint32_dtype_config) class UserObservedLSTM(torch.ao.nn.quantizable.LSTM): """ Example of user provided LSTM implementation that assigns fixed qparams to the inner ops. """ @classmethod def from_float(cls, float_lstm): assert isinstance(float_lstm, cls._FLOAT_MODULE) # uint16, [-16, 16) linear_output_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=2 ** 15, dtype=torch.qint32) # uint16, [0, 1) sigmoid_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -16, zero_point=0, dtype=torch.qint32) # uint16, [-1, 1) tanh_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -15, zero_point=2 ** 15, dtype=torch.qint32) # int16, [-16, 16) cell_state_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=0, dtype=torch.qint32) # uint8, [-1, 1) hidden_state_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -7, zero_point=2 ** 7, dtype=torch.quint8) example_inputs = (torch.rand(5, 3, 50), (torch.rand(1, 3, 50), torch.randn(1, 3, 50))) return torch.ao.quantization.fx.lstm_utils._get_lstm_with_individually_observed_parts( float_lstm=float_lstm, example_inputs=example_inputs, backend_config=my_backend_config, linear_output_obs_ctr=linear_output_obs_ctr, sigmoid_obs_ctr=sigmoid_obs_ctr, tanh_obs_ctr=tanh_obs_ctr, cell_state_obs_ctr=cell_state_obs_ctr, hidden_state_obs_ctr=hidden_state_obs_ctr, ) class UserQuantizedLSTM(torch.ao.nn.quantized.LSTM): """ Example of user provided LSTM implementation that produces a reference quantized module from a `UserObservedLSTM`. """ @classmethod def from_observed(cls, observed_lstm): assert isinstance(observed_lstm, cls._FLOAT_MODULE) return torch.ao.quantization.fx.lstm_utils._get_reference_quantized_lstm_module( observed_lstm=observed_lstm, backend_config=my_backend_config, ) # FX graph mode quantization m = MyModel() qconfig_mapping = get_default_qconfig_mapping("qnnpack") example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50)) prepare_custom_config = PrepareCustomConfig() \ .set_float_to_observed_mapping(torch.nn.LSTM, UserObservedLSTM) convert_custom_config = ConvertCustomConfig() \ .set_observed_to_quantized_mapping(torch.ao.nn.quantizable.LSTM, UserQuantizedLSTM) prepared = prepare_fx( m, qconfig_mapping, example_inputs, prepare_custom_config, backend_config=my_backend_config, ) prepared(*example_inputs) converted = convert_fx( prepared, convert_custom_config, backend_config=my_backend_config, ) converted(*example_inputs) # Find the patterns [dq - op - q_to_specific_dtype] in the graph and # verify that qparams and dtypes are set correctly in the quantize ops node_name_to_expected_quantize_args = { "igates": (None, None, torch.quint8), "hgates": (None, None, torch.quint8), "add": (2 ** -11, 2 ** 15, torch.qint32), # gates.add "input_gate": (2 ** -16, 0, torch.qint32), "forget_gate": (2 ** -16, 0, torch.qint32), "cell_gate": (2 ** -15, 2 ** 15, torch.qint32), "output_gate": (2 ** -16, 0, torch.qint32), "mul": (2 ** -11, 0, torch.qint32), # fgate_cx.mul "mul_1": (2 ** -11, 0, torch.qint32), # igate_cgate.mul "add_1": (2 ** -11, 0, torch.qint32), # fgate_cx_igate_cgate.add "mul_2": (2 ** -7, 2 ** 7, torch.quint8), # ogate_cy.mul } cell = converted.my_lstm.layers.get_submodule("0").layer_fw.cell matched_names = set() for node in cell.graph.nodes: if node.name not in node_name_to_expected_quantize_args: continue matched_names.add(node.name) # Match preceding dequantize self.assertTrue(all(arg.target == "dequantize" for arg in node.args)) # Match following quantize with the specific qparams and dtypes expected_scale, expected_zp, expected_dtype = node_name_to_expected_quantize_args[node.name] for user in node.users.keys(): self.assertEqual(user.target, torch.quantize_per_tensor) if expected_scale is not None: self.assertEqual(getattr(cell, user.args[1].target), expected_scale) if expected_zp is not None: self.assertEqual(getattr(cell, user.args[2].target), expected_zp) self.assertEqual(user.args[-1], expected_dtype) # Ensure all patterns were matched self.assertEqual(matched_names, set(node_name_to_expected_quantize_args.keys())) def test_reroute_tuple_getitem_patterns(self): """ The following graph should redirect the output to `b`. After the transformation, all other nodes, including the inputs `a` and `c`, are no longer needed. a b c | \\ / \\ tuple \\ / tuple / \\ / \\ | \\ | \\ | \\ getitem0 getitem1 | / \\ | getitem0 getitem1 | \\ / \\ tuple \\ / \\ / tuple | getitem1 | getitem0 | output """ # Construct graph manually because symbolic_trace does not insert tuple and getitem nodes graph = torch.fx.Graph() a = graph.create_node("placeholder", "a") b = graph.create_node("placeholder", "b") c = graph.create_node("placeholder", "c") bc = graph.call_function(tuple, args=([b, c],)) abc = graph.call_function(tuple, args=([a, bc],)) # Break down tuple and reconstruct it again a2 = graph.call_function(operator.getitem, args=(abc, 0)) bc2 = graph.call_function(operator.getitem, args=(abc, 1)) b2 = graph.call_function(operator.getitem, args=(bc2, 0)) c2 = graph.call_function(operator.getitem, args=(bc2, 1)) bc3 = graph.call_function(tuple, args=([b2, c2],)) abc2 = graph.call_function(tuple, args=([a2, bc3],)) # Output tuple[1][0] bc4 = graph.call_function(operator.getitem, args=(abc2, 1)) b3 = graph.call_function(operator.getitem, args=(bc4, 0)) output = graph.output(b3) # Do reroute _reroute_tuple_getitem_pattern(graph) # Assert that output reroutes to `b` directly, and all other nodes can be removed output_ancestors = [] def gather_ancestors(current_node): # noqa: E306 for arg in current_node.args: output_ancestors.append(arg) gather_ancestors(arg) gather_ancestors(output) self.assertEqual(output_ancestors, [b]) self.assertEqual(output.args[0], b) def test_relu_lowering(self): class M(torch.nn.Module): def forward(self, x): return torch.nn.functional.relu(x) m = M().eval() m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.randn(1),)) m_copy = copy.deepcopy(m) m = convert_fx(m) m_ref = convert_to_reference_fx(m_copy) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method("dequantize"): 1 } node_occurrence_ref = { ns.call_function(torch.quantize_per_tensor): 2, ns.call_method("dequantize"): 2 } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref) @skipIfNoFBGEMM def test_dynamic_with_fusion(self): """ Tests that dynamic quantization APIs work with Linear + Relu fusion """ with override_quantized_engine('fbgemm'): class LinearRelu(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 5) self.relu = torch.nn.ReLU() def forward(self, x): x = self.linear(x) return self.relu(x) class Linear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.w = torch.ones(5, 5) self.b = torch.zeros(5) def forward(self, x): return torch.nn.functional.linear(x, self.w, self.b) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mods1 = torch.nn.Sequential(LinearRelu(), LinearRelu()) self.mods2 = Linear() self.relu = F.relu def forward(self, x): x = self.mods1(x) x = self.mods2(x) x = self.relu(x) return x dynamic_quantized_ops = { float16_dynamic_qconfig: torch.ops.quantized.linear_relu_dynamic_fp16, default_dynamic_qconfig: torch.ops.quantized.linear_relu_dynamic } for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]: model = M().eval() qconfig_dict = { "": qconfig } example_inputs = (torch.rand(5, 5),) m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) m = convert_fx(m) m(*example_inputs) node_list = [ ns.call_module(nniqd.LinearReLU), ns.call_module(nniqd.LinearReLU), ns.call_function(dynamic_quantized_ops[qconfig]), ] self.checkGraphModuleNodes(m, expected_node_list=node_list) @skipIfNoFBGEMM def test_dynamic_with_fusion_multiple_uses(self): """ Tests that dynamic quantization APIs work with Linear + Relu fusion """ with override_quantized_engine('fbgemm'): class LinearRelu(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 5) self.relu = torch.nn.ReLU() def forward(self, x): x = self.linear(x) return self.relu(x) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear_relu = LinearRelu() def forward(self, x): x = self.linear_relu(x) x = self.linear_relu(x) return x for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]: model = M().eval() qconfig_dict = { "": qconfig } example_inputs = (torch.randn(5, 5),) m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) m = convert_fx(m) m(*example_inputs) node_list = [ ns.call_module(nniqd.LinearReLU), ns.call_module(nniqd.LinearReLU), ] self.checkGraphModuleNodes(m, expected_node_list=node_list) @skipIfNoFBGEMM def test_dynamic_linear_input_multiple_use(self): """ Tests input for dynamic linear being used by multiple ops """ with override_quantized_engine('fbgemm'): class LinearRelu(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 5) self.relu = torch.nn.ReLU() def forward(self, x): x = self.linear(x) return self.relu(x) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mod1 = LinearRelu() self.mod2 = LinearRelu() def forward(self, x): y1 = self.mod1(x) y2 = self.mod2(x) return y1 + y2 for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]: model = M().eval() qconfig_dict = { "": qconfig } example_inputs = (torch.rand(5, 5, 5),) m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) m = convert_fx(m) m(*example_inputs) node_list = [ ns.call_module(nniqd.LinearReLU), ns.call_module(nniqd.LinearReLU), ] self.checkGraphModuleNodes(m, expected_node_list=node_list) def test_ref_linear_module(self): """ Make sure the numerics for models with ref linear module matches models with fbgemm/qnnpack module """ class M1(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(10, 5) def forward(self, x): return self.linear(x) class M2(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(10, 5) self.relu = torch.nn.ReLU() def forward(self, x): return self.relu(self.linear(x)) for M in [M1, M2]: m = M().eval() example_inputs = (torch.randn(5, 10),) m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) m_copy = copy.deepcopy(m) m = convert_fx(m) m_ref = convert_to_reference_fx(m_copy) result = m(*example_inputs) result_ref = m_ref(*example_inputs) self.assertTrue(torch.equal(result, result_ref)) def test_ref_conv_module(self): """ Make sure the numerics for models with ref conv module matches models with fbgemm/qnnpack module """ convs = { 1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d, } class M1(torch.nn.Module): def __init__(self, dim): super().__init__() self.conv = convs[dim](3, 3, 3) def forward(self, x): return self.conv(x) class M2(torch.nn.Module): def __init__(self, dim): super().__init__() self.conv = convs[dim](3, 3, 3) self.relu = torch.nn.ReLU() def forward(self, x): return self.relu(self.conv(x)) for dim, M in itertools.product([1, 2, 3], [M1, M2]): m = M(dim).eval() data = self.img_data_dict[dim][0][0] m = prepare_fx(m, {"": default_qconfig}, example_inputs=(data,)) m_copy = copy.deepcopy(m) m = convert_fx(m) m_ref = convert_to_reference_fx(m_copy) result = m(data) result_ref = m_ref(data) self.assertTrue(torch.equal(result, result_ref)) def test_sub_scalar(self): class M(torch.nn.Module): def forward(self, x): x = x + 1 x = x - 1 x = x + 3 x = x - 4 return x m = M().eval() m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.rand(3),)) m = convert_fx(m) occurrence = { ns.call_function(torch.quantize_per_tensor): 2, ns.call_method("dequantize"): 2 } self.checkGraphModuleNodes(m, expected_node_occurrence=occurrence) def test_observer_fqn(self): """ Test to make sure the observer FQN is based on the quantizable op/module that it is observing and uses the modules FQN to determine the observer name. """ class Linear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.w = torch.ones(5, 5) self.b = torch.zeros(5) def forward(self, x): return torch.nn.functional.linear(x, self.w, self.b) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mods1 = torch.nn.Sequential( Linear(), Linear() ) self.mods2 = Linear() self.mods3 = torch.nn.Linear(5, 5) def forward(self, x): x = self.mods1(x) x = torch.add(x, 4) x = self.mods2(x) y = torch.add(x, 2) z = torch.mul(x, 5) a = self.mods3(y) return a, z model = M().eval() prepared = prepare_fx(model, {"": default_qconfig}, example_inputs=(torch.randn(1, 5))) name_list = [] for name, mod in prepared.named_modules(): if isinstance(mod, torch.ao.quantization.observer.MinMaxObserver): name_list.append(name) expected_name_list = ['activation_post_process_0', 'activation_post_process_1', 'activation_post_process_2', 'activation_post_process_3', 'activation_post_process_4', 'activation_post_process_6', 'activation_post_process_7', 'activation_post_process_10'] assert name_list == expected_name_list def test_conv_lowering(self): convs = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d} qconvs = {1: nn.quantized.Conv1d, 2: nn.quantized.Conv2d, 3: nn.quantized.Conv3d} class M(torch.nn.Module): def __init__(self, dim): super().__init__() self.conv = convs[dim](3, 3, 3) def forward(self, x): return self.conv(x) for dim in range(1, len(convs) + 1): m = M(dim).eval() data = self.img_data_dict[dim][0][0] m = prepare_fx(m, {"": default_qconfig}, example_inputs=(data,)) m_ref = copy.deepcopy(m) m_ref = convert_to_reference_fx(m_ref) m = convert_fx(m) out_ref = m_ref(data) out = m(data) # check that reference pattern for quantized conv module is fused expected_node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_module(qconvs[dim]): 1, ns.call_method("dequantize"): 1 } self.checkGraphModuleNodes(m, expected_node_occurrence=expected_node_occurrence) # checking result match self.assertTrue(torch.equal(out_ref, out)) def test_convert_qconfig_mapping(self): class Linear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.w = torch.ones(5, 5) self.b = torch.zeros(5) def forward(self, x): return torch.nn.functional.linear(x, self.w, self.b) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mods1 = torch.nn.Sequential( Linear(), Linear() ) self.mods3 = torch.nn.Linear(5, 5) def forward(self, x): x = self.mods1(x) x = torch.add(x, 4) z = torch.mul(x, 5) x = self.mods3(z) return x model = M().train() for check in ["module_name", "object_type"]: qconfig_dict = {"": None, "object_type": [ (nn.functional.linear, get_default_qat_qconfig("fbgemm")), (torch.add, get_default_qat_qconfig("fbgemm")), (nn.Linear, get_default_qat_qconfig("fbgemm")), ], } example_inputs = (torch.rand(5, 5),) prepared = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) prepared(*example_inputs) if check == "module_name": convert_qconfig_dict = {"": None, "object_type": [ (nn.functional.linear, get_default_qat_qconfig("fbgemm")), (torch.add, get_default_qat_qconfig("fbgemm")), (nn.Linear, get_default_qat_qconfig("fbgemm")), ], "module_name": [("mods1.0", None)]} node_occurrence = { ns.call_function(torch.quantize_per_tensor): 2, ns.call_function(torch.nn.functional.linear): 1, ns.call_function(torch.ops.quantized.linear): 1, ns.call_function(torch.ops.quantized.add): 1, ns.call_method("dequantize"): 2 } order_check = [ ns.call_function(torch.nn.functional.linear), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_function(torch.ops.quantized.add), ns.call_method("dequantize"), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method("dequantize"), ] elif check == "object_type": convert_qconfig_dict = {"": None, "object_type": [ (nn.functional.linear, get_default_qat_qconfig("fbgemm")), (torch.add, get_default_qat_qconfig("fbgemm")), (nn.Linear, None), ]} node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_function(torch.ops.quantized.linear): 2, ns.call_function(torch.ops.quantized.add): 1, ns.call_function(torch.mul): 1, ns.call_method("dequantize"): 1 } order_check = [ ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_function(torch.ops.quantized.linear), ns.call_function(torch.ops.quantized.add), ns.call_method("dequantize"), ns.call_function(torch.mul), ns.call_module(nn.Linear), ] converted = convert_fx(prepared, qconfig_mapping=convert_qconfig_dict) converted(torch.rand(5, 5)) self.checkGraphModuleNodes( converted, expected_node_occurrence=node_occurrence, expected_node_list=order_check) def _assertFixedQParamsFakeQuantizeEqual(self, fq1, fq2): self.assertEqual(fq1()._observer_ctr, fq2()._observer_ctr) def test_register_patterns(self): def cleanUp(): del _DEFAULT_FUSION_PATTERNS["dummy_fusion"] del _DEFAULT_QUANTIZATION_PATTERNS["dummy_quant"] del _DEFAULT_QUANTIZATION_PATTERNS["dummy_quant2"] del _DEFAULT_QUANTIZATION_PATTERNS["dummy_quant3"] del _DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant2"] del _DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant3"] del _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant2"] del _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant3"] self.addCleanup(cleanUp) @_register_fusion_pattern("dummy_fusion") class DummyFusion: pass @_register_quant_pattern("dummy_quant") class DummyQuant: pass @_register_quant_pattern("dummy_quant2", default_fixed_qparams_range_0to1_observer) class DummyQuant2: pass @_register_quant_pattern("dummy_quant3", default_fixed_qparams_range_neg1to1_observer) class DummyQuant3: pass self.assertEqual(_DEFAULT_FUSION_PATTERNS["dummy_fusion"], DummyFusion) self.assertEqual(_DEFAULT_QUANTIZATION_PATTERNS["dummy_quant"], DummyQuant) self.assertEqual(_DEFAULT_QUANTIZATION_PATTERNS["dummy_quant2"], DummyQuant2) self.assertEqual(_DEFAULT_QUANTIZATION_PATTERNS["dummy_quant3"], DummyQuant3) self.assertEqual(_DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant2"], default_fixed_qparams_range_0to1_observer) self.assertEqual(_DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant3"], default_fixed_qparams_range_neg1to1_observer) self._assertFixedQParamsFakeQuantizeEqual(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant2"], default_fixed_qparams_range_0to1_fake_quant) self._assertFixedQParamsFakeQuantizeEqual(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant3"], default_fixed_qparams_range_neg1to1_fake_quant) output_fake_quantize_map = get_default_output_activation_post_process_map(is_training=True) output_observer_map = get_default_output_activation_post_process_map(is_training=False) self.assertEqual(output_observer_map.get("dummy_quant3"), default_fixed_qparams_range_neg1to1_observer) self._assertFixedQParamsFakeQuantizeEqual(output_fake_quantize_map.get("dummy_quant3"), default_fixed_qparams_range_neg1to1_fake_quant) def test_reuse_input_qconfig(self): class M1(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) def forward(self, x): x = self.conv(x) x = x.reshape() return x class M2(torch.nn.Module): def forward(self, x): x = x.reshape() return x options = itertools.product([M1, M2], [True, False]) for M, is_qat in options: m = M1().eval() example_inputs = (torch.randn(1, 3, 3, 3),) m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs) m = convert_fx(m) node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_method("reshape"), ns.call_method("dequantize"), ] self.checkGraphModuleNodes( m, expected_node_list=node_list) m = M2().eval() m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs) m = convert_fx(m) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 0, ns.call_method("dequnatize"): 0, } node_list = [ ns.call_method("reshape"), ] self.checkGraphModuleNodes( m, expected_node_occurrence=node_occurrence, expected_node_list=node_list) def test_stack_trace_preserved_linear(self): class M(nn.Module): def __init__(self) -> None: super().__init__() self.linear = nn.Linear(1, 1) def forward(self, x): x = self.linear(x) return x m = M().eval() mp = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=(torch.randn(1, 1),)) found_stack_trace = False for n in mp.graph.nodes: if n.op == 'call_module' and n.target == 'linear': found_stack_trace = n.stack_trace is not None break self.assertTrue(found_stack_trace) # test reference model mq = convert_to_reference_fx(copy.deepcopy(mp)) found_stack_trace = False for n in mq.graph.nodes: if n.op == 'call_module' and n.target == 'linear': found_stack_trace = n.stack_trace is not None break self.assertTrue(found_stack_trace, f"stack trace not found, node: {n.format_node()}, is_reference: True") # test quantized model mq = convert_fx(mp) found_stack_trace = False for n in mq.graph.nodes: if n.op == 'call_module' and n.target == 'linear': found_stack_trace = n.stack_trace is not None break self.assertTrue(found_stack_trace, f"stack trace not found, node: {n.format_node()}, is_reference: False") def test_qat_skip_untraced(self): class UnTraceableModuleClass(nn.Module): def __init__(self) -> None: super().__init__() self.linear = nn.Linear(2, 2) def forward(self, x): return self.linear(x) class UnTraceableModuleName(nn.Module): def __init__(self) -> None: super().__init__() self.linear = nn.Linear(2, 2) def forward(self, x): return self.linear(x) class M(nn.Module): def __init__(self) -> None: super().__init__() self.untraceable_module_class = UnTraceableModuleClass() self.untraceable_module_name = UnTraceableModuleClass() def forward(self, x): x = self.untraceable_module_class(x) x = self.untraceable_module_name(x) return x mod = M() qconfig_dict = {"": torch.ao.quantization.get_default_qat_qconfig()} prepare_custom_config_dict = { "non_traceable_module_class": [UnTraceableModuleClass], "non_traceable_module_name": ["untraceable_module_name"], } example_inputs = (torch.randn(2, 2),) mod_prep = torch.ao.quantization.quantize_fx.prepare_qat_fx( mod.train(), qconfig_dict, example_inputs=example_inputs, prepare_custom_config=prepare_custom_config_dict ) mod_prep = torch.ao.quantization.quantize_fx.prepare_qat_fx( mod.train(), qconfig_dict, example_inputs=example_inputs, prepare_custom_config=prepare_custom_config_dict ) self.assertTrue( isinstance(mod_prep.untraceable_module_class.linear, torch.nn.Linear) ) self.assertTrue( isinstance(mod_prep.untraceable_module_name.linear, torch.nn.Linear) ) self.assertTrue( type(mod_prep.untraceable_module_class.linear) is not torch.ao.nn.qat.modules.linear.Linear, "prepare_qat_fx shold not convert anything inside untraced module classes", ) self.assertTrue( type(mod_prep.untraceable_module_name.linear) is not torch.ao.nn.qat.modules.linear.Linear, "prepare_qat_fx shold not convert anything inside modules named in untraced_module_names", ) def test_qconfig_dict_setup(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.Conv1d = torch.nn.Conv1d(1, 1, 1) self.Conv2d = torch.nn.Conv2d(1, 1, 1) self.Conv3d = torch.nn.Conv3d(1, 1, 1) self.ConvTranspose1d = torch.nn.ConvTranspose1d(1, 1, 1) self.ConvTranspose2d = torch.nn.ConvTranspose2d(1, 1, 1) self.ConvTranspose3d = torch.nn.ConvTranspose3d(1, 1, 1) self.Linear = torch.nn.Linear(1, 1, 1) def forward(self, x): x = self.Conv1d(x) x = self.Conv2d(x) x = self.Conv3d(x) x = self.ConvTranspose1d(x) x = self.ConvTranspose2d(x) x = self.ConvTranspose3d(x) x = self.Linear(x) x = torch.nn.functional.conv1d(x, torch.rand(2, 2)) x = torch.nn.functional.conv2d(x, torch.rand(2, 2)) x = torch.nn.functional.conv3d(x, torch.rand(2, 2)) x = torch.nn.functional.linear(x, torch.rand(2, 2)) return x backends = ["qnnpack", "fbgemm"] for func in [get_default_qconfig_mapping, get_default_qat_qconfig_mapping]: for backend in backends: m = M().eval() qconfig_dict = func(backend) m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1))) for name, mod in m.named_modules(): if _is_activation_post_process(mod) and mod.dtype == torch.quint8: if backend == "fbgemm": lower_bnd = 0 upper_bnd = 127 else: lower_bnd = 0 upper_bnd = 255 if issubclass(type(mod), FakeQuantize): self.assertEqual(mod.activation_post_process.quant_min, lower_bnd) self.assertEqual(mod.activation_post_process.quant_max, upper_bnd) else: self.assertEqual(mod.quant_min, lower_bnd) self.assertEqual(mod.quant_max, upper_bnd) def test_prepare_mode(self): class LinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 10) def forward(self, x): return self.linear(x) def _test(prepare_fn, qconfig_dict): m = LinearModel() m1 = copy.deepcopy(m) m1.train() example_inputs = (torch.randn(1, 5),) prepare_fn(m1, qconfig_dict, example_inputs=example_inputs) m2 = copy.deepcopy(m) m2.eval() prepare_fn(m2, qconfig_dict, example_inputs=example_inputs) # Ensure prepare_fx and prepare_qat_fx work in both training and eval modes _test(prepare_fx, get_default_qconfig_mapping()) _test(prepare_qat_fx, get_default_qat_qconfig_mapping()) def _validate_qconfig_against_backend_config_constraints( self, model: torch.nn.Module, qconfig: QConfig, backend_config: BackendConfig, satisfies_constraints: bool, qconfig_name: Optional[str] = None): """ Helper method to validate whether `qconfig` satisfies the constraints specified in `backend_config`. """ qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig) example_inputs = (torch.rand((1, 30), dtype=torch.float),) model = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config) model(*example_inputs) model = convert_fx(model, backend_config=backend_config) if satisfies_constraints: expected_node_occurrence = { ns.call_module(torch.ao.nn.quantized.Linear) : 1, ns.call_module(torch.nn.Linear) : 0, } else: expected_node_occurrence = { ns.call_module(torch.ao.nn.quantized.Linear) : 0, ns.call_module(torch.nn.Linear) : 1, } try: self.checkGraphModuleNodes(model, expected_node_occurrence=expected_node_occurrence) except AssertionError as e: if qconfig_name is not None: print(f"ERROR: Validation for QConfig '{qconfig_name}' failed") raise e def test_backend_config_quantization_range(self): """ Check that quantization ranges specified through the BackendConfig are reflected in the observers inserted into the model. """ class MyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(30, 4).float() def forward(self, x): return self.linear(x) dtype_config = DTypeConfig( input_dtype=DTypeWithConstraints( dtype=torch.quint8, quant_min_lower_bound=0, quant_max_upper_bound=31, ), output_dtype=DTypeWithConstraints( dtype=torch.quint8, quant_min_lower_bound=0, quant_max_upper_bound=31, ), weight_dtype=DTypeWithConstraints( dtype=torch.qint8, quant_min_lower_bound=-64, quant_max_upper_bound=63, ), bias_dtype=torch.float, ) backend_config = BackendConfig() \ .set_backend_pattern_config(BackendPatternConfig(torch.nn.Linear) .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E128 .add_dtype_config(dtype_config) .set_root_module(torch.nn.Linear) .set_reference_quantized_module(nnqr.Linear)) def validate_qconfig(qconfig: QConfig, satisfies_constraints: bool): self._validate_qconfig_against_backend_config_constraints( MyModel(), qconfig, backend_config, satisfies_constraints) # Case 1: QConfig ranges fit within backend ranges, OK qconfig1 = QConfig( activation=MinMaxObserver.with_args(quant_min=0, quant_max=15, dtype=torch.quint8), weight=MinMaxObserver.with_args(quant_min=-32, quant_max=31, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)) validate_qconfig(qconfig1, satisfies_constraints=True) # Case 2: QConfig activation range falls outside backend range, should fail qconfig2 = QConfig( activation=MinMaxObserver.with_args(quant_min=0, quant_max=63, dtype=torch.quint8), weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)) validate_qconfig(qconfig2, satisfies_constraints=False) # Case 3: QConfig weight range falls outside backend range, should fail qconfig3 = QConfig( activation=MinMaxObserver.with_args(dtype=torch.quint8), weight=MinMaxObserver.with_args(quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)) validate_qconfig(qconfig3, satisfies_constraints=False) # Case 4: QConfig doesn't specify range, should fail qconfig4 = QConfig(activation=ReuseInputObserver, weight=ReuseInputObserver) validate_qconfig(qconfig4, satisfies_constraints=False) def test_backend_config_scale_min(self): """ Test QConfig eps validation against the BackendConfig's min scale value. """ class MyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(30, 4).float() def forward(self, x): return self.linear(x) dtype_config = DTypeConfig( input_dtype=DTypeWithConstraints(dtype=torch.quint8, scale_min_lower_bound=2 ** -12), output_dtype=DTypeWithConstraints(dtype=torch.quint8, scale_min_lower_bound=2 ** -12), weight_dtype=DTypeWithConstraints(dtype=torch.qint8, scale_min_lower_bound=2 ** -12), bias_dtype=torch.float, ) backend_config = BackendConfig() \ .set_backend_pattern_config(BackendPatternConfig(torch.nn.Linear) .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E128 .add_dtype_config(dtype_config) .set_root_module(torch.nn.Linear) .set_reference_quantized_module(nnqr.Linear)) def validate_qconfig(qconfig: QConfig, satisfies_constraints: bool): self._validate_qconfig_against_backend_config_constraints( MyModel(), qconfig, backend_config, satisfies_constraints) # Case 1: QConfig min scale value == backend min scale value, OK qconfig1 = QConfig( activation=MinMaxObserver.with_args(dtype=torch.quint8, eps=2 ** -12), weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, eps=2 ** -12)) validate_qconfig(qconfig1, satisfies_constraints=True) # Case 2: QConfig min scale value > backend min scale value, OK qconfig2 = QConfig( activation=MinMaxObserver.with_args(dtype=torch.quint8, eps=2 ** -10), weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, eps=2 ** -10)) validate_qconfig(qconfig2, satisfies_constraints=True) # Case 3: QConfig activation min scale value < backend min scale value, should fail qconfig3 = QConfig( activation=MinMaxObserver.with_args(dtype=torch.quint8, eps=2 ** -14), weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)) validate_qconfig(qconfig3, satisfies_constraints=False) # Case 3: QConfig weight min scale value < backend min scale value, should fail qconfig4 = QConfig( activation=MinMaxObserver.with_args(dtype=torch.quint8), weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, eps=2 ** -14)) validate_qconfig(qconfig4, satisfies_constraints=False) # Case 5: QConfig doesn't specify eps, should fail qconfig5 = QConfig( activation=FixedQParamsObserver.with_args(scale=1.0, zero_point=0), weight=FixedQParamsObserver.with_args(scale=1.0, zero_point=0)) validate_qconfig(qconfig5, satisfies_constraints=False) def test_qnnpack_backend_config(self): """ Test whether default QNNPACK QConfigs are compatible with the QNNPACK BackendConfig. """ class MyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(30, 4).float() def forward(self, x): return self.linear(x) all_qconfigs: list[tuple[QConfig, str]] = [ (get_default_qconfig("qnnpack", version=0), "default_qnnpack_qconfig_v0"), (get_default_qat_qconfig("qnnpack", version=0), "default_qat_qnnpack_qconfig_v0"), (get_default_qat_qconfig("qnnpack", version=1), "default_qat_qnnpack_qconfig_v1"), (default_symmetric_qnnpack_qconfig, "default_symmetric_qnnpack_qconfig"), (default_symmetric_qnnpack_qat_qconfig, "default_symmetric_qnnpack_qat_qconfig"), # TODO: Test these QConfigs once they are fixed, see https://github.com/pytorch/pytorch/issues/85862 # (default_per_channel_symmetric_qnnpack_qconfig, "default_per_channel_symmetric_qnnpack_qconfig"), # (default_per_channel_symmetric_qnnpack_qat_qconfig, "default_per_channel_symmetric_qnnpack_qat_qconfig"), ] backend_config = get_qnnpack_backend_config() for qconfig, qconfig_name in all_qconfigs: self._validate_qconfig_against_backend_config_constraints( MyModel(), qconfig, backend_config, satisfies_constraints=True, qconfig_name=qconfig_name) def test_symmetric_qnnpack_qconfig_mapping(self): """ Test whether `torch.ao.quantization.qconfig_mapping._get_symmetric_qnnpack_qconfig_mapping` works with the QNNPACK BackendConfig. """ if "qnnpack" not in supported_qengines: return class MyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(30, 4).float() def forward(self, x): return self.linear(x) with override_quantized_engine("qnnpack"): qconfig_mapping = _get_symmetric_qnnpack_qconfig_mapping() example_inputs = (torch.rand((1, 30), dtype=torch.float),) backend_config = get_qnnpack_backend_config() model = MyModel() model = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config) model(*example_inputs) model = convert_fx(model, backend_config=backend_config) expected_node_occurrence = { ns.call_module(torch.ao.nn.quantized.Linear) : 1, ns.call_module(torch.nn.Linear) : 0, } self.checkGraphModuleNodes(model, expected_node_occurrence=expected_node_occurrence) model(*example_inputs) def test_symmetric_qnnpack_qat_qconfig_mapping(self): """ Test whether `torch.ao.quantization.qconfig_mapping._get_symmetric_qnnpack_qat_qconfig_mapping` works with the QNNPACK BackendConfig. """ if "qnnpack" not in supported_qengines: return class MyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(30, 4).float() def forward(self, x): return self.linear(x) with override_quantized_engine("qnnpack"): qconfig_mapping = _get_symmetric_qnnpack_qat_qconfig_mapping() example_inputs = (torch.rand((1, 30), dtype=torch.float),) backend_config = get_qnnpack_backend_config() model = MyModel() model = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config) model(*example_inputs) model = convert_fx(model, backend_config=backend_config) expected_node_occurrence = { ns.call_module(torch.ao.nn.quantized.Linear) : 1, ns.call_module(torch.nn.Linear) : 0, } self.checkGraphModuleNodes(model, expected_node_occurrence=expected_node_occurrence) model(*example_inputs) def test_get_executorch_backend_config(self): from torch.ao.quantization.backend_config import get_executorch_backend_config # make sure this runs executorch_backend_config = get_executorch_backend_config() def test_backend_config_check_for_weight_and_bias(self): """ Test to make sure the backend_config check for weight and bias runs when the qconfig is None for the ops with weight and bias previously the error was not hit because we first check input, and the check for weight and bias are skipped. """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.weight = torch.tensor((5, 5)) self.bias = torch.tensor((5,)) def forward(self, x): return torch.addmm(self.bias, x, self.weight) m = M().eval() qconfig_mapping = QConfigMapping() observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT weighted_op_quint8_dtype_config = DTypeConfig( input_dtype=torch.quint8, output_dtype=torch.quint8, weight_dtype=torch.qint8, bias_dtype=torch.float, ) dtype_configs = [weighted_op_quint8_dtype_config] backend_pattern_config = BackendPatternConfig(torch.addmm) \ .set_observation_type(observation_type) \ .set_dtype_configs(dtype_configs) \ ._set_input_type_to_index({"weight": 2, "bias": 0}) backend_config = BackendConfig() \ .set_backend_pattern_config(backend_pattern_config) example_inputs = (torch.rand(1, 5),) # make sure this runs m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config) def test_get_default_qconfig_valid_backend(self): """ Checks that AssertionError is raised when non expected backend input is specified """ invalid_backends = ["imaginary_backend", 3] for invalid_backend in invalid_backends: with self.assertRaisesRegex(AssertionError, "not supported"): qconfig = get_default_qconfig(invalid_backend) with self.assertRaisesRegex(AssertionError, "not supported"): qconfig = get_default_qat_qconfig(invalid_backend) with self.assertRaisesRegex(AssertionError, "not supported"): qconfig_mapping = get_default_qconfig_mapping(invalid_backend) with self.assertRaisesRegex(AssertionError, "not supported"): qconfig_mapping = get_default_qat_qconfig_mapping(invalid_backend) def test__convert_to_reference_decomposed_fx(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 10) def forward(self, x): return self.linear(x) m = M().eval() qconfig_mapping = get_default_qconfig_mapping("fbgemm") example_inputs = (torch.randn(1, 5),) m = prepare_fx(m, qconfig_mapping, example_inputs) m_ref = copy.deepcopy(m) m_ref = convert_to_reference_fx(m_ref) m = _convert_to_reference_decomposed_fx(m) expected_occurrence = { ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2, ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2, } self.checkGraphModuleNodes( m, expected_node_occurrence=expected_occurrence) # make sure it runs res_ref = m_ref(*example_inputs) res = m(*example_inputs) self.assertEqual(res, res_ref) @skipIfNoQNNPACK def test__convert_to_reference_decomposed_fx_dynamic_quant(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 10) def forward(self, x): return self.linear(x) # to avoid reduce_range with override_quantized_engine("qnnpack"): m = M().eval() qconfig_mapping = get_default_qconfig_mapping("fbgemm") \ .set_object_type(torch.nn.Linear, default_dynamic_qconfig) example_inputs = (torch.randn(1, 5),) m = prepare_fx(m, qconfig_mapping, example_inputs) m(*example_inputs) m_ref = copy.deepcopy(m) m_ref = convert_to_reference_fx(m_ref) m = _convert_to_reference_decomposed_fx(m) expected_occurrence = { ns.call_function(torch.ops.quantized_decomposed.choose_qparams.tensor): 1, ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 1, ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor): 1, } self.checkGraphModuleNodes( m, expected_node_occurrence=expected_occurrence) # make sure it runs res_ref = m_ref(*example_inputs) res = m(*example_inputs) self.assertEqual(res, res_ref) def test__convert_to_reference_decomposed_fx_per_channel_quant(self): class M(torch.nn.Module): def forward(self, x, weight, bias): return F.linear(x, weight, bias) m = M().eval() qconfig_mapping = get_default_qconfig_mapping("fbgemm") \ .set_object_type(F.linear, default_per_channel_qconfig) example_inputs = (torch.randn(1, 5), torch.randn(10, 5), torch.randn(10,)) m = prepare_fx(m, qconfig_mapping, example_inputs) m(*example_inputs) m_ref = copy.deepcopy(m) m_ref = convert_to_reference_fx(m_ref) m = _convert_to_reference_decomposed_fx(m) expected_occurrence = { # for input and output activations ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2, ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2, # for weight ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1, ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1, } self.checkGraphModuleNodes( m, expected_node_occurrence=expected_occurrence) # make sure it runs res_ref = m_ref(*example_inputs) res = m(*example_inputs) self.assertEqual(res, res_ref) def test_change_backend_config_for_fixed_qparam_ops(self): """ Making sure we can skip validation of qconfigs for fixedqparam ops based on BackendConfig """ class M(nn.Module): def __init__(self) -> None: super().__init__() self.tanh = torch.nn.Tanh() def forward(self, x: torch.Tensor): x = self.tanh(x) return x model = M().eval() # we set a global default_qconfig, which will be ignored since the backend # we defined doesn't support anything # this is to make sure we don't validate the qconfig when BackendConfig does not # have fixed qparam op related configurations qconfig_mapping = QConfigMapping().set_global(default_qconfig) backend_config = BackendConfig() # make sure this runs model = prepare_fx( model, qconfig_mapping=qconfig_mapping, example_inputs=(torch.randn(1, 2, 3, 4),), backend_config=backend_config ) def test_channel_shuffle_lowering(self): # Three versions of channel shuffle class M1(torch.nn.Module): def __init__(self) -> None: super().__init__() self.op = torch.nn.ChannelShuffle(2) def forward(self, x): return self.op(x + x) + x class M2(torch.nn.Module): def forward(self, x): return torch.channel_shuffle(x + x, 2) + x class M3(torch.nn.Module): def forward(self, x): return torch.nn.functional.channel_shuffle(x + x, 2) + x x = torch.randn(4, 4, 4, 4) # torch.channel_shuffle is equivalent to torch.nn.functional.channel_shuffle model_node_pairs = [ (M1().eval(), ns.call_module(torch.nn.ChannelShuffle)), (M2().eval(), ns.call_function(torch.channel_shuffle)), (M3().eval(), ns.call_function(torch.channel_shuffle)) ] for m, node in model_node_pairs: m = prepare_fx(m, {"": default_qconfig}, example_inputs=(x,)) m_copy = copy.deepcopy(m) m = convert_fx(m) m_ref = convert_to_reference_fx(m_copy) node_occurrence = { node: 1, ns.call_function(torch.quantize_per_tensor): 1, ns.call_method("dequantize"): 1 } node_occurrence_ref = { node: 1, ns.call_function(torch.quantize_per_tensor): 4, ns.call_method("dequantize"): 4 } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref) def test_match_pattern_with_multiple_args(self): """ Test that we can match a pattern that has multiple arguments Pattern: shape \ transpose (observed) -> reshape -> output (observed) -> where `reshape` has two arguments """ def _get_pattern_configs(): backend_pattern_configs = [] observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT weighted_op_quint8_dtype_config = DTypeConfig( input_dtype=torch.quint8, output_dtype=torch.quint8, weight_dtype=torch.qint8, bias_dtype=torch.float, ) dtype_configs = [weighted_op_quint8_dtype_config] def root_node_getter(node_pattern): reshape, transpose, shape = node_pattern return transpose backend_pattern_configs.append( BackendPatternConfig() ._set_pattern_complex_format((torch.reshape, torch.transpose, MatchAllNode)) # noqa: E131 .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ._set_root_node_getter(root_node_getter) ) return backend_pattern_configs backend_config = BackendConfig().set_backend_pattern_configs(_get_pattern_configs()) class M(torch.nn.Module): def forward(self, x): x = torch.transpose(x, 0, 1) x = torch.reshape(x, (-1,)) return x m = M().eval() qconfig_mapping = QConfigMapping().set_global(default_qconfig) example_inputs = (torch.randn(1, 3, 3, 3),) m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config) node_occurrence = { # one for input of the pattern and one for output of the pattern ns.call_module(MinMaxObserver): 2 } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) def _test_linear_activation_fusion_lowering_helper( self, module, example_inputs, qconfig_mapping, backend_config, fused_module, root_module, activation_module): node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method("dequantize"): 1, ns.call_module(fused_module): 1, ns.call_module(root_module): 0, ns.call_module(activation_module): 0, } node_occurrence_ref = { ns.call_function(torch.quantize_per_tensor): 2, ns.call_method("dequantize"): 2, } m = module.eval() m = prepare_fx(m, qconfig_mapping, example_inputs=example_inputs, backend_config=backend_config) m_copy = copy.deepcopy(m) m = convert_fx(m, backend_config=backend_config) m_ref = convert_to_reference_fx(m_copy) self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref) m(*example_inputs) @skipIfNoONEDNN def test_linear_leaky_relu_lowering(self): """ Test fusion and lowering of Linear - (bn -) LeakyReLU by FX. For onednn backedn only. """ from torch.ao.quantization.backend_config import get_onednn_backend_config qconfig_mapping = get_default_qconfig_mapping('onednn') with override_quantized_engine('onednn'): for with_bn in [True, False]: m = LinearBnLeakyReluModel(with_bn) self._test_linear_activation_fusion_lowering_helper( m, m.get_example_inputs(), qconfig_mapping, get_onednn_backend_config(), nniq.LinearLeakyReLU, nn.Linear, nn.LeakyReLU) @skipIfNoONEDNN def test_linear_tanh_lowering(self): """ Test fusion and lowering of Linear - Tanh by FX. For onednn backedn only. """ from torch.ao.quantization.backend_config import get_onednn_backend_config qconfig_mapping = get_default_qconfig_mapping('onednn') # TODO Currently it's required that separate ops in a fused op/module have the same qconfig. # Need to be able to support fusion of ops with different qconfigs # Since tanh must have 'fixed_qparams_qconfig' while linear should use # the global qconfig, we need to set qconfigs for them manually here for # fusion and cannot put such configs in onednn's default qconfig_mapping. # Known issue: # Cannot fuse linear - tanh and quantize standalone tanh at the same time. qconfig = get_default_qconfig('onednn') qconfig_mapping.set_object_type(torch.nn.Linear, qconfig) qconfig_mapping.set_object_type(torch.nn.Tanh, qconfig) with override_quantized_engine('onednn'): m = LinearTanhModel() self._test_linear_activation_fusion_lowering_helper( m, m.get_example_inputs(), qconfig_mapping, get_onednn_backend_config(), nniq.LinearTanh, nn.Linear, nn.Tanh) @override_qengines def test_linear_size_view(self): class M(torch.nn.Module): def __init__(self, use_relu=False): super().__init__() self.linear = torch.nn.Linear(16, 32) self.relu = torch.nn.ReLU() self.use_relu = use_relu def forward(self, x): x = self.linear(x) if self.use_relu: x = self.relu(x) return x.view(x.size(0), 1, 4, 8) for use_relu in [False, True]: model_fp32 = M(use_relu).eval() qengine = torch.backends.quantized.engine qconfig_mapping = get_default_qconfig_mapping(qengine) x = torch.randn((5, 16)) model_fp32(x) prepared_model = prepare_fx(model_fp32, qconfig_mapping, x) prepared_model(x) quantized_model = convert_fx(prepared_model) node_occurrence = { ns.call_module(nnq.Linear): 0 if use_relu else 1, ns.call_module(nniq.LinearReLU): 1 if use_relu else 0, ns.call_function(torch.quantize_per_tensor): 1, ns.call_method("dequantize"): 1 } self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence) @override_qengines def test_linear_shape_view(self): class M(torch.nn.Module): def __init__(self, use_relu=False): super().__init__() self.linear = torch.nn.Linear(16, 32) self.relu = torch.nn.ReLU() self.use_relu = use_relu def forward(self, x): x = self.linear(x) if self.use_relu: x = self.relu(x) return x.view(x.shape[0], 1, 4, 8) for use_relu in [False, True]: model_fp32 = M(use_relu).eval() qengine = torch.backends.quantized.engine qconfig_mapping = get_default_qconfig_mapping(qengine) x = torch.randn((5, 16)) model_fp32(x) prepared_model = prepare_fx(model_fp32, qconfig_mapping, x) prepared_model(x) quantized_model = convert_fx(prepared_model) node_occurrence = { ns.call_module(nnq.Linear): 0 if use_relu else 1, ns.call_module(nniq.LinearReLU): 1 if use_relu else 0, ns.call_function(torch.quantize_per_tensor): 1, ns.call_method("dequantize"): 1 } self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence) def test_mixed_dtypes(self): """ Test that multiple dtypes can be used in the same model for different layers, and the dtypes will be converted correctly between the layers. """ class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear1 = torch.nn.Linear(5, 5) self.linear2 = torch.nn.Linear(5, 5) self.sigmoid = torch.nn.Sigmoid() self.tanh = torch.nn.Tanh() self.float_functional = torch.ao.nn.quantized.FloatFunctional() def forward(self, x: torch.Tensor): x = self.linear1(x) # qint32 x = self.linear2(x) # quint8 linear2 = x x = self.sigmoid(x) # back to qint32 x = self.tanh(x) # back to quint8 x = self.float_functional.add(linear2, x) # adding two quint8's together return x def make_qconfig(scale, zp, dtype): return QConfig( activation=FixedQParamsObserver.with_args(scale=scale, zero_point=zp, dtype=dtype), weight=torch.ao.quantization.default_weight_observer) # Set up a QConfigMapping that specifies different qparams and dtypes for different layers qconfig_mapping = QConfigMapping() \ .set_global(get_default_qconfig("qnnpack")) \ .set_module_name("linear1", make_qconfig(1234, 11, torch.qint32)) \ .set_module_name("linear2", make_qconfig(2345, 22, torch.quint8)) \ .set_object_type(torch.nn.Sigmoid, make_qconfig(3456, 33, torch.qint32)) \ .set_object_type(torch.nn.Tanh, make_qconfig(4567, 44, torch.quint8)) # Set up BackendConfig that supports the dtypes configured in the above QConfigMapping weighted_op_qint32_dtype_config = DTypeConfig( input_dtype=torch.qint32, output_dtype=torch.qint32, weight_dtype=torch.qint8, bias_dtype=torch.float, ) fixed_qparams_op_quint8_dtype_config = DTypeConfig( input_dtype=torch.quint8, output_dtype=torch.quint8, ) fixed_qparams_op_qint32_dtype_config = DTypeConfig( input_dtype=torch.qint32, output_dtype=torch.qint32, ) backend_config = get_qnnpack_backend_config() for config in backend_config.configs: if config.pattern == torch.nn.Linear: config.add_dtype_config(weighted_op_qint32_dtype_config) elif config.pattern in [torch.nn.Sigmoid, torch.nn.Tanh]: config.add_dtype_config(fixed_qparams_op_quint8_dtype_config) config.add_dtype_config(fixed_qparams_op_qint32_dtype_config) # Produce the reference quantized model m = MyModule() example_inputs = (torch.rand(5, 5),) prepared = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config) prepared(*example_inputs) # calibrate converted = convert_to_reference_fx(prepared, backend_config=backend_config) converted(*example_inputs) # Verify that the reference model is correct # # Reference model until add should be: # fp32_input -> q_to_int32 -> [dq -> linear1_fp32 -> q_to_int32] -> dq -> # q_to_uint8 -> [dq -> linear2_fp32 -> q_to_uint8] -> dq (linear2_dq) -> # q_to_int32 -> [dq -> sigmoid_fp32 -> q_to_int32] -> dq -> # q_to_uint8 -> [dq -> tanh_fp32 -> q_to_uint8] -> dq (tanh_dq) # # Complete reference model with add should be: # [(linear2_dq, tanh_dq) -> add_fp32 -> q_to_uint8] -> dq -> fp32_output target_to_expected_dtypes = { "linear1": torch.qint32, "linear2": torch.quint8, "sigmoid": torch.qint32, "tanh": torch.quint8, torch.add: torch.quint8, } # Find the patterns [dq - op_fp32 - q_to_specific_dtype] in the graph linear2_node = tanh_node = None for node in converted.graph.nodes: if node.target not in target_to_expected_dtypes: continue # Match preceding dequantize self.assertTrue(len(node.args) == 1 or len(node.args) == 2) self.assertTrue(all(arg.target == "dequantize" for arg in node.args)) # Match following quantize with the specific dtypes self.assertEqual(len(node.users), 1) user = next(iter(node.users.keys())) self.assertEqual(user.target, torch.quantize_per_tensor) self.assertEqual(user.args[-1], target_to_expected_dtypes[node.target]) # Match [dq - torch.add(linear2_dq, tanh_dq) - q] if node.target == "linear2": linear2_node = node elif node.target == "tanh": tanh_node = node elif node.target == torch.add: linear2_dq, tanh_dq = node.args self.assertEqual(tanh_dq.args[0].args[0], tanh_node) self.assertEqual(linear2_dq.args[0].args[0], linear2_node) def test_lowering_functional_conv_with_kwargs(self): dim_to_op = { 1: F.conv1d, 2: F.conv2d, 3: F.conv3d, } dim_to_qop = { 1: torch.ops.quantized.conv1d, 2: torch.ops.quantized.conv2d, 3: torch.ops.quantized.conv3d, } class Mod(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, dimension): super().__init__() self.dim = dimension self.op = dim_to_op[dimension] kernel_sizes = [kernel_size] * self.dim self.weight = nn.Parameter(torch.randn(out_channels, in_channels, *kernel_sizes)) def forward(self, input): return self.op(input, self.weight, bias=None, stride=[1] * self.dim, padding=[0] * self.dim, dilation=[1] * self.dim, groups=1) for dimension in [1, 2, 3]: model = Mod(3, 16, 3, dimension) model.eval() qconfig_mapping = get_default_qconfig_mapping() input_shape = (1, 3, *([8] * dimension)) example_inputs = torch.randn(input_shape) prepared_model = prepare_fx(model, qconfig_mapping, example_inputs) prepared_model(example_inputs) quantized_model = convert_fx(prepared_model) # This should pass quantized_model(example_inputs) # Ensure the quantized model has the expected op node_occurrence = { ns.call_function(dim_to_qop[dimension]): 1, } self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence) def test_lowering_functional_conv_transpose_with_kwargs(self): dim_to_op = { 1: F.conv_transpose1d, 2: F.conv_transpose2d, 3: F.conv_transpose3d, } dim_to_qop = { 1: torch.ops.quantized.conv_transpose1d, 2: torch.ops.quantized.conv_transpose2d, 3: torch.ops.quantized.conv_transpose3d, } class Mod(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, dimension): super().__init__() self.dim = dimension self.op = dim_to_op[dimension] kernel_sizes = [kernel_size] * self.dim self.weight = nn.Parameter(torch.randn(in_channels, out_channels, *kernel_sizes)) def forward(self, input): return self.op(input, self.weight, bias=None, stride=[1] * self.dim, padding=[0] * self.dim, output_padding=[0] * self.dim, dilation=[1] * self.dim, groups=1) for dimension in [1, 2, 3]: model = Mod(3, 16, 3, dimension) model.eval() qconfig_mapping = get_default_qconfig_mapping() input_shape = (1, 3, *([8] * dimension)) example_inputs = torch.randn(input_shape) prepared_model = prepare_fx(model, qconfig_mapping, example_inputs) prepared_model(example_inputs) quantized_model = convert_fx(prepared_model) # This should pass quantized_model(example_inputs) # Ensure the quantized model has the expected op node_occurrence = { ns.call_function(dim_to_qop[dimension]): 1, } self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence) def test_lowering_functional_linear_with_kwargs(self): class Mod(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.weight = nn.Parameter(torch.randn(out_channels, in_channels)) def forward(self, input): return F.linear(input, self.weight, bias=None) model = Mod(8, 4) model.eval() qconfig_mapping = get_default_qconfig_mapping() example_inputs = torch.randn(1, 8) prepared_model = prepare_fx(model, qconfig_mapping, example_inputs) prepared_model(example_inputs) quantized_model = convert_fx(prepared_model) # This should pass quantized_model(example_inputs) # Ensure the quantized model has the expected op node_occurrence = { ns.call_function(torch.ops.quantized.linear): 1, } self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence) @skipIfNoFBGEMM def test_keep_original_weights(self): class SubModule(nn.Module): """ A simple submodule containing a linear layer. """ def __init__(self, input_dim, output_dim): super().__init__() self.w = nn.Parameter(torch.randn(input_dim, output_dim)) self.b = nn.Parameter(torch.randn(input_dim)) def forward(self, x): return F.linear(x, self.w, self.b) class MainModule(nn.Module): """ The main module containing the submodule. """ def __init__(self, input_dim, hidden_dim, output_dim): super().__init__() self.submodule_1 = SubModule(hidden_dim, input_dim) setattr(self, 'submodule|2', SubModule(hidden_dim, hidden_dim)) setattr(self, 'submodule/3', SubModule(hidden_dim, hidden_dim)) setattr(self, 'submodule:4', SubModule(hidden_dim, hidden_dim)) setattr(self, 'submodule: 5', SubModule(hidden_dim, hidden_dim)) self._w = nn.Parameter(torch.randn(output_dim, hidden_dim)) def forward(self, x): x1 = self.submodule_1(x) x2 = getattr(self, 'submodule|2')(x1) x3 = getattr(self, 'submodule/3')(x2) x4 = getattr(self, 'submodule:4')(x3) x5 = getattr(self, 'submodule: 5')(x4) x6 = F.linear(x5, self._w) return x6 input_dim = 10 hidden_dim = 20 output_dim = 5 model = MainModule(input_dim, hidden_dim, output_dim) model.eval() example_inputs = torch.randn(1, input_dim) _ = model(*example_inputs) qconfig_mapping = QConfigMapping().set_object_type(nn.functional.linear, float16_dynamic_qconfig) prepared_model = prepare_fx(model, qconfig_mapping, example_inputs) prepared_model(example_inputs) quantized_model = convert_fx(prepared_model, keep_original_weights=True) self.assertTrue(len(quantized_model.original_weights_lookup) == 6) self.assertTrue("submodule_1_packed_weight_0" in quantized_model.original_weights_lookup) torch.testing.assert_close( quantized_model.original_weights_lookup["submodule_1_packed_weight_0"][0], model.submodule_1.w ) torch.testing.assert_close( quantized_model.original_weights_lookup["submodule_1_packed_weight_0"][1], model.submodule_1.b ) self.assertTrue("submodule_2_packed_weight_0" in quantized_model.original_weights_lookup) torch.testing.assert_close( quantized_model.original_weights_lookup["submodule_2_packed_weight_0"][0], getattr(model, "submodule|2").w ) torch.testing.assert_close( quantized_model.original_weights_lookup["submodule_2_packed_weight_0"][1], getattr(model, "submodule|2").b ) self.assertTrue("submodule_3_packed_weight_0" in quantized_model.original_weights_lookup) torch.testing.assert_close( quantized_model.original_weights_lookup["submodule_3_packed_weight_0"][0], getattr(model, "submodule/3").w ) torch.testing.assert_close( quantized_model.original_weights_lookup["submodule_3_packed_weight_0"][1], getattr(model, "submodule/3").b ) self.assertTrue("submodule_4_packed_weight_0" in quantized_model.original_weights_lookup) torch.testing.assert_close( quantized_model.original_weights_lookup["submodule_4_packed_weight_0"][0], getattr(model, "submodule:4").w ) torch.testing.assert_close( quantized_model.original_weights_lookup["submodule_4_packed_weight_0"][1], getattr(model, "submodule:4").b ) self.assertTrue("submodule_5_packed_weight_0" in quantized_model.original_weights_lookup) torch.testing.assert_close( quantized_model.original_weights_lookup["submodule_5_packed_weight_0"][0], getattr(model, "submodule: 5").w ) torch.testing.assert_close( quantized_model.original_weights_lookup["submodule_5_packed_weight_0"][1], getattr(model, "submodule: 5").b ) self.assertTrue("_packed_weight_0" in quantized_model.original_weights_lookup) torch.testing.assert_close( quantized_model.original_weights_lookup["_packed_weight_0"][0], model._w ) torch.testing.assert_close( quantized_model.original_weights_lookup["_packed_weight_0"][1], None ) @skipIfNoFBGEMM class TestQuantizeFxOps(QuantizationTestCase): def setUp(self): super().setUp() self.custom_qconfig = torch.ao.quantization.QConfig( activation=torch.ao.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 ), weight=torch.ao.quantization.default_per_channel_weight_observer ) self.common_quant_patterns = { torch.nn.ConvTranspose1d: DefaultNodeQuantizeHandler, torch.nn.ConvTranspose2d: DefaultNodeQuantizeHandler, torch.nn.ELU: DefaultNodeQuantizeHandler, torch.nn.LeakyReLU: DefaultNodeQuantizeHandler, torch.nn.Hardswish: DefaultNodeQuantizeHandler, torch.nn.InstanceNorm1d: DefaultNodeQuantizeHandler, torch.nn.InstanceNorm2d: DefaultNodeQuantizeHandler, torch.nn.InstanceNorm3d: DefaultNodeQuantizeHandler, torch.nn.LayerNorm: DefaultNodeQuantizeHandler, torch.nn.SiLU: DefaultNodeQuantizeHandler, torch.nn.Mish: DefaultNodeQuantizeHandler, torch.nn.GELU: DefaultNodeQuantizeHandler, torch.nn.Softmax: DefaultNodeQuantizeHandler, torch.nn.functional.elu: DefaultNodeQuantizeHandler, torch.nn.functional.hardswish: DefaultNodeQuantizeHandler, torch.nn.functional.instance_norm: DefaultNodeQuantizeHandler, torch.nn.functional.layer_norm: DefaultNodeQuantizeHandler, torch.nn.functional.leaky_relu: DefaultNodeQuantizeHandler, torch.nn.functional.silu: DefaultNodeQuantizeHandler, torch.nn.functional.mish: DefaultNodeQuantizeHandler, torch.nn.functional.gelu: DefaultNodeQuantizeHandler, torch.nn.functional.softmax: DefaultNodeQuantizeHandler, torch.sum: DefaultNodeQuantizeHandler } """Unit tests for individual ops """ @skipIfNoFBGEMM def test_linear_module(self): with override_quantized_engine('fbgemm'): class LinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(30, 4).float() def forward(self, x): return self.linear(x) class LinearReLUModel(torch.nn.Module): def __init__(self, f_relu=False): super().__init__() self.linear = torch.nn.Linear(30, 4).float() if f_relu: self.relu = F.relu else: self.relu = torch.nn.ReLU() def forward(self, x): x = self.linear(x) x = self.relu(x) return x class LinearBnModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(4, 4).float() self.bn = torch.nn.BatchNorm1d(4) def forward(self, x): x = self.linear(x) x = self.bn(x) return x # Test linear data = (torch.rand((1, 30), dtype=torch.float),) for quant_type in self.all_quant_types: model = LinearModel() quantized_module = nnqd.Linear if quant_type == QuantType.DYNAMIC else nnq.Linear quantized_node = ns.call_module(quantized_module) result_dict = self.checkGraphModeFxOp(model, data, quant_type, quantized_node) if quant_type in self.static_quant_types: self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"]) # TODO: enable test for dynamic quant # Test linear-relu for f_relu, quant_type in itertools.product([True, False], [QuantType.STATIC, QuantType.QAT]): model = LinearReLUModel(f_relu) quantized_node = ns.call_module(nniq.LinearReLU) result_dict = self.checkGraphModeFxOp(model, data, quant_type, quantized_node) self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"]) # Test linear-bn data = (torch.rand((4, 4), dtype=torch.float),) for quant_type in self.static_quant_types: model = LinearBnModel() quantized_node = ns.call_module(nnq.Linear) result_dict = self.checkGraphModeFxOp(model, data, quant_type, quantized_node) self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"]) @skipIfNoFBGEMM def test_functional_linear(self): with override_quantized_engine('fbgemm'): class FuncLinear(torch.nn.Module): def __init__(self, use_bias, has_relu, f_relu): super().__init__() self.w = torch.randn(4, 30) self.b = torch.randn(4) self.use_bias = use_bias if has_relu: if f_relu: self.relu_or_id = F.relu else: self.relu_or_id = torch.nn.ReLU() else: self.relu_or_id = torch.nn.Identity() def forward(self, x): if self.use_bias: x = F.linear(x, self.w, self.b) else: x = F.linear(x, self.w) x = self.relu_or_id(x) return x data = (torch.rand((1, 30), dtype=torch.float),) quant_type_to_qlinear_fun = { QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic), QuantType.STATIC: ns.call_function(torch.ops.quantized.linear), QuantType.QAT: ns.call_function(torch.ops.quantized.linear), } quant_type_to_qlinear_relu_fun = { # we don't have linear_relu_dynamic QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_relu_dynamic), QuantType.STATIC: ns.call_function(torch.ops.quantized.linear_relu), QuantType.QAT: ns.call_function(torch.ops.quantized.linear_relu), } options = itertools.product( self.all_quant_types, (True, False), # use_bias (True, False), # has_relu (True, False), # functional relu ) for quant_type, use_bias, has_relu, f_relu in options: # when has_relu is False, we are using an nn.Identity and # we will insert observer/fake_quant for the output of nn.Identity since # it is a copy node, that's why we have extra observer/fake_quant # when has_relu is False quant_type_to_prepare_expected_node_occurrence = { QuantType.DYNAMIC: { ns.call_module(torch.ao.quantization.PlaceholderObserver): 1, ns.call_module(torch.ao.quantization.MinMaxObserver): 1, }, # There should be 3 observers: after input, weight and activation. # one more observer for torch.nn.Identity when there is no relu QuantType.STATIC: { ns.call_module(torch.ao.quantization.HistogramObserver): 2 if has_relu else 3, ns.call_module(torch.ao.quantization.PerChannelMinMaxObserver): 1, }, # There should be 3 observers: after input, weight and activation. QuantType.QAT: { ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 3 if has_relu else 4, }, } model = FuncLinear(use_bias, has_relu, f_relu) if has_relu: qlinear_fun = quant_type_to_qlinear_relu_fun[quant_type] else: qlinear_fun = quant_type_to_qlinear_fun[quant_type] if quant_type != QuantType.DYNAMIC: num_dequantize = 1 else: # we will have an extra quantize_per_tensor_dynamic + dequantize for # nn.Identity right now, but it will be fixed after we use # backend_config to configure the default pt backend num_dequantize = int(not has_relu) convert_node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1 if quant_type != QuantType.DYNAMIC else 0, qlinear_fun: 1, ns.call_method("dequantize"): num_dequantize if quant_type != QuantType.DYNAMIC else 0, } prepare_expected_node_occurrence = \ quant_type_to_prepare_expected_node_occurrence[quant_type] result_dict = self.checkGraphModeFxOp( model, data, quant_type, qlinear_fun, prepare_expected_node_occurrence=prepare_expected_node_occurrence, expected_node_occurrence=convert_node_occurrence) if quant_type != QuantType.DYNAMIC: self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"]) # Ensure packed weights in lowered models are folded self.assertIn("_packed_weight_0", result_dict["quantized"].state_dict().keys()) @skipIfNoFBGEMM def test_linear_dynamic_fp16(self): with override_quantized_engine('fbgemm'): class FuncLinear(torch.nn.Module): def __init__(self, use_bias, has_relu, f_relu): super().__init__() self.w = torch.randn(4, 30) self.b = torch.randn(4) self.use_bias = use_bias if has_relu: if f_relu: self.relu = F.relu else: self.relu = torch.nn.ReLU() else: self.relu = torch.nn.Identity() def forward(self, x): if self.use_bias: x = F.linear(x, self.w, self.b) else: x = F.linear(x, self.w) x = self.relu(x) return x data = (torch.rand((1, 30), dtype=torch.float),) options = itertools.product( (True, False), # use_bias (True, False), # has_relu (True, False), # functional relu (True, False), # is_reference ) for use_bias, has_relu, f_relu, is_reference in options: model = FuncLinear(use_bias, has_relu, f_relu) if is_reference: qlinear_fun = ns.call_function(torch.nn.functional.linear) else: if has_relu: qlinear_fun = ns.call_function(torch.ops.quantized.linear_relu_dynamic_fp16) else: qlinear_fun = ns.call_function(torch.ops.quantized.linear_dynamic_fp16) prepare_node_occurrence = { # activation and weight ns.call_module(torch.ao.quantization.PlaceholderObserver): 2 } convert_node_occurrence = { qlinear_fun: 1, # weight ns.call_method("to"): 1 if is_reference else 0 } self.checkGraphModeFxOp( model, data, QuantType.DYNAMIC, qlinear_fun, is_reference=is_reference, custom_qconfig_dict={"": float16_dynamic_qconfig}, prepare_expected_node_occurrence=prepare_node_occurrence, expected_node_occurrence=convert_node_occurrence) def test_linear_static_fp16(self): class FuncLinear(torch.nn.Module): def __init__(self, use_bias, has_relu, f_relu): super().__init__() self.w = torch.randn(4, 30) self.b = torch.randn(4) self.use_bias = use_bias if has_relu: if f_relu: self.relu = F.relu else: self.relu = torch.nn.ReLU() else: self.relu = torch.nn.Identity() def forward(self, x): if self.use_bias: x = F.linear(x, self.w, self.b) else: x = F.linear(x, self.w) x = self.relu(x) return x data = (torch.rand((1, 30), dtype=torch.float),) options = itertools.product( (True, False), # use_bias (True, False), # has_relu (True, False), # functional relu (True, False), # is_reference ) backend_config = get_test_only_legacy_native_backend_config() for use_bias, has_relu, f_relu, is_reference in options: model = FuncLinear(use_bias, has_relu, f_relu) linear_fun = ns.call_function(torch.nn.functional.linear) # when has_relu is False, we are using an nn.Identity and # we will insert observer/fake_quant for the output of nn.Identity since # it is a copy node, that's why we have extra observer/fake_quant # when has_relu is False prepare_node_occurrence = { # activation, weight, bias and output ns.call_module(torch.ao.quantization.PlaceholderObserver): 3 + int(use_bias) + int(not has_relu), } # We have extra to and dequantize when is_reference is True # and has_relu is False since when has_relu is False, we # have an nn.Identity in the model, which is a CopyNode # and we would add extra quant - dequant for CopyNode in # reference patterns convert_node_occurrence = { # we don't support static fp16 ops, so the linear function # is unfused linear_fun: 1, # activation, weight, bias and output ns.call_method("to"): 3 + int(use_bias) + int(not has_relu and is_reference), ns.call_method("dequantize"): 3 + int(use_bias) + int(not has_relu and is_reference) } self.checkGraphModeFxOp( model, data, QuantType.DYNAMIC, linear_fun, is_reference=is_reference, custom_qconfig_dict={"": float16_static_qconfig}, prepare_expected_node_occurrence=prepare_node_occurrence, expected_node_occurrence=convert_node_occurrence, backend_config=backend_config) @skipIfNoFBGEMM def test_conv_module(self): conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d} class ConvWrapper(torch.nn.Module): def __init__(self, dim): super().__init__() self.conv = conv_module[dim](3, 3, 3).float() def forward(self, x): return self.conv(x) options = itertools.product([1, 2, 3], self.static_quant_types) quantized_nodes = { # dim 1: ns.call_module(nnq.Conv1d), 2: ns.call_module(nnq.Conv2d), 3: ns.call_module(nnq.Conv3d), } for dim, quant_type in options: self.checkGraphModeFxOp( ConvWrapper(dim), self.img_data_dict[dim], quant_type, quantized_nodes[dim]) @skipIfNoFBGEMM def test_functional_conv(self): with override_quantized_engine('fbgemm'): """ Test for function conv and functional conv + relu """ convs = { 1: torch.nn.functional.conv1d, 2: torch.nn.functional.conv2d, 3: torch.nn.functional.conv3d, } class FuncConv(torch.nn.Module): def __init__(self, dim, use_bias, has_relu, f_relu): super().__init__() self.dim = dim self.w = torch.randn(tuple([3] * (dim + 2))) self.b = torch.randn(3) if use_bias else None self.stride = tuple([1] * dim) self.padding = tuple([0] * dim) self.dilation = tuple([1] * dim) self.groups = 1 self.use_bias = use_bias if has_relu: if f_relu: self.relu = F.relu else: self.relu = torch.nn.ReLU() else: self.relu = torch.nn.Identity() def forward(self, x): x = convs[self.dim](x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups) x = self.relu(x) return x quant_type_to_qconv_fun = { QuantType.STATIC: { 1: ns.call_function(torch.ops.quantized.conv1d), 2: ns.call_function(torch.ops.quantized.conv2d), 3: ns.call_function(torch.ops.quantized.conv3d) }, QuantType.QAT: { 1: ns.call_function(torch.ops.quantized.conv1d), 2: ns.call_function(torch.ops.quantized.conv2d), 3: ns.call_function(torch.ops.quantized.conv3d) }, } quant_type_to_qconv_relu_fun = { QuantType.STATIC: { 1: ns.call_function(torch.ops.quantized.conv1d_relu), 2: ns.call_function(torch.ops.quantized.conv2d_relu), 3: ns.call_function(torch.ops.quantized.conv3d_relu) }, QuantType.QAT: { 1: ns.call_function(torch.ops.quantized.conv1d_relu), 2: ns.call_function(torch.ops.quantized.conv2d_relu), 3: ns.call_function(torch.ops.quantized.conv3d_relu) }, } options = itertools.product( [1, 2, 3], # dims self.static_quant_types, (True, False), # use_bias (True, False), # has_relu (True, False), # functional relu ) for dim, quant_type, use_bias, has_relu, f_relu in options: # when has_relu is False, we are using an nn.Identity and # we will insert observer/fake_quant for the output of nn.Identity since # it is a copy node, that's why we have extra observer/fake_quant # when has_relu is False quant_type_to_prepare_expected_node_occurrence = { QuantType.DYNAMIC: {}, # There should be 3 observers: after input, weight and activation. QuantType.STATIC: { ns.call_module(torch.ao.quantization.HistogramObserver): 2 if has_relu else 3, ns.call_module(torch.ao.quantization.PerChannelMinMaxObserver): 1, }, # There should be 3 observers: after input, weight and activation. QuantType.QAT: { ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 3 if has_relu else 4, }, } data_dims = [2, 3] + [4] * dim data = (torch.randn(tuple(data_dims), dtype=torch.float),) model = FuncConv(dim, use_bias, has_relu, f_relu) if has_relu: qconv_fun = quant_type_to_qconv_relu_fun[quant_type][dim] else: qconv_fun = quant_type_to_qconv_fun[quant_type][dim] convert_node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, qconv_fun: 1, ns.call_method("dequantize"): 1 } prepare_expected_node_occurrence = \ quant_type_to_prepare_expected_node_occurrence[quant_type] result_dict = self.checkGraphModeFxOp( model, data, quant_type, qconv_fun, prepare_expected_node_occurrence=prepare_expected_node_occurrence, expected_node_occurrence=convert_node_occurrence) if quant_type != QuantType.DYNAMIC: self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"]) # Ensure packed weights in lowered models are folded self.assertIn("_packed_weight_0", result_dict["quantized"].state_dict().keys()) @skipIfNoFBGEMM def test_quantized_conv_relu(self): """tests for conv1d_relu/conv2d_relu/conv3d_relu""" conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d} class ConvNdRelu(torch.nn.Module): def __init__(self, dim, inplace): super().__init__() self.conv = conv_module[dim](3, 3, 3).float() self.relu = torch.nn.ReLU(inplace) def forward(self, x): return self.relu(self.conv(x)) class ConvNdFunctionalRelu(torch.nn.Module): def __init__(self, dim): super().__init__() self.conv = conv_module[dim](3, 3, 3).float() def forward(self, x): return F.relu(self.conv(x)) class ConvNdInplaceFunctionalRelu(torch.nn.Module): def __init__(self, dim): super().__init__() self.conv = conv_module[dim](3, 3, 3).float() def forward(self, x): return F.relu(self.conv(x), True) options = itertools.product([1, 2, 3], self.static_quant_types) quantized_nodes = { # dim 1: ns.call_module(nniq.ConvReLU1d), 2: ns.call_module(nniq.ConvReLU2d), 3: ns.call_module(nniq.ConvReLU3d), } for dim, quant_type in options: for m in [ConvNdRelu(dim, True), ConvNdRelu(dim, False), ConvNdFunctionalRelu(dim), ConvNdInplaceFunctionalRelu(dim)]: self.checkGraphModeFxOp( m, self.img_data_dict[dim], quant_type, quantized_nodes[dim]) def _test_binary_op_int8_impl(self, binary_op, ibinary_op, quantized_op): data = (torch.randn(1, 1, 1, 1, dtype=torch.float), torch.randn(1, 1, 1, 1, dtype=torch.float)) options = itertools.product([True, False], [True, False], [True, False]) quant_type = QuantType.STATIC # testing for default int8 static quant for is_inplace, is_scalar, is_reference in options: if is_reference: node_list = [ ns.call_method("dequantize"), ns.call_function(binary_op), ns.call_function(torch.quantize_per_tensor) ] quantized_node = None else: node_list = None quantized_node = ns.call_function(quantized_op) self.checkGraphModeFxOp( BinaryOp(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type, quantized_node, expected_node_list=node_list, is_reference=is_reference) # This tests the binary op should be quantized even when it is not feed with a # quantized input self.checkGraphModeFxOp( BinaryOpNonQuantizedInput(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type, quantized_node, expected_node_list=node_list, is_reference=is_reference) def _test_binary_op_float16_impl(self, binary_op, ibinary_op): data = (torch.randn(1, 1, 1, 1, dtype=torch.float), torch.randn(1, 1, 1, 1, dtype=torch.float)) quant_type = QuantType.STATIC # testing for fp16 static quant # we are producing fp16 patterns options = itertools.product([True, False], [True, False]) custom_qconfig_dict = { "object_type": [(binary_op, float16_static_qconfig)] } backend_config = get_test_only_legacy_native_backend_config() for is_inplace, is_scalar in options: node_occurrence = { # output_conv1, output_add1, output_add2 for scalar # output_conv1, output_conv2, output_add1, output_add2 for non-scalar ns.call_method("to"): 3 if is_scalar else 4 } self.checkGraphModeFxOp( BinaryOp(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type, expected_node_occurrence=node_occurrence, custom_qconfig_dict=custom_qconfig_dict, backend_config=backend_config) node_occurrence = { # input_add, output_add for scalar # input_add1, input_add2, output_add for non-scalar ns.call_method("to"): 2 if is_scalar else 3 } self.checkGraphModeFxOp( BinaryOpNonQuantizedInput(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type, expected_node_occurrence=node_occurrence, custom_qconfig_dict=custom_qconfig_dict, backend_config=backend_config) def _test_binary_op_relu_int8_impl(self, binary_op, ibinary_op, quantized_op): data = (torch.rand((1, 1, 1, 1), dtype=torch.float), torch.rand((1, 1, 1, 1), dtype=torch.float)) quant_type = QuantType.STATIC quantized_node = ns.call_function(quantized_op) options = itertools.product( [True, False], [nn.ReLU, F.relu, torch.relu], [True, False]) for is_inplace_op, relu_callable, is_scalar in options: model = BinaryOpRelu( binary_op, ibinary_op, is_inplace_op, relu_callable, is_scalar) self.checkGraphModeFxOp( model, data, quant_type, quantized_node) def _test_binary_op_relu_float16_impl(self, binary_op, ibinary_op): data = (torch.rand((1, 1, 1, 1), dtype=torch.float), torch.rand((1, 1, 1, 1), dtype=torch.float)) quant_type = QuantType.STATIC options = itertools.product( [True, False], [nn.ReLU, F.relu, torch.relu], [True, False]) custom_qconfig_dict = { "": float16_static_qconfig, "object_type": [(torch.nn.Conv2d, None)] } backend_config = get_test_only_legacy_native_backend_config() for is_inplace_op, is_functional_relu, is_scalar in options: node_occurrence = { ns.call_method("to"): 3 if is_scalar else 4 } model = BinaryOpRelu( binary_op, ibinary_op, is_inplace_op, is_functional_relu, is_scalar) self.checkGraphModeFxOp( model, data, quant_type, custom_qconfig_dict=custom_qconfig_dict, expected_node_occurrence=node_occurrence, backend_config=backend_config) @skipIfNoFBGEMM def test_add(self): self._test_binary_op_int8_impl( operator.add, operator.iadd, torch.ops.quantized.add) self._test_binary_op_float16_impl( operator.add, operator.iadd) @unittest.skip("This is no longer needed right now, can enable later with new api") def test_sub(self): self._test_binary_op_float16_impl(operator.sub, operator.isub) self._test_binary_op_float16_impl(torch.sub, None) @unittest.skip("This is no longer needed right now, can enable later with new api") def test_div(self): self._test_binary_op_float16_impl(operator.truediv, operator.itruediv) self._test_binary_op_float16_impl(torch.div, None) @skipIfNoFBGEMM def test_mul(self): self._test_binary_op_int8_impl( operator.mul, operator.imul, torch.ops.quantized.mul) self._test_binary_op_float16_impl(operator.mul, operator.imul) @unittest.skip("This is no longer needed right now, can enable later with new api") def test_sum(self): class Sum(torch.nn.Module): def forward(self, x): x = torch.sum(x, [1], keepdim=True) x = torch.sum(x, [1]) return x data = torch.randn(1, 2, 3, 4, dtype=torch.float) quant_type = QuantType.STATIC # testing for fp16 static quant # we are producing fp16 patterns custom_qconfig_dict = { "object_type": [(torch.sum, float16_static_qconfig)] } node_occurrence = { # input_sum1, output_sum1, output_sum2 ns.call_method("to"): 3 } self.checkGraphModeFxOp( Sum(), data, quant_type, expected_node_occurrence=node_occurrence, custom_qconfig_dict=custom_qconfig_dict) @unittest.skip("This is no longer needed right now, can enable later with new api") def test_bmm(self): class BMMMethod(torch.nn.Module): def forward(self, x, y): return x.bmm(y) data = (torch.randn(1, 1, 1, dtype=torch.float), torch.randn(1, 1, 1, dtype=torch.float)) quant_type = QuantType.STATIC # testing for fp16 static quant # we are producing fp16 patterns custom_qconfig_dict = { "object_type": [(torch.bmm, float16_static_qconfig), ("bmm", float16_static_qconfig)] } node_occurrence = { # input_bmm1, input_bmm2, output_bmm ns.call_method("to"): 3 } self.checkGraphModeFxOp( BinaryOpNonQuantizedInput(torch.bmm, None, False, False), data, quant_type, expected_node_occurrence=node_occurrence, custom_qconfig_dict=custom_qconfig_dict) # TODO: support call_method("bmm") # we can transform call_method("bmm") to call_function(torch.bmm) # self.checkGraphModeFxOp( # BMMMethod(), data, quant_type, # expected_node_occurrence=node_occurrence, # custom_qconfig_dict=custom_qconfig_dict, # print_debug_info=True) @skipIfNoFBGEMM def test_add_relu(self): self._test_binary_op_relu_int8_impl( operator.add, operator.iadd, torch.ops.quantized.add_relu) self._test_binary_op_relu_float16_impl( operator.add, operator.iadd) @skipIfNoFBGEMM def test_add_relu_multiple_uses_of_relu(self): class Sub(torch.nn.Module): def __init__(self) -> None: super().__init__() self.relu = torch.nn.ReLU(inplace=True) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.sub = Sub() def forward(self, x, y): x = x + y x = self.sub.relu(x) x = x + y x = self.sub.relu(x) return x m = M().eval() example_inputs = (torch.randn(3), torch.randn(3)) m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) m = convert_fx(m) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 2, ns.call_function(torch.ops.quantized.add_relu): 2, ns.call_method("dequantize"): 1, } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) # check the model is scriptable m = torch.jit.script(m) # check the model is runnable m(*example_inputs) @skipIfNoFBGEMM def test_mul_relu(self): self._test_binary_op_relu_int8_impl( operator.mul, operator.imul, torch.ops.quantized.mul_relu) self._test_binary_op_relu_float16_impl( operator.mul, operator.imul) # TODO(future PR): make more generic def _test_quantized_add_mul_qat(self, model, example_inputs, expected_node_occurrence): qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} mp = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) self.checkGraphModuleNodes( mp, expected_node_occurrence=expected_node_occurrence) @skipIfNoFBGEMM def test_quantized_add_qat(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = torch.nn.Conv2d(1, 1, 1) self.conv2 = torch.nn.Conv2d(1, 1, 1) def forward(self, x): x = torch.add(x, 1.0) x = self.conv1(x) x = torch.add(x, 1.0) x = torch.relu(x) x = self.conv2(x) return x m = M() example_inputs = (torch.randn(1, 1, 1, 1),) expected_node_occurrence = { ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 5, } self._test_quantized_add_mul_qat(m, example_inputs, expected_node_occurrence) @skipIfNoFBGEMM def test_quantized_mul_qat(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = torch.nn.Conv2d(1, 1, 1) self.conv2 = torch.nn.Conv2d(1, 1, 1) def forward(self, x): x = torch.mul(x, 1.0) x = self.conv1(x) x = torch.mul(x, 1.0) x = torch.relu(x) x = self.conv2(x) return x m = M() example_inputs = (torch.randn(1, 1, 1, 1),) expected_node_occurrence = { ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 5, } self._test_quantized_add_mul_qat(m, example_inputs, expected_node_occurrence) def test_int8_input_no_unnecessary_fq(self): """ If the inputs to the graph are quantized and the only node does not need an activation observer, verifies that the activation observer is not inserted. """ class M(nn.Module): def __init__(self, scalar): super().__init__() self.scalar = scalar self.add_func = torch.ao.nn.quantized.FloatFunctional() def forward(self, x): return self.add_func.add_scalar(x, self.scalar) m = M(0.5) mp = torch.ao.quantization.quantize_fx.prepare_qat_fx( m, {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}, example_inputs=(torch.randn(1),), prepare_custom_config={"input_quantized_idxs": [0]}) expected_node_occurrence = { ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 1, } self.checkGraphModuleNodes( mp, expected_node_occurrence=expected_node_occurrence) @skipIfNoFBGEMM def test_cat(self): """ quantization of the output of cat will depend on the input of cat. we only quantize the output of cat when its inputs are quantized. """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = torch.nn.Conv2d(2, 2, 2).float() self.conv2 = torch.nn.Conv2d(2, 2, 2).float() def forward(self, x, y): x = self.conv1(x) y = self.conv2(y) return torch.cat([x, y], 1) example_inputs = (torch.randn(1, 2, 5, 5, dtype=torch.float), torch.randn(1, 2, 5, 5, dtype=torch.float)) quantized_node = ns.call_function(torch.cat) options = itertools.product(self.static_quant_types, [True, False]) for quant_type, is_reference in options: if is_reference: converted_node_list = [ ns.call_method("dequantize"), ns.call_function(torch.cat), ns.call_function(torch.quantize_per_tensor) ] converted_node_occurrence = { # inputs and outputs of the two conv, and output of cat ns.call_method("dequantize"): 5, ns.call_function(torch.cat): 1, # inputs and outputs of the two conv, and output of cat ns.call_function(torch.quantize_per_tensor): 5, } else: converted_node_list = None converted_node_occurrence = { # output of cat ns.call_method("dequantize"): 1, ns.call_function(torch.cat): 1, # for two inputs ns.call_function(torch.quantize_per_tensor): 2, } self.checkGraphModeFxOp( M(), example_inputs, quant_type, quantized_node, expected_node_list=converted_node_list, expected_node_occurrence=converted_node_occurrence, is_reference=is_reference) # check cat is using the same observer for input and output m = M().eval() m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) # two inputs and one output of torch.cat are using same observer, so we have # 2 observers that's replicated all_observers = len(dict(m.named_modules(remove_duplicate=False))) distinct_observers = len(dict(m.named_modules())) self.assertEqual(all_observers, distinct_observers + 2) # make sure the converted model runs m = convert_fx(m) m(*example_inputs) @skipIfNoFBGEMM def test_qbatch_norm(self): bn_module = { # TODO: quantized batchnorm 1d module is missing # 1 : torch.nn.BatchNorm1d, 2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d, } class M(torch.nn.Module): def __init__(self, dim): super().__init__() self.bn = bn_module[dim](3).to(torch.float) def forward(self, x): return self.bn(x) options = itertools.product(self.static_quant_types, [2, 3], [True, False]) quantized_nodes = { False: { # 1: ns.call_module(nnq.BatchNorm1d), 2: ns.call_module(nnq.BatchNorm2d), 3: ns.call_module(nnq.BatchNorm3d), }, True: { # 1: ns.call_module(nn.BatchNorm1d), 2: ns.call_module(nn.BatchNorm2d), 3: ns.call_module(nn.BatchNorm3d), } } for quant_type, dim, is_reference in options: self.checkGraphModeFxOp( M(dim), self.img_data_dict[dim], quant_type, quantized_nodes[is_reference][dim], is_reference=is_reference) @skipIfNoFBGEMM def test_qbatch_norm_relu(self): bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d} class BNRelu(torch.nn.Module): def __init__(self, dim, inplace): super().__init__() self.bn = bn_module[dim](3).to(torch.float) self.relu = torch.nn.ReLU(inplace=inplace) def forward(self, x): return self.relu(self.bn(x)) class BNFuncRelu(torch.nn.Module): def __init__(self, dim): super().__init__() self.bn = bn_module[dim](3).to(torch.float) def forward(self, x): return F.relu(self.bn(x), False) class BNFuncInplaceRelu(torch.nn.Module): def __init__(self, dim): super().__init__() self.bn = bn_module[dim](3).to(torch.float) def forward(self, x): return F.relu(self.bn(x), True) options = itertools.product(self.static_quant_types, [2, 3], [True, False]) quantized_nodes = { True: { 2: ns.call_module(nni.BNReLU2d), 3: ns.call_module(nni.BNReLU3d), }, False: { 2: ns.call_module(nniq.BNReLU2d), 3: ns.call_module(nniq.BNReLU3d), } } for quant_type, dim, is_reference in options: for instance in [BNRelu(dim, True), BNRelu(dim, False), BNFuncRelu(dim), BNFuncInplaceRelu(dim)]: self.checkGraphModeFxOp( instance, self.img_data_dict[dim], quant_type, quantized_nodes[is_reference][dim], is_reference=is_reference) def _test_activation_impl( self, float_module, float_op, quantized_module, quantized_op): ''' Test for activation op(with inplace options), float_op can be torch op or functional op ''' class M(torch.nn.Module): def __init__(self, is_module, inplace): super().__init__() self.is_module = is_module self.inplace = inplace if self.is_module: self.op = float_module(self.inplace) else: self.op = float_op def forward(self, input): if self.is_module: return self.op(input) else: return self.op(input, self.inplace) options = itertools.product([True, False], [True, False], self.static_quant_types, [True, False]) quantized_nodes = { # is_module True: { # is_reference True: ns.call_module(float_module), False: ns.call_module(quantized_module), }, False: { True: ns.call_function(float_op), False: ns.call_function(quantized_op), } } for is_module, is_inplace, quant_type, is_reference in options: self.checkGraphModeFxOp( M(is_module, is_inplace), self.img_data_2d, quant_type, quantized_nodes[is_module][is_reference], is_reference=is_reference) def test_hardswish(self): self._test_activation_impl(nn.Hardswish, F.hardswish, nnq.Hardswish, torch.ops.quantized.hardswish) def test_elu(self): self._test_activation_impl(nn.ELU, F.elu, nnq.ELU, torch.ops.quantized.elu) def test_leaky_relu(self): self._test_activation_impl(nn.LeakyReLU, F.leaky_relu, nnq.LeakyReLU, torch.ops.quantized.leaky_relu) def test_prelu(self): class M(torch.nn.Module): def __init__(self, num_param: int): super().__init__() self.op = torch.nn.PReLU(num_parameters=num_param) def forward(self, input): return self.op(input) X = [[torch.randn(4, 4, 4, 4, dtype=torch.float)]] options = itertools.product([1, 4], self.static_quant_types, [True, False]) quantized_nodes = { # is_reference True: ns.call_module(torch.nn.PReLU), False: ns.call_module(torch.ao.nn.quantized.PReLU), } for num_parameter, quant_type, is_reference in options: self.checkGraphModeFxOp( M(num_parameter), X, quant_type, quantized_nodes[is_reference], is_reference=is_reference) def _test_norm_impl( self, float_module, float_op, op_args, data, quantized_module, quantized_op, skip_op_arg_for_functional=False): ''' Test for normalization op, float_op can be torch op or functional op, op_args is a list of positional argument for the module/op ''' class M(torch.nn.Module): def __init__(self, is_module): super().__init__() self.is_module = is_module if self.is_module: self.op = float_module(*op_args) else: self.op = float_op def forward(self, input): if self.is_module: return self.op(input) else: args = [input] if not skip_op_arg_for_functional: args += op_args return self.op(*args) options = itertools.product([True, False], self.static_quant_types) quantized_nodes = { # is_module True: ns.call_module(quantized_module), False: ns.call_function(quantized_op), } for is_module, quant_type in options: self.checkGraphModeFxOp( M(is_module), data, quant_type, quantized_nodes[is_module]) def _test_norm_float16_impl( self, float_module, float_op, op_args, data, skip_op_arg_for_functional=False): ''' Test for normalization op, float_op can be torch op or functional op, op_args is a list of positional argument for the module/op ''' class M(torch.nn.Module): def __init__(self, is_module): super().__init__() self.is_module = is_module if self.is_module: self.op = float_module(*op_args) else: self.op = float_op def forward(self, input): if self.is_module: return self.op(input) else: args = [input] if not skip_op_arg_for_functional: args += op_args return self.op(*args) options = itertools.product([True, False], self.static_quant_types) qconfig_dict = { "object_type": [ (float_module, float16_static_qconfig), (float_op, float16_static_qconfig) ] } node_occurrence = { ns.call_method("to"): 2 } for is_module, quant_type in options: self.checkGraphModeFxOp( M(is_module), data, quant_type, custom_qconfig_dict=qconfig_dict, expected_node_occurrence=node_occurrence) def test_layer_norm(self): data = (torch.rand((1, 2, 5, 5), dtype=torch.float),) self._test_norm_impl( nn.LayerNorm, F.layer_norm, [[2, 5, 5]], data, nnq.LayerNorm, torch.ops.quantized.layer_norm) def test_instance_norm(self): data_1d = (torch.rand((1, 4, 5), dtype=torch.float),) data_2d = (torch.rand((1, 4, 5, 1), dtype=torch.float),) data_3d = (torch.rand((1, 4, 5, 1, 1), dtype=torch.float),) data_dict = {1 : data_1d, 2 : data_2d, 3 : data_3d} instance_norm_modules = {1 : nn.InstanceNorm1d, 2 : nn.InstanceNorm2d, 3 : nn.InstanceNorm3d} quantized_instance_norm_modules = { 1 : nnq.InstanceNorm1d, 2 : nnq.InstanceNorm2d, 3 : nnq.InstanceNorm3d } for dim in [1, 2, 3]: data = data_dict[dim] module = instance_norm_modules[dim] quantized_module = quantized_instance_norm_modules[dim] self._test_norm_impl( module, F.instance_norm, [4], data, quantized_module, torch.ops.quantized.instance_norm, skip_op_arg_for_functional=True) def test_norm_weight_bias(self): class Linear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.w = torch.ones(5, 5) self.b = torch.zeros(5) def forward(self, x): return torch.nn.functional.linear(x, self.w, self.b) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mods1 = Linear() self.scale = torch.randn(5, 5) self.bias = torch.randn(5, 5) def forward(self, x): x1 = self.mods1(x) y = F.layer_norm(x1, [5, 5], weight=self.scale, bias=self.bias) return y model = M() expected_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_function(torch.ops.quantized.linear): 1, ns.call_function(torch.ops.quantized.layer_norm): 1, ns.call_method("dequantize"): 1, } self.checkGraphModeFxOp( model, (torch.rand(5, 5),), QuantType.STATIC, expected_node_occurrence=expected_occurrence, custom_qconfig_dict=get_default_qconfig_mapping().to_dict() ) def _test_default_node_quant_handler_ops( self, module, functional, qconfig, is_reference=True, node_list=None, additional_quant_pattern_dict=None ): class M(torch.nn.Module): def __init__(self, mod, func): super().__init__() self.module = mod() self.functional = func def forward(self, x): x = self.module(x) x = self.functional(x) return x if node_list is None: node_list = [] if additional_quant_pattern_dict is None: additional_quant_pattern_dict = {} data = torch.randn((2, 2, 2, 2)) quant_type = QuantType.STATIC prepare_custom_qconfig_dict = {"additional_quant_pattern": additional_quant_pattern_dict} qconfig_dict = {"": qconfig} m = M(module, functional).eval() m_prep = prepare_fx(m, qconfig_dict, prepare_custom_qconfig_dict) m_prep(data) convert_fn = convert_to_reference_fx if is_reference else convert_fx m_quant = convert_fn(m_prep, is_reference=is_reference) m_quant(data) self.checkGraphModuleNodes(m_quant, expected_node_list=node_list) @unittest.skip("TODO: reenable with backend_config api") def test_gelu_normal(self): module = torch.nn.GELU functional = torch.nn.functional.gelu qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") is_reference = False node_list = [ ns.call_module(module), ns.call_function(functional), ] self._test_default_node_quant_handler_ops( module, functional, qconfig, is_reference, node_list) @unittest.skip("TODO: reenable with backend_config api") def test_softmax_normal(self): module = torch.nn.Softmax functional = torch.nn.functional.softmax qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") is_reference = False node_list = [ ns.call_module(torch.ao.nn.quantized.Softmax), ns.call_function(functional), ] self._test_default_node_quant_handler_ops( module, functional, qconfig, is_reference, node_list) @unittest.skip("This is no longer needed right now, can enable later with new api") def test_gelu_reference(self): module = torch.nn.GELU functional = torch.nn.functional.gelu qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") is_reference = True node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_method("dequantize"), ns.call_module(module), ns.call_function(torch.quantize_per_tensor), ns.call_method('dequantize'), ns.call_function(functional), ns.call_function(torch.quantize_per_tensor), ns.call_method('dequantize') ] # TODO: change these to use backend_config additional_patterns = {torch.nn.GELU: DefaultNodeQuantizeHandler, torch.nn.functional.gelu: DefaultNodeQuantizeHandler} self._test_default_node_quant_handler_ops( module, functional, qconfig, is_reference, node_list, additional_patterns) self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list, additional_quant_pattern_dict=self.common_quant_patterns) @unittest.skip("This is no longer needed right now, can enable later with new api") def test_softmax_reference(self): module = torch.nn.Softmax functional = torch.nn.functional.softmax qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") is_reference = True node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_method("dequantize"), ns.call_module(module), ns.call_function(torch.quantize_per_tensor), ns.call_method('dequantize'), ns.call_function(functional), ns.call_function(torch.quantize_per_tensor), ns.call_method('dequantize') ] additional_patterns = {torch.nn.Softmax: DefaultNodeQuantizeHandler, torch.nn.functional.softmax: DefaultNodeQuantizeHandler} self._test_default_node_quant_handler_ops( module, functional, qconfig, is_reference, node_list, additional_patterns) self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list, additional_quant_pattern_dict=self.common_quant_patterns) @unittest.skip("This is no longer needed right now, can enable later with new api") def test_silu_reference(self): module = torch.nn.SiLU functional = torch.nn.functional.silu qconfig = float16_static_qconfig is_reference = True node_list = [ ns.call_method("to"), ns.call_method("dequantize"), ns.call_module(module), ns.call_method("to"), ns.call_method('dequantize'), ns.call_function(functional), ns.call_method("to"), ns.call_method('dequantize') ] self._test_default_node_quant_handler_ops( module, functional, qconfig, is_reference, node_list) node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_method("dequantize"), ns.call_module(module), ns.call_function(torch.quantize_per_tensor), ns.call_method("dequantize"), ns.call_function(functional), ns.call_function(torch.quantize_per_tensor), ns.call_method("dequantize") ] self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list, additional_quant_pattern_dict=self.common_quant_patterns) @unittest.skip("This is no longer needed right now, can enable later with new api") def test_mish_reference(self): module = torch.nn.Mish functional = torch.nn.functional.mish qconfig = float16_static_qconfig is_reference = True node_list = [ ns.call_method("to"), ns.call_method("dequantize"), ns.call_module(module), ns.call_method("to"), ns.call_method('dequantize'), ns.call_function(functional), ns.call_method("to"), ns.call_method('dequantize') ] self._test_default_node_quant_handler_ops( module, functional, qconfig, is_reference, node_list) node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_method("dequantize"), ns.call_module(module), ns.call_function(torch.quantize_per_tensor), ns.call_method("dequantize"), ns.call_function(functional), ns.call_function(torch.quantize_per_tensor), ns.call_method("dequantize") ] self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list, additional_quant_pattern_dict=self.common_quant_patterns) def test_bmm_int_reference(self): """ int8 is not supported for bmm so we won't produce reference pattern for it """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.bmm = torch.bmm def forward(self, x, y): out = self.bmm(x, y) return out data_x = torch.randn((2, 2, 2,)) data_y = torch.randn((2, 2, 2,)) example_inputs = (data_x, data_y) qconfig_dict = {"": torch.ao.quantization.get_default_qconfig("fbgemm")} is_reference = True node_list = [ ns.call_function(torch.bmm), ] m = M().eval() m_prep = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) m_prep(*example_inputs) convert_fn = convert_to_reference_fx if is_reference else convert_fx m_quant = convert_fn(m_prep) m_quant(*example_inputs) self.checkGraphModuleNodes(m_quant, expected_node_list=node_list) @skipIfNoFBGEMM def test_clamp(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(2, 2, 2).float() self.relu6 = torch.nn.ReLU6() self.relu6_ = torch.nn.ReLU6(True) self.hardtanh = torch.nn.Hardtanh() self.hardtanh_ = torch.nn.Hardtanh(inplace=True) def forward(self, x): x = self.conv(x) x = self.relu6(x) self.relu6_(x) x = F.relu6(x) x = torch.clamp(x, -3, 3) x = x.clamp(-2.5, 2.5) # x = x.clamp_(-2, 2) # Enable when quantized `clamp_` is ready x = self.hardtanh(x) self.hardtanh_(x) x = F.hardtanh(x) return x data = (torch.rand((1, 2, 5, 5), dtype=torch.float),) # list of node that should occur in order node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_method('dequantize') ] for quant_type in self.static_quant_types: self.checkGraphModeFxOp( M(), data, quant_type, expected_node_list=node_list) def test_fixed_qparams_ops_fp16(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.sigmoid = torch.nn.Sigmoid() self.tanh = torch.nn.Tanh() def forward(self, x): x = self.sigmoid(x) x = torch.sigmoid(x) x = x.sigmoid() x = self.tanh(x) x = torch.tanh(x) x = x.tanh() return x data = (torch.randn((2, 2, 2, 2), dtype=torch.float),) quant_type = QuantType.STATIC # TODO: use get_default_qconfig_mapping once it handles fp16 qconfig_mapping = QConfigMapping().set_global(float16_static_qconfig) backend_config = get_test_only_legacy_native_backend_config() node_occurrence = { ns.call_method("to"): 7 } self.checkGraphModeFxOp( M(), data, quant_type, custom_qconfig_dict=qconfig_mapping, expected_node_occurrence=node_occurrence, backend_config=backend_config) def test_fixed_qparams_ops_qint8(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.sigmoid = torch.nn.Sigmoid() self.tanh = torch.nn.Tanh() def forward(self, x): x = self.sigmoid(x) x = torch.sigmoid(x) x = x.sigmoid() x = self.tanh(x) x = torch.tanh(x) x = x.tanh() return x data = (torch.randn((2, 2, 2, 2), dtype=torch.float),) quant_type = QuantType.STATIC qconfig = torch.ao.quantization.QConfig( activation=HistogramObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.quint8), weight=default_weight_observer) qconfig_mapping = get_default_qconfig_mapping().set_global(qconfig) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 7, ns.call_method("dequantize"): 7 } self.checkGraphModeFxOp( M(), data, quant_type, custom_qconfig_dict=qconfig_mapping, expected_node_occurrence=node_occurrence, is_reference=True) def test_fixed_qparams_ops_wrong_qconfig(self): """ Test that wrong qconfigs for fixed qparams ops results in the ops not being quantized. """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.sigmoid = torch.nn.Sigmoid() self.tanh = torch.nn.Tanh() def forward(self, x): x = self.sigmoid(x) x = torch.sigmoid(x) x = x.sigmoid() x = self.tanh(x) x = torch.tanh(x) x = x.tanh() return x data = (torch.randn((2, 2, 2, 2), dtype=torch.float),) qconfig_mapping = QConfigMapping().set_global(default_qconfig) m = M().eval() node_occurrence = { ns.call_function(torch.quantize_per_tensor): 0, ns.call_method("dequantize"): 0, } self.checkGraphModeFxOp( m, data, QuantType.STATIC, custom_qconfig_dict=qconfig_mapping, expected_node_occurrence=node_occurrence, is_reference=True) self.assertTrue(isinstance(m.sigmoid, torch.nn.Sigmoid)) self.assertTrue(isinstance(m.tanh, torch.nn.Tanh)) @skipIfNoFBGEMM def test_general_shape_ops(self): """ A test that checks dequantize will be swapped for all supported general shape ops like aten::flatten without actually checking for execution of these ops """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3) self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3) self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3) self.dropout = torch.nn.Dropout() self.conv1 = torch.nn.Conv2d(3, 3, 3) self.conv2 = torch.nn.Conv2d(3, 3, 3) self.relu = torch.nn.ReLU() def forward(self, x): x = self.conv1(x) # add_scalar x = x + 3 # mul_scalar x = x * 3 # add_scalar_out x += 3 # mul_scalar_out x *= 3 # add_scalar_relu x = x + 3 x = F.relu(x) # add_scalar_relu_out x += 3 x = F.relu(x) # mul_scalar_relu x = x * 3 x = F.relu(x) # mul_scalar_relu_out x *= 3 x = F.relu(x) x = self.maxpool1d(x) x = self.maxpool2d(x) x = self.maxpool3d(x) x = torch.flatten(x) x = x.reshape([-1]) x = x.resize_(1, 1, x) x = x.view(-1) # prim::ListConstruct xs = [x, x] # prim::ListUnpack x, y = xs # prim::TupleConstruct xs = (x, x) # prim::TupleUnpack x, y = xs x = x.transpose(1, 2) x = x.contiguous() # chunk is not supported since observer only supports # observing single Tensor currently x, y = torch.chunk(x, 2) x = F.dropout(x) x = self.dropout(x) x = x.permute(0, 2, 3, 1) x = x.repeat_interleave(3, 1) x = torch.repeat_interleave(x, 3, 1) x = self.relu(x) x = F.relu(x) x = F.relu(x, inplace=True) x = x.relu() x.relu_() x = x.squeeze(0) x.squeeze_(0) x = torch.squeeze(x, 0) x = x.unsqueeze(0) x.unsqueeze_(0) x = torch.unsqueeze(x, 0) x = x.detach() x.detach_() x = x.repeat(4, 2) y = [] y.append(x) z = torch.stack(y, 0) z = [z, z] x, _ = z x = self.conv2(x) return x example_inputs = (torch.rand(1, 3, 10, 10),) # This model is not executable since we just put all ops # in the same forward m = M().eval() qconfig_dict = {'': default_qconfig} prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) # not runnable quantized = convert_fx(prepared) # This checks that the dequantize from the output of first conv # is being propagated to the end, so that we don't insert extra # observers and also successfully fused two quantized::conv2d # patterns # one quantize_per_tensor for input # check exact counts of quantize and dequantize count_check = { # input of conv and two outputs of getitem ns.call_function(torch.quantize_per_tensor) : 2, # output of the model and two outputs of getitem ns.call_method('dequantize') : 2 } order_check = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_module(nnq.Conv2d), ns.call_method('dequantize'), ] self.checkGraphModuleNodes( quantized, expected_node_occurrence=count_check, expected_node_list=order_check) # Checking the is_reference output m = M().eval() qconfig_dict = {'': default_qconfig} prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) # not runnable quantized = convert_to_reference_fx(prepared) @skipIfNoFBGEMM def test_ave_pool_with_custom_cfg(self): """ A test that checks correct patterns are produced for avg_pool2d with customized config """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.avg_pool2d = torch.nn.AvgPool2d(3) def forward(self, x): x = self.avg_pool2d(x) return x # This model is not executable since we just put all ops # in the same forward m = M().eval() # nothing to fuse so skipping the fuse step qconfig_dict = {'': default_qconfig} example_inputs = (torch.randn(1, 3, 3, 3),) prepared = prepare_fx( m, qconfig_dict, example_inputs=example_inputs, prepare_custom_config={"input_quantized_idxs": [0]}) # not runnable quantized = convert_fx(prepared) # This checks that the dequantize from the output of first conv # is being propagated to the end, so that we don't insert extra # observers # check exact counts of quantize and dequantize count_check = { ns.call_method('dequantize') : 1 } order_check = [ ns.call_module(nn.AvgPool2d), ns.call_method('dequantize'), ] self.checkGraphModuleNodes( quantized, expected_node_occurrence=count_check, expected_node_list=order_check) @skipIfNoFBGEMM def test_general_value_ops(self): """ A test that checks correct patterns are produced for all supported general value ops like aten::avg_pool2d \ without actually checking for execution of these ops """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.avg_pool1d = torch.nn.AvgPool1d(3) self.avg_pool2d = torch.nn.AvgPool2d(3) self.avg_pool3d = torch.nn.AvgPool3d(3) self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d(1) self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1)) def forward(self, x): x = self.conv(x) x = self.avg_pool1d(x) x = self.avg_pool2d(x) x = self.avg_pool3d(x) x = self.adaptive_avg_pool1d(x) x = self.adaptive_avg_pool2d(x) x = self.adaptive_avg_pool3d(x) x = F.avg_pool1d(x, 3) x = F.avg_pool2d(x, 3) x = F.avg_pool3d(x, 3) x = F.adaptive_avg_pool1d(x, (1)) x = F.adaptive_avg_pool2d(x, (1, 1)) x = F.adaptive_avg_pool3d(x, (1, 1, 1)) x = torch.mean(x) x = torch.mean(x, [2, 3], False) x = x.mean() x = x.mean([2, 3], True) x = F.interpolate(x, 4, mode='nearest') x = F.interpolate(x, 4, mode='linear') x = self.conv(x) return x # This model is not executable since we just put all ops # in the same forward m = M().eval() # nothing to fuse so skipping the fuse step qconfig_dict = {'': default_qconfig} example_inputs = (torch.randn(1, 3, 3, 3),) prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) # not runnable quantized = convert_fx(prepared) # This checks that the dequantize from the output of first conv # is being propagated to the end, so that we don't insert extra # observers # check exact counts of quantize and dequantize count_check = { ns.call_function(torch.quantize_per_tensor) : 1, ns.call_method('dequantize') : 1 } order_check = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_module(nnq.Conv2d), ns.call_method('dequantize'), ] self.checkGraphModuleNodes( quantized, expected_node_occurrence=count_check, expected_node_list=order_check) def test_copy_node_fp32_input(self): """ CopyNode works for both fp32 and int8 inputs, this is a test to make sure that a CopyNode can be successfully quantized in both cases """ class M(torch.nn.Module): def forward(self, x): x = x.relu() return x m = M().eval() m = prepare_fx(m, {"": default_reuse_input_qconfig}, example_inputs=(torch.randn(1),)) m = convert_fx(m) # make sure it runs m(torch.rand(1)) def test_getitem(self): """ Make sure we only insert observer for getitem if the following node is matched or needs to be quantized """ class M(torch.nn.Module): def forward(self, xs): x = xs[0] return x m = M().eval() example_inputs = (torch.rand(1, 2),) qconfig_mapping = get_default_qconfig_mapping() m = prepare_fx(m, qconfig_mapping, example_inputs=example_inputs) self.checkGraphModuleNodes(m, expected_node_occurrence={ ns.call_module(torch.ao.quantization.MinMaxObserver): 0 }) m = convert_fx(m) m(*example_inputs) class M2(torch.nn.Module): def forward(self, xs): x = xs[0] x = torch.sigmoid(x) return x m2 = M2().eval() example_inputs = ([torch.rand(1, 2)],) qconfig_mapping = get_default_qconfig_mapping() m2 = prepare_fx(m2, qconfig_mapping, example_inputs=example_inputs) self.checkGraphModuleNodes(m2, expected_node_occurrence={ ns.call_module(torch.ao.quantization.FixedQParamsObserver): 2 }) m2 = convert_fx(m2) self.checkGraphModuleNodes(m2, expected_node_list=[ ns.call_function(torch.quantize_per_tensor), ns.call_method("dequantize") ]) m2(*example_inputs) # testing prepare recognizes non-Tensor input for getitem class M3(torch.nn.Module): def forward(self, x): s = x.shape n, c = s[:2] x = torch.sigmoid(x) return x m3 = M3().eval() example_inputs = (torch.rand(1, 2, 3, 4),) qconfig_mapping = get_default_qconfig_mapping() m3 = prepare_fx(m3, qconfig_mapping, example_inputs=example_inputs) self.checkGraphModuleNodes(m3, expected_node_occurrence={ ns.call_module(torch.ao.quantization.FixedQParamsObserver): 2 }) m3 = convert_fx(m3) self.checkGraphModuleNodes(m3, expected_node_list=[ ns.call_function(torch.quantize_per_tensor), ns.call_method("dequantize") ]) m3(*example_inputs) @skipIfNoFBGEMM def test_fixed_qparams_ops(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.sigmoid = torch.nn.Sigmoid() self.hardsigmoid = torch.nn.Hardsigmoid() self.tanh = torch.nn.Tanh() self.softmax = torch.nn.Softmax(dim=0) def forward(self, x): x = self.conv(x) # F.sigmoid is deprecated x = self.sigmoid(x) x = torch.sigmoid(x) x = x.sigmoid() x = self.hardsigmoid(x) x = F.hardsigmoid(x) x = F.hardsigmoid(x, inplace=True) x = self.tanh(x) # F.tanh is deprecated x = torch.tanh(x) x = x.tanh() # TODO(future PR): handle F.softmax x = self.softmax(x) return x for eval_mode in [True, False]: # This model is not executable since we just put all ops # in the same forward m = M() if eval_mode: m.eval() qconfig_mapping = get_default_qconfig_mapping() prepare = prepare_fx fq_count = 10 else: m.train() qconfig_mapping = get_default_qat_qconfig_mapping() prepare = prepare_qat_fx fq_count = 10 # nothing to fuse so skipping the fuse step m_copy = copy.deepcopy(m) example_inputs = (torch.rand(3, 3, 3, 3),) prepared = prepare(m, qconfig_mapping, example_inputs=example_inputs) prepared_copy = copy.deepcopy(prepared) # check that prepare does not change model result if eval_mode: self.assertEqual(m_copy(*example_inputs), prepared_copy(*example_inputs)) # check the correct number of activation_post_process is inserted expected_activation_post_process = FixedQParamsObserver if eval_mode else FixedQParamsFakeQuantize count_check = { ns.call_module(expected_activation_post_process) : fq_count, } self.checkGraphModuleNodes( prepared, expected_node_occurrence=count_check) # not runnable quantized = convert_fx(prepared) quantized_reference = convert_to_reference_fx(prepared_copy) # This checks that the dequantize from the output of first conv # is being propagated to the end, so that we don't insert extra # observers # check exact counts of quantize and dequantize count_check = { ns.call_function(torch.quantize_per_tensor) : 1, ns.call_method('dequantize') : 1 } order_check = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_module(nn.Sigmoid), ns.call_module(nnq.Softmax), ns.call_method('dequantize'), ] self.checkGraphModuleNodes( quantized, expected_node_occurrence=count_check, expected_node_list=order_check) reference_count_check = { ns.call_function(torch.quantize_per_tensor) : 12, ns.call_method('dequantize') : 12 } reference_order_check = [ ns.call_function(torch.quantize_per_tensor), ns.call_method('dequantize'), ns.call_module(nnqr.Conv2d), ns.call_function(torch.quantize_per_tensor), ns.call_method('dequantize'), ns.call_module(nn.Sigmoid), ns.call_function(torch.quantize_per_tensor), ns.call_method('dequantize'), ns.call_module(nn.Softmax), ns.call_function(torch.quantize_per_tensor), ns.call_method('dequantize'), ] self.checkGraphModuleNodes( quantized_reference, expected_node_occurrence=reference_count_check, expected_node_list=reference_order_check) # Verify that softmax scale and zero_point are correct self.assertTrue(quantized.softmax.scale - (1.0 / 256) <= 1e-8) self.assertTrue(quantized.softmax.zero_point == 0) def test_float_functional(self): class TorchAdd(nn.Module): """Wrapper around torch.add so that all ops can be found at build""" def __init__(self) -> None: super().__init__() self.add_func = nnq.FloatFunctional() def forward(self, x, y): return self.add_func.add(x, y) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.ff1 = TorchAdd() self.ff2 = nnq.FloatFunctional() self.ff3 = nnq.FloatFunctional() self.ff4 = nnq.FloatFunctional() self.ff5 = nnq.FloatFunctional() self.ff6 = nnq.FloatFunctional() def forward(self, x): x = self.ff1(x, x) x = self.ff2.add_scalar(x, 3) x = self.ff3.mul(x, x) x = self.ff4.mul_scalar(x, 3) x = self.ff5.add_relu(x, x) x = self.ff6.cat([x]) return x example_inputs = (torch.rand(3, 3),) # Note: QAT test succeeded by chance, to make it actually work # we need to fix eager mode FloatFunctional by removing # activation_post_process in add_scalar and mul_scalar for quant_type in self.static_quant_types: m = M() ref_m = torch.ao.quantization.QuantWrapper(M()) is_qat = quant_type == QuantType.QAT if is_qat: m.train() ref_m.train() qconfig = default_qat_qconfig expected_act_post_process = torch.ao.quantization.FakeQuantize else: m.eval() ref_m.eval() qconfig = default_qconfig expected_act_post_process = torch.ao.quantization.MinMaxObserver prepare_fx_function = prepare_qat_fx if is_qat else prepare_fx qconfig_dict = {"": qconfig} m = prepare_fx_function(m, qconfig_dict, example_inputs=example_inputs) node_occurrence = { ns.call_module(expected_act_post_process): 7, ns.call_module(torch.ao.nn.quantized.FloatFunctional): 0 } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) m(*example_inputs) node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.add), ns.call_function(torch.ops.quantized.add), ns.call_function(torch.ops.quantized.mul), ns.call_function(torch.ops.quantized.mul), ns.call_function(torch.ops.quantized.add_relu), ns.call_function(torch.cat), ns.call_method('dequantize') ] m = convert_fx(m) self.checkGraphModuleNodes(m, expected_node_list=node_list) # make sure numerics match with eager mode ref_m.qconfig = qconfig prepare_function = prepare_qat if is_qat else prepare ref_m = prepare_function(ref_m) ref_m(*example_inputs) ref_m = convert(ref_m) # FX Graph Mode and Eager Mode now diverages in numerics of add_scalar and mul_scalar # self.assertEqual(m(data), ref_m(data)) def test_embedding(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) def forward(self, indices): return self.emb(indices) for qconfig_type in [float_qparams_weight_only_qconfig, float_qparams_weight_only_qconfig_4bit]: model = M().eval() indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) example_inputs = (indices,) quantized_node = ns.call_module(nnq.Embedding) # check dynamic quant self.checkGraphModeFxOp( model, example_inputs, QuantType.DYNAMIC, quantized_node, custom_qconfig_dict={"": qconfig_type} ) model = M().eval() configs = [ (qconfig_type, ns.call_module(nnq.Embedding)), (None, ns.call_module(nn.Embedding)), (default_qconfig, ns.call_module(nn.Embedding)), ] # check static quantization for qconfig, node in configs: qconfig_dict = {"": qconfig} m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) self.checkGraphModuleNodes(m, expected_node_occurrence={ ns.call_module(torch.ao.quantization.MinMaxObserver): 0 }) m = convert_fx(m) self.checkGraphModuleNodes(m, expected_node=node) # make sure it runs m(*example_inputs) def test_embedding_bag(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True) def forward(self, indices, offsets): return self.emb(indices, offsets) indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) offsets = torch.tensor([0, 19, 20, 28, 28, 32]) quantized_node = ns.call_module(nnq.EmbeddingBag) example_inputs = (indices, offsets) for dtype in [torch.quint8, torch.quint4x2]: model = M().eval() float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) float_qparams_qconfig = QConfig(activation=default_placeholder_observer, weight=float_qparams_observer) self.checkGraphModeFxOp( model, example_inputs, QuantType.DYNAMIC, quantized_node, custom_qconfig_dict={"": float_qparams_qconfig} ) # check it works in None and static qconfig for qconfig in [None, default_qconfig]: qconfig_dict = {"": default_qconfig} m = M().eval() m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) self.checkGraphModuleNodes(m, expected_node_occurrence={ ns.call_module(torch.ao.quantization.MinMaxObserver): 0 }) m = convert_fx(m) self.checkGraphModuleNodes(m, expected_node=ns.call_module(nn.EmbeddingBag)) # make sure it runs m(*example_inputs) def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_input): options = itertools.product(qconfigs, module_type_strs) for qconfig, module_type_str in options: model_eager = M(module_type_str).eval() model_graph = copy.deepcopy(model_eager) if torch.backends.quantized.engine == 'qnnpack' and \ qconfig is float16_dynamic_qconfig: continue # fp16 dynamic quant is not supported for qnnpack eager_qconfig_dict = dict.fromkeys(module_types, qconfig) model_eager = quantize_dynamic(model_eager, qconfig_spec=eager_qconfig_dict) graph_qconfig_dict = { "object_type": [ (x, qconfig) for x in module_types ] } model_graph = prepare_fx(model_graph, graph_qconfig_dict, example_inputs=(sample_input,)) model_graph = convert_fx(model_graph) self.assertEqual(model_eager(sample_input), model_graph(sample_input)) self.checkScriptable(model_graph, [[sample_input]], True) @override_qengines def test_rnn_cell(self): if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'): return qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig] module_type_strs = ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU'] module_types = [torch.nn.LSTMCell, torch.nn.GRUCell, torch.nn.RNNCell] sample_input = torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float) self._test_rnn_impl(qconfigs, RNNCellDynamicModel, module_type_strs, module_types, sample_input) @override_qengines def test_rnn(self): if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'): return qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig] module_type_strs = ['LSTM', 'GRU'] module_types = [torch.nn.LSTM, torch.nn.GRU] niter = 10 sample_input = torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) self._test_rnn_impl(qconfigs, RNNDynamicModel, module_type_strs, module_types, sample_input) def _test_conv_transpose_impl( self, float_cls: Callable, q_cls: Callable, data: torch.Tensor): with override_quantized_engine('qnnpack'): # Create fp32 versions of FX and Eager models m1 = torch.nn.Sequential(float_cls(1, 1, 1)) m2 = torch.nn.Sequential(float_cls(1, 1, 1)) m2.load_state_dict(m1.state_dict()) m2 = torch.ao.quantization.QuantWrapper(m2) # FX graph result_dict = self.checkGraphModeFxOp( m1, (data,), QuantType.STATIC, expected_node_occurrence={ ns.call_module(q_cls): 1, }) q_result1 = result_dict["quantized_output"] # Eager m2.qconfig = get_default_qconfig(torch.backends.quantized.engine) m2.eval() m2p = torch.ao.quantization.prepare(m2) m2p(data) m2q = torch.ao.quantization.convert(m2p) q_result2 = m2q(data) # verify results match self.assertEqual(q_result1, q_result2) @unittest.skipUnless('qnnpack' in supported_qengines, "This Pytorch Build has not been built with or does not support QNNPACK") def test_conv_transpose_1d(self): self._test_conv_transpose_impl( torch.nn.ConvTranspose1d, nnq.ConvTranspose1d, torch.randn(4, 1, 4)) @unittest.skipUnless('qnnpack' in supported_qengines, "This Pytorch Build has not been built with or does not support QNNPACK") def test_conv_transpose_2d(self): self._test_conv_transpose_impl( torch.nn.ConvTranspose2d, nnq.ConvTranspose2d, torch.randn(4, 1, 4, 4)) def test_reshape_fp16(self): class M(torch.nn.Module): def __init__(self, w, b): super().__init__() self.w = w self.b = b def forward(self, x): x = torch.nn.functional.linear(x, self.w) x = x.reshape(-1, 4) x = torch.nn.functional.linear(x, self.w) return x w = torch.randn(4, 4) b = torch.randn(4) m = M(w, b).eval() qconfig_dict = { # reshape will be quantized to fp16 as requested by this qconfig "": float16_static_qconfig, "object_type": [ (torch.nn.functional.linear, default_qconfig) ] } backend_config = get_test_only_legacy_native_backend_config() example_inputs = (torch.randn(1, 4),) m = prepare_fx( m, qconfig_dict, example_inputs=example_inputs, backend_config=backend_config) expected_occurrence = { # input and weight of first and second linear, output of first and second linear ns.call_module(torch.ao.quantization.MinMaxObserver): 6, # we insert placeholder observer for both input and output of reshape ns.call_module(torch.ao.quantization.PlaceholderObserver): 2 } self.checkGraphModuleNodes( m, expected_node_occurrence=expected_occurrence ) m = convert_fx(m, backend_config=backend_config) expected_occurrence = { ns.call_function(torch.quantize_per_tensor): 2, # dequantize after first linear, before reshape and before output ns.call_method("dequantize"): 3, # before reshape, to(fp16) ns.call_method("to"): 1, ns.call_function(torch.ops.quantized.linear): 2 } self.checkGraphModuleNodes( m, expected_node_occurrence=expected_occurrence ) # make sure it runs m(torch.randn(2, 4)) def test_multiple_qconfigs_for_single_value(self): """ Test multiple qconfigs for a single value""" class M(torch.nn.Module): def __init__(self, w, b): super().__init__() self.w = w self.b = b def forward(self, x): x = torch.nn.functional.linear(x, self.w) x = torch.sigmoid(x) return x w = torch.randn(4, 4) b = torch.randn(4) m = M(w, b).eval() # TODO: use get_default_qconfig_mapping once it handles fp16 qconfig_mapping = QConfigMapping() \ .set_global(float16_static_qconfig) \ .set_object_type(torch.nn.functional.linear, default_qconfig) example_inputs = (torch.randn(1, 4),) backend_config = get_test_only_legacy_native_backend_config() m = prepare_fx( m, qconfig_mapping, example_inputs=example_inputs, backend_config=backend_config) expected_occurrence = { # input and weight of linear, output of linear ns.call_module(torch.ao.quantization.MinMaxObserver): 3, # input and output of sigmoid ns.call_module(torch.ao.quantization.PlaceholderObserver): 2, } self.checkGraphModuleNodes( m, expected_node_occurrence=expected_occurrence ) # make sure it runs m = convert_fx(m) expected_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method("dequantize"): 3, ns.call_method("to"): 2 } self.checkGraphModuleNodes( m, expected_node_occurrence=expected_occurrence ) def test_boolean_tensor(self): """ Make sure we don't insert observer for boolean Tensors """ class M(torch.nn.Module): def forward(self, x, mask): mask = mask.unsqueeze(0) mask = mask.unsqueeze(1) x = x.masked_fill(mask, 1) return x m = M().eval() example_inputs = (torch.rand(1, 2, 3, 4), torch.rand(3, 4).bool()) m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) expected_occurrence = { ns.call_module(torch.ao.quantization.MinMaxObserver): 0 } self.checkGraphModuleNodes( m, expected_node_occurrence=expected_occurrence) m = convert_fx(m) m(*example_inputs) def test_chunk(self): class M(torch.nn.Module): def forward(self, x): x, y = torch.chunk(x, 2) x = x + y return x m = M().eval() example_inputs = (torch.rand(2, 2, 2, 2),) m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) m(*example_inputs) m = convert_fx(m) m(*example_inputs) # make sure everything runs def test_ref_pattern_multi_use(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 5) self.linear1 = torch.nn.Linear(5, 5) def forward(self, x): y = self.linear(x) z = self.linear1(x) a = torch.mul(z, 5) b = torch.add(z, 5) return (y, a, b) m = M().eval() qconfig_dict = { "": None, "object_type": [ (torch.nn.Linear, get_default_qconfig("fbgemm")), (torch.nn.ReLU, get_default_qconfig("fbgemm")), ], } example_inputs = (torch.randn(1, 5),) m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) m = convert_fx(m) expected_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_module(nnq.Linear): 2, ns.call_method("dequantize"): 2, ns.call_function(torch.add): 1, ns.call_function(torch.mul): 1, } self.checkGraphModuleNodes( m, expected_node_occurrence=expected_occurrence) def test_qmatmul(self): class M(torch.nn.Module): def forward(self, x, y): z = torch.matmul(x, y) return z m = M().eval() example_inputs = (torch.randn(2, 2), torch.randn(2, 2)) qconfig_dict = get_default_qconfig_mapping("fbgemm") mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mp(*example_inputs) mq = convert_fx(mp) expected_occurrence = { ns.call_function(torch.matmul): 0, ns.call_function(torch.ops.quantized.matmul): 1, } self.checkGraphModuleNodes( mq, expected_node_occurrence=expected_occurrence) # verify no crash res = mq(*example_inputs) def test_pixel_shuffle(self): class MyBias(nn.Module): def __init__(self) -> None: super().__init__() self.bias = nn.Parameter(torch.randn(8)) class MyModel(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(8, 8, 1, bias=False) self.bias = MyBias() def forward(self, x): x = self.conv(x) x = nn.functional.pixel_shuffle(x, 2) x = x.view(-1, 8, 2, 2) bias = self.bias.bias return x + bias backend_config = get_qnnpack_backend_config() qconfig_mapping = get_default_qconfig_mapping("qnnpack") model = MyModel() m = prepare_fx( model, qconfig_mapping=qconfig_mapping, example_inputs=(torch.randn(1, 8, 3, 3),), backend_config=backend_config ) m = convert_fx(m) expected_occurrence = { ns.call_function(torch.quantize_per_tensor): 2, ns.call_method("dequantize"): 1, } self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) def test_pixel_shuffle_module(self) -> None: class MyBias(nn.Module): def __init__(self) -> None: super().__init__() self.bias = nn.Parameter(torch.randn(8)) class MyModel(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(8, 8, 1, bias=False) self.ps = nn.PixelShuffle(upscale_factor=2) self.bias = MyBias() def forward(self, x): x = self.conv(x) x = self.ps(x) x = x.view(-1, 8, 2, 2) bias = self.bias.bias return x + bias backend_config = get_qnnpack_backend_config() qconfig_mapping = get_default_qconfig_mapping("qnnpack") model = MyModel() m = prepare_fx( model, qconfig_mapping=qconfig_mapping, example_inputs=(torch.randn(1, 8, 3, 3),), backend_config=backend_config ) m = convert_fx(m) expected_occurrence = { ns.call_function(torch.quantize_per_tensor): 2, ns.call_method("dequantize"): 1, ns.call_module(nn.PixelShuffle): 1, } self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) def test_pixel_unshuffle(self): class MyBias(nn.Module): def __init__(self) -> None: super().__init__() self.bias = nn.Parameter(torch.randn(64)) class MyModel(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(8, 8, 1, bias=False) self.bias = MyBias() def forward(self, x): x = self.conv(x) x = nn.functional.pixel_unshuffle(x, 2) bias = self.bias.bias return x + bias for backend in ["fbgemm", "qnnpack"]: if backend == "fbgemm": backend_config = get_fbgemm_backend_config() else: backend_config = get_qnnpack_backend_config() qconfig_mapping = get_default_qconfig_mapping(backend) model = MyModel() m = prepare_fx( model, qconfig_mapping=qconfig_mapping, example_inputs=(torch.randn(1, 8, 6, 6),), backend_config=backend_config ) m = convert_fx(m) expected_occurrence = { ns.call_function(torch.quantize_per_tensor): 2, ns.call_method("dequantize"): 1, } self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) def test_pixel_unshuffle_module(self) -> None: class MyBias(nn.Module): def __init__(self) -> None: super().__init__() self.bias = nn.Parameter(torch.randn(64)) class MyModel(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(8, 8, 1, bias=False) self.unshuffle = nn.PixelUnshuffle(downscale_factor=2) self.bias = MyBias() def forward(self, x): x = self.conv(x) x = self.unshuffle(x) bias = self.bias.bias return x + bias for backend in ["fbgemm", "qnnpack"]: if backend == "fbgemm": backend_config = get_fbgemm_backend_config() else: backend_config = get_qnnpack_backend_config() qconfig_mapping = get_default_qconfig_mapping(backend) model = MyModel() m = prepare_fx( model, qconfig_mapping=qconfig_mapping, example_inputs=(torch.randn(1, 8, 6, 6),), backend_config=backend_config ) m = convert_fx(m) expected_occurrence = { ns.call_function(torch.quantize_per_tensor): 2, ns.call_method("dequantize"): 1, ns.call_module(nn.PixelUnshuffle): 1, } self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) def test_narrow(self): class MyBias(nn.Module): def __init__(self) -> None: super().__init__() self.bias = nn.Parameter(torch.randn(4)) class MyModel(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(8, 8, 1, bias=False) self.bias = MyBias() def forward(self, x): x = self.conv(x) x = torch.narrow(x, 1, 0, 4) bias = self.bias.bias return x + bias for backend in ["fbgemm", "qnnpack"]: if backend == "fbgemm": backend_config = get_fbgemm_backend_config() else: backend_config = get_qnnpack_backend_config() qconfig_mapping = get_default_qconfig_mapping(backend) model = MyModel() m = prepare_fx( model, qconfig_mapping=qconfig_mapping, example_inputs=(torch.randn(1, 8, 3, 3),), backend_config=backend_config ) m = convert_fx(m) expected_occurrence = { ns.call_function(torch.quantize_per_tensor): 2, ns.call_method("dequantize"): 1, } self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) class TestQuantizeFxModels(QuantizationTestCase): @skipIfNoFBGEMM @unittest.skipIf(not TEST_CUDA, "gpu is not available.") def test_static_gpu_convert_basic(self): class Net(nn.Module): def __init__(self) -> None: super().__init__() self.relu1 = nn.ReLU() self.conv1 = nn.Conv2d(1, 6, 5) self.linear1 = nn.Linear(120, 1) def forward(self, x): x = self.relu1(self.conv1(x)) y = self.linear1(x.view(-1)) return y input = torch.randn((5, 1, 6, 6)).to('cuda') example_inputs = (input,) model = Net().to('cuda').eval() qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')} model_prepared = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) model_prepared(*example_inputs) model_quantized = convert_to_reference_fx(model_prepared) out = model_quantized(*example_inputs) self.assertEqual(out.device.type, 'cuda') @skipIfNoFBGEMM @unittest.skipIf(not TEST_CUDA, "gpu is not available.") def test_switch_device_prepare_convert(self): class Net(nn.Module): def __init__(self) -> None: super().__init__() self.relu1 = nn.ReLU() self.conv1 = nn.Conv2d(1, 6, 5) self.linear1 = nn.Linear(120, 1) def forward(self, x): x = self.relu1(self.conv1(x)) y = self.linear1(x.view(-1)) return y for device in ['cuda', 'cpu']: device_after = 'cuda' if device == 'cpu' else 'cpu' input = torch.randn((5, 1, 6, 6)).to(device) model = Net().to(device).eval() qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')} model_prepared = prepare_fx(model, qconfig_dict, example_inputs=(input,)) model_prepared(input) model_prepared.to(device_after) model_quantized = convert_to_reference_fx(model_prepared) out = model_quantized(input.to(device_after)) self.assertEqual(out.device.type, device_after) @skipIfNoFBGEMM @unittest.skipIf(not TEST_CUDA, "gpu is not available.") def test_prepare_serialize_switch_device_convert(self): class Net(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 6, 5) self.linear1 = nn.Linear(120, 1) def forward(self, x): x = self.conv1(x) y = self.linear1(x.view(-1)) return y for device in ['cuda', 'cpu']: for device_after in ['cuda', 'cpu']: input = torch.randn((5, 1, 6, 6)).to(device) model = Net().to(device).eval() qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')} model_prepared_first = prepare_fx(model, qconfig_dict, example_inputs=(input,)) model_prepared_second = prepare_fx(model, qconfig_dict, example_inputs=(input,)) model_prepared_first(input) state_dict = model_prepared_first.state_dict() del model_prepared_first model_prepared_second.load_state_dict(state_dict) model_prepared_second.to(device_after) model_quantized = convert_to_reference_fx(model_prepared_second) out = model_quantized(input.to(device_after)) self.assertEqual(out.device.type, device_after) @skipIfTorchDynamo("too slow") @skip_if_no_torchvision def test_model_dropout(self): from torchvision import models m = models.mobilenet_v3_small() qconfig_mapping = torch.ao.quantization.get_default_qat_qconfig_mapping('fbgemm') example_inputs = (torch.randn(1, 3, 224, 224),) mp = prepare_qat_fx(m, qconfig_mapping, example_inputs=example_inputs) mp(*example_inputs) with override_quantized_engine("qnnpack") if IS_ARM64 else contextlib.nullcontext(): mq = convert_fx(mp) mq(*example_inputs) def _test_model_impl( self, mode, name, model, eager_quantizable_model, check_with_eager=True, diff_of_quant=None, diff_from_eager=None): if diff_of_quant is None or diff_from_eager is None: diff_of_quant = {} diff_from_eager = {} if mode not in diff_of_quant or mode not in diff_from_eager: diff_of_quant[mode] = {} diff_from_eager[mode] = {} input_tensor = torch.rand(1, 3, 224, 224) input_tensor_inception = torch.rand(1, 3, 299, 299) output_value = torch.randint(0, 1, (1,)) # print('quantizing:', name, ' mode:', mode) if name == 'inception_v3': input_value = input_tensor_inception else: input_value = input_tensor qconfig = default_qconfig if mode == 'static' else default_qat_qconfig qconfig_dict = {'': qconfig} script = torch.jit.script(model) # make sure graph module and script module are both runanble original_out = model(input_value) is_not_tuple_out = not isinstance(original_out, tuple) script_out = script(input_value) # set to train just before quantization prepare_fx_fn = prepare_fx if mode != 'static': model.train() prepare_fx_fn = prepare_qat_fx prepared = prepare_fx_fn(model, qconfig_dict) if mode == 'ddp': mp.spawn(run_ddp, args=(world_size, prepared), # noqa: F821 nprocs=world_size, # noqa: F821 join=True) elif mode == 'qat': assert prepared.training, 'prepared must be in training mode for qat' optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001) criterion = nn.CrossEntropyLoss() train_one_epoch(prepared, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) else: for i in range(10): prepared(input_value) # print('after observation root:', prepared.root) qgraph = convert_fx(prepared) # print('after quantization root:', qgraph.root) # print('after quantization code:', qgraph.src) qgraph.eval() qgraph_script = torch.jit.script(qgraph) # print('quantized and scripted:', qgraph_script.graph) qgraph_out = qgraph(input_value) qgraph_script = qgraph_script(input_value) if is_not_tuple_out: diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max() assert torch.allclose(qgraph_out, qgraph_script), 'graph, scripted graph' else: print('tuple output') if eager_quantizable_model is not None: # comparing to eager mode quantization qeager = eager_quantizable_model ref_out = qeager(input_value) qeager.qconfig = qconfig if mode == 'static': qeager.fuse_model() prepare(qeager, inplace=True) else: qeager.train() qeager.fuse_model() prepare_qat(qeager, inplace=True) # calibration if mode == 'ddp': mp.spawn(run_ddp, args=(world_size, qeager), # noqa: F821 nprocs=world_size, # noqa: F821 join=True) elif mode == 'qat': assert qeager.training, 'qeager should be in training mode for qat' optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001) train_one_epoch(qeager, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) else: for i in range(10): qeager(input_value) # print('ref after observation:', qeager) convert(qeager, inplace=True) qeager.eval() # print('ref after quantization:', qeager) qeager_out = qeager(input_value) qeager_script = torch.jit.script(qeager) qscript_out = qeager_script(input_value) if is_not_tuple_out: diff_from_eager[mode][name] = (qeager_out - qgraph_out).abs().max() if check_with_eager: self.assertEqual(diff_from_eager[mode][name], 0, 'Result of graph mode quantization and ' + 'eager mode quantization on model: ' + name + ' should match. Mode: ' + mode + ' diff:' + str(diff_from_eager[mode][name])) def _test_building_block(self, quant_type, BB): eager = BB().float() graph = copy.deepcopy(eager) if quant_type == QuantType.STATIC: qconfig = default_qconfig eager_prepare = prepare graph_prepare = prepare_fx eager.eval() graph.eval() calibrate_or_train = test_only_eval_fn data = self.img_data_2d is_qat = False else: assert quant_type == QuantType.QAT qconfig = default_qat_qconfig eager_prepare = prepare_qat graph_prepare = prepare_qat_fx eager.train() graph.train() calibrate_or_train = test_only_train_fn data = self.img_data_2d_train is_qat = True if hasattr(eager, "fuse_model"): eager.fuse_model() eager = QuantWrapper(eager) eager.qconfig = qconfig eager = eager_prepare(eager) qconfig_dict = {"": qconfig} graph = graph_prepare(graph, qconfig_dict, example_inputs=(data[0][0],)) eager_out = eager(data[0][0]) graph_out = graph(data[0][0]) # Eager Mode and FX Graph Mode QAT now differ in numerics both # in Post Training and QAT because FX Graph Mode uses same fake_quant instances # for input and output of CopyNode # self.assertEqual(eager_out, graph_out) calibrate_or_train(eager, data) calibrate_or_train(graph, data) eager = convert(eager) graph = convert_fx(graph) eager_out = eager(data[0][0]) graph_out = graph(data[0][0]) @override_qengines def test_resnet_base(self): models = [ResNetBase] options = itertools.product(self.static_quant_types, models) for quant_type, M in options: self._test_building_block(quant_type, M) @skip_if_no_torchvision @skipIfNoFBGEMM @unittest.skip("skip for now since tbb failed") def test_torchvision(self): from torchvision import models from torchvision.models import quantization as quantized_models from torchvision.models.quantization.utils import _replace_relu def get_available_classification_models(models): return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] model_list = get_available_classification_models(models) quantized_model_list = get_available_classification_models(quantized_models) quantized_model_list = set(quantized_model_list) # test eager and graph consistency model_list = quantized_model_list # mobilenet/inception_v3/googlenet qat is not working due to AdaptiveAveragePool qat # we might observe the output of AdaptiveAveragePool in the future # and re-enable the test fx_eager_not_matching = [ ("mobilenet_v2", "qat"), ("inception_v3", "qat"), ("googlenet", "qat") ] # because relu6 is replaced as relu in mobilenetv2 diff_of_quant = {} diff_from_eager = {} modes = ['static', 'qat'] options = itertools.product(modes, model_list) for mode, name in options: pretrained = name in quantized_model_list # load pretrained model to compare with quantized model kwargs = {} # turn off transform input for inception_v3 since # it's not quantized in eager mode and in fx graph # mode we can't skip quantizing a method right now # (might be supported in the future) if name in ["inception_v3", "googlenet"]: kwargs["transform_input"] = False eager_quantizable_model = None if name in quantized_model_list: eager_quantizable_model = quantized_models.__dict__[name](pretrained=False, quantize=False, **kwargs).eval().float() # compare with eager mode quantized model when it is available pretrained = eager_quantizable_model is not None model = models.__dict__[name](pretrained=pretrained, **kwargs).eval().float() if name == "mobilenet_v2": _replace_relu(model) # disable aux logits if hasattr(model, "aux_logits"): model.aux_logits = False model.AuxLogits = None if eager_quantizable_model: eager_quantizable_model.aux_logits = False eager_quantizable_model.AuxLogits = None check_with_eager = (name, mode) not in fx_eager_not_matching self._test_model_impl( mode, name, model, eager_quantizable_model, check_with_eager, diff_of_quant, diff_from_eager) def print_diffs(diffs): for mode, diffs_for_mode in diffs.items(): print('mode:', mode) for name, diff in diffs_for_mode.items(): print(name, ':', diff) # print('differences between float and quantized') # print_diffs(diff_of_quant) # print('----------------------') # print('differences between graph mode and eager mode') # print_diffs(diff_from_eager) # print('----------------------') @skip_if_no_torchvision @skipIfNoFBGEMM @unittest.skip("TODO: Test is always failing - https://github.com/pytorch/pytorch/issues/54979") def test_resnet18_ddp(self): from torchvision import models from torchvision.models import quantization as quantized_models eager_quantizable_model = quantized_models.__dict__[name](pretrained=False, quantize=False).eval().float() # noqa: F821 model = models.__dict__[name](pretrained=False).eval().float() # noqa: F821 self._test_model_impl( 'ddp', 'resnet18', model, eager_quantizable_model) @override_qengines def test_qat_embeddingbag_linear(self): for device in get_supported_device_types(): class EmbeddingBagLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode='sum') self.linear = torch.nn.Linear(12, 1).to(dtype=torch.float) def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None, per_sample_weights: Optional[torch.Tensor] = None): x = self.emb(input, offsets, per_sample_weights) x = self.linear(x) return x qengine = torch.backends.quantized.engine qconfig_dict = QConfigMapping() \ .set_global(get_default_qat_qconfig(qengine)) \ .set_object_type(torch.nn.EmbeddingBag, default_embedding_qat_qconfig) train_indices = [[torch.randint(0, 10, (12, 12)), torch.randn((12, 1))] for _ in range(2)] eval_output = [[torch.randint(0, 10, (12, 1))]] model = EmbeddingBagLinear().train() prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],)) test_only_train_fn(prepared_fx_model, train_indices) quant_model = convert_fx(prepared_fx_model, qconfig_mapping=qconfig_dict) def checkQuantized(model): # Make sure EmbeddingBag is now a quantized EmbeddingBag. self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag) # Also test that Linear has been quantized. self.assertTrue(type(model.linear), nnq.Linear) test_only_eval_fn(model, eval_output) self.checkScriptable(model, eval_output) self.checkNoQconfig(model) checkQuantized(quant_model) @override_qengines def test_qat_embedding_linear(self): for device in get_supported_device_types(): class EmbeddingLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) self.linear = torch.nn.Linear(12, 1).to(dtype=torch.float) def forward(self, input: torch.Tensor): x = torch.sum(self.emb(input), dim=1) x = self.linear(x) return x qengine = torch.backends.quantized.engine qconfig_dict = {"": get_default_qat_qconfig(qengine), "object_type": [(torch.nn.Embedding, default_embedding_qat_qconfig)]} train_indices = [[torch.randint(0, 10, (12, 12)), torch.randn((12, 1))] for _ in range(2)] eval_output = [[torch.randint(0, 10, (12, 1))]] model = EmbeddingLinear().train() prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],)) test_only_train_fn(prepared_fx_model, train_indices) quant_model = convert_fx(prepared_fx_model, qconfig_mapping=qconfig_dict) def checkQuantized(model): # Make sure EmbeddingBag is now a quantized EmbeddingBag. self.assertTrue(type(model.emb), nn.quantized.Embedding) # Also test that Linear has been quantized. self.assertTrue(type(model.linear), nnq.Linear) test_only_eval_fn(model, eval_output) self.checkScriptable(model, eval_output) self.checkNoQconfig(model) checkQuantized(quant_model) @given( device=st.sampled_from( ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] ) ) @settings(deadline=None) @override_qengines def test_qat_functional_linear(self, device): if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'): return class Linear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.w = torch.ones(5, 5) self.b = torch.zeros(5) def forward(self, x): return torch.nn.functional.linear(x, self.w, self.b) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mods1 = torch.nn.Sequential(Linear(), Linear()) self.mods2 = Linear() def forward(self, x): x = self.mods1(x) x = self.mods2(x) return x model = M().train() ref_fake_quant = FakeQuantize.with_args( observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, dtype=torch.quint8, reduce_range=False, ) ref_weight_fake_quant = FakeQuantize.with_args( observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127, dtype=torch.qint8, reduce_range=False, ) ref_qat_qconfig = QConfig( activation=ref_fake_quant, weight=ref_weight_fake_quant ) qconfig_dict = {"": ref_qat_qconfig} example_inputs = (torch.randn(1, 5),) prepared_ref = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) custom_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, dtype=torch.quint8, reduce_range=False, ) custom_weight_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127, dtype=torch.qint8, reduce_range=False, ) custom_qconfig = QConfig( activation=custom_fake_quant, weight=custom_weight_fake_quant ) custom_qconfig_dict = {"": custom_qconfig} prepared = prepare_qat_fx(model, custom_qconfig_dict, example_inputs=example_inputs) prepared.to(device) prepared_ref.to(device) prepared.apply(torch.ao.quantization.disable_fake_quant) prepared.apply(torch.ao.quantization.disable_observer) prepared_ref.apply(torch.ao.quantization.disable_fake_quant) prepared_ref.apply(torch.ao.quantization.disable_observer) inp = torch.randn(5, 5, device=device, requires_grad=True) for i in range(10): if i == 2: prepared.apply(torch.ao.quantization.enable_observer) prepared_ref.apply(torch.ao.quantization.enable_observer) if i == 4: prepared.apply(torch.ao.quantization.enable_fake_quant) prepared_ref.apply(torch.ao.quantization.enable_fake_quant) inp = torch.randn(5, 5, device=device, requires_grad=True) out_ref = prepared_ref(inp) out = prepared(inp) torch.testing.assert_close(out, out_ref) # try backward pass labels = torch.randn(5, 5, device=device) loss = (out - labels).sum() grad = torch.autograd.grad(loss, [inp]) loss_ref = (out_ref - labels).sum() grad_ref = torch.autograd.grad(loss_ref, [inp]) torch.testing.assert_close(grad[0], grad_ref[0]) if 'fbgemm' in torch.backends.quantized.supported_engines: # During the lowering step in convert, fold_weight calls quantized::linear_prepack # which doesn't support QuantizedCuda backend prepared.cpu() prepared_ref.cpu() converted = convert_fx(prepared) converted_ref = convert_fx(prepared_ref) inp = torch.rand(5, 5) out = converted(inp) out_ref = converted_ref(inp) torch.testing.assert_close(out, out_ref) if __name__ == '__main__': raise RuntimeError("This test file is not meant to be run directly, use:\n\n" "\tpython test/test_quantization.py TESTNAME\n\n" "instead.")