from __future__ import absolute_import, division, print_function, unicode_literals import torch import torch.jit import numpy as np import unittest from common_utils import run_tests # Reference method for quantizing a tensor. def _fake_quantize_per_tensor_affine_reference(X, scale, zero_point, num_bits): quant_min, quant_max = 0, 2 ** num_bits - 1 res = (np.clip(np.round(X / scale) + zero_point, quant_min, quant_max) - zero_point) * scale res = res.reshape(X.shape) return res # Reference method for the gradient of the quantizer. def _fake_quantize_per_tensor_affine_grad_reference(X, dY, scale, zero_point, num_bits): Xq = np.round(X / scale) + zero_point quant_min, quant_max = 0, 2 ** num_bits - 1 mask = np.logical_and(Xq >= quant_min, Xq <= quant_max) res = dY[mask].reshape(dY.shape) return res NP_RANDOM_SEED = 19 class TestFakeQuantizePerTensorAffine(unittest.TestCase): """Tests the forward path of the FakeQuantizePerTensorAffine op.""" def test_forward(self): np.random.seed(NP_RANDOM_SEED) fake_quantize_per_tensor_affine_forward = torch.ops.quantized.fake_quantize_per_tensor_affine_forward scale = 3 zero_point = 2 num_bits = 8 X = np.random.rand(20, 20) * 125 X_torch = torch.from_numpy(X).float() Y = _fake_quantize_per_tensor_affine_reference(X, scale, zero_point, num_bits) Y_prime = fake_quantize_per_tensor_affine_forward( X=X_torch, scale=scale, zero_point=zero_point, num_bits=num_bits, quant_delay=0, iter=0) tolerance = 1e-6 np.testing.assert_allclose(Y, Y_prime, rtol=tolerance, atol=tolerance) """Tests the backward method. Note that this runs the reference quantization and thus the errors might be originating there.""" def test_backward(self): np.random.seed(NP_RANDOM_SEED) fake_quantize_per_tensor_affine_backward = torch.ops.quantized.fake_quantize_per_tensor_affine_backward scale = 3 zero_point = 2 num_bits = 8 X = np.random.rand(20, 20) * 125 Y = _fake_quantize_per_tensor_affine_reference(X, scale, zero_point, num_bits) dY = Y - X # Fake gradient dX = _fake_quantize_per_tensor_affine_grad_reference(X, dY, scale, zero_point, num_bits) X_torch = torch.from_numpy(X).float() dY_torch = torch.from_numpy(dY).float() dX_prime = fake_quantize_per_tensor_affine_backward( X=X_torch, dY=dY_torch, scale=scale, zero_point=zero_point, num_bits=num_bits, quant_delay=0, iter=0) tolerance = 1e-6 np.testing.assert_allclose(dX, dX_prime, rtol=tolerance, atol=tolerance) def test_numerical_consistency(self): np.random.seed(NP_RANDOM_SEED) fake_quantize_per_tensor_affine_forward = torch.ops.quantized.fake_quantize_per_tensor_affine_forward scale = 3 zero_point = 2 num_bits = 8 X = np.random.rand(20, 20) * 125 X_torch = torch.from_numpy(X).float() Y = X_torch.quantize_linear(scale, zero_point).dequantize() Y_prime = fake_quantize_per_tensor_affine_forward( X=X_torch, scale=scale, zero_point=zero_point, num_bits=num_bits, quant_delay=0, iter=0) tolerance = 1e-6 np.testing.assert_allclose(Y, Y_prime, rtol=tolerance, atol=tolerance) if __name__ == '__main__': run_tests()