mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
## Description Implement unified quantization backend 'X86' for x86 platforms. It combines the advantages of FBGEMM and ONEDNN. It selects kernels during weight prepacking and hide the details from end users. It will be the default backend in place of FBGEMM. For details, please refer to this RFC: [[RFC] Unified quantization backend for x86 CPU platforms](https://github.com/pytorch/pytorch/issues/83888) ## Validation **Correctness** Covered by UT **Accuracy** By running torchvision models on imagenet, no accuracy difference is found between FBGEMM and the unified X86 backend: [torchvision_accuracy_comparison_fbgemm_vs_x86.xlsx](https://github.com/pytorch/pytorch/files/9598114/torchvision_accuracy_comparison_fbgemm_vs_x86.xlsx) **Performance** Depends on https://github.com/pytorch/pytorch/pull/84470 which improves performance. For early PoC results, please refer to https://github.com/pytorch/pytorch/files/9399202/unified_qengine_poc_performance_bechmark.xlsx With the two PRs combined, we collected some data on Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz Method: Run multi-instances with 4 cores per instance on whole socket. Using JeMalloc and Intel OMP. Models/throughput | fbgemm | x86 | improvement -- | -- | -- | -- wide_resnet101_2 | 173.5675 | 241.815 | 39.32% resnext101_32x8d | 174.365 | 339.8175 | 94.89% resnet50 | 573.155 | 1174.14 | 104.86% vgg19_bn | 260.335 | 337.92 | 29.80% vgg19 | 257.935 | 333.265 | 29.21% inception_v3 | 601.1175 | 1309.33 | 117.82% densenet161 | 296.645 | 435.5625 | 46.83% mnasnet1_0 | 1216.7 | 4057.515 | 233.49% squeezenet1_0 | 1220.085 | 5153.3875 | 322.38% alexnet | 2294.91 | 2624.6375 | 14.37% fbnetc_100 | 976.2825 | 3110.1825 | 218.57% shufflenet_v2_x0_5 | 1555.76 | 3026.125 | 94.51% spnasnet_100 | 1059.065 | 3502.0975 | 230.68% pytorch-unet | 192.76 | 246.77 | 28.02% acgan | 257.32 | 333.7325 | 29.70% cgan | 7790.6925 | 7803.1025 | 0.16% sgan | 257.565 | 338.8875 | 31.57% se_resnet50 | 492.3725 | 916.5175 | 86.14% vggm | 300.2875 | 316.2075 | 5.30% Environment: - PyTorch version: 1.13.0a0+gitcdd625b - Is debug build: False - CUDA used to build PyTorch: None - ROCM used to build PyTorch: N/A - OS: Ubuntu 20.04.3 LTS (x86_64) - GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0 - Clang version: Could not collect - CMake version: version 3.22.5 - Libc version: glibc-2.31 - Python version: 3.9.12 (main, Jun 1 2022, 11:38:51) [GCC 7.5.0] (64-bit runtime) - Python platform: Linux-5.11.0-27-generic-x86_64-with-glibc2.31 - Is CUDA available: False - CUDA runtime version: No CUDA - GPU models and configuration: No CUDA - Nvidia driver version: No CUDA - cuDNN version: No CUDA - HIP runtime version: N/A - MIOpen runtime version: N/A - Is XNNPACK available: True Versions of relevant libraries: - [pip3] intel-extension-for-pytorch==1.13.0+cpu - [pip3] numpy==1.23.3 - [pip3] pytorch-widedeep==0.3.7 - [pip3] torch==1.13.0a0+git48b423b - [pip3] torchvision==0.14.0a0+ebb68f3 - [conda] blas 1.0 mkl - [conda] intel-extension-for-pytorch 1.13.0+cpu pypi_0 pypi - [conda] mkl 2021.4.0 h06a4308_640 - [conda] mkl-include 2022.1.0 pypi_0 pypi - [conda] mkl-service 2.4.0 py39h7f8727e_0 - [conda] mkl-static 2022.1.0 pypi_0 pypi - [conda] mkl_fft 1.3.1 py39hd3c417c_0 - [conda] mkl_random 1.2.2 py39h51133e4_0 - [conda] numpy 1.23.3 pypi_0 pypi - [conda] numpy-base 1.22.3 py39hf524024_0 - [conda] torch 1.13.0a0+git48b423b pypi_0 pypi - [conda] torchvision 0.14.0a0+ebb68f3 pypi_0 pypi Pull Request resolved: https://github.com/pytorch/pytorch/pull/84329 Approved by: https://github.com/jerryzh168
226 lines
8.5 KiB
Python
226 lines
8.5 KiB
Python
r"""Importing this file includes common utility methods for checking quantized
|
|
tensors and modules.
|
|
"""
|
|
import numpy as np
|
|
import torch
|
|
from contextlib import contextmanager
|
|
from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_TSAN, TEST_WITH_UBSAN, IS_PPC, IS_MACOS, IS_WINDOWS
|
|
|
|
supported_qengines = torch.backends.quantized.supported_engines
|
|
supported_qengines.remove('none')
|
|
# Note: We currently do not run QNNPACK tests on WINDOWS and MACOS as it is flaky. Issue #29326
|
|
# QNNPACK is not supported on PPC
|
|
# QNNPACK throws ASAN heap-buffer-overflow error.
|
|
if 'qnnpack' in supported_qengines and any([IS_PPC, TEST_WITH_ASAN, TEST_WITH_TSAN, TEST_WITH_UBSAN, IS_MACOS, IS_WINDOWS]):
|
|
supported_qengines.remove('qnnpack')
|
|
|
|
def _conv_output_shape(input_size, kernel_size, padding, stride, dilation,
|
|
output_padding=0):
|
|
"""Computes the output shape given convolution parameters."""
|
|
return np.floor((input_size + 2 * padding - kernel_size - (kernel_size - 1)
|
|
* (dilation - 1)) / stride) + 2 * output_padding + 1
|
|
|
|
# Quantization references
|
|
def _quantize(x, scale, zero_point, qmin=None, qmax=None, dtype=np.uint8):
|
|
"""Quantizes a numpy array."""
|
|
if qmin is None:
|
|
qmin = np.iinfo(dtype).min
|
|
if qmax is None:
|
|
qmax = np.iinfo(dtype).max
|
|
qx = np.round(x / scale + zero_point).astype(np.int64)
|
|
qx = np.clip(qx, qmin, qmax)
|
|
qx = qx.astype(dtype)
|
|
return qx
|
|
|
|
|
|
def _dequantize(qx, scale, zero_point):
|
|
"""Dequantizes a numpy array."""
|
|
x = (qx.astype(float) - zero_point) * scale
|
|
return x
|
|
|
|
|
|
def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8):
|
|
"""Requantizes a numpy array, i.e., intermediate int32 or int16 values are
|
|
converted back to given type"""
|
|
qx = (x * multiplier).round() + zero_point
|
|
qx = np.clip(qx, qmin, qmax).astype(qtype)
|
|
return qx
|
|
|
|
def _calculate_dynamic_qparams(X, dtype, reduce_range=False, qscheme=torch.per_tensor_affine):
|
|
"""Calculate the dynamic quantization parameters (scale, zero_point)
|
|
according to the min and max element of the tensor"""
|
|
assert qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric)
|
|
if qscheme == torch.per_tensor_symmetric:
|
|
assert dtype == torch.qint8
|
|
if isinstance(X, torch.Tensor):
|
|
X = X.numpy()
|
|
if dtype == torch.qint8:
|
|
if reduce_range:
|
|
qmin, qmax = -64, 63
|
|
else:
|
|
qmin, qmax = -128, 127
|
|
else: # dtype == torch.quint8
|
|
if reduce_range:
|
|
qmin, qmax = 0, 127
|
|
else:
|
|
qmin, qmax = 0, 255
|
|
min_val = X.min()
|
|
max_val = X.max()
|
|
is_symmetric = (qscheme == torch.per_tensor_symmetric)
|
|
if min_val == max_val:
|
|
scale = 1.0
|
|
zero_point = 0
|
|
else:
|
|
if is_symmetric:
|
|
max_val = max(max_val, -min_val)
|
|
min_val = -max_val
|
|
scale = (max_val - min_val) / (qmax - qmin)
|
|
scale = max(scale, np.finfo(np.float32).eps)
|
|
zero_point = 0
|
|
else:
|
|
max_val = max(max_val, 0.0)
|
|
min_val = min(min_val, 0.0)
|
|
scale = (max_val - min_val) / (qmax - qmin)
|
|
scale = max(scale, np.finfo(np.float32).eps)
|
|
zero_point = qmin - round(min_val / scale)
|
|
zero_point = max(qmin, zero_point)
|
|
zero_point = min(qmax, zero_point)
|
|
return [float(scale), int(zero_point)]
|
|
|
|
def _calculate_dynamic_per_channel_qparams(X, dtype):
|
|
"""Calculate the dynamic quantization parameters (scale, zero_point)
|
|
according to the min and max element of the tensor"""
|
|
if isinstance(X, torch.Tensor):
|
|
X = X.numpy()
|
|
qmin, qmax = torch.iinfo(dtype).min, torch.iinfo(dtype).max
|
|
n_levels = qmax - qmin
|
|
scale = np.zeros(X.shape[0], dtype=np.float64)
|
|
zero_point = np.zeros(X.shape[0], dtype=np.int64)
|
|
for i in range(zero_point.shape[0]):
|
|
min_val = X.min()
|
|
max_val = X.max()
|
|
if min_val == max_val:
|
|
scale[i] = 1.0
|
|
zero_point[i] = 0
|
|
else:
|
|
max_val = max(max_val, 0.0)
|
|
min_val = min(min_val, 0.0)
|
|
scale[i] = (max_val - min_val) / n_levels
|
|
scale[i] = max(scale[i], np.finfo(np.float32).eps)
|
|
zero_point[i] = qmin - round(min_val / scale[i])
|
|
zero_point[i] = max(qmin, zero_point[i])
|
|
zero_point[i] = min(qmax, zero_point[i])
|
|
|
|
return scale, zero_point
|
|
|
|
def _snr(x, x_hat):
|
|
"""Calculates the signal to noise ratio and returns the signal and noise
|
|
power, as well as the SNR in dB.
|
|
If the input is a list/tuple this function is called recursively on each
|
|
element. The result will have the same nested structure as the inputs.
|
|
|
|
Args:
|
|
x, x_hat: Either a tensor or a nested list/tuple of tensors.
|
|
Returns:
|
|
signal, noise, SNR(in dB): Either floats or a nested list of floats
|
|
"""
|
|
if isinstance(x, (list, tuple)):
|
|
assert(len(x) == len(x_hat))
|
|
res = []
|
|
for idx in range(len(x)):
|
|
res.append(_snr(x[idx], x_hat[idx]))
|
|
return res
|
|
if x_hat.is_quantized:
|
|
x_hat = x_hat.dequantize()
|
|
if x.is_quantized:
|
|
x = x.dequantize()
|
|
noise = (x - x_hat).norm()
|
|
if noise == 0:
|
|
return 0.0, float('inf'), float('inf')
|
|
signal = x.norm()
|
|
snr = signal / noise
|
|
snr_db = 20 * snr.log10()
|
|
return signal, noise, snr_db
|
|
|
|
@contextmanager
|
|
def override_quantized_engine(qengine):
|
|
previous = torch.backends.quantized.engine
|
|
torch.backends.quantized.engine = qengine
|
|
try:
|
|
yield
|
|
finally:
|
|
torch.backends.quantized.engine = previous
|
|
|
|
@contextmanager
|
|
def override_cpu_allocator_for_qnnpack(qengine_is_qnnpack):
|
|
try:
|
|
if qengine_is_qnnpack:
|
|
torch._C._set_default_mobile_cpu_allocator()
|
|
yield
|
|
finally:
|
|
if qengine_is_qnnpack:
|
|
torch._C._unset_default_mobile_cpu_allocator()
|
|
|
|
# TODO: Update all quantization tests to use this decorator.
|
|
# Currently for some of the tests it seems to have inconsistent params
|
|
# for fbgemm vs qnnpack.
|
|
def override_qengines(qfunction):
|
|
def test_fn(*args, **kwargs):
|
|
for qengine in supported_qengines:
|
|
with override_quantized_engine(qengine):
|
|
# qfunction should not return anything.
|
|
qfunction(*args, **kwargs)
|
|
return test_fn
|
|
|
|
def qengine_is_fbgemm():
|
|
return torch.backends.quantized.engine == 'fbgemm'
|
|
def qengine_is_qnnpack():
|
|
return torch.backends.quantized.engine == 'qnnpack'
|
|
def qengine_is_onednn():
|
|
return torch.backends.quantized.engine == 'onednn'
|
|
def qengine_is_x86():
|
|
return torch.backends.quantized.engine == 'x86'
|
|
|
|
# Helper function used to simulate per-channel fake-quant against any axis
|
|
def _permute_to_axis_zero(X, axis):
|
|
new_axis_list = list(range(X.dim()))
|
|
new_axis_list[axis] = 0
|
|
new_axis_list[0] = axis
|
|
y = X.permute(tuple(new_axis_list))
|
|
return y, new_axis_list
|
|
|
|
# Reference method for fake quantize
|
|
# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
|
|
def _fake_quantize_per_channel_affine_reference(X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
|
|
dtype = X.dtype
|
|
X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
|
|
res = torch.zeros_like(X)
|
|
|
|
for i in range(X.size()[0]):
|
|
res[i] = (torch.clamp(torch.round(X[i] * (1.0 / per_channel_scale[i]) +
|
|
per_channel_zero_point[i]), quant_min, quant_max) - per_channel_zero_point[i]) * per_channel_scale[i]
|
|
|
|
out = res.permute(tuple(permute_axis_list))
|
|
return out.to(dtype)
|
|
|
|
# Reference method for the gradient of the fake quantize operator
|
|
# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
|
|
def _fake_quantize_per_channel_affine_grad_reference(dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
|
|
dtype = X.dtype
|
|
X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
|
|
Xq = torch.zeros_like(X)
|
|
for i in range(X.size()[0]):
|
|
Xq[i] = torch.round(X[i] * (1.0 / per_channel_scale[i]) + per_channel_zero_point[i])
|
|
Xq = Xq.permute(tuple(permute_axis_list))
|
|
mask = (Xq >= quant_min) * (Xq <= quant_max)
|
|
res = torch.zeros_like(dY)
|
|
res[mask] = dY[mask]
|
|
return res.to(dtype)
|
|
|
|
def to_tensor(X, device):
|
|
if not isinstance(X, torch.Tensor):
|
|
X = torch.tensor(X)
|
|
else:
|
|
X = X.clone().detach()
|
|
return X.to(device=torch.device(device), dtype=torch.float32)
|