mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19319 Quantized SUM + ReLU (Fused). The implementation is the same as the one in the DNNLOWP. Reviewed By: jianyuh Differential Revision: D14866442 fbshipit-source-id: c8c737a37e35b6ce3c1c2077c07546aba16e0612
126 lines
4.1 KiB
Python
126 lines
4.1 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import torch
|
|
import torch.jit
|
|
import numpy as np
|
|
import unittest
|
|
from common_utils import TestCase, run_tests, skipIfNotRegistered
|
|
|
|
|
|
def canonical(graph):
|
|
return str(torch._C._jit_pass_canonicalize(graph))
|
|
|
|
|
|
def _quantize(x, scale, zero_point, qmin=0, qmax=255):
|
|
"""Quantizes a numpy array."""
|
|
qx = np.round(x / scale + zero_point)
|
|
qx = np.clip(qx, qmin, qmax).astype(np.uint8)
|
|
return qx
|
|
|
|
|
|
def _dequantize(qx, scale, zero_point):
|
|
"""Dequantizes a numpy array."""
|
|
x = (qx.astype(np.float) - zero_point) * scale
|
|
return x
|
|
|
|
|
|
@skipIfNotRegistered("Relu_ENGINE_DNNLOWP",
|
|
"fbgemm-based Caffe2 ops are not linked")
|
|
class TestQuantized(TestCase):
|
|
def test_relu(self):
|
|
a = (torch.tensor([4, 6, 1, 10], dtype=torch.uint8), 0.01, 5)
|
|
r = torch.ops.c10.quantized_relu(a)
|
|
np.testing.assert_equal(r[0].numpy(), torch.tensor([5, 6, 5, 10], dtype=torch.uint8).numpy())
|
|
np.testing.assert_almost_equal(0.01, r[1])
|
|
self.assertEqual(5, r[2])
|
|
|
|
def test_quantize(self):
|
|
a = (torch.tensor([4, 6, 1, 10], dtype=torch.uint8), 0.01, 5)
|
|
r = torch.ops.c10.dequantize(a)
|
|
np.testing.assert_almost_equal(r.numpy(), [-0.01, 0.01, -0.04, 0.05])
|
|
# default args
|
|
q_def = torch.ops.c10.quantize(r)
|
|
# specified
|
|
q = torch.ops.c10.quantize(r, scale=0.01, zero_point=5)
|
|
np.testing.assert_equal(q[0].numpy(), a[0].numpy())
|
|
np.testing.assert_almost_equal(q[1], a[1])
|
|
self.assertEqual(q[2], a[2])
|
|
|
|
def test_script(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
# type: (Tuple[Tensor, float, int]) -> Tuple[Tensor, float, int]
|
|
return torch.ops.c10.quantized_relu(x)
|
|
self.assertExpectedInline(canonical(foo.graph), '''\
|
|
graph(%x : (Tensor, float, int)):
|
|
%1 : (Tensor, float, int) = c10::quantized_relu(%x)
|
|
return (%1)
|
|
''')
|
|
|
|
|
|
class TestQuantizedOps(unittest.TestCase):
|
|
"""Tests the correctness of the quantized::relu op."""
|
|
def test_qrelu(self):
|
|
relu = torch.ops.quantized.relu
|
|
|
|
X = torch.arange(-5, 5, dtype=torch.float)
|
|
scale = 2.0
|
|
zero_point = 1
|
|
qX = X.quantize_linear(scale=scale, zero_point=zero_point)
|
|
print("X:\n{}".format(X))
|
|
print("\nQuantized:\n{}\nFake:\n{}".format(qX.int_repr(), _quantize(X.numpy(), scale, zero_point)))
|
|
|
|
Y = X.numpy().copy()
|
|
Y[Y < 0] = 0
|
|
qY = _quantize(Y, scale, zero_point)
|
|
qY_hat = relu(qX)
|
|
np.testing.assert_equal(qY, qY_hat.int_repr())
|
|
|
|
"""Tests the correctness of the quantized::sum_relu op."""
|
|
def test_qsumrelu_same_qparams(self):
|
|
sum_relu = torch.ops.quantized.sum_relu
|
|
|
|
A = torch.arange(-25, 25, dtype=torch.float)
|
|
B = torch.arange(-25, 25, dtype=torch.float)
|
|
scale = 2.0
|
|
zero_point = 127
|
|
qA = A.quantize_linear(scale=scale, zero_point=zero_point)
|
|
qB = A.quantize_linear(scale=scale, zero_point=zero_point)
|
|
|
|
# Sum + ReLU ground truth
|
|
C = (qA.dequantize() + qB.dequantize()).numpy()
|
|
C[C < 0] = 0
|
|
qC = _quantize(C, scale, zero_point)
|
|
|
|
qC_hat = sum_relu(qA, qB, scale=scale, zero_point=zero_point)
|
|
np.testing.assert_equal(qC, qC_hat.int_repr())
|
|
|
|
"""Tests the correctness of the quantized::sum_relu op."""
|
|
def test_qsumrelu_different_qparams(self):
|
|
sum_relu = torch.ops.quantized.sum_relu
|
|
|
|
A = torch.arange(-25, 25, dtype=torch.float)
|
|
B = torch.arange(-25, 25, dtype=torch.float)
|
|
scale_A = 3.0
|
|
zero_point_A = 7
|
|
scale_B = 5.0
|
|
zero_point_B = 127
|
|
|
|
scale_C = 0.5
|
|
zero_point_C = 5
|
|
|
|
qA = A.quantize_linear(scale=scale_A, zero_point=zero_point_A)
|
|
qB = A.quantize_linear(scale=scale_B, zero_point=zero_point_B)
|
|
|
|
# Sum + ReLU ground truth
|
|
C = (qA.dequantize() + qB.dequantize()).numpy()
|
|
C[C < 0] = 0
|
|
qC = _quantize(C, scale_C, zero_point_C)
|
|
|
|
qC_hat = sum_relu(qA, qB, scale=scale_C, zero_point=zero_point_C)
|
|
np.testing.assert_equal(qC, qC_hat.int_repr())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|