import torch import torch.nn as nn import torch.nn.quantized as nnq import torch.nn.intrinsic as nni import torch.nn.intrinsic.quantized as nniq import torch.nn.intrinsic.qat as nniqat from torch.nn.utils.rnn import PackedSequence from torch.quantization import ( quantize, prepare, convert, prepare_qat, quantize_qat, fuse_modules, quantize_dynamic, QuantWrapper, QConfig, default_qconfig, default_qat_qconfig, default_dynamic_qconfig, per_channel_dynamic_qconfig, float16_dynamic_qconfig, ) from torch.testing._internal.common_quantization import ( QuantizationTestCase, AnnotatedSingleLayerLinearModel, QuantStubModel, ModelForFusion, ModelWithSequentialFusion, ManualLinearQATModel, ManualConvLinearQATModel, ModelWithFunctionals, ModelMultipleOps, ModelMultipleOpsNoAvgPool, SingleLayerLinearDynamicModel, TwoLayerLinearModel, NestedModel, ResNetBase, RNNDynamicModel, RNNCellDynamicModel, ModelForFusionWithBias, ActivationsTestModel, NormalizationTestModel, test_only_eval_fn, test_only_train_fn, prepare_dynamic, convert_dynamic, skipIfNoFBGEMM, ) # annotated models from torch.testing._internal.common_quantization import ( AnnotatedTwoLayerLinearModel, AnnotatedNestedModel, AnnotatedSubNestedModel, AnnotatedCustomConfigNestedModel, AnnotatedSkipQuantModel, ) from torch.testing._internal.common_quantized import ( override_quantized_engine, supported_qengines, override_qengines, ) from torch.testing._internal.common_utils import TemporaryFileName from torch.testing._internal.common_utils import suppress_warnings from torch.testing._internal.jit_utils import JitTestCase from hypothesis import given from hypothesis import strategies as st import torch.testing._internal.hypothesis_utils as hu hu.assert_deadline_disabled() # Standard library import copy import io import unittest import numpy as np class TestPostTrainingStatic(QuantizationTestCase): def test_single_layer(self): r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped to nnq.Linear which is the quantized version of the module """ for qengine in supported_qengines: with override_quantized_engine(qengine): qconfig = torch.quantization.get_default_qconfig(qengine) model = AnnotatedSingleLayerLinearModel(qengine) model.qconfig = qconfig model = prepare(model) # Check if observers and quant/dequant nodes are inserted self.checkNoPrepModules(model) self.checkHasPrepModules(model.fc1) self.checkObservers(model) test_only_eval_fn(model, self.calib_data) model = convert(model) def checkQuantized(model): self.checkNoPrepModules(model) self.checkHasPrepModules(model.fc1) self.checkWrappedQuantizedLinear(model.fc1) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) checkQuantized(model) # test one line API - out of place version base = AnnotatedSingleLayerLinearModel(qengine) base.qconfig = qconfig keys_before = set(list(base.state_dict().keys())) model = quantize(base, test_only_eval_fn, self.calib_data) checkQuantized(model) keys_after = set(list(base.state_dict().keys())) self.assertEqual(keys_before, keys_after) # simple check that nothing changed # in-place version model = AnnotatedSingleLayerLinearModel(qengine) model.qconfig = qconfig quantize(model, test_only_eval_fn, self.calib_data, inplace=True) checkQuantized(model) @skipIfNoFBGEMM def test_two_layers(self): r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one `fc2`, and `fc1`is not quantized """ with override_quantized_engine('fbgemm'): model = AnnotatedTwoLayerLinearModel() model = prepare(model) self.checkNoPrepModules(model) self.checkObservers(model) self.checkNoPrepModules(model.fc1) self.checkHasPrepModules(model.fc2) test_only_eval_fn(model, self.calib_data) model = convert(model) def checkQuantized(model): self.checkNoPrepModules(model) self.checkNoPrepModules(model.fc1) self.checkHasPrepModules(model.fc2) self.assertEqual(type(model.fc1), torch.nn.Linear) self.checkWrappedQuantizedLinear(model.fc2) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) checkQuantized(model) # test one line API model = quantize(AnnotatedTwoLayerLinearModel(), test_only_eval_fn, self.calib_data) checkQuantized(model) def test_nested1(self): r"""Test quantization for nested model, top level 'fc3' and 'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized """ for qengine in supported_qengines: with override_quantized_engine(qengine): model = AnnotatedNestedModel(qengine) def checkPrepModules(model, before_calib=False): if before_calib: self.checkObservers(model) self.checkNoPrepModules(model) self.checkNoPrepModules(model.sub1) self.checkNoPrepModules(model.sub1.fc) self.checkNoPrepModules(model.sub1.relu) self.checkNoPrepModules(model.sub2) self.checkHasPrepModules(model.sub2.fc1) self.checkNoPrepModules(model.sub2.fc2) self.checkHasPrepModules(model.fc3) model = prepare(model) checkPrepModules(model, True) test_only_eval_fn(model, self.calib_data) model = convert(model) def checkQuantized(model): checkPrepModules(model) self.checkLinear(model.sub1.fc) self.checkWrappedQuantizedLinear(model.fc3) self.checkWrappedQuantizedLinear(model.sub2.fc1) self.checkLinear(model.sub2.fc2) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) checkQuantized(model) # test one line API model = quantize(AnnotatedNestedModel(qengine), test_only_eval_fn, self.calib_data) checkQuantized(model) @skipIfNoFBGEMM def test_nested2(self): model = AnnotatedSubNestedModel() model = prepare(model) def checkPrepModules(model, before_calib=False): if before_calib: self.checkObservers(model) self.checkNoPrepModules(model) self.checkNoPrepModules(model.sub1) self.checkNoPrepModules(model.sub1.fc) self.checkNoPrepModules(model.sub1.relu) self.checkHasPrepModules(model.sub2) self.checkNoPrepModules(model.sub2.module.fc1) self.checkNoPrepModules(model.sub2.module.fc2) self.checkHasPrepModules(model.fc3) checkPrepModules(model, True) test_only_eval_fn(model, self.calib_data) model = convert(model) def checkQuantized(model): checkPrepModules(model) self.checkLinear(model.sub1.fc) self.assertEqual(type(model.sub1.relu), torch.nn.ReLU) self.checkQuantizedLinear(model.sub2.module.fc1) self.checkQuantizedLinear(model.sub2.module.fc2) self.checkWrappedQuantizedLinear(model.fc3) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) checkQuantized(model) # test one line API model = quantize(AnnotatedSubNestedModel(), test_only_eval_fn, self.calib_data) checkQuantized(model) def test_nested3(self): r"""More complicated nested test case with child qconfig overrides parent qconfig """ for qengine in supported_qengines: with override_quantized_engine(qengine): model = AnnotatedCustomConfigNestedModel() model = prepare(model) def checkPrepModules(model, before_calib=False): if before_calib: self.checkObservers(model) self.checkNoPrepModules(model) self.checkNoPrepModules(model.sub1) self.checkNoPrepModules(model.sub1.fc) self.checkNoPrepModules(model.sub1.relu) self.checkNoPrepModules(model.sub2) self.checkHasPrepModules(model.sub2.fc1) self.checkHasPrepModules(model.sub2.fc2) self.checkHasPrepModules(model.fc3) checkPrepModules(model, True) test_only_eval_fn(model, self.calib_data) model = convert(model) def checkQuantized(model): checkPrepModules(model) self.checkWrappedQuantizedLinear(model.sub2.fc1) self.checkWrappedQuantizedLinear(model.sub2.fc2) self.checkWrappedQuantizedLinear(model.fc3) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) checkQuantized(model) # test one line API model = quantize(AnnotatedCustomConfigNestedModel(), test_only_eval_fn, self.calib_data) checkQuantized(model) def test_skip_quant(self): r"""The case when we want to skip quantizing some layers """ for qengine in supported_qengines: with override_quantized_engine(qengine): model = AnnotatedSkipQuantModel(qengine) model = prepare(model) self.checkObservers(model) test_only_eval_fn(model, self.calib_data) model = convert(model) def checkQuantized(model): self.checkLinear(model.fc) self.checkQuantDequant(model.sub) self.checkQuantizedLinear(model.sub.module.fc1) self.checkQuantizedLinear(model.sub.module.fc2) self.assertEqual(type(model.sub.module.relu1), nnq.ReLU) self.assertEqual(type(model.sub.module.relu2), nnq.ReLU) self.checkScriptable(model, self.calib_data) checkQuantized(model) # test one line API model = quantize(AnnotatedSkipQuantModel(qengine), test_only_eval_fn, self.calib_data) checkQuantized(model) @skipIfNoFBGEMM def test_manual(self): r"""User inserts QuantStub and DeQuantStub in model code and call the quantization utility functions. """ model = QuantStubModel() # propagate the qconfig of parents to children, model is changed # inplace model = prepare(model) self.checkObservers(model) test_only_eval_fn(model, self.calib_data) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.fc), nnq.Linear) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) checkQuantized(model) # test one line API model = quantize(QuantStubModel(), test_only_eval_fn, self.calib_data) checkQuantized(model) def test_resnet_base(self): r"""Test quantization for bottleneck topology used in resnet/resnext and add coverage for conversion of average pool and float functional """ for qengine in supported_qengines: with override_quantized_engine(qengine): qconfig = torch.quantization.get_default_qconfig(qengine) model = ResNetBase().float().eval() model = QuantWrapper(model) model.qconfig = qconfig fuse_list = ['module.conv1', 'module.bn1', 'module.relu1'] fuse_modules(model, fuse_list, inplace=True) model = prepare(model) self.checkObservers(model) test_only_eval_fn(model, self.img_data_2d) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.module.conv1), nn.intrinsic.quantized.ConvReLU2d) self.assertEqual(type(model.module.myop), nn.quantized.QFunctional) self.assertEqual(type(model.module.avgpool), nn.AdaptiveAvgPool2d) test_only_eval_fn(model, self.img_data_2d) checkQuantized(model) @skipIfNoFBGEMM def test_normalization(self): r""" Test quantization of normalization layers """ model = NormalizationTestModel() model.qconfig = torch.quantization.get_default_qconfig('fbgemm') prepare(model, inplace=True) self.checkObservers(model) test_only_eval_fn(model, self.calib_data) model = convert(model) def checkQuantized(model): self.checkNoPrepModules(model.layer_norm) self.checkNoPrepModules(model.group_norm) self.checkNoPrepModules(model.instance_norm1d) self.checkNoPrepModules(model.instance_norm2d) self.checkNoPrepModules(model.instance_norm3d) self.assertEqual(type(model.layer_norm), nnq.LayerNorm) self.assertEqual(type(model.group_norm), nnq.GroupNorm) self.assertEqual(type(model.instance_norm1d), nnq.InstanceNorm1d) self.assertEqual(type(model.instance_norm2d), nnq.InstanceNorm2d) self.assertEqual(type(model.instance_norm3d), nnq.InstanceNorm3d) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) checkQuantized(model) model_oneline = quantize( NormalizationTestModel(), test_only_eval_fn, self.calib_data) checkQuantized(model) def test_save_load_state_dict(self): r"""Test PTQ flow of creating a model and quantizing it and saving the quantized state_dict Load the quantized state_dict for eval and compare results against original model """ for qengine in supported_qengines: with override_quantized_engine(qengine): model = TwoLayerLinearModel() model = torch.quantization.QuantWrapper(model) model.qconfig = torch.quantization.get_default_qconfig(qengine) model = prepare(model) # calibrate test_only_eval_fn(model, self.calib_data) model = convert(model) x = torch.rand(2, 5, dtype=torch.float) ref = model(x) quant_state_dict = model.state_dict() # Create model again for eval model = TwoLayerLinearModel() model = torch.quantization.QuantWrapper(model) model.qconfig = torch.quantization.get_default_qconfig(qengine) model = prepare(model) model = convert(model) new_state_dict = model.state_dict() # Check to make sure the state dict keys match original model after convert. self.assertEqual(set(new_state_dict.keys()), set(quant_state_dict.keys())) model.load_state_dict(quant_state_dict) out = model(x) self.assertEqual(ref, out) @skipIfNoFBGEMM def test_activations(self): r""" Test quantization of activations """ model = ActivationsTestModel() model.qconfig = torch.quantization.get_default_qconfig('fbgemm') prepare(model, inplace=True) self.checkObservers(model) test_only_eval_fn(model, self.calib_data) model = convert(model) def checkQuantized(model): self.checkNoPrepModules(model.hardswish) self.assertEqual(type(model.hardswish), nnq.Hardswish) self.assertEqual(type(model.elu), nnq.ELU) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) checkQuantized(model) # test one line API model_oneline = quantize(ActivationsTestModel(), test_only_eval_fn, self.calib_data) checkQuantized(model_oneline) @override_qengines def test_forward_hooks_preserved(self): r"""Test post-training static quantization on preserving pre forward and post forward hooks of original model """ qengine = torch.backends.quantized.engine model = QuantStubModel() counter = { 'pre_forwards': 0, 'forwards': 0, } def fw_pre_hook(h_module, input): counter['pre_forwards'] += 1 def fw_hook(h_module, input, output): counter['forwards'] += 1 model.fc.register_forward_pre_hook(fw_pre_hook) model.fc.register_forward_hook(fw_hook) model.qconfig = torch.quantization.get_default_qconfig(qengine) model = prepare(model) def checkHooksIsPresent(model, before_convert=True): num_fwd_hooks = 1 if before_convert: self.assertEqual(len(model.quant._forward_hooks.values()), 1, "Quantization observer hook has disappeared") num_fwd_hooks = 2 self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values()) self.assertObjectIn(fw_hook, model.fc._forward_hooks.values()) self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1, "Extra pre forward hooks have appeared on a layer") # During static quantization non stub layers are provided with quantization observer hook too self.assertEqual(len(model.fc._forward_hooks.values()), num_fwd_hooks, "Extra post forward hooks have appeared on a layer") # Implicitly check that fw_hook goes after _observer_forward_hook self.assertEqual(list(model.fc._forward_hooks.values())[-1], fw_hook, "_observer_forward_hook is not a first entry of the hooks list") checkHooksIsPresent(model, True) test_only_eval_fn(model, self.calib_data) torch.quantization.convert(model, inplace=True) checkHooksIsPresent(model, False) @skipIfNoFBGEMM class TestPostTrainingDynamic(QuantizationTestCase): def test_single_layer(self): r"""Dynamic Quantize SingleLayerLinearDynamicModel which has one Linear module, make sure it is swapped to nnqd.Linear which is the quantized version of the module """ for dtype in [torch.qint8, torch.float16]: model = SingleLayerLinearDynamicModel().eval() qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig qconfig_dict = { 'fc1': qconfig } prepare_dynamic(model, qconfig_dict) convert_dynamic(model) def checkQuantized(model): self.checkDynamicQuantizedLinear(model.fc1, dtype) self.checkScriptable(model, self.calib_data, check_save_load=True) checkQuantized(model) # test one line API - out of place version base = SingleLayerLinearDynamicModel() keys_before = set(list(base.state_dict().keys())) model = quantize_dynamic(base, qconfig_dict) checkQuantized(model) keys_after = set(list(base.state_dict().keys())) self.assertEqual(keys_before, keys_after) # simple check that nothing changed # in-place version model = SingleLayerLinearDynamicModel() quantize_dynamic(model, qconfig_dict, inplace=True) checkQuantized(model) # Test set qconfig model = SingleLayerLinearDynamicModel() quantize_dynamic(model, set([nn.Linear]), inplace=True, dtype=dtype) checkQuantized(model) def test_two_layers(self): r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one `fc2`, and `fc1`is not quantized """ for dtype in [torch.qint8, torch.float16]: model = TwoLayerLinearModel().eval() qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig qconfig_dict = { 'fc2': qconfig } prepare_dynamic(model, qconfig_dict) convert_dynamic(model) def checkQuantized(model): self.assertEqual(type(model.fc1), torch.nn.Linear) self.checkDynamicQuantizedLinear(model.fc2, dtype=dtype) self.checkScriptable(model, self.calib_data, check_save_load=True) checkQuantized(model) # test one line API model = quantize_dynamic(TwoLayerLinearModel().eval(), qconfig_dict) checkQuantized(model) # Test set API model = quantize_dynamic(TwoLayerLinearModel().eval(), {'fc2'}, dtype=dtype) checkQuantized(model) def test_nested1(self): r"""Test quantization for nested model, top level 'fc3' and 'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized """ for dtype in [torch.qint8, torch.float16]: model = NestedModel().eval() qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig qconfig_dict = { 'fc3': qconfig, 'sub2.fc1': qconfig } prepare_dynamic(model, qconfig_dict) convert_dynamic(model) def checkQuantized(model): self.checkLinear(model.sub1.fc) self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype) self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype) self.checkLinear(model.sub2.fc2) self.checkScriptable(model, self.calib_data, check_save_load=True) checkQuantized(model) # test one line API model = quantize_dynamic(NestedModel().eval(), qconfig_dict) checkQuantized(model) model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2.fc1'}, dtype=dtype) checkQuantized(model) def test_nested2(self): r"""Another test case for quantized, we will quantize all submodules of submodule sub2 """ for dtype in [torch.qint8, torch.float16]: model = NestedModel().eval() qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig qconfig_dict = { 'fc3': qconfig, 'sub2': qconfig } prepare_dynamic(model, qconfig_dict) convert_dynamic(model) def checkQuantized(model): self.checkLinear(model.sub1.fc) self.assertEqual(type(model.sub1.relu), torch.nn.ReLU) self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype) self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype) self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype) self.checkScriptable(model, self.calib_data, check_save_load=True) checkQuantized(model) # test one line API model = quantize_dynamic(NestedModel().eval(), qconfig_dict, dtype=dtype) checkQuantized(model) # Test set API model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2'}, dtype=dtype) checkQuantized(model) def test_nested3(self): r"""More complicated nested test case with child qconfig overrides parent qconfig """ for dtype in [torch.qint8, torch.float16]: model = NestedModel().eval() qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig qconfig_dynamic_dict = { 'fc3': qconfig, 'sub2': qconfig, 'sub2.fc1': qconfig } prepare_dynamic(model, qconfig_dynamic_dict) convert_dynamic(model) def checkQuantized(model): self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype) self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype) self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype) self.checkScriptable(model, self.calib_data, check_save_load=True) checkQuantized(model) # test one line API model = quantize_dynamic(NestedModel().eval(), qconfig_dynamic_dict) checkQuantized(model) # Test set API model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2', 'sub2.fc1'}, dtype=dtype) checkQuantized(model) def test_type_match_rule(self): r"""Test quantization for nested model, top level 'fc3' and 'fc1' of submodule 'sub2', All 'torch.nn.Linear' modules are quantized """ for dtype in [torch.qint8, torch.float16]: model = NestedModel().eval() qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig qconfig_dict = { 'fc3': None, 'sub2.fc1': None, torch.nn.Linear: qconfig } prepare_dynamic(model, qconfig_dict) test_only_eval_fn(model, self.calib_data) convert_dynamic(model) def checkQuantized(model): self.checkDynamicQuantizedLinear(model.sub1.fc, dtype=dtype) self.checkLinear(model.fc3) self.checkLinear(model.sub2.fc1) self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data, check_save_load=True) checkQuantized(model) # test one line API model = quantize_dynamic(NestedModel().eval(), qconfig_dict, dtype=dtype) checkQuantized(model) def test_per_channel_linear_quantize(self): r"""Test quantization for per_channel dynamic quantization """ model = NestedModel().eval() qconfig_dict = { torch.nn.Linear: per_channel_dynamic_qconfig } prepare_dynamic(model, qconfig_dict) test_only_eval_fn(model, self.calib_data) convert_dynamic(model) def checkQuantized(model): self.checkDynamicQuantizedLinear(model.sub1.fc, dtype=torch.qint8) self.checkDynamicQuantizedLinear(model.fc3, dtype=torch.qint8) self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=torch.qint8) self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=torch.qint8) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data, check_save_load=True) checkQuantized(model) # test one line API model = quantize_dynamic(NestedModel().eval(), qconfig_dict) checkQuantized(model) @given(qconfig=st.sampled_from([per_channel_dynamic_qconfig, default_dynamic_qconfig]), dtype=st.sampled_from([torch.qint8, torch.float16])) def test_quantized_rnn(self, qconfig, dtype): r"""Test dynamic quantization, scriptability and serialization for dynamic quantized lstm modules on int8 and fp16 """ model = RNNDynamicModel('LSTM').eval() niter = 10 x = torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) qconfig_dict = { torch.nn.LSTM : qconfig } if dtype == torch.float16: model_quantized = quantize_dynamic(model=model, dtype=dtype) else: model_quantized = quantize_dynamic(model=model, qconfig_spec=qconfig_dict, dtype=dtype) # Smoke test extra reprs self.assertTrue('DynamicQuantizedLSTM' in str(model_quantized)) self.checkDynamicQuantizedModule(model_quantized.mod, torch.nn.quantized.dynamic.LSTM, dtype) self.checkScriptable(model_quantized, [[x]], check_save_load=True) class ScriptWrapperPacked(torch.nn.Module): def __init__(self, cell): super(ScriptWrapperPacked, self).__init__() self.cell = cell def forward(self, x # type: PackedSequence ): # type: (...) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]] return self.cell(x) packed_input = torch.nn.utils.rnn.pack_padded_sequence(x, torch.tensor([10, 5, 2])) model_with_packed_input = ScriptWrapperPacked(model_quantized.mod) scripted = torch.jit.script(model_with_packed_input) # We cannot trace with input dtype being a packed sequence self._checkScriptable(model_with_packed_input, scripted, [[packed_input]], True) @given(qconfig=st.sampled_from([per_channel_dynamic_qconfig, default_dynamic_qconfig]), dtype=st.sampled_from([torch.qint8, torch.float16])) def test_quantized_rnn_cell(self, qconfig, dtype): r"""Test dynamic quantization, scriptability and serialization for dynamic quantized rnn cell modules on int8 and fp16 """ qconfig_dict = { torch.nn.LSTMCell : qconfig, torch.nn.GRUCell : qconfig, torch.nn.RNNCell : qconfig } for module_type in ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU']: model = RNNCellDynamicModel(module_type).eval() x = torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float) if torch.backends.quantized.engine == 'qnnpack' and dtype == torch.float16: continue # fp16 dynamic quant is not supported for qnnpack if dtype == torch.float16: model_quantized = quantize_dynamic(model=model, dtype=dtype) else: model_quantized = quantize_dynamic(model=model, qconfig_spec=qconfig_dict, dtype=dtype) def checkQuantized(model, module_type): mod_type_map = {'LSTMCell': torch.nn.quantized.dynamic.LSTMCell, 'GRUCell': torch.nn.quantized.dynamic.GRUCell, 'RNNTanh': torch.nn.quantized.dynamic.RNNCell, 'RNNReLU': torch.nn.quantized.dynamic.RNNCell} mod_repr_map = {'LSTMCell': 'DynamicQuantizedLSTMCell', 'GRUCell': 'DynamicQuantizedGRUCell', 'RNNTanh': 'DynamicQuantizedRNNCell', 'RNNReLU': 'DynamicQuantizedRNNCell'} self.assertTrue(mod_repr_map[module_type] in str(model_quantized)) self.checkDynamicQuantizedModule(model_quantized.mod, mod_type_map[module_type], dtype) # Smoke test extra reprs checkQuantized(model_quantized, module_type) self.checkScriptable(model_quantized, [[x]], check_save_load=True) def test_forward_hooks_preserved(self): r"""Test post-training dynamic quantization on preserving pre forward and post forward hooks of original model """ for dtype in [torch.qint8, torch.float16]: model = SingleLayerLinearDynamicModel().eval() qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig qconfig_dict = { 'fc1': qconfig } convert_dynamic(model) counter = { 'pre_forwards': 0, 'forwards': 0, } def fw_pre_hook(h_module, input): counter['pre_forwards'] += 1 def fw_hook(h_module, input, output): counter['forwards'] += 1 model.fc1.register_forward_pre_hook(fw_pre_hook) model.fc1.register_forward_hook(fw_hook) prepare_dynamic(model, qconfig_dict) def checkHooksIsPresent(model): self.assertObjectIn(fw_pre_hook, model.fc1._forward_pre_hooks.values()) self.assertObjectIn(fw_hook, model.fc1._forward_hooks.values()) self.assertEqual(len(model.fc1._forward_pre_hooks.values()), 1, "Extra pre forward hooks have appeared on a layer") self.assertEqual(len(model.fc1._forward_hooks.values()), 1, "Extra post forward hooks have appeared on a layer") checkHooksIsPresent(model) test_only_eval_fn(model, self.calib_data) convert_dynamic(model) checkHooksIsPresent(model) class TestQuantizationAwareTraining(QuantizationTestCase): def test_manual(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualLinearQATModel(qengine) model = prepare_qat(model) self.checkObservers(model) test_only_train_fn(model, self.train_data) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.fc1), nnq.Linear) self.assertEqual(type(model.fc2), nnq.Linear) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) checkQuantized(model) model = quantize_qat(ManualLinearQATModel(qengine), test_only_train_fn, self.train_data) checkQuantized(model) def test_eval_only_fake_quant(self): r"""Using FakeQuant in evaluation only mode, this is useful for estimating accuracy loss when we quantize the network """ for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualLinearQATModel(qengine) model = prepare_qat(model) self.checkObservers(model) model.eval() test_only_eval_fn(model, self.calib_data) def test_conv_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualConvLinearQATModel() model = prepare_qat(model) self.checkObservers(model) test_only_train_fn(model, self.img_data_2d_train) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.conv), nnq.Conv2d) self.assertEqual(type(model.fc1), nnq.Linear) self.assertEqual(type(model.fc2), nnq.Linear) test_only_eval_fn(model, self.img_data_2d) self.checkScriptable(model, self.img_data_2d) checkQuantized(model) model = ManualConvLinearQATModel() model = quantize_qat(model, test_only_train_fn, self.img_data_2d_train) checkQuantized(model) def test_train_save_load_eval(self): r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict During eval, we first call prepare_qat and conver on the model and then load the state_dict and compare results against original model """ for qengine in supported_qengines: with override_quantized_engine(qengine): model = TwoLayerLinearModel() model = torch.quantization.QuantWrapper(model) model.qconfig = torch.quantization.get_default_qat_qconfig(qengine) model = prepare_qat(model) fq_state_dict = model.state_dict() test_only_train_fn(model, self.train_data) model = convert(model) quant_state_dict = model.state_dict() x = torch.rand(2, 5, dtype=torch.float) ref = model(x) # Create model again for eval. Check result using quantized state_dict model = TwoLayerLinearModel() model = torch.quantization.QuantWrapper(model) model.qconfig = torch.quantization.get_default_qat_qconfig(qengine) torch.quantization.prepare_qat(model, inplace=True) new_state_dict = model.state_dict() # Check to make sure the model after prepare_qat has the same state_dict as original. self.assertEqual(set(fq_state_dict.keys()), set(new_state_dict.keys())) torch.quantization.convert(model, inplace=True) model.eval() model.load_state_dict(quant_state_dict) out = model(x) self.assertEqual(ref, out) # Check model created using prepare has same state dict as quantized state_dict model = TwoLayerLinearModel() model.eval() model = torch.quantization.QuantWrapper(model) model.qconfig = torch.quantization.get_default_qconfig(qengine) torch.quantization.prepare(model, inplace=True) torch.quantization.convert(model, inplace=True) self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys())) model.eval() model.load_state_dict(quant_state_dict) out = model(x) self.assertEqual(ref, out) @override_qengines def test_forward_hooks_preserved(self): r"""Test QAT on preserving pre forward and post forward hooks of original model """ qengine = torch.backends.quantized.engine model = QuantStubModel() counter = { 'pre_forwards': 0, 'forwards': 0, } def fw_pre_hook(h_module, input): counter['pre_forwards'] += 1 def fw_hook(h_module, input, output): counter['forwards'] += 1 model.fc.register_forward_pre_hook(fw_pre_hook) model.fc.register_forward_hook(fw_hook) model.qconfig = torch.quantization.get_default_qat_qconfig(qengine) model = prepare_qat(model) def checkHooksIsPresent(model, before_convert=True): if before_convert: self.assertEqual(len(model.quant._forward_hooks.values()), 1, "Quantization observer hook has disappeared") self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values()) self.assertObjectIn(fw_hook, model.fc._forward_hooks.values()) self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1, "Extra pre forward hooks have appeared on a layer") self.assertEqual(len(model.fc._forward_hooks.values()), 1, "Extra post forward hooks have appeared on a layer") checkHooksIsPresent(model, True) x = torch.rand(2, 5, dtype=torch.float) model(x) torch.quantization.convert(model, inplace=True) checkHooksIsPresent(model, False) class TestFunctionalModule(QuantizationTestCase): # Histogram Observers are slow, so have no-deadline to ensure test doesn't time out @given(train_mode=st.booleans()) def test_functional_module(self, train_mode): model = ModelWithFunctionals() x = torch.rand(10, 1, dtype=torch.float) xq = torch.quantize_per_tensor(x, 0.01, 30, torch.quint8) self.checkScriptable(model, [[x]], check_save_load=True) if train_mode: model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') model = prepare_qat(model) else: model.qconfig = torch.quantization.get_default_qconfig('qnnpack') model = prepare(model) # Check if observers and quant/dequant nodes are inserted self.checkNoPrepModules(model) self.checkObservers(model) # Calibrate model(xq.dequantize()) model = convert(model) def checkQuantized(model): self.checkNoPrepModules(model) self.assertEqual(type(model.myadd), torch.nn.quantized.QFunctional) self.assertEqual(type(model.mycat), torch.nn.quantized.QFunctional) self.assertEqual(type(model.myadd_relu), torch.nn.quantized.QFunctional) checkQuantized(model) self.checkScriptable(model, [[xq]], check_save_load=True) @skipIfNoFBGEMM class TestFusion(QuantizationTestCase): def test_fuse_module_train(self): model = ModelForFusion(default_qat_qconfig).train() # Test step by step fusion model = fuse_modules(model, ['conv1', 'bn1', 'relu1']) model = fuse_modules(model, ['sub1.conv', 'sub1.bn']) self.assertEqual(type(model.conv1), nni.ConvBnReLU2d, msg="Fused Conv + BN + Relu first layer") self.assertEqual(type(model.bn1), torch.nn.Identity, msg="Fused Conv + BN + Relu (skipped BN)") self.assertEqual(type(model.relu1), torch.nn.Identity, msg="Fused Conv + BN + Relu (skipped Relu)") self.assertEqual(type(model.sub1.conv), nni.ConvBn2d, msg="Fused submodule Conv + BN") self.assertEqual(type(model.sub1.bn), torch.nn.Identity, msg="Fused submodule Conv + BN (skipped BN)") self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d, msg="Non-fused submodule Conv") self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, msg="Non-fused submodule ReLU") model = prepare_qat(model) self.checkObservers(model) def checkQAT(model): self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d) self.assertEqual(type(model.sub1.bn), nn.Identity) self.assertEqual(type(model.sub2.conv), nn.Conv2d) self.assertEqual(type(model.sub2.relu), nn.ReLU) checkQAT(model) test_only_train_fn(model, self.img_data_1d_train) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.conv1), nniq.ConvReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.sub1.conv), nnq.Conv2d) self.assertEqual(type(model.sub1.bn), nn.Identity) self.assertEqual(type(model.sub2.conv), nn.Conv2d) self.assertEqual(type(model.sub2.relu), nn.ReLU) test_only_eval_fn(model, self.img_data_1d) with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"): checkQuantized(model) model = ModelForFusion(default_qat_qconfig).train() model = fuse_modules(model, [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']]) model = quantize_qat(model, test_only_train_fn, self.img_data_1d_train) with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"): checkQuantized(model) def test_fuse_module_eval(self): model = ModelForFusion(default_qconfig) model.eval() model = fuse_modules(model, [['conv3', 'bn3', 'relu4'], ['conv1', 'bn1', 'relu1'], ['conv2', 'relu2'], ['bn2', 'relu3'], ['sub1.conv', 'sub1.bn']]) self.assertEqual(type(model.conv1), nni.ConvReLU2d, msg="Fused Conv + BN + Relu first layer (BN is folded)") self.assertEqual(type(model.conv1[0]), nn.Conv2d, msg="Fused Conv + BN + Relu (Conv + folded BN only)") self.assertEqual(type(model.conv1[1]), nn.ReLU, msg="Fused Conv + BN + Relu second layer (Relu only)") self.assertEqual(type(model.bn1), nn.Identity, msg="Fused Conv + BN + Relu second layer (Skipped BN)") self.assertEqual(type(model.relu1), nn.Identity, msg="Fused Conv + BN + Relu second layer (Skipped Relu)") self.assertEqual(type(model.conv2), nni.ConvReLU3d, msg="Fused Conv + BN + Relu first layer (BN is folded)") self.assertEqual(type(model.bn2), nni.BNReLU3d, msg="Fused BN + Relu first layer (Relu is folded))") self.assertEqual(type(model.relu3), nn.Identity, msg="Fused BN + Relu second layer (Skipped Relu)") self.assertEqual(type(model.conv2[0]), nn.Conv3d, msg="Fused Conv + BN + Relu (Conv + folded BN only)") self.assertEqual(type(model.conv2[1]), nn.ReLU, msg="Fused Conv + BN + Relu second layer (Relu only)") self.assertEqual(type(model.relu2), nn.Identity, msg="Fused Conv + BN + Relu second layer (Skipped Relu)") self.assertEqual(type(model.conv3), nni.ConvReLU1d, msg="Fused Conv + Relu for Conv1d (folded BN)") self.assertEqual(type(model.conv3[0]), nn.Conv1d, msg="Fused Conv + Relu for Conv1d ") self.assertEqual(type(model.conv3[1]), nn.ReLU, msg="Fused Conv + Relu for Conv1d") self.assertEqual(type(model.bn3), nn.Identity, msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)") self.assertEqual(type(model.sub1.conv), nn.Conv2d, msg="Fused submodule Conv + folded BN") self.assertEqual(type(model.sub1.bn), nn.Identity, msg="Fused submodule (skipped BN)") self.assertEqual(type(model.sub2.conv), nn.Conv2d, msg="Non-fused submodule Conv") self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, msg="Non-fused submodule ReLU") model = prepare(model) self.checkObservers(model) test_only_eval_fn(model, self.img_data_1d) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.conv3), nniq.ConvReLU1d) self.assertEqual(type(model.conv1), nniq.ConvReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.sub1.conv), nnq.Conv2d) self.assertEqual(type(model.sub1.bn), nn.Identity) self.assertEqual(type(model.sub2.conv), nn.Conv2d) self.assertEqual(type(model.sub2.relu), nn.ReLU) self.assertEqual(type(model.bn2), nniq.BNReLU3d) test_only_eval_fn(model, self.img_data_1d) checkQuantized(model) model = ModelForFusion(default_qconfig).eval() model = fuse_modules(model, [['conv1', 'bn1', 'relu1'], ['conv2', 'relu2'], ['bn2', 'relu3'], ['sub1.conv', 'sub1.bn'], ['conv3', 'bn3', 'relu4']]) model = quantize(model, test_only_eval_fn, self.img_data_1d) checkQuantized(model) def test_fusion_sequential_model_train(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ModelWithSequentialFusion().train() model.to(torch.float) fuse_modules(model, [['conv1', 'relu1'] , ['features.0.0', 'features.0.1', 'features.0.2'], ['features.1.0', 'features.1.1', 'features.1.2'], ['features.2.0', 'features.2.1', 'features.2.2'], ['classifier.0', 'classifier.1']], inplace=True) self.assertEqual(type(model.conv1), nni.ConvReLU2d, msg="Fused Conv + Relu: nni.ConvReLU2d") self.assertEqual(type(model.conv1[0]), nn.Conv2d, msg="Fused Conv + Relu: Conv2d") self.assertEqual(type(model.conv1[1]), nn.ReLU, msg="Fused Conv + Relu: Relu") self.assertEqual(type(model.relu1), nn.Identity, msg="Fused Conv + Relu: Identity") for i in range(3): self.assertEqual(type(model.features[i][0]), nni.ConvBnReLU2d, msg="Fused submodule Conv + folded BN") self.assertEqual(type(model.features[i][1]), nn.Identity, msg="Fused submodule (skipped BN)") self.assertEqual(type(model.features[i][2]), nn.Identity, msg="Non-fused submodule Conv") self.assertEqual(type(model.classifier[0]), nni.LinearReLU) self.assertEqual(type(model.classifier[1]), nn.Identity) model.qconfig = torch.quantization.get_default_qat_qconfig(qengine) prepare_qat(model, inplace=True) self.checkObservers(model) model(self.img_data_2d[0][0]) def checkQAT(model): self.assertEqual(type(model.conv1), nniqat.ConvReLU2d) self.assertEqual(type(model.relu1), nn.Identity) for i in range(3): self.assertEqual(type(model.features[i][0]), nniqat.ConvBnReLU2d, msg="Fused submodule Conv + folded BN") self.assertEqual(type(model.features[i][1]), nn.Identity, msg="Fused submodule (skipped BN)") self.assertEqual(type(model.features[i][2]), nn.Identity, msg="Non-fused submodule Conv") self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU) self.assertEqual(type(model.classifier[1]), nn.Identity) checkQAT(model) model(self.img_data_2d[1][0]) convert(model, inplace=True) model(self.img_data_2d[1][0]) self.checkModelWithSequentialQuantized(model) def test_fusion_sequential_model_eval(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ModelWithSequentialFusion().eval() model.to(torch.float) fuse_modules(model, [['conv1', 'relu1'] , ['features.0.0', 'features.0.1', 'features.0.2'], ['features.1.0', 'features.1.1', 'features.1.2'], ['features.2.0', 'features.2.1', 'features.2.2'], ['classifier.0', 'classifier.1']], inplace=True) self.assertEqual(type(model.conv1), nni.ConvReLU2d, msg="Fused Conv + Relu: nni.ConvReLU2d") self.assertEqual(type(model.conv1[0]), nn.Conv2d, msg="Fused Conv + Relu: Conv2d") self.assertEqual(type(model.conv1[1]), nn.ReLU, msg="Fused Conv + Relu: Relu") self.assertEqual(type(model.relu1), nn.Identity, msg="Fused Conv + Relu: Identity") for i in range(3): self.assertEqual(type(model.features[i][0]), nni.ConvReLU2d, msg="Fused submodule Conv + folded BN") self.assertEqual(type(model.features[i][1]), nn.Identity, msg="Fused submodule (skipped BN)") self.assertEqual(type(model.features[i][2]), nn.Identity, msg="Non-fused submodule Conv") self.assertEqual(type(model.classifier[0]), nni.LinearReLU) self.assertEqual(type(model.classifier[1]), nn.Identity) model.qconfig = torch.quantization.get_default_qconfig(qengine) prepare(model, inplace=True) self.checkObservers(model) model(self.img_data_2d[0][0]) convert(model, inplace=True) model(self.img_data_2d[1][0]) self.checkModelWithSequentialQuantized(model) def checkModelWithSequentialQuantized(self, model): self.assertEqual(type(model.conv1), nniq.ConvReLU2d) self.assertEqual(type(model.relu1), nn.Identity) for i in range(3): self.assertEqual(type(model.features[i][0]), nniq.ConvReLU2d) self.assertEqual(type(model.features[i][1]), nn.Identity) self.assertEqual(type(model.features[i][2]), nn.Identity) self.assertEqual(type(model.classifier[0]), nniq.LinearReLU) self.assertEqual(type(model.classifier[1]), nn.Identity) def test_fusion_conv_with_bias(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ModelForFusionWithBias().train() # output with no fusion. out_ref = model(self.img_data_2d[0][0]) model.qconfig = QConfig(activation=torch.nn.Identity, weight=torch.nn.Identity) model = fuse_modules(model, [["conv1", "bn1", "relu1"], ["conv2", "bn2"]]) prep_model = prepare_qat(model, inplace=False) # output with fusion but no observers. out_fused = prep_model(self.img_data_2d[0][0]) self.assertEqual(out_ref, out_fused) model.qconfig = torch.quantization.get_default_qconfig(qengine) prepare_qat(model, inplace=True) model(self.img_data_2d[0][0]) def checkQAT(model): self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.conv2), nniqat.ConvBn2d) self.assertEqual(type(model.bn2), nn.Identity) checkQAT(model) def test_forward_hooks_preserved(self): r"""Test case that checks whether forward pre hooks of the first module and post forward hooks of the last module in modules list passed to fusion function preserved. (e.g. before fusion: [nn.Conv2d (with pre forward hooks), nn.BatchNorm2d, nn.ReLU (with post forward hooks)] after fusion: [nni.ConvBnReLU2d (with pre and post hooks), nn.Identity, nn.Identity]) """ model = ModelForFusion(default_qat_qconfig).train() counter = { 'pre_forwards': 0, 'forwards': 0, } fused = False def fw_pre_hook(fused_module_class, h_module, input): if fused: self.assertEqual(type(h_module), fused_module_class, "After fusion owner of the first module's forward pre hook is not a fused module") counter['pre_forwards'] += 1 def fw_hook(fused_module_class, h_module, input, output): if fused: self.assertEqual(type(h_module), fused_module_class, "After fusion owner of the last module's forward hook is not a fused module") counter['forwards'] += 1 # Registering two pre and two post forward hooks, thus expecting counter increment by two each inference model.conv1.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args)) model.sub1.conv.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBn2d, *args)) model.relu1.register_forward_hook(lambda *args: fw_hook(nni.ConvBnReLU2d, *args)) model.sub1.bn.register_forward_hook(lambda *args: fw_hook(nni.ConvBn2d, *args)) test_only_eval_fn(model, self.img_data_1d) self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d)) self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d)) model = fuse_modules(model, ['conv1', 'bn1', 'relu1']) model = fuse_modules(model, ['sub1.conv', 'sub1.bn']) fused = True before_fusion_pre_count = counter['pre_forwards'] before_fusion_post_count = counter['forwards'] test_only_eval_fn(model, self.img_data_1d) self.assertEqual(counter['pre_forwards'] - before_fusion_pre_count, 2 * len(self.img_data_1d)) self.assertEqual(counter['forwards'] - before_fusion_post_count, 2 * len(self.img_data_1d)) class TestModelNumerics(QuantizationTestCase): def test_float_quant_compare_per_tensor(self): for qengine in supported_qengines: with override_quantized_engine(qengine): torch.manual_seed(42) my_model = ModelMultipleOps().to(torch.float32) my_model.eval() calib_data = torch.rand(1024, 3, 15, 15, dtype=torch.float32) eval_data = torch.rand(1, 3, 15, 15, dtype=torch.float32) out_ref = my_model(eval_data) qModel = torch.quantization.QuantWrapper(my_model) qModel.eval() qModel.qconfig = torch.quantization.default_qconfig torch.quantization.fuse_modules(qModel.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.quantization.prepare(qModel, inplace=True) qModel(calib_data) torch.quantization.convert(qModel, inplace=True) out_q = qModel(eval_data) SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q)) # Quantized model output should be close to floating point model output numerically # Setting target SQNR to be 30 dB so that relative error is 1e-3 below the desired # output self.assertGreater(SQNRdB, 30, msg='Quantized model numerics diverge from float, expect SQNR > 30 dB') def test_float_quant_compare_per_channel(self): # Test for per-channel Quant torch.manual_seed(67) my_model = ModelMultipleOps().to(torch.float32) my_model.eval() calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32) eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32) out_ref = my_model(eval_data) q_model = torch.quantization.QuantWrapper(my_model) q_model.eval() q_model.qconfig = torch.quantization.default_per_channel_qconfig torch.quantization.fuse_modules(q_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.quantization.prepare(q_model) q_model(calib_data) torch.quantization.convert(q_model) out_q = q_model(eval_data) SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q)) # Quantized model output should be close to floating point model output numerically # Setting target SQNR to be 35 dB self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB') def test_fake_quant_true_quant_compare(self): for qengine in supported_qengines: with override_quantized_engine(qengine): torch.manual_seed(67) my_model = ModelMultipleOpsNoAvgPool().to(torch.float32) calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32) eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32) my_model.eval() out_ref = my_model(eval_data) fq_model = torch.quantization.QuantWrapper(my_model) fq_model.train() fq_model.qconfig = torch.quantization.default_qat_qconfig torch.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.quantization.prepare_qat(fq_model) fq_model.eval() fq_model.apply(torch.quantization.disable_fake_quant) fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) fq_model(calib_data) fq_model.apply(torch.quantization.enable_fake_quant) fq_model.apply(torch.quantization.disable_observer) out_fq = fq_model(eval_data) SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq)) # Quantized model output should be close to floating point model output numerically # Setting target SQNR to be 35 dB self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB') torch.quantization.convert(fq_model) out_q = fq_model(eval_data) SQNRdB = 20 * torch.log10(torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10)) self.assertGreater(SQNRdB, 60, msg='Fake quant and true quant numerics diverge, expect SQNR > 60 dB') # Test to compare weight only quantized model numerics and # activation only quantized model numerics with float def test_weight_only_activation_only_fakequant(self): for qengine in supported_qengines: with override_quantized_engine(qengine): torch.manual_seed(67) calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32) eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32) qconfigset = set([torch.quantization.default_weight_only_qconfig, torch.quantization.default_activation_only_qconfig]) SQNRTarget = [35, 45] for idx, qconfig in enumerate(qconfigset): my_model = ModelMultipleOpsNoAvgPool().to(torch.float32) my_model.eval() out_ref = my_model(eval_data) fq_model = torch.quantization.QuantWrapper(my_model) fq_model.train() fq_model.qconfig = qconfig torch.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.quantization.prepare_qat(fq_model) fq_model.eval() fq_model.apply(torch.quantization.disable_fake_quant) fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) fq_model(calib_data) fq_model.apply(torch.quantization.enable_fake_quant) fq_model.apply(torch.quantization.disable_observer) out_fq = fq_model(eval_data) SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq)) self.assertGreater(SQNRdB, SQNRTarget[idx], msg='Quantized model numerics diverge from float') class TestQuantizeONNXExport(JitTestCase): def _test_lower_graph_impl(self, model, data): model.qconfig = torch.quantization.default_qconfig model = torch.quantization.prepare(model) model = torch.quantization.convert(model) outputs = model(data) input_names = ["x"] def export_to_onnx(model, input, input_names): outputs = model(input) traced = torch.jit.trace(model, input) buf = io.BytesIO() torch.jit.save(traced, buf) buf.seek(0) model = torch.jit.load(buf) f = io.BytesIO() torch.onnx.export(model, input, f, input_names=input_names, example_outputs=outputs, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK) onnx_model = export_to_onnx(model, data, input_names) @skipIfNoFBGEMM def test_lower_graph_linear(self): model = torch.quantization.QuantWrapper(torch.nn.Linear(5, 10, bias=True)).to(dtype=torch.float) data_numpy = np.random.rand(1, 2, 5).astype(np.float32) data = torch.from_numpy(data_numpy).to(dtype=torch.float) self._test_lower_graph_impl(model, data) @skipIfNoFBGEMM def test_lower_graph_conv2d(self): model = torch.quantization.QuantWrapper(torch.nn.Conv2d(3, 5, 2, bias=True)).to(dtype=torch.float) data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32) data = torch.from_numpy(data_numpy).to(dtype=torch.float) self._test_lower_graph_impl(model, data) @skipIfNoFBGEMM @unittest.skip("onnx opset9 does not support quantize_per_tensor and caffe2 \ does not support conv3d") def test_lower_graph_conv3d(self): model = torch.quantization.QuantWrapper(torch.nn.Conv3d(3, 5, 2, bias=True)).to(dtype=torch.float) data_numpy = np.random.rand(1, 3, 6, 6, 6).astype(np.float32) data = torch.from_numpy(data_numpy).to(dtype=torch.float) self._test_lower_graph_impl(model, data) class TestDeprecatedJitQuantized(JitTestCase): @skipIfNoFBGEMM def test_rnn_cell_quantized(self): d_in, d_hid = 2, 2 for cell in [ torch.nn.LSTMCell(d_in, d_hid).float(), torch.nn.GRUCell(d_in, d_hid).float(), torch.nn.RNNCell(d_in, d_hid).float(), ]: if isinstance(cell, torch.nn.LSTMCell): num_chunks = 4 elif isinstance(cell, torch.nn.GRUCell): num_chunks = 3 elif isinstance(cell, torch.nn.RNNCell): num_chunks = 1 # Replace parameter values s.t. the range of values is exactly # 255, thus we will have 0 quantization error in the quantized # GEMM call. This i s for testing purposes. # # Note that the current implementation does not support # accumulation values outside of the range representable by a # 16 bit integer, instead resulting in a saturated value. We # must take care that in our test we do not end up with a dot # product that overflows the int16 range, e.g. # (255*127+255*127) = 64770. So, we hardcode the test values # here and ensure a mix of signedness. vals = [[100, -155], [100, -155], [-155, 100], [-155, 100], [100, -155], [-155, 100], [-155, 100], [100, -155]] vals = vals[:d_hid * num_chunks] cell.weight_ih = torch.nn.Parameter( torch.tensor(vals, dtype=torch.float), requires_grad=False) cell.weight_hh = torch.nn.Parameter( torch.tensor(vals, dtype=torch.float), requires_grad=False) ref = copy.deepcopy(cell) cell = torch.jit.quantized.quantize_rnn_cell_modules(cell) x = torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float) h0_vals = [[-155, 100], [-155, 155], [100, -155]] hx = torch.tensor(h0_vals, dtype=torch.float) if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell): cx = torch.tensor(h0_vals, dtype=torch.float) hiddens = (hx, cx) else: hiddens = hx if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell): class ScriptWrapper(torch.jit.ScriptModule): def __init__(self, cell): super(ScriptWrapper, self).__init__() self.cell = cell @torch.jit.script_method def forward(self, x, hiddens): # type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor] return self.cell(x, hiddens) else: class ScriptWrapper(torch.jit.ScriptModule): def __init__(self, cell): super(ScriptWrapper, self).__init__() self.cell = cell @torch.jit.script_method def forward(self, x, hiddens): # type: (torch.Tensor, torch.Tensor) -> torch.Tensor return self.cell(x, hiddens) cell = ScriptWrapper(cell) outs = cell(x, hiddens) cell = self.getExportImportCopyWithPacking(cell) outs = cell(x, hiddens) ref_outs = ref(x, hiddens) self.assertEqual(len(outs), len(ref_outs)) for out, ref_out in zip(outs, ref_outs): torch.testing.assert_allclose(out, ref_out) @skipIfNoFBGEMM def test_rnn_quantized(self): d_in, d_hid = 2, 2 for cell in [ torch.nn.LSTM(d_in, d_hid).float(), torch.nn.GRU(d_in, d_hid).float(), ]: # Replace parameter values s.t. the range of values is exactly # 255, thus we will have 0 quantization error in the quantized # GEMM call. This i s for testing purposes. # # Note that the current implementation does not support # accumulation values outside of the range representable by a # 16 bit integer, instead resulting in a saturated value. We # must take care that in our test we do not end up with a dot # product that overflows the int16 range, e.g. # (255*127+255*127) = 64770. So, we hardcode the test values # here and ensure a mix of signedness. vals = [[100, -155], [100, -155], [-155, 100], [-155, 100], [100, -155], [-155, 100], [-155, 100], [100, -155]] if isinstance(cell, torch.nn.LSTM): num_chunks = 4 elif isinstance(cell, torch.nn.GRU): num_chunks = 3 vals = vals[:d_hid * num_chunks] cell.weight_ih_l0 = torch.nn.Parameter( torch.tensor(vals, dtype=torch.float), requires_grad=False) cell.weight_hh_l0 = torch.nn.Parameter( torch.tensor(vals, dtype=torch.float), requires_grad=False) ref = copy.deepcopy(cell) cell_int8 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.int8) cell_fp16 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.float16) niter = 10 x = torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) h0_vals = [[-155, 100], [-155, 155], [100, -155]] hx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0) cx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0) if isinstance(ref, torch.nn.LSTM): hiddens = (hx, cx) elif isinstance(ref, torch.nn.GRU): hiddens = hx ref_out, ref_hid = ref(x, hiddens) # Compare int8 quantized to unquantized output_int8, final_hiddens_int8 = cell_int8(x, hiddens) torch.testing.assert_allclose(output_int8, ref_out) for out, ref in zip(final_hiddens_int8, ref_hid): torch.testing.assert_allclose(out, ref) # Compare fp16 quantized to unquantized output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens) torch.testing.assert_allclose(output_fp16, ref_out) for out, ref in zip(final_hiddens_fp16, ref_hid): torch.testing.assert_allclose(out, ref) def compare_quantized_unquantized(ScriptWrapper, cell): wrapper = ScriptWrapper(cell) # Compare quantize scripted module to unquantized script_out, script_hid = wrapper(x, hiddens) torch.testing.assert_allclose(script_out, ref_out) for out, ref in zip(script_hid, ref_hid): torch.testing.assert_allclose(out, ref) # Compare export/import to unquantized export_import_wrapper = self.getExportImportCopyWithPacking(wrapper) ei_out, ei_hid = export_import_wrapper(x, hiddens) torch.testing.assert_allclose(ei_out, ref_out) for out, ref in zip(ei_hid, ref_hid): torch.testing.assert_allclose(out, ref) if isinstance(cell, torch.jit.quantized.QuantizedGRU): class ScriptWrapper(torch.jit.ScriptModule): def __init__(self, cell): super(ScriptWrapper, self).__init__() self.cell = cell @torch.jit.script_method def forward(self, x, hiddens): # type: (torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] return self.cell(x, hiddens) compare_quantized_unquantized(ScriptWrapper, cell) elif isinstance(cell, torch.jit.quantized.QuantizedLSTM): for cell in [cell_int8, cell_fp16]: class ScriptWrapper(torch.jit.ScriptModule): def __init__(self, cell): super(ScriptWrapper, self).__init__() self.cell = cell @torch.jit.script_method def forward(self, x, hiddens): # type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor]) # -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] return self.cell(x, hiddens) compare_quantized_unquantized(ScriptWrapper, cell) if 'fbgemm' in torch.backends.quantized.supported_engines: # Suppression: using deprecated quant api @suppress_warnings def test_quantization_modules(self): K1, N1 = 2, 2 class FooBar(torch.nn.Module): def __init__(self): super(FooBar, self).__init__() self.linear1 = torch.nn.Linear(K1, N1).float() def forward(self, x): x = self.linear1(x) return x fb = FooBar() fb.linear1.weight = torch.nn.Parameter( torch.tensor([[-150, 100], [100, -150]], dtype=torch.float), requires_grad=False) fb.linear1.bias = torch.nn.Parameter(torch.zeros_like(fb.linear1.bias), requires_grad=False) x = (torch.rand(1, K1).float() - 0.5) / 10.0 value = torch.tensor([[100, -150]], dtype=torch.float) y_ref = fb(value) fb_int8 = torch.jit.quantized.quantize_linear_modules(fb) traced_int8 = torch.jit.trace(fb_int8, (x,)) fb_int8 = self.getExportImportCopyWithPacking(traced_int8) y_int8 = fb_int8(value) fb_fp16 = torch.jit.quantized.quantize_linear_modules(fb, torch.float16) traced_fp16 = torch.jit.trace(fb_fp16, (x,)) fb_fp16 = self.getExportImportCopyWithPacking(traced_fp16) y_fp16 = fb_fp16(value) torch.testing.assert_allclose(y_int8, y_ref, rtol=0.0001, atol=1e-3) torch.testing.assert_allclose(y_fp16, y_ref, rtol=0.0001, atol=1e-3) def _test_pickle_checkpoint_qtensor(self, device): with TemporaryFileName() as fname: class M(torch.jit.ScriptModule): __constants__ = ['fname'] def __init__(self): super(M, self).__init__() self.fname = fname @torch.jit.script_method def forward(self, x, y): torch.save((x, y), self.fname) return y q = torch.quantize_per_tensor( torch.rand(2, 3, dtype=torch.float), scale=0.1, zero_point=10, dtype=torch.quint8).to(device) qc = torch.quantize_per_channel( torch.rand(2, 3, dtype=torch.float), scales=torch.tensor([0.1, 0.5, 0.01]), zero_points=torch.tensor([10, 0, 20]), axis=1, dtype=torch.quint8).to(device) m = M() m(q, qc) with open(fname, "rb") as handle: loaded_q, loaded_qc = torch.load(fname) self.assertEqual(loaded_q, q) self.assertEqual(loaded_qc, qc) def test_pickle_checkpoint_qtensor(self): self._test_pickle_checkpoint_qtensor('cpu') def test_serialize_qtensor(self): class SimpleQTensor(torch.jit.ScriptModule): def __init__(self, per_channel): super(SimpleQTensor, self).__init__() x = torch.rand(5, 5).float() if not per_channel: x_q = torch.quantize_per_tensor(x, 0.2, 10, torch.quint8) else: s = torch.rand(5, dtype=torch.float64) + 0.1 zp = torch.randint(5, 15, (5,)) x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8) self.register_buffer('x', x_q) @torch.jit.script_method def forward(self): return self.x for per_channel in [False, True]: model = SimpleQTensor(per_channel) buffer = io.BytesIO() torch.jit.save(model, buffer) buffer.seek(0) model_loaded = torch.jit.load(buffer) self.assertEqual(model_loaded(), model()) @skipIfNoFBGEMM def test_erase_class_tensor_shapes(self): class Linear(torch.nn.Module): def __init__(self, in_features, out_features): super(Linear, self).__init__() qweight = torch._empty_affine_quantized( [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8) self._packed_weight = torch.ops.quantized.linear_prepack(qweight) @torch.jit.export def __getstate__(self): return (torch.ops.quantized.linear_unpack(self._packed_weight)[0], self.training) def forward(self): return self._packed_weight @torch.jit.export def __setstate__(self, state): self._packed_weight = torch.ops.quantized.linear_prepack(state[0]) self.training = state[1] @property def weight(self): return torch.ops.quantized.linear_unpack(self._packed_weight)[0] @weight.setter def weight(self, w): self._packed_weight = torch.ops.quantized.linear_prepack(w) with torch._jit_internal._disable_emit_hooks(): x = torch.jit.script(Linear(10, 10)) torch._C._jit_pass_erase_shape_information(x.graph) if __name__ == '__main__': raise RuntimeError("This test file is not meant to be run directly, use:\n\n" "\tpython test/test_quantization.py TESTNAME\n\n" "instead.")