mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
Benchmark:
NVIDIA GTX 1650 + AMD Ryzen Threadripper 3970X
```python
import torch
print(torch.__version__)
for i in range(1000):
torch.randn(1024 * 128, device='cuda')
def cuda(e):
a = torch.randn(2 ** e, 32, device='cuda')
s = torch.randn(32, device='cuda')
z = torch.randn(32, device='cuda')
torch.cuda.synchronize()
%timeit torch.fake_quantize_per_channel_affine(a, s, z, 1, -999, 999); torch.cuda.synchronize()
def cpu(e):
a = torch.randn(2 ** e, 32, device='cpu')
s = torch.randn(32, device='cpu')
z = torch.randn(32, device='cpu')
%timeit torch.fake_quantize_per_channel_affine(a, s, z, 1, -999, 999);
for i in range(10, 24):
cuda(i)
print()
for i in range(10, 32):
cpu(i)
```
Before
```
1.5.0a0+9bc922d
849 µs ± 44.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
817 µs ± 30.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
814 µs ± 2.93 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.11 ms ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.19 ms ± 4.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.6 ms ± 5.58 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
2.44 ms ± 14.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.14 ms ± 2.55 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.41 ms ± 2.46 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
13.9 ms ± 2.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
26.9 ms ± 254 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
52.6 ms ± 260 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
104 ms ± 176 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
207 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
249 µs ± 158 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
420 µs ± 230 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
766 µs ± 391 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.45 ms ± 574 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
2.84 ms ± 34.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.69 ms ± 83 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.29 ms ± 2.58 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.32 ms ± 13.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
17.4 ms ± 38.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
47.5 ms ± 264 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
187 ms ± 1.19 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
379 ms ± 5.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
652 ms ± 11.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.22 s ± 4.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.34 s ± 8.77 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4.56 s ± 7.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
8.97 s ± 33.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
17.8 s ± 32.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
35.2 s ± 167 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
After
```
1.5.0a0+a7ec8cc
92.5 µs ± 2.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
97.7 µs ± 469 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
109 µs ± 4.73 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
119 µs ± 6.17 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
146 µs ± 1.84 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
211 µs ± 2.45 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
347 µs ± 4.18 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
624 µs ± 14.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.17 ms ± 16.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
2.25 ms ± 48.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.43 ms ± 220 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
8.51 ms ± 44.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
16.9 ms ± 30.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
33.7 ms ± 7.64 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
201 µs ± 234 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
285 µs ± 465 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
287 µs ± 214 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
287 µs ± 221 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
287 µs ± 761 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
347 µs ± 399 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
675 µs ± 213 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.34 ms ± 643 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
4.82 ms ± 34.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
10.7 ms ± 88.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
20.3 ms ± 25.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
39.4 ms ± 242 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
78.8 ms ± 2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
153 ms ± 786 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
285 ms ± 911 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
541 ms ± 1.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.03 s ± 1.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.97 s ± 8.59 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.81 s ± 10.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
Fixes https://github.com/pytorch/pytorch/issues/33647
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33772
Differential Revision: D20112531
Pulled By: ngimel
fbshipit-source-id: f90e3ef1b5be8276851637f3e1251cb8f1af411f
315 lines
15 KiB
Python
315 lines
15 KiB
Python
import torch
|
|
import torch.cuda
|
|
import torch.jit
|
|
import numpy as np
|
|
from hypothesis import given
|
|
from hypothesis import strategies as st
|
|
import torch.testing._internal.hypothesis_utils as hu
|
|
hu.assert_deadline_disabled()
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
from torch.quantization import FakeQuantize
|
|
from torch.quantization import default_observer, default_per_channel_weight_observer
|
|
import io
|
|
import unittest
|
|
|
|
# Reference method for fake quantize
|
|
def _fake_quantize_per_tensor_affine_reference(X, scale, zero_point, quant_min, quant_max):
|
|
res = (torch.clamp(torch.round(X * (1.0 / scale) + zero_point), quant_min, quant_max) - zero_point) * scale
|
|
return res
|
|
|
|
# Reference method for the gradient of the fake quantize operator
|
|
def _fake_quantize_per_tensor_affine_grad_reference(dY, X, scale, zero_point, quant_min, quant_max):
|
|
Xq = torch.round(X * (1.0 / scale) + zero_point)
|
|
mask = (Xq >= quant_min) * (Xq <= quant_max)
|
|
res = torch.zeros_like(dY)
|
|
res[mask] = dY[mask]
|
|
return res
|
|
|
|
# Helper function used to simulate per-channel fake-quant against any axis
|
|
def _permute_to_axis_zero(X, axis):
|
|
new_axis_list = list(range(X.dim()))
|
|
new_axis_list[axis] = 0
|
|
new_axis_list[0] = axis
|
|
y = X.permute(tuple(new_axis_list))
|
|
return y, new_axis_list
|
|
|
|
# Reference method for fake quantize
|
|
def _fake_quantize_per_channel_affine_reference(X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
|
|
X, permute_axis_list = _permute_to_axis_zero(X, axis)
|
|
res = torch.zeros_like(X)
|
|
|
|
for i in range(X.size()[0]):
|
|
res[i] = (torch.clamp(torch.round(X[i] * (1.0 / per_channel_scale[i]) +
|
|
per_channel_zero_point[i]), quant_min, quant_max) - per_channel_zero_point[i]) * per_channel_scale[i]
|
|
|
|
out = res.permute(tuple(permute_axis_list))
|
|
return out
|
|
|
|
# Reference method for the gradient of the fake quantize operator
|
|
def _fake_quantize_per_channel_affine_grad_reference(dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
|
|
X, permute_axis_list = _permute_to_axis_zero(X, axis)
|
|
Xq = torch.zeros_like(X)
|
|
for i in range(X.size()[0]):
|
|
Xq[i] = torch.round(X[i] * (1.0 / per_channel_scale[i]) + per_channel_zero_point[i])
|
|
Xq = Xq.permute(tuple(permute_axis_list))
|
|
mask = (Xq >= quant_min) * (Xq <= quant_max)
|
|
res = torch.zeros_like(dY)
|
|
res[mask] = dY[mask]
|
|
return res
|
|
|
|
def to_tensor(X, device):
|
|
return torch.tensor(X).to(device=torch.device(device), dtype=torch.float32)
|
|
|
|
NP_RANDOM_SEED = 19
|
|
tolerance = 1e-6
|
|
|
|
class TestFakeQuantizePerTensor(TestCase):
|
|
|
|
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
|
X=hu.tensor(shapes=hu.array_shapes(1, 5,),
|
|
qparams=hu.qparams(dtypes=torch.quint8)))
|
|
def test_forward_per_tensor(self, device, X):
|
|
r"""Tests the forward path of the FakeQuantizePerTensorAffine op.
|
|
"""
|
|
np.random.seed(NP_RANDOM_SEED)
|
|
X, (scale, zero_point, torch_type) = X
|
|
quant_min = torch.iinfo(torch_type).min
|
|
quant_max = torch.iinfo(torch_type).max
|
|
|
|
X = to_tensor(X, device)
|
|
Y = _fake_quantize_per_tensor_affine_reference(X.cpu(), scale, zero_point, quant_min, quant_max)
|
|
Y_prime = torch.fake_quantize_per_tensor_affine(
|
|
X, scale, zero_point, quant_min, quant_max)
|
|
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
|
|
|
|
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
|
X=hu.tensor(shapes=hu.array_shapes(1, 5,),
|
|
qparams=hu.qparams(dtypes=torch.quint8)))
|
|
@unittest.skip("temporarily disable the test")
|
|
def test_backward_per_tensor(self, device, X):
|
|
r"""Tests the backward method.
|
|
"""
|
|
np.random.seed(NP_RANDOM_SEED)
|
|
X, (scale, zero_point, torch_type) = X
|
|
quant_min = torch.iinfo(torch_type).min
|
|
quant_max = torch.iinfo(torch_type).max
|
|
|
|
X = to_tensor(X, device)
|
|
X.requires_grad_()
|
|
Y = _fake_quantize_per_tensor_affine_reference(X.cpu(), scale, zero_point, quant_min, quant_max)
|
|
Y_prime = torch.fake_quantize_per_tensor_affine(
|
|
X, scale, zero_point, quant_min, quant_max)
|
|
dout = torch.rand(X.shape, dtype=torch.float).to(device)
|
|
dX = _fake_quantize_per_tensor_affine_grad_reference(
|
|
dout, X, scale, zero_point, quant_min, quant_max)
|
|
Y_prime.backward(dout)
|
|
np.testing.assert_allclose(dX.cpu(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
|
|
|
|
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
|
X=hu.tensor(shapes=hu.array_shapes(1, 5,),
|
|
qparams=hu.qparams(dtypes=torch.quint8)))
|
|
# https://github.com/pytorch/pytorch/issues/30604
|
|
@unittest.skip("temporarily disable the test")
|
|
def test_numerical_consistency_per_tensor(self, device, X):
|
|
r"""Comparing numerical consistency between CPU quantize/dequantize op and the CPU fake quantize op
|
|
"""
|
|
np.random.seed(NP_RANDOM_SEED)
|
|
X, (scale, zero_point, torch_type) = X
|
|
quant_min = torch.iinfo(torch_type).min
|
|
quant_max = torch.iinfo(torch_type).max
|
|
|
|
X = to_tensor(X, device)
|
|
# quantize_per_tensor and dequantize are only implemented in CPU
|
|
Y = torch.dequantize(torch.quantize_per_tensor(X.cpu(), scale, zero_point, torch_type))
|
|
Y_prime = torch.fake_quantize_per_tensor_affine(
|
|
X, scale, zero_point, quant_min, quant_max)
|
|
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
|
|
|
|
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
|
X=hu.tensor(shapes=hu.array_shapes(1, 5,),
|
|
qparams=hu.qparams(dtypes=[torch.quint8])),
|
|
)
|
|
def test_fq_module(self, device, X):
|
|
np.random.seed(NP_RANDOM_SEED)
|
|
X, (scale, zero_point, torch_type) = X
|
|
quant_min = torch.iinfo(torch_type).min
|
|
quant_max = torch.iinfo(torch_type).max
|
|
|
|
X = to_tensor(X, device)
|
|
X.requires_grad_()
|
|
fq_module = torch.quantization.default_fake_quant().to(device)
|
|
Y_prime = fq_module(X)
|
|
assert fq_module.scale is not None
|
|
assert fq_module.zero_point is not None
|
|
Y = _fake_quantize_per_tensor_affine_reference(X, fq_module.scale, fq_module.zero_point, quant_min, quant_max)
|
|
np.testing.assert_allclose(Y.cpu().detach().numpy(), Y_prime.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
|
|
|
|
# Test backward
|
|
dout = torch.rand(X.shape, dtype=torch.float, device=device)
|
|
Y_prime.backward(dout)
|
|
dX = _fake_quantize_per_tensor_affine_grad_reference(dout, X, fq_module.scale, fq_module.zero_point, quant_min, quant_max)
|
|
np.testing.assert_allclose(dX.cpu().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
|
|
|
|
def test_fq_serializable(self):
|
|
observer = default_observer
|
|
quant_min = 0
|
|
quant_max = 255
|
|
fq_module = FakeQuantize(observer, quant_min, quant_max)
|
|
X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32)
|
|
y_ref = fq_module(X)
|
|
state_dict = fq_module.state_dict()
|
|
self.assertEqual(state_dict['scale'], 0.094488)
|
|
self.assertEqual(state_dict['zero_point'], 53)
|
|
b = io.BytesIO()
|
|
torch.save(state_dict, b)
|
|
b.seek(0)
|
|
loaded_dict = torch.load(b)
|
|
loaded_fq_module = FakeQuantize(observer, quant_min, quant_max)
|
|
loaded_fq_module.load_state_dict(loaded_dict)
|
|
for key in state_dict:
|
|
self.assertEqual(state_dict[key], loaded_fq_module.state_dict()[key])
|
|
|
|
self.assertEqual(loaded_fq_module.calculate_qparams(), fq_module.calculate_qparams())
|
|
|
|
def test_fake_quant_control(self):
|
|
torch.manual_seed(42)
|
|
X = torch.rand(20, 10, dtype=torch.float32)
|
|
fq_module = torch.quantization.default_fake_quant()
|
|
# Output of fake quant is not identical to input
|
|
Y = fq_module(X)
|
|
self.assertNotEqual(Y, X)
|
|
torch.quantization.disable_fake_quant(fq_module)
|
|
X = torch.rand(20, 10, dtype=torch.float32)
|
|
Y = fq_module(X)
|
|
# Fake quant is disabled,output is identical to input
|
|
self.assertEqual(Y, X)
|
|
scale = fq_module.scale
|
|
zero_point = fq_module.zero_point
|
|
torch.quantization.disable_observer(fq_module)
|
|
torch.quantization.enable_fake_quant(fq_module)
|
|
X = 10.0 * torch.rand(20, 10, dtype=torch.float32) - 5.0
|
|
Y = fq_module(X)
|
|
self.assertNotEqual(Y, X)
|
|
# Observer is disabled, scale and zero-point do not change
|
|
self.assertEqual(fq_module.scale, scale)
|
|
self.assertEqual(fq_module.zero_point, zero_point)
|
|
torch.quantization.enable_observer(fq_module)
|
|
Y = fq_module(X)
|
|
self.assertNotEqual(Y, X)
|
|
# Observer is enabled, scale and zero-point are different
|
|
self.assertNotEqual(fq_module.scale, scale)
|
|
self.assertNotEqual(fq_module.zero_point, zero_point)
|
|
|
|
|
|
|
|
class TestFakeQuantizePerChannel(TestCase):
|
|
|
|
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
|
X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
|
|
qparams=hu.qparams(dtypes=torch.quint8)))
|
|
def test_forward_per_channel(self, device, X):
|
|
r"""Tests the forward path of the FakeQuantizePerTensorAffine op.
|
|
"""
|
|
np.random.seed(NP_RANDOM_SEED)
|
|
X, (scale, zero_point, axis, torch_type) = X
|
|
quant_min = torch.iinfo(torch_type).min
|
|
quant_max = torch.iinfo(torch_type).max
|
|
|
|
X = to_tensor(X, device)
|
|
scale = to_tensor(scale, device)
|
|
zero_point = torch.tensor(zero_point).to(dtype=torch.int64, device=device)
|
|
Y = _fake_quantize_per_channel_affine_reference(X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min, quant_max)
|
|
Y_prime = torch.fake_quantize_per_channel_affine(
|
|
X, scale, zero_point, axis, quant_min, quant_max)
|
|
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
|
|
|
|
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
|
X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
|
|
qparams=hu.qparams(dtypes=torch.quint8)))
|
|
def test_backward_per_channel(self, device, X):
|
|
r"""Tests the backward method.
|
|
"""
|
|
np.random.seed(NP_RANDOM_SEED)
|
|
X, (scale, zero_point, axis, torch_type) = X
|
|
quant_min = torch.iinfo(torch_type).min
|
|
quant_max = torch.iinfo(torch_type).max
|
|
|
|
X = to_tensor(X, device)
|
|
scale = to_tensor(scale, device)
|
|
zero_point = torch.tensor(zero_point).to(dtype=torch.int64, device=device)
|
|
X.requires_grad_()
|
|
Y_prime = torch.fake_quantize_per_channel_affine(
|
|
X, scale, zero_point, axis, quant_min, quant_max)
|
|
dout = torch.rand(X.shape, dtype=torch.float).to(device)
|
|
dX = _fake_quantize_per_channel_affine_grad_reference(
|
|
dout, X, scale, zero_point, axis, quant_min, quant_max)
|
|
Y_prime.backward(dout)
|
|
np.testing.assert_allclose(dX.cpu().detach().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
|
|
|
|
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
|
X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
|
|
qparams=hu.qparams(dtypes=torch.quint8)))
|
|
@unittest.skip("temporarily disable the test")
|
|
def test_numerical_consistency_per_channel(self, device, X):
|
|
r"""Comparing numerical consistency between CPU quantize/dequantize op and the CPU fake quantize op
|
|
"""
|
|
np.random.seed(NP_RANDOM_SEED)
|
|
X, (scale, zero_point, axis, torch_type) = X
|
|
quant_min = torch.iinfo(torch_type).min
|
|
quant_max = torch.iinfo(torch_type).max
|
|
|
|
X = to_tensor(X, device)
|
|
scale = to_tensor(scale, device)
|
|
zero_point = torch.tensor(zero_point).to(dtype=torch.int64, device=device)
|
|
# quantize_linear and dequantize are only implemented in CPU
|
|
Y = torch.dequantize(torch.quantize_per_channel(X.cpu(), scale.cpu(), zero_point.cpu(), axis, torch_type))
|
|
Y_prime = torch.fake_quantize_per_channel_affine(
|
|
X, scale, zero_point, axis, quant_min, quant_max)
|
|
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
|
|
|
|
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
|
X=hu.per_channel_tensor(shapes=hu.array_shapes(2, 5,),
|
|
qparams=hu.qparams(dtypes=torch.qint8)))
|
|
def test_fq_module(self, device, X):
|
|
np.random.seed(NP_RANDOM_SEED)
|
|
X, (scale, zero_point, axis, torch_type) = X
|
|
quant_min = torch.iinfo(torch_type).min
|
|
quant_max = torch.iinfo(torch_type).max
|
|
|
|
X = to_tensor(X, device)
|
|
X.requires_grad_()
|
|
fq_module = FakeQuantize(default_per_channel_weight_observer, quant_min, quant_max, ch_axis=axis).to(device)
|
|
Y_prime = fq_module(X)
|
|
assert fq_module.scale is not None
|
|
assert fq_module.zero_point is not None
|
|
Y = _fake_quantize_per_channel_affine_reference(X, fq_module.scale,
|
|
fq_module.zero_point, axis, quant_min, quant_max)
|
|
np.testing.assert_allclose(Y.cpu().detach().numpy(), Y_prime.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
|
|
|
|
# Test backward
|
|
dout = torch.rand(X.shape, dtype=torch.float, device=device)
|
|
Y_prime.backward(dout)
|
|
dX = _fake_quantize_per_channel_affine_grad_reference(dout, X, fq_module.scale,
|
|
fq_module.zero_point, axis, quant_min, quant_max)
|
|
np.testing.assert_allclose(dX.cpu().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
|
|
|
|
def test_fq_serializable(self):
|
|
observer = default_per_channel_weight_observer
|
|
quant_min = -128
|
|
quant_max = 127
|
|
fq_module = FakeQuantize(observer, quant_min, quant_max)
|
|
X = torch.tensor([[-5, -3.5, -2, 0, 3, 5, 7], [1, 3, 2, 5, 6.5, 8, 10]], dtype=torch.float32)
|
|
y_ref = fq_module(X)
|
|
state_dict = fq_module.state_dict()
|
|
self.assertEqual(state_dict['scale'], [0.054902, 0.078431])
|
|
self.assertEqual(state_dict['zero_point'], [0, 0])
|
|
b = io.BytesIO()
|
|
torch.save(state_dict, b)
|
|
b.seek(0)
|
|
loaded_dict = torch.load(b)
|
|
for key in state_dict:
|
|
self.assertEqual(state_dict[key], loaded_dict[key])
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|