import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.intrinsic as nni import torch.nn.intrinsic.quantized as nnq_fused import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd import torch.nn.quantized.functional as qF import torch.quantization from common_quantization import QuantizationTestCase, prepare_dynamic from common_quantized import _calculate_dynamic_qparams, override_quantized_engine from common_utils import run_tests, IS_PPC, TEST_WITH_UBSAN from hypothesis import assume, given from hypothesis import strategies as st from hypothesis_utils import no_deadline import io import numpy as np import unittest ''' Note that tests in this file are just API test, to make sure we wrapped the quantized operator implementations correctly in the user facing APIs, these are not correctness test for the underlying quantized operators. For correctness test please see `caffe2/test/test_quantized.py`. ''' def _make_conv_test_input( batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, X_scale, X_zero_point, W_scale, W_zero_point, use_bias, use_channelwise, ): in_channels = in_channels_per_group * groups out_channels = out_channels_per_group * groups (X_value_min, X_value_max) = (0, 4) X_init = torch.randint( X_value_min, X_value_max, (batch_size, in_channels,) + input_feature_map_size) X = X_scale * (X_init - X_zero_point).float() X_q = torch.quantize_per_tensor( X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8) W_scale = W_scale * out_channels W_zero_point = W_zero_point * out_channels # Resize W_scale and W_zero_points arrays equal to out_channels W_scale = W_scale[:out_channels] W_zero_point = W_zero_point[:out_channels] # For testing, we use small values for weights and for activations so that # no overflow occurs in vpmaddubsw instruction. If the overflow occurs in # qconv implementation and if there is no overflow. # In reference we can't exactly match the results with reference. # Please see the comment in qconv implementation file # aten/src/ATen/native/quantized/cpu/qconv.cpp for more details. (W_value_min, W_value_max) = (-5, 5) # The operator expects them in the format # (out_channels, in_channels/groups,) + kernel_size W_init = torch.randint( W_value_min, W_value_max, (out_channels, in_channels_per_group,) + kernel_size) b_init = torch.randint(0, 10, (out_channels,)) if use_channelwise: W_shape = (-1, 1) + (1,) * len(kernel_size) W_scales_tensor = torch.tensor(W_scale, dtype=torch.float) W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float) W = W_scales_tensor.reshape(*W_shape) * ( W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float() b = X_scale * W_scales_tensor * b_init.float() W_q = torch.quantize_per_channel( W, W_scales_tensor, W_zero_points_tensor.long(), 0, dtype=torch.qint8) else: W = W_scale[0] * (W_init - W_zero_point[0]).float() b = X_scale * W_scale[0] * b_init.float() W_q = torch.quantize_per_tensor( W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8) return (X, X_q, W, W_q, b if use_bias else None) class FunctionalAPITest(QuantizationTestCase): def test_relu_api(self): X = torch.arange(-5, 5, dtype=torch.float) scale = 2.0 zero_point = 1 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, dtype=torch.quint8) qY = torch.relu(qX) qY_hat = qF.relu(qX) self.assertEqual(qY, qY_hat) def _test_conv_api_impl( self, qconv_fn, conv_fn, batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, padding, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_channelwise, ): for i in range(len(kernel_size)): assume(input_feature_map_size[i] + 2 * padding[i] >= dilation[i] * (kernel_size[i] - 1) + 1) (X, X_q, W, W_q, b) = _make_conv_test_input( batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, X_scale, X_zero_point, W_scale, W_zero_point, use_bias, use_channelwise) Y_exp = conv_fn(X, W, b, stride, padding, dilation, groups) Y_exp = torch.quantize_per_tensor( Y_exp, scale=Y_scale, zero_point=Y_zero_point, dtype=torch.quint8) Y_act = qconv_fn( X_q, W_q, b, stride, padding, dilation, groups, padding_mode="zeros", scale=Y_scale, zero_point=Y_zero_point) # Make sure the results match # assert_array_almost_equal compares using the following formula: # abs(desired-actual) < 1.5 * 10**(-decimal) # (https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_almost_equal.html) # We use decimal = 0 to ignore off-by-1 differences between reference # and test. Off-by-1 differences arise due to the order of round and # zero_point addition operation, i.e., if addition followed by round is # used by reference and round followed by addition is used by test, the # results may differ by 1. # For example, the result of round(2.5) + 1 is 3 while round(2.5 + 1) is # 4 assuming the rounding mode is round-to-nearest, ties-to-even. np.testing.assert_array_almost_equal( Y_exp.int_repr().numpy(), Y_act.int_repr().numpy(), decimal=0) @no_deadline @given(batch_size=st.integers(1, 3), in_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), H=st.integers(4, 16), W=st.integers(4, 16), out_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), groups=st.integers(1, 4), kernel_h=st.integers(1, 7), kernel_w=st.integers(1, 7), stride_h=st.integers(1, 2), stride_w=st.integers(1, 2), pad_h=st.integers(0, 2), pad_w=st.integers(0, 2), dilation=st.integers(1, 2), X_scale=st.floats(1.2, 1.6), X_zero_point=st.integers(0, 4), W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2), Y_scale=st.floats(4.2, 5.6), Y_zero_point=st.integers(0, 4), use_bias=st.booleans(), use_channelwise=st.booleans(), qengine=st.sampled_from(("qnnpack", "fbgemm"))) def test_conv2d_api( self, batch_size, in_channels_per_group, H, W, out_channels_per_group, groups, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_channelwise, qengine, ): # Tests the correctness of the conv2d function. if qengine not in torch.backends.quantized.supported_engines: return if qengine == 'qnnpack': if IS_PPC or TEST_WITH_UBSAN: return use_channelwise = False input_feature_map_size = (H, W) kernel_size = (kernel_h, kernel_w) stride = (stride_h, stride_w) padding = (pad_h, pad_w) dilation = (dilation, dilation) with override_quantized_engine(qengine): qconv_fn = qF.conv2d conv_fn = F.conv2d self._test_conv_api_impl( qconv_fn, conv_fn, batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, padding, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_channelwise) @no_deadline @given(batch_size=st.integers(1, 3), in_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), D=st.integers(4, 8), H=st.integers(4, 8), W=st.integers(4, 8), out_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), groups=st.integers(1, 4), kernel_d=st.integers(1, 4), kernel_h=st.integers(1, 4), kernel_w=st.integers(1, 4), stride_d=st.integers(1, 2), stride_h=st.integers(1, 2), stride_w=st.integers(1, 2), pad_d=st.integers(0, 2), pad_h=st.integers(0, 2), pad_w=st.integers(0, 2), dilation=st.integers(1, 2), X_scale=st.floats(1.2, 1.6), X_zero_point=st.integers(0, 4), W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2), Y_scale=st.floats(4.2, 5.6), Y_zero_point=st.integers(0, 4), use_bias=st.booleans(), use_channelwise=st.booleans(), qengine=st.sampled_from(("fbgemm",))) def test_conv3d_api( self, batch_size, in_channels_per_group, D, H, W, out_channels_per_group, groups, kernel_d, kernel_h, kernel_w, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_channelwise, qengine, ): # Tests the correctness of the conv3d function. # Currently conv3d only supports FbGemm engine if qengine not in torch.backends.quantized.supported_engines: return input_feature_map_size = (D, H, W) kernel_size = (kernel_d, kernel_h, kernel_w) stride = (stride_d, stride_h, stride_w) padding = (pad_d, pad_h, pad_w) dilation = (dilation, dilation, dilation) with override_quantized_engine(qengine): qconv_fn = qF.conv3d conv_fn = F.conv3d self._test_conv_api_impl( qconv_fn, conv_fn, batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, padding, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_channelwise) class DynamicModuleAPITest(QuantizationTestCase): @no_deadline @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" " with instruction set support avx2 or newer.") @given( batch_size=st.integers(1, 5), in_features=st.integers(16, 32), out_features=st.integers(4, 8), use_bias=st.booleans(), use_default_observer=st.booleans(), ) def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_default_observer): """test API functionality for nn.quantized.dynamic.Linear""" W = torch.rand(out_features, in_features).float() W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8) W_q = torch.quantize_per_tensor(W, W_scale, W_zp, torch.qint8) X = torch.rand(batch_size, in_features).float() B = torch.rand(out_features).float() if use_bias else None qlinear = nnqd.Linear(in_features, out_features) # Run module with default-initialized parameters. # This tests that the constructor is correct. qlinear.set_weight_bias(W_q, B) qlinear(X) # Simple round-trip test to ensure weight()/set_weight() API self.assertEqual(qlinear.weight(), W_q) W_pack = qlinear._packed_params._packed_params Z_dq = qlinear(X) # Check if the module implementation matches calling the # ops directly Z_ref = torch.ops.quantized.linear_dynamic(X, W_pack) self.assertEqual(Z_ref, Z_dq) # Test serialization of dynamic quantized Linear Module using state_dict model_dict = qlinear.state_dict() self.assertEqual(model_dict['_packed_params.weight'], W_q) if use_bias: self.assertEqual(model_dict['_packed_params.bias'], B) b = io.BytesIO() torch.save(model_dict, b) b.seek(0) loaded_dict = torch.load(b) for key in model_dict: self.assertEqual(model_dict[key], loaded_dict[key]) loaded_qlinear = nnqd.Linear(in_features, out_features) loaded_qlinear.load_state_dict(loaded_dict) linear_unpack = torch.ops.quantized.linear_unpack self.assertEqual(linear_unpack(qlinear._packed_params._packed_params), linear_unpack(loaded_qlinear._packed_params._packed_params)) if use_bias: self.assertEqual(qlinear.bias(), loaded_qlinear.bias()) self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) self.assertTrue(hasattr(qlinear, '_packed_params')) self.assertTrue(hasattr(loaded_qlinear, '_packed_params')) self.assertTrue(hasattr(qlinear, '_weight_bias')) self.assertTrue(hasattr(loaded_qlinear, '_weight_bias')) self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias()) self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params)) Z_dq2 = qlinear(X) self.assertEqual(Z_dq, Z_dq2) # The below check is meant to ensure that `torch.save` and `torch.load` # serialization works, however it is currently broken by the following: # https://github.com/pytorch/pytorch/issues/24045 # # Instead, we currently check that the proper exception is thrown on save. # # b = io.BytesIO() # torch.save(qlinear, b) # b.seek(0) # loaded = torch.load(b) # self.assertEqual(qlinear.weight(), loaded.weight()) # self.assertEqual(qlinear.zero_point, loaded.zero_point) # with self.assertRaisesRegex(RuntimeError, r'torch.save\(\) is not currently supported'): b = io.BytesIO() torch.save(qlinear, b) # Test JIT self.checkScriptable(qlinear, list(zip([X], [Z_ref])), check_save_load=True) # Test from_float float_linear = torch.nn.Linear(in_features, out_features).float() if use_default_observer: float_linear.qconfig = torch.quantization.default_dynamic_qconfig prepare_dynamic(float_linear) float_linear(X.float()) quantized_float_linear = nnqd.Linear.from_float(float_linear) # Smoke test to make sure the module actually runs quantized_float_linear(X) # Smoke test extra_repr self.assertTrue('QuantizedLinear' in str(quantized_float_linear)) class ModuleAPITest(QuantizationTestCase): def test_relu(self): relu_module = nnq.ReLU() relu6_module = nnq.ReLU6() x = torch.arange(-10, 10, dtype=torch.float) y_ref = torch.relu(x) y6_ref = torch.nn.modules.ReLU6()(x) qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.qint32) qy = relu_module(qx) qy6 = relu6_module(qx) self.assertEqual(y_ref, qy.dequantize(), message="ReLU module API failed") self.assertEqual(y6_ref, qy6.dequantize(), message="ReLU6 module API failed") @no_deadline @given( batch_size=st.integers(1, 5), in_features=st.integers(16, 32), out_features=st.integers(4, 8), use_bias=st.booleans(), use_fused=st.booleans(), per_channel=st.booleans(), qengine=st.sampled_from(("qnnpack", "fbgemm")) ) def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel, qengine): """test API functionality for nn.quantized.linear and nn.intrinsic.quantized.linear_relu""" if qengine not in torch.backends.quantized.supported_engines: return if qengine == 'qnnpack': if IS_PPC or TEST_WITH_UBSAN: return per_channel = False with override_quantized_engine(qengine): W = torch.rand(out_features, in_features).float() if per_channel: scale_tensor = torch.ones(out_features, dtype=torch.double) zero_point_tensor = torch.zeros(out_features, dtype=torch.long) for i in range(len(scale_tensor)): scale_tensor[i] = (i + 1.0) / 255.0 W_q = torch.quantize_per_channel(W, scales=scale_tensor, zero_points=zero_point_tensor, axis=0, dtype=torch.qint8) else: W_q = torch.quantize_per_tensor(W, 0.1, 4, torch.qint8) X = torch.rand(batch_size, in_features).float() X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8) B = torch.rand(out_features).float() if use_bias else None scale = 0.5 zero_point = 3 if use_fused: qlinear = nnq_fused.LinearReLU(in_features, out_features) else: qlinear = nnq.Linear(in_features, out_features) # Run module with default-initialized parameters. # This tests that the constructor is correct. qlinear(X_q) qlinear.set_weight_bias(W_q, B) # Simple round-trip test to ensure weight()/set_weight() API self.assertEqual(qlinear.weight(), W_q) W_pack = qlinear._packed_params._packed_params qlinear.scale = float(scale) qlinear.zero_point = int(zero_point) Z_q = qlinear(X_q) # Check if the module implementation matches calling the # ops directly if use_fused: Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point) self.assertTrue('QuantizedLinearReLU' in str(qlinear)) else: Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point) self.assertTrue('QuantizedLinear' in str(qlinear)) self.assertEqual(Z_ref, Z_q) # Test serialization of quantized Linear Module using state_dict model_dict = qlinear.state_dict() self.assertEqual(model_dict['_packed_params.weight'], W_q) if use_bias: self.assertEqual(model_dict['_packed_params.bias'], B) b = io.BytesIO() torch.save(model_dict, b) b.seek(0) loaded_dict = torch.load(b) for key in model_dict: self.assertEqual(model_dict[key], loaded_dict[key]) if use_fused: loaded_qlinear = nnq_fused.LinearReLU(in_features, out_features) else: loaded_qlinear = nnq.Linear(in_features, out_features) loaded_qlinear.load_state_dict(loaded_dict) linear_unpack = torch.ops.quantized.linear_unpack self.assertEqual(linear_unpack(qlinear._packed_params._packed_params), linear_unpack(loaded_qlinear._packed_params._packed_params)) if use_bias: self.assertEqual(qlinear.bias(), loaded_qlinear.bias()) self.assertEqual(qlinear.scale, loaded_qlinear.scale) self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point) self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) self.assertTrue(hasattr(qlinear, '_packed_params')) self.assertTrue(hasattr(loaded_qlinear, '_packed_params')) self.assertTrue(hasattr(qlinear, '_weight_bias')) self.assertTrue(hasattr(loaded_qlinear, '_weight_bias')) self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias()) self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params)) Z_q2 = loaded_qlinear(X_q) self.assertEqual(Z_q, Z_q2) # The below check is meant to ensure that `torch.save` and `torch.load` # serialization works, however it is currently broken by the following: # https://github.com/pytorch/pytorch/issues/24045 # # Instead, we currently check that the proper exception is thrown on save. # # b = io.BytesIO() # torch.save(qlinear, b) # b.seek(0) # loaded = torch.load(b) # self.assertEqual(qlinear.weight(), loaded.weight()) # self.assertEqual(qlinear.scale, loaded.scale) # self.assertEqual(qlinear.zero_point, loaded.zero_point) # with self.assertRaisesRegex(RuntimeError, r'torch.save\(\) is not currently supported'): b = io.BytesIO() torch.save(qlinear, b) # Test JIT self.checkScriptable(qlinear, list(zip([X_q], [Z_ref])), check_save_load=True) # Test from_float. float_linear = torch.nn.Linear(in_features, out_features).float() float_linear.qconfig = torch.quantization.default_qconfig torch.quantization.prepare(float_linear, inplace=True) float_linear(X.float()) # Sequential allows swapping using "convert". quantized_float_linear = torch.nn.Sequential(float_linear) quantized_float_linear = torch.quantization.convert(quantized_float_linear, inplace=True) # Smoke test to make sure the module actually runs quantized_float_linear(X_q) # Smoke test extra_repr self.assertTrue('QuantizedLinear' in str(quantized_float_linear)) def test_quant_dequant_api(self): r = torch.tensor([[1., -1.], [1., -1.]], dtype=torch.float) scale, zero_point, dtype = 1.0, 2, torch.qint8 # testing Quantize API qr = torch.quantize_per_tensor(r, scale, zero_point, dtype) quant_m = nnq.Quantize(scale, zero_point, dtype) qr2 = quant_m(r) self.assertEqual(qr, qr2) # testing Dequantize API rqr = qr.dequantize() dequant_m = nnq.DeQuantize() rqr2 = dequant_m(qr2) self.assertEqual(rqr, rqr2) def _test_conv_api_impl( self, module_name, qconv_module, conv_module, batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, padding, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise, ): for i in range(len(kernel_size)): assume(input_feature_map_size[i] + 2 * padding[i] >= dilation[i] * (kernel_size[i] - 1) + 1) in_channels = in_channels_per_group * groups out_channels = out_channels_per_group * groups (X, X_q, W, W_q, b) = _make_conv_test_input( batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, X_scale, X_zero_point, W_scale, W_zero_point, use_bias, use_channelwise) qconv_module.set_weight_bias(W_q, b) qconv_module.scale = Y_scale qconv_module.zero_point = Y_zero_point if use_fused: conv_module[0].weight.data = W if use_bias: conv_module[0].bias.data = b else: conv_module.weight.data = W if use_bias: conv_module.bias.data = b # Test members self.assertTrue(module_name in str(qconv_module)) self.assertTrue(hasattr(qconv_module, '_packed_params')) self.assertTrue(hasattr(qconv_module, 'scale')) self.assertTrue(hasattr(qconv_module, 'zero_point')) # Test properties self.assertEqual(W_q, qconv_module.weight()) if use_bias: self.assertEqual(b, qconv_module.bias()) self.assertEqual(Y_scale, qconv_module.scale) self.assertEqual(Y_zero_point, qconv_module.zero_point) # Test forward Y_exp = conv_module(X) Y_exp = torch.quantize_per_tensor( Y_exp, scale=Y_scale, zero_point=Y_zero_point, dtype=torch.quint8) Y_act = qconv_module(X_q) # Make sure the results match # assert_array_almost_equal compares using the following formula: # abs(desired-actual) < 1.5 * 10**(-decimal) # (https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_almost_equal.html) # We use decimal = 0 to ignore off-by-1 differences between reference # and test. Off-by-1 differences arise due to the order of round and # zero_point addition operation, i.e., if addition followed by round is # used by reference and round followed by addition is used by test, the # results may differ by 1. # For example, the result of round(2.5) + 1 is 3 while round(2.5 + 1) is # 4 assuming the rounding mode is round-to-nearest, ties-to-even. np.testing.assert_array_almost_equal( Y_exp.int_repr().numpy(), Y_act.int_repr().numpy(), decimal=0) # Test serialization of quantized Conv Module using state_dict model_dict = qconv_module.state_dict() self.assertEqual(W_q, model_dict['weight']) if use_bias: self.assertEqual(b, model_dict['bias']) bytes_io = io.BytesIO() torch.save(model_dict, bytes_io) bytes_io.seek(0) loaded_dict = torch.load(bytes_io) for key in loaded_dict: self.assertEqual(model_dict[key], loaded_dict[key]) loaded_qconv_module = type(qconv_module)( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode="zeros") loaded_qconv_module.load_state_dict(loaded_dict) self.assertTrue(dir(loaded_qconv_module) == dir(qconv_module)) self.assertTrue(module_name in str(loaded_qconv_module)) self.assertTrue(hasattr(loaded_qconv_module, '_packed_params')) self.assertTrue(hasattr(loaded_qconv_module, '_weight_bias')) self.assertEqual(qconv_module.weight(), loaded_qconv_module.weight()) if use_bias: self.assertEqual(qconv_module.bias(), loaded_qconv_module.bias()) self.assertEqual(qconv_module.scale, loaded_qconv_module.scale) self.assertEqual(qconv_module.zero_point, loaded_qconv_module.zero_point) Y_loaded = loaded_qconv_module(X_q) np.testing.assert_array_almost_equal( Y_exp.int_repr().numpy(), Y_loaded.int_repr().numpy(), decimal=0) # The below check is meant to ensure that `torch.save` and `torch.load` # serialization works, however it is currently broken by the following: # https://github.com/pytorch/pytorch/issues/24045 # # Instead, we currently check that the proper exception is thrown on # save. # # b = io.BytesIO() # torch.save(conv_under_test, b) # b.seek(0) # loaded_conv = torch.load(b) # # self.assertEqual(loaded_qconv_module.bias(), qconv_module.bias()) # self.assertEqual(loaded_qconv_module.scale, qconv_module.scale) # self.assertEqual(loaded_qconv_module.zero_point, # qconv_module.zero_point) # with self.assertRaisesRegex( RuntimeError, r'torch.save\(\) is not currently supported' ): bytes_io = io.BytesIO() torch.save(qconv_module, bytes_io) # JIT testing self.checkScriptable( qconv_module, list(zip([X_q], [Y_exp])), check_save_load=True) # Test from_float conv_module.qconfig = torch.quantization.default_qconfig torch.quantization.prepare(conv_module, inplace=True) conv_module(X.float()) converted_qconv_module = torch.nn.Sequential(conv_module) torch.quantization.convert(converted_qconv_module, inplace=True) # Smoke test to make sure the module actually runs if use_bias: if use_fused: self.assertEqual(conv_module[0].bias, converted_qconv_module[0].bias()) else: self.assertEqual(conv_module.bias, converted_qconv_module[0].bias()) # Smoke test extra_repr self.assertTrue(module_name in str(converted_qconv_module)) @no_deadline @given(batch_size=st.integers(1, 3), in_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), H=st.integers(4, 16), W=st.integers(4, 16), out_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), groups=st.integers(1, 4), kernel_h=st.integers(1, 7), kernel_w=st.integers(1, 7), stride_h=st.integers(1, 2), stride_w=st.integers(1, 2), pad_h=st.integers(0, 2), pad_w=st.integers(0, 2), dilation=st.integers(1, 2), X_scale=st.floats(1.2, 1.6), X_zero_point=st.integers(0, 4), W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2), Y_scale=st.floats(4.2, 5.6), Y_zero_point=st.integers(0, 4), use_bias=st.booleans(), use_fused=st.booleans(), use_channelwise=st.booleans(), qengine=st.sampled_from(("qnnpack", "fbgemm"))) def test_conv2d_api( self, batch_size, in_channels_per_group, H, W, out_channels_per_group, groups, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise, qengine, ): # Tests the correctness of the conv2d module. if qengine not in torch.backends.quantized.supported_engines: return if qengine == 'qnnpack': if IS_PPC or TEST_WITH_UBSAN: return use_channelwise = False in_channels = in_channels_per_group * groups out_channels = out_channels_per_group * groups input_feature_map_size = (H, W) kernel_size = (kernel_h, kernel_w) stride = (stride_h, stride_w) padding = (pad_h, pad_w) dilation = (dilation, dilation) with override_quantized_engine(qengine): if use_fused: module_name = "QuantizedConvReLU2d" qconv_module = nnq_fused.ConvReLU2d( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode="zeros") else: module_name = "QuantizedConv2d" qconv_module = nnq.Conv2d( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode="zeros") conv_module = nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode="zeros") if use_fused: relu_module = nn.ReLU() conv_module = nni.ConvReLU2d(conv_module, relu_module) conv_module = conv_module.float() self._test_conv_api_impl( module_name, qconv_module, conv_module, batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, padding, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise) @given(batch_size=st.integers(1, 3), in_channels_per_group=st.sampled_from([2, 4, 5, 8, 16]), D=st.integers(3, 6), H=st.integers(3, 6), W=st.integers(3, 6), out_channels_per_group=st.sampled_from([2, 4, 5, 8, 16]), groups=st.integers(1, 4), kernel_d=st.integers(1, 3), kernel_h=st.integers(1, 3), kernel_w=st.integers(1, 3), stride_d=st.integers(1, 2), stride_h=st.integers(1, 2), stride_w=st.integers(1, 2), pad_d=st.integers(0, 1), pad_h=st.integers(0, 1), pad_w=st.integers(0, 1), dilation=st.integers(1, 2), X_scale=st.floats(1.2, 1.6), X_zero_point=st.integers(0, 4), W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2), Y_scale=st.floats(4.2, 5.6), Y_zero_point=st.integers(0, 4), use_bias=st.booleans(), use_fused=st.booleans(), use_channelwise=st.booleans(), qengine=st.sampled_from(("fbgemm",))) def test_conv3d_api( self, batch_size, in_channels_per_group, D, H, W, out_channels_per_group, groups, kernel_d, kernel_h, kernel_w, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_channelwise, use_fused, qengine, ): # Tests the correctness of the conv3d module. if qengine not in torch.backends.quantized.supported_engines: return in_channels = in_channels_per_group * groups out_channels = out_channels_per_group * groups input_feature_map_size = (D, H, W) kernel_size = (kernel_d, kernel_h, kernel_w) stride = (stride_d, stride_h, stride_w) padding = (pad_d, pad_h, pad_w) dilation = (dilation, dilation, dilation) with override_quantized_engine(qengine): if use_fused: module_name = "QuantizedConvReLU3d" qconv_module = nnq_fused.ConvReLU3d( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode="zeros") else: module_name = "QuantizedConv3d" qconv_module = nnq.Conv3d( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode="zeros") conv_module = nn.Conv3d( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode="zeros") if use_fused: relu_module = nn.ReLU() conv_module = nni.ConvReLU3d(conv_module, relu_module) conv_module = conv_module.float() self._test_conv_api_impl( module_name, qconv_module, conv_module, batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, padding, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise) def test_pool_api(self): """Tests the correctness of the pool module. The correctness is defined against the functional implementation. """ N, C, H, W = 10, 10, 10, 3 kwargs = { 'kernel_size': 2, 'stride': None, 'padding': 0, 'dilation': 1 } scale, zero_point = 1.0 / 255, 128 X = torch.randn(N, C, H, W, dtype=torch.float32) qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, dtype=torch.quint8) qX_expect = torch.nn.functional.max_pool2d(qX, **kwargs) pool_under_test = torch.nn.quantized.MaxPool2d(**kwargs) qX_hat = pool_under_test(qX) self.assertEqual(qX_expect, qX_hat) # JIT Testing self.checkScriptable(pool_under_test, list(zip([X], [qX_expect]))) if __name__ == '__main__': run_tests()