mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Quant] Add unified x86 quant backend (#84329)
## Description Implement unified quantization backend 'X86' for x86 platforms. It combines the advantages of FBGEMM and ONEDNN. It selects kernels during weight prepacking and hide the details from end users. It will be the default backend in place of FBGEMM. For details, please refer to this RFC: [[RFC] Unified quantization backend for x86 CPU platforms](https://github.com/pytorch/pytorch/issues/83888) ## Validation **Correctness** Covered by UT **Accuracy** By running torchvision models on imagenet, no accuracy difference is found between FBGEMM and the unified X86 backend: [torchvision_accuracy_comparison_fbgemm_vs_x86.xlsx](https://github.com/pytorch/pytorch/files/9598114/torchvision_accuracy_comparison_fbgemm_vs_x86.xlsx) **Performance** Depends on https://github.com/pytorch/pytorch/pull/84470 which improves performance. For early PoC results, please refer to https://github.com/pytorch/pytorch/files/9399202/unified_qengine_poc_performance_bechmark.xlsx With the two PRs combined, we collected some data on Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz Method: Run multi-instances with 4 cores per instance on whole socket. Using JeMalloc and Intel OMP. Models/throughput | fbgemm | x86 | improvement -- | -- | -- | -- wide_resnet101_2 | 173.5675 | 241.815 | 39.32% resnext101_32x8d | 174.365 | 339.8175 | 94.89% resnet50 | 573.155 | 1174.14 | 104.86% vgg19_bn | 260.335 | 337.92 | 29.80% vgg19 | 257.935 | 333.265 | 29.21% inception_v3 | 601.1175 | 1309.33 | 117.82% densenet161 | 296.645 | 435.5625 | 46.83% mnasnet1_0 | 1216.7 | 4057.515 | 233.49% squeezenet1_0 | 1220.085 | 5153.3875 | 322.38% alexnet | 2294.91 | 2624.6375 | 14.37% fbnetc_100 | 976.2825 | 3110.1825 | 218.57% shufflenet_v2_x0_5 | 1555.76 | 3026.125 | 94.51% spnasnet_100 | 1059.065 | 3502.0975 | 230.68% pytorch-unet | 192.76 | 246.77 | 28.02% acgan | 257.32 | 333.7325 | 29.70% cgan | 7790.6925 | 7803.1025 | 0.16% sgan | 257.565 | 338.8875 | 31.57% se_resnet50 | 492.3725 | 916.5175 | 86.14% vggm | 300.2875 | 316.2075 | 5.30% Environment: - PyTorch version: 1.13.0a0+gitcdd625b - Is debug build: False - CUDA used to build PyTorch: None - ROCM used to build PyTorch: N/A - OS: Ubuntu 20.04.3 LTS (x86_64) - GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0 - Clang version: Could not collect - CMake version: version 3.22.5 - Libc version: glibc-2.31 - Python version: 3.9.12 (main, Jun 1 2022, 11:38:51) [GCC 7.5.0] (64-bit runtime) - Python platform: Linux-5.11.0-27-generic-x86_64-with-glibc2.31 - Is CUDA available: False - CUDA runtime version: No CUDA - GPU models and configuration: No CUDA - Nvidia driver version: No CUDA - cuDNN version: No CUDA - HIP runtime version: N/A - MIOpen runtime version: N/A - Is XNNPACK available: True Versions of relevant libraries: - [pip3] intel-extension-for-pytorch==1.13.0+cpu - [pip3] numpy==1.23.3 - [pip3] pytorch-widedeep==0.3.7 - [pip3] torch==1.13.0a0+git48b423b - [pip3] torchvision==0.14.0a0+ebb68f3 - [conda] blas 1.0 mkl - [conda] intel-extension-for-pytorch 1.13.0+cpu pypi_0 pypi - [conda] mkl 2021.4.0 h06a4308_640 - [conda] mkl-include 2022.1.0 pypi_0 pypi - [conda] mkl-service 2.4.0 py39h7f8727e_0 - [conda] mkl-static 2022.1.0 pypi_0 pypi - [conda] mkl_fft 1.3.1 py39hd3c417c_0 - [conda] mkl_random 1.2.2 py39h51133e4_0 - [conda] numpy 1.23.3 pypi_0 pypi - [conda] numpy-base 1.22.3 py39hf524024_0 - [conda] torch 1.13.0a0+git48b423b pypi_0 pypi - [conda] torchvision 0.14.0a0+ebb68f3 pypi_0 pypi Pull Request resolved: https://github.com/pytorch/pytorch/pull/84329 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
d542aab5c1
commit
3a3e2002d8
|
|
@ -309,8 +309,11 @@ const std::vector<at::QEngine>& Context::supportedQEngines() {
|
|||
#ifdef USE_FBGEMM
|
||||
if (fbgemm::fbgemmSupportedCPU()) {
|
||||
engines.push_back(at::kFBGEMM);
|
||||
// The X86 qengine is available if and only if FBGEMM is available
|
||||
engines.push_back(at::kX86);
|
||||
}
|
||||
#endif
|
||||
|
||||
return engines;
|
||||
}();
|
||||
return supported_qengines;
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include <ATen/Tensor.h>
|
||||
#include <ATen/native/quantized/PackedParams.h>
|
||||
#include <ideep.hpp>
|
||||
#include <cpuinfo.h>
|
||||
|
||||
using PrimitiveCacheKey = std::tuple<
|
||||
double, // input_scale
|
||||
|
|
@ -349,6 +350,50 @@ static void try_reorder(
|
|||
t.set_scale(scales);
|
||||
}
|
||||
}
|
||||
|
||||
// ONEDNN requires symmetric quantization of weight
|
||||
// Use this util function to check.
|
||||
static bool is_weight_symmetric_quant(
|
||||
const at::Tensor& weight,
|
||||
bool is_transposed_conv) {
|
||||
bool is_symmetric = true;
|
||||
const auto qtype = weight.qscheme();
|
||||
if (qtype == c10::kPerTensorAffine) {
|
||||
is_symmetric &= (weight.q_zero_point() == 0);
|
||||
} else if (qtype == c10::kPerChannelAffine) {
|
||||
if (is_transposed_conv) {
|
||||
// This case is currently not supported in PyTorch
|
||||
// but we do not want to raise an error in this util function.
|
||||
is_symmetric = false;
|
||||
} else {
|
||||
auto output_channels = weight.size(0);
|
||||
for (int i = 0; i < output_channels; ++i) {
|
||||
auto zp = weight.q_per_channel_zero_points()[i].item<int32_t>();
|
||||
is_symmetric &= (zp == 0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// This case is currently not supported in PyTorch
|
||||
// but we do not want to raise an error in this util function.
|
||||
is_symmetric = false;
|
||||
}
|
||||
return is_symmetric;
|
||||
}
|
||||
|
||||
// Check if onednn should be used w.r.t fbgemm
|
||||
static bool should_use_onednn_quant(
|
||||
const at::Tensor& weight,
|
||||
bool is_transposed_conv,
|
||||
int groups,
|
||||
torch::List<int64_t> output_padding) {
|
||||
bool vnni_available = cpuinfo_has_x86_avx512vnni();
|
||||
bool w_sym_quant =
|
||||
is_weight_symmetric_quant(weight, is_transposed_conv);
|
||||
bool opad_all_zero =
|
||||
std::all_of(output_padding.begin(), output_padding.end(), [](int i) { return i==0; });
|
||||
return vnni_available && (groups <= 100) && w_sym_quant && opad_all_zero;
|
||||
}
|
||||
|
||||
} // onednn_utils
|
||||
|
||||
#endif // #if AT_MKLDNN_ENABLED()
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <ATen/native/quantized/cpu/QnnpackUtils.h>
|
||||
#include <ATen/native/quantized/cpu/OnednnUtils.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <cpuinfo.h>
|
||||
|
||||
#include <tuple>
|
||||
|
||||
|
|
@ -330,6 +331,37 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> deserialize_conv(
|
|||
|
||||
auto& ctx = at::globalContext();
|
||||
|
||||
#ifdef USE_FBGEMM
|
||||
if (ctx.qEngine() == at::QEngine::X86) {
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
bool use_onednn = onednn_utils::should_use_onednn_quant(
|
||||
weight.value(), transpose, groups, output_padding);
|
||||
if (use_onednn) {
|
||||
return PackedConvWeightsOnednn<kSpatialDim>::prepack(
|
||||
weight.value(),
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
output_padding,
|
||||
dilation,
|
||||
groups,
|
||||
transpose
|
||||
);
|
||||
}
|
||||
#endif
|
||||
return PackedConvWeight<kSpatialDim>::prepack(
|
||||
weight.value(),
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
output_padding,
|
||||
dilation,
|
||||
groups,
|
||||
transpose
|
||||
);
|
||||
} // x86
|
||||
#endif
|
||||
|
||||
#ifdef USE_FBGEMM
|
||||
if (ctx.qEngine() == at::QEngine::FBGEMM) {
|
||||
return PackedConvWeight<kSpatialDim>::prepack(
|
||||
|
|
|
|||
|
|
@ -445,7 +445,8 @@ int register_linear_params() {
|
|||
bias = std::move(std::get<1>(state));
|
||||
|
||||
#ifdef USE_FBGEMM
|
||||
if (at::globalContext().qEngine() == at::QEngine::FBGEMM) {
|
||||
if (at::globalContext().qEngine() == at::QEngine::FBGEMM ||
|
||||
at::globalContext().qEngine() == at::QEngine::X86) {
|
||||
if (weight.scalar_type() == at::kQInt8) {
|
||||
return PackedLinearWeight::prepack(
|
||||
std::move(weight), std::move(bias));
|
||||
|
|
|
|||
|
|
@ -521,6 +521,21 @@ class QConvPackWeightInt8 final {
|
|||
int64_t groups,
|
||||
bool transpose) {
|
||||
auto& ctx = at::globalContext();
|
||||
#ifdef USE_FBGEMM
|
||||
if (ctx.qEngine() == at::QEngine::X86) {
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
bool use_onednn = onednn_utils::should_use_onednn_quant(
|
||||
weight, transpose, groups, output_padding);
|
||||
if (use_onednn) {
|
||||
return PackedConvWeightsOnednn<kSpatialDim>::prepack(
|
||||
weight, bias, stride, padding, output_padding, dilation, groups, transpose);
|
||||
}
|
||||
#endif
|
||||
return PackedConvWeight<kSpatialDim>::prepack(
|
||||
weight, bias, stride, padding, output_padding, dilation, groups, transpose);
|
||||
} // x86
|
||||
#endif // defined(USE_FBGEMM) || AT_MKLDNN_ENABLED()
|
||||
|
||||
#ifdef USE_FBGEMM
|
||||
if (ctx.qEngine() == at::QEngine::FBGEMM) {
|
||||
return PackedConvWeight<kSpatialDim>::prepack(
|
||||
|
|
@ -598,6 +613,25 @@ class QConv1dPackWeightInt8 final {
|
|||
padding = quant_utils::MakeArgForConv1d(padding, 0);
|
||||
output_padding = quant_utils::MakeArgForConv1d(output_padding, 0);
|
||||
dilation = quant_utils::MakeArgForConv1d(dilation, 1);
|
||||
|
||||
#ifdef USE_FBGEMM
|
||||
if (ctx.qEngine() == at::QEngine::X86) {
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
bool use_onednn = onednn_utils::should_use_onednn_quant(
|
||||
weight, transpose, groups, output_padding);
|
||||
if (use_onednn) {
|
||||
return PackedConvWeightsOnednn<2>::prepack(
|
||||
weight, bias, stride, padding, output_padding, dilation, groups,
|
||||
transpose);
|
||||
}
|
||||
#endif
|
||||
return PackedConvWeight<2>::prepack(
|
||||
weight, bias, stride, padding, output_padding, dilation, groups,
|
||||
transpose);
|
||||
|
||||
} // x86
|
||||
#endif
|
||||
|
||||
#ifdef USE_FBGEMM
|
||||
if (ctx.qEngine() == at::QEngine::FBGEMM) {
|
||||
return PackedConvWeight<2>::prepack(
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ template <int kSpatialDim>
|
|||
std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightsOnednn<
|
||||
kSpatialDim>::unpack() {
|
||||
return std::tuple<at::Tensor, c10::optional<at::Tensor>>(
|
||||
orig_weight_, orig_bias_);
|
||||
orig_weight_.clone(), orig_bias_);
|
||||
}
|
||||
|
||||
template std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightsOnednn<
|
||||
|
|
|
|||
|
|
@ -288,7 +288,8 @@ class QLinearPackWeightInt8 final {
|
|||
auto& ctx = at::globalContext();
|
||||
|
||||
#ifdef USE_FBGEMM
|
||||
if (ctx.qEngine() == at::QEngine::FBGEMM) {
|
||||
if (ctx.qEngine() == at::QEngine::FBGEMM ||
|
||||
ctx.qEngine() == at::QEngine::X86) {
|
||||
return PackedLinearWeight::prepack(std::move(weight), std::move(bias));
|
||||
}
|
||||
#endif
|
||||
|
|
@ -320,7 +321,8 @@ class QLinearPackWeightFp16 final {
|
|||
// temporarily convert weight back to fp32, needs to be fixed
|
||||
// after fbgemm fixes the interface for their prepacking op (take fp16 input0
|
||||
weight = weight.to(ScalarType::Float);
|
||||
if (ctx.qEngine() == at::QEngine::FBGEMM) {
|
||||
if (ctx.qEngine() == at::QEngine::FBGEMM ||
|
||||
ctx.qEngine() == at::QEngine::X86) {
|
||||
return PackedLinearWeightFp16::prepack(
|
||||
std::move(weight), std::move(bias));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,7 +36,8 @@ class QConvUnpackWeightsInt8 final {
|
|||
auto& ctx = at::globalContext();
|
||||
|
||||
#ifdef USE_FBGEMM
|
||||
if (ctx.qEngine() == at::QEngine::FBGEMM) {
|
||||
if (ctx.qEngine() == at::QEngine::FBGEMM ||
|
||||
ctx.qEngine() == at::QEngine::X86) {
|
||||
return packed_weight->unpack();
|
||||
}
|
||||
#endif
|
||||
|
|
@ -72,7 +73,8 @@ class QConv1dUnpackWeightsInt8 final {
|
|||
at::Tensor weight;
|
||||
c10::optional<at::Tensor> bias;
|
||||
#ifdef USE_FBGEMM
|
||||
if (ctx.qEngine() == at::QEngine::FBGEMM) {
|
||||
if (ctx.qEngine() == at::QEngine::FBGEMM ||
|
||||
ctx.qEngine() == at::QEngine::X86) {
|
||||
std::tie(weight, bias) = packed_weight->unpack();
|
||||
weight = weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2);
|
||||
return std::tuple<at::Tensor, c10::optional<at::Tensor>>(weight, bias);
|
||||
|
|
|
|||
|
|
@ -16,12 +16,14 @@ enum class QEngine : uint8_t {
|
|||
FBGEMM = 1,
|
||||
QNNPACK = 2,
|
||||
ONEDNN = 3,
|
||||
X86 = 4,
|
||||
};
|
||||
|
||||
constexpr auto kNoQEngine = QEngine::NoQEngine;
|
||||
constexpr auto kFBGEMM = QEngine::FBGEMM;
|
||||
constexpr auto kQNNPACK = QEngine::QNNPACK;
|
||||
constexpr auto kONEDNN = QEngine::ONEDNN;
|
||||
constexpr auto kX86 = QEngine::X86;
|
||||
|
||||
inline std::string toString(QEngine qengine) {
|
||||
switch (qengine) {
|
||||
|
|
@ -33,6 +35,8 @@ inline std::string toString(QEngine qengine) {
|
|||
return "QNNPACK";
|
||||
case kONEDNN:
|
||||
return "ONEDNN";
|
||||
case kX86:
|
||||
return "X86";
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false, "Unrecognized Quantized Engine: ", static_cast<int>(qengine));
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from torch.testing._internal.common_quantized import (
|
|||
qengine_is_qnnpack,
|
||||
qengine_is_fbgemm,
|
||||
qengine_is_onednn,
|
||||
qengine_is_x86,
|
||||
)
|
||||
|
||||
# TODO: Once more test files are created, move the contents to a ao folder.
|
||||
|
|
@ -48,8 +49,8 @@ class TestQuantizedSparseKernels(TestCase):
|
|||
# to other higher priority works.
|
||||
if qengine_is_qnnpack() and not (row_block_size == 1 and col_block_size == 4):
|
||||
return
|
||||
# ONEDNN does not support this yet
|
||||
if qengine_is_onednn():
|
||||
# ONEDNN and X86 do not support this yet
|
||||
if qengine_is_onednn() or qengine_is_x86():
|
||||
return
|
||||
|
||||
dense_prepack = torch.ops.quantized.linear_prepack
|
||||
|
|
|
|||
|
|
@ -2900,7 +2900,7 @@ class TestQuantizedOps(TestCase):
|
|||
]
|
||||
|
||||
q_data = []
|
||||
reduce_range = (qengine in ('fbgemm', 'onednn'))
|
||||
reduce_range = (qengine in ('x86', 'fbgemm', 'onednn'))
|
||||
for idx, x in enumerate(fp_data):
|
||||
scale, zero_point = _calculate_dynamic_qparams(
|
||||
x, dtype=dtype, reduce_range=reduce_range)
|
||||
|
|
@ -3018,7 +3018,7 @@ class TestDynamicQuantizedOps(TestCase):
|
|||
(b_value_max - b_value_min) + b_value_min
|
||||
).astype(np.int32) if use_bias else None
|
||||
|
||||
if torch.backends.quantized.engine in ('fbgemm', 'onednn'):
|
||||
if torch.backends.quantized.engine in ('x86', 'fbgemm', 'onednn'):
|
||||
avoid_vpmaddubsw_overflow_linear(
|
||||
batch_size,
|
||||
input_channels,
|
||||
|
|
@ -3590,7 +3590,7 @@ class TestQuantizedLinear(TestCase):
|
|||
np.random.rand(output_channels) *
|
||||
(b_value_max - b_value_min) + b_value_min
|
||||
).astype(np.int32) if use_bias else None
|
||||
if torch.backends.quantized.engine in ('fbgemm', 'onednn'):
|
||||
if torch.backends.quantized.engine in ('x86', 'fbgemm', 'onednn'):
|
||||
avoid_vpmaddubsw_overflow_linear(
|
||||
batch_size,
|
||||
input_channels,
|
||||
|
|
@ -4479,7 +4479,7 @@ class TestQuantizedConv(TestCase):
|
|||
height=st.integers(10, 16),
|
||||
width=st.integers(7, 14),
|
||||
output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
|
||||
groups=st.integers(1, 3),
|
||||
groups=st.integers(1, 300),
|
||||
kernel_h=st.integers(1, 7),
|
||||
kernel_w=st.integers(1, 7),
|
||||
stride_h=st.integers(1, 2),
|
||||
|
|
@ -4835,7 +4835,7 @@ class TestQuantizedConv(TestCase):
|
|||
height=st.integers(10, 16),
|
||||
width=st.integers(7, 14),
|
||||
output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
|
||||
groups=st.integers(1, 3),
|
||||
groups=st.integers(1, 300),
|
||||
kernel_h=st.integers(1, 7),
|
||||
kernel_w=st.integers(1, 7),
|
||||
stride_h=st.integers(1, 2),
|
||||
|
|
@ -4961,7 +4961,7 @@ class TestQuantizedConv(TestCase):
|
|||
height=st.integers(10, 16),
|
||||
width=st.integers(7, 14),
|
||||
output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
|
||||
groups=st.integers(1, 3),
|
||||
groups=st.integers(1, 300),
|
||||
kernel_t=st.integers(1, 7),
|
||||
kernel_h=st.integers(1, 7),
|
||||
kernel_w=st.integers(1, 7),
|
||||
|
|
|
|||
111
torch/ao/quantization/backend_config/x86.py
Normal file
111
torch/ao/quantization/backend_config/x86.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
import torch
|
||||
from ._common_operator_config_utils import (
|
||||
_get_binary_op_configs,
|
||||
_get_bn_configs,
|
||||
_get_cat_config,
|
||||
_get_conv_configs,
|
||||
_get_default_op_configs,
|
||||
_get_embedding_op_configs,
|
||||
_get_fixed_qparams_op_configs,
|
||||
_get_linear_configs,
|
||||
_get_rnn_op_configs,
|
||||
_get_share_qparams_op_configs,
|
||||
)
|
||||
from .backend_config import BackendConfig, DTypeConfig
|
||||
|
||||
|
||||
# ===================
|
||||
# | DTYPE CONFIGS |
|
||||
# ===================
|
||||
|
||||
# X86 aligns with FBGEMM for now
|
||||
|
||||
x86_weighted_op_int8_dtype_config = DTypeConfig(
|
||||
input_dtype=torch.quint8,
|
||||
output_dtype=torch.quint8,
|
||||
weight_dtype=torch.qint8,
|
||||
bias_dtype=torch.float,
|
||||
)
|
||||
|
||||
x86_default_op_quint8_dtype_config = DTypeConfig(
|
||||
input_dtype=torch.quint8,
|
||||
output_dtype=torch.quint8,
|
||||
)
|
||||
|
||||
x86_default_op_fp16_dtype_config = DTypeConfig(
|
||||
input_dtype=torch.float16,
|
||||
output_dtype=torch.float16,
|
||||
weight_dtype=torch.float16,
|
||||
bias_dtype=torch.float16,
|
||||
)
|
||||
|
||||
x86_default_dynamic_int8_dtype_config = DTypeConfig(
|
||||
input_dtype=torch.quint8,
|
||||
output_dtype=torch.float,
|
||||
weight_dtype=torch.qint8,
|
||||
bias_dtype=torch.float,
|
||||
is_dynamic=True,
|
||||
)
|
||||
|
||||
x86_default_dynamic_float16_dtype_config = DTypeConfig(
|
||||
input_dtype=torch.float16,
|
||||
output_dtype=torch.float,
|
||||
weight_dtype=torch.float16,
|
||||
bias_dtype=torch.float,
|
||||
is_dynamic=True,
|
||||
)
|
||||
|
||||
x86_weight_only_quint8_dtype_config = DTypeConfig(
|
||||
input_dtype=torch.float,
|
||||
output_dtype=torch.float,
|
||||
weight_dtype=torch.quint8,
|
||||
)
|
||||
|
||||
x86_weight_only_quint4x2_dtype_config = DTypeConfig(
|
||||
input_dtype=torch.float,
|
||||
output_dtype=torch.float,
|
||||
weight_dtype=torch.quint4x2,
|
||||
)
|
||||
|
||||
|
||||
# =====================
|
||||
# | BACKEND CONFIGS |
|
||||
# =====================
|
||||
|
||||
def get_x86_backend_config() -> BackendConfig:
|
||||
"""
|
||||
Return the `BackendConfig` for PyTorch's native x86 backend.
|
||||
"""
|
||||
conv_dtype_configs = [x86_weighted_op_int8_dtype_config]
|
||||
linear_dtype_configs = [
|
||||
x86_weighted_op_int8_dtype_config,
|
||||
x86_default_dynamic_int8_dtype_config,
|
||||
x86_default_dynamic_float16_dtype_config,
|
||||
]
|
||||
binary_op_dtype_configs = [x86_weighted_op_int8_dtype_config]
|
||||
default_op_dtype_configs = [x86_default_op_quint8_dtype_config]
|
||||
fixed_qparams_op_dtype_configs = [x86_weighted_op_int8_dtype_config]
|
||||
share_qparams_op_dtype_configs = [x86_default_op_quint8_dtype_config]
|
||||
rnn_op_dtype_configs = [
|
||||
x86_default_dynamic_int8_dtype_config,
|
||||
x86_default_dynamic_float16_dtype_config,
|
||||
]
|
||||
embedding_op_dtype_configs = [
|
||||
x86_weight_only_quint8_dtype_config,
|
||||
x86_weight_only_quint4x2_dtype_config,
|
||||
]
|
||||
return BackendConfig("x86") \
|
||||
.set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
|
||||
.set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
|
||||
.set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
|
||||
.set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
|
||||
.set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
|
||||
.set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
|
||||
.set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
|
||||
.set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
|
||||
.set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
|
||||
.set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
|
||||
|
||||
__all__ = [
|
||||
"get_x86_backend_config",
|
||||
]
|
||||
|
|
@ -222,6 +222,7 @@ class PerChannelDetector(DetectorBase):
|
|||
"fbgemm": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]),
|
||||
"qnnpack": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]),
|
||||
"onednn": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]),
|
||||
"x86": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]),
|
||||
}
|
||||
|
||||
def __init__(self, backend: str = torch.backends.quantized.engine):
|
||||
|
|
|
|||
|
|
@ -190,7 +190,7 @@ def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[
|
|||
quantize_op = torch.quantize_per_tensor_dynamic
|
||||
# TODO: get reduce range from observer
|
||||
# reduce_range = activation_post_process.reduce_range
|
||||
reduce_range = torch.backends.quantized.engine == "fbgemm"
|
||||
reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86")
|
||||
qparams = {"_dtype_": compute_dtype, "_reduce_range_": reduce_range}
|
||||
elif dtype == torch.float16:
|
||||
node_type = "call_method"
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ def get_default_qconfig(backend='fbgemm', version=0):
|
|||
|
||||
Args:
|
||||
* `backend`: a string representing the target backend. Currently supports `fbgemm`,
|
||||
`qnnpack` and `onednn`.
|
||||
`qnnpack`, `onednn` and `x86`.
|
||||
|
||||
Return:
|
||||
qconfig
|
||||
|
|
@ -244,6 +244,9 @@ def get_default_qconfig(backend='fbgemm', version=0):
|
|||
elif backend == 'onednn':
|
||||
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
|
||||
weight=default_per_channel_weight_observer)
|
||||
elif backend == 'x86':
|
||||
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
|
||||
weight=default_per_channel_weight_observer)
|
||||
else:
|
||||
qconfig = default_qconfig
|
||||
else:
|
||||
|
|
@ -300,7 +303,7 @@ def get_default_qat_qconfig(backend='fbgemm', version=1):
|
|||
|
||||
Args:
|
||||
* `backend`: a string representing the target backend. Currently supports `fbgemm`,
|
||||
`qnnpack` and `onednn`.
|
||||
`qnnpack`, `onednn` and `x86`.
|
||||
* `version`: version, for backwards compatibility. Can be `None` or `1`.
|
||||
|
||||
Return:
|
||||
|
|
@ -325,6 +328,12 @@ def get_default_qat_qconfig(backend='fbgemm', version=1):
|
|||
quant_min=0,
|
||||
quant_max=255),
|
||||
weight=default_per_channel_weight_fake_quant)
|
||||
if backend == 'x86':
|
||||
qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
||||
quant_min=0,
|
||||
quant_max=255,
|
||||
reduce_range=True),
|
||||
weight=default_per_channel_weight_fake_quant)
|
||||
else:
|
||||
qconfig = default_qat_qconfig
|
||||
# Use the fused observe + fake_quant modules for doing QAT.
|
||||
|
|
@ -346,6 +355,12 @@ def get_default_qat_qconfig(backend='fbgemm', version=1):
|
|||
quant_min=0,
|
||||
quant_max=255),
|
||||
weight=default_fused_per_channel_wt_fake_quant)
|
||||
elif backend == 'x86':
|
||||
qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
||||
quant_min=0,
|
||||
quant_max=255,
|
||||
reduce_range=True),
|
||||
weight=default_fused_per_channel_wt_fake_quant)
|
||||
else:
|
||||
qconfig = default_qat_qconfig_v2
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int) -> QC
|
|||
# so we have to modify the weight observer to default_weight_observer or another
|
||||
# per tensor supported observer.
|
||||
# see https://github.com/pytorch/pytorch/issues/47535
|
||||
if backend == "fbgemm":
|
||||
if backend in ("fbgemm", "x86"):
|
||||
qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight)
|
||||
else:
|
||||
qconfig_transpose = qconfig
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ def _get_qengine_id(qengine: str) -> int:
|
|||
ret = 2
|
||||
elif qengine == 'onednn':
|
||||
ret = 3
|
||||
elif qengine == 'x86':
|
||||
ret = 4
|
||||
else:
|
||||
ret = -1
|
||||
raise RuntimeError("{} is not a valid value for quantized engine".format(qengine))
|
||||
|
|
@ -20,7 +22,7 @@ def _get_qengine_id(qengine: str) -> int:
|
|||
|
||||
# This function should correspond to the enums present in c10/core/QEngine.h
|
||||
def _get_qengine_str(qengine: int) -> str:
|
||||
all_engines = {0 : 'none', 1 : 'fbgemm', 2 : 'qnnpack', 3 : 'onednn'}
|
||||
all_engines = {0 : 'none', 1 : 'fbgemm', 2 : 'qnnpack', 3 : 'onednn', 4 : 'x86'}
|
||||
return all_engines.get(qengine, '*undefined')
|
||||
|
||||
class _QEngineProp(object):
|
||||
|
|
|
|||
|
|
@ -178,6 +178,8 @@ def qengine_is_qnnpack():
|
|||
return torch.backends.quantized.engine == 'qnnpack'
|
||||
def qengine_is_onednn():
|
||||
return torch.backends.quantized.engine == 'onednn'
|
||||
def qengine_is_x86():
|
||||
return torch.backends.quantized.engine == 'x86'
|
||||
|
||||
# Helper function used to simulate per-channel fake-quant against any axis
|
||||
def _permute_to_axis_zero(X, axis):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user