diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 1a4327aaa6b..a08a5013c37 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -309,8 +309,11 @@ const std::vector& Context::supportedQEngines() { #ifdef USE_FBGEMM if (fbgemm::fbgemmSupportedCPU()) { engines.push_back(at::kFBGEMM); + // The X86 qengine is available if and only if FBGEMM is available + engines.push_back(at::kX86); } #endif + return engines; }(); return supported_qengines; diff --git a/aten/src/ATen/native/quantized/cpu/OnednnUtils.h b/aten/src/ATen/native/quantized/cpu/OnednnUtils.h index 65294b18145..fefc2426e2d 100644 --- a/aten/src/ATen/native/quantized/cpu/OnednnUtils.h +++ b/aten/src/ATen/native/quantized/cpu/OnednnUtils.h @@ -5,6 +5,7 @@ #include #include #include +#include using PrimitiveCacheKey = std::tuple< double, // input_scale @@ -349,6 +350,50 @@ static void try_reorder( t.set_scale(scales); } } + +// ONEDNN requires symmetric quantization of weight +// Use this util function to check. +static bool is_weight_symmetric_quant( + const at::Tensor& weight, + bool is_transposed_conv) { + bool is_symmetric = true; + const auto qtype = weight.qscheme(); + if (qtype == c10::kPerTensorAffine) { + is_symmetric &= (weight.q_zero_point() == 0); + } else if (qtype == c10::kPerChannelAffine) { + if (is_transposed_conv) { + // This case is currently not supported in PyTorch + // but we do not want to raise an error in this util function. + is_symmetric = false; + } else { + auto output_channels = weight.size(0); + for (int i = 0; i < output_channels; ++i) { + auto zp = weight.q_per_channel_zero_points()[i].item(); + is_symmetric &= (zp == 0); + } + } + } else { + // This case is currently not supported in PyTorch + // but we do not want to raise an error in this util function. + is_symmetric = false; + } + return is_symmetric; } +// Check if onednn should be used w.r.t fbgemm +static bool should_use_onednn_quant( + const at::Tensor& weight, + bool is_transposed_conv, + int groups, + torch::List output_padding) { + bool vnni_available = cpuinfo_has_x86_avx512vnni(); + bool w_sym_quant = + is_weight_symmetric_quant(weight, is_transposed_conv); + bool opad_all_zero = + std::all_of(output_padding.begin(), output_padding.end(), [](int i) { return i==0; }); + return vnni_available && (groups <= 100) && w_sym_quant && opad_all_zero; +} + +} // onednn_utils + #endif // #if AT_MKLDNN_ENABLED() diff --git a/aten/src/ATen/native/quantized/cpu/conv_serialization.h b/aten/src/ATen/native/quantized/cpu/conv_serialization.h index 9e4edb8f9a8..293aa50856c 100644 --- a/aten/src/ATen/native/quantized/cpu/conv_serialization.h +++ b/aten/src/ATen/native/quantized/cpu/conv_serialization.h @@ -6,6 +6,7 @@ #include #include #include +#include #include @@ -330,6 +331,37 @@ c10::intrusive_ptr> deserialize_conv( auto& ctx = at::globalContext(); +#ifdef USE_FBGEMM + if (ctx.qEngine() == at::QEngine::X86) { +#if AT_MKLDNN_ENABLED() + bool use_onednn = onednn_utils::should_use_onednn_quant( + weight.value(), transpose, groups, output_padding); + if (use_onednn) { + return PackedConvWeightsOnednn::prepack( + weight.value(), + bias, + stride, + padding, + output_padding, + dilation, + groups, + transpose + ); + } +#endif + return PackedConvWeight::prepack( + weight.value(), + bias, + stride, + padding, + output_padding, + dilation, + groups, + transpose + ); + } // x86 +#endif + #ifdef USE_FBGEMM if (ctx.qEngine() == at::QEngine::FBGEMM) { return PackedConvWeight::prepack( diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp index 33d8bd88b85..0850ba6ebb5 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp @@ -445,7 +445,8 @@ int register_linear_params() { bias = std::move(std::get<1>(state)); #ifdef USE_FBGEMM - if (at::globalContext().qEngine() == at::QEngine::FBGEMM) { + if (at::globalContext().qEngine() == at::QEngine::FBGEMM || + at::globalContext().qEngine() == at::QEngine::X86) { if (weight.scalar_type() == at::kQInt8) { return PackedLinearWeight::prepack( std::move(weight), std::move(bias)); diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp index fd31c2e7088..5509928cf8e 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp @@ -521,6 +521,21 @@ class QConvPackWeightInt8 final { int64_t groups, bool transpose) { auto& ctx = at::globalContext(); +#ifdef USE_FBGEMM + if (ctx.qEngine() == at::QEngine::X86) { +#if AT_MKLDNN_ENABLED() + bool use_onednn = onednn_utils::should_use_onednn_quant( + weight, transpose, groups, output_padding); + if (use_onednn) { + return PackedConvWeightsOnednn::prepack( + weight, bias, stride, padding, output_padding, dilation, groups, transpose); + } +#endif + return PackedConvWeight::prepack( + weight, bias, stride, padding, output_padding, dilation, groups, transpose); + } // x86 +#endif // defined(USE_FBGEMM) || AT_MKLDNN_ENABLED() + #ifdef USE_FBGEMM if (ctx.qEngine() == at::QEngine::FBGEMM) { return PackedConvWeight::prepack( @@ -598,6 +613,25 @@ class QConv1dPackWeightInt8 final { padding = quant_utils::MakeArgForConv1d(padding, 0); output_padding = quant_utils::MakeArgForConv1d(output_padding, 0); dilation = quant_utils::MakeArgForConv1d(dilation, 1); + +#ifdef USE_FBGEMM + if (ctx.qEngine() == at::QEngine::X86) { +#if AT_MKLDNN_ENABLED() + bool use_onednn = onednn_utils::should_use_onednn_quant( + weight, transpose, groups, output_padding); + if (use_onednn) { + return PackedConvWeightsOnednn<2>::prepack( + weight, bias, stride, padding, output_padding, dilation, groups, + transpose); + } +#endif + return PackedConvWeight<2>::prepack( + weight, bias, stride, padding, output_padding, dilation, groups, + transpose); + + } // x86 +#endif + #ifdef USE_FBGEMM if (ctx.qEngine() == at::QEngine::FBGEMM) { return PackedConvWeight<2>::prepack( diff --git a/aten/src/ATen/native/quantized/cpu/qconv_unpack_impl.cpp b/aten/src/ATen/native/quantized/cpu/qconv_unpack_impl.cpp index ad32d9b16a2..8af8d62f2f8 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_unpack_impl.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_unpack_impl.cpp @@ -126,7 +126,7 @@ template std::tuple> PackedConvWeightsOnednn< kSpatialDim>::unpack() { return std::tuple>( - orig_weight_, orig_bias_); + orig_weight_.clone(), orig_bias_); } template std::tuple> PackedConvWeightsOnednn< diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index b4f0f4c41f4..c741f12b5ee 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -288,7 +288,8 @@ class QLinearPackWeightInt8 final { auto& ctx = at::globalContext(); #ifdef USE_FBGEMM - if (ctx.qEngine() == at::QEngine::FBGEMM) { + if (ctx.qEngine() == at::QEngine::FBGEMM || + ctx.qEngine() == at::QEngine::X86) { return PackedLinearWeight::prepack(std::move(weight), std::move(bias)); } #endif @@ -320,7 +321,8 @@ class QLinearPackWeightFp16 final { // temporarily convert weight back to fp32, needs to be fixed // after fbgemm fixes the interface for their prepacking op (take fp16 input0 weight = weight.to(ScalarType::Float); - if (ctx.qEngine() == at::QEngine::FBGEMM) { + if (ctx.qEngine() == at::QEngine::FBGEMM || + ctx.qEngine() == at::QEngine::X86) { return PackedLinearWeightFp16::prepack( std::move(weight), std::move(bias)); } diff --git a/aten/src/ATen/native/quantized/qconv_unpack.cpp b/aten/src/ATen/native/quantized/qconv_unpack.cpp index 41f4754e8f1..8077c183e50 100644 --- a/aten/src/ATen/native/quantized/qconv_unpack.cpp +++ b/aten/src/ATen/native/quantized/qconv_unpack.cpp @@ -36,7 +36,8 @@ class QConvUnpackWeightsInt8 final { auto& ctx = at::globalContext(); #ifdef USE_FBGEMM - if (ctx.qEngine() == at::QEngine::FBGEMM) { + if (ctx.qEngine() == at::QEngine::FBGEMM || + ctx.qEngine() == at::QEngine::X86) { return packed_weight->unpack(); } #endif @@ -72,7 +73,8 @@ class QConv1dUnpackWeightsInt8 final { at::Tensor weight; c10::optional bias; #ifdef USE_FBGEMM - if (ctx.qEngine() == at::QEngine::FBGEMM) { + if (ctx.qEngine() == at::QEngine::FBGEMM || + ctx.qEngine() == at::QEngine::X86) { std::tie(weight, bias) = packed_weight->unpack(); weight = weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2); return std::tuple>(weight, bias); diff --git a/c10/core/QEngine.h b/c10/core/QEngine.h index 60c21361f15..71eb4b34ac9 100644 --- a/c10/core/QEngine.h +++ b/c10/core/QEngine.h @@ -16,12 +16,14 @@ enum class QEngine : uint8_t { FBGEMM = 1, QNNPACK = 2, ONEDNN = 3, + X86 = 4, }; constexpr auto kNoQEngine = QEngine::NoQEngine; constexpr auto kFBGEMM = QEngine::FBGEMM; constexpr auto kQNNPACK = QEngine::QNNPACK; constexpr auto kONEDNN = QEngine::ONEDNN; +constexpr auto kX86 = QEngine::X86; inline std::string toString(QEngine qengine) { switch (qengine) { @@ -33,6 +35,8 @@ inline std::string toString(QEngine qengine) { return "QNNPACK"; case kONEDNN: return "ONEDNN"; + case kX86: + return "X86"; default: TORCH_CHECK( false, "Unrecognized Quantized Engine: ", static_cast(qengine)); diff --git a/test/ao/sparsity/test_kernels.py b/test/ao/sparsity/test_kernels.py index 3bcbed09c59..8e88545167a 100644 --- a/test/ao/sparsity/test_kernels.py +++ b/test/ao/sparsity/test_kernels.py @@ -22,6 +22,7 @@ from torch.testing._internal.common_quantized import ( qengine_is_qnnpack, qengine_is_fbgemm, qengine_is_onednn, + qengine_is_x86, ) # TODO: Once more test files are created, move the contents to a ao folder. @@ -48,8 +49,8 @@ class TestQuantizedSparseKernels(TestCase): # to other higher priority works. if qengine_is_qnnpack() and not (row_block_size == 1 and col_block_size == 4): return - # ONEDNN does not support this yet - if qengine_is_onednn(): + # ONEDNN and X86 do not support this yet + if qengine_is_onednn() or qengine_is_x86(): return dense_prepack = torch.ops.quantized.linear_prepack diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 62663909996..1a39c6925ba 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -2900,7 +2900,7 @@ class TestQuantizedOps(TestCase): ] q_data = [] - reduce_range = (qengine in ('fbgemm', 'onednn')) + reduce_range = (qengine in ('x86', 'fbgemm', 'onednn')) for idx, x in enumerate(fp_data): scale, zero_point = _calculate_dynamic_qparams( x, dtype=dtype, reduce_range=reduce_range) @@ -3018,7 +3018,7 @@ class TestDynamicQuantizedOps(TestCase): (b_value_max - b_value_min) + b_value_min ).astype(np.int32) if use_bias else None - if torch.backends.quantized.engine in ('fbgemm', 'onednn'): + if torch.backends.quantized.engine in ('x86', 'fbgemm', 'onednn'): avoid_vpmaddubsw_overflow_linear( batch_size, input_channels, @@ -3590,7 +3590,7 @@ class TestQuantizedLinear(TestCase): np.random.rand(output_channels) * (b_value_max - b_value_min) + b_value_min ).astype(np.int32) if use_bias else None - if torch.backends.quantized.engine in ('fbgemm', 'onednn'): + if torch.backends.quantized.engine in ('x86', 'fbgemm', 'onednn'): avoid_vpmaddubsw_overflow_linear( batch_size, input_channels, @@ -4479,7 +4479,7 @@ class TestQuantizedConv(TestCase): height=st.integers(10, 16), width=st.integers(7, 14), output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), - groups=st.integers(1, 3), + groups=st.integers(1, 300), kernel_h=st.integers(1, 7), kernel_w=st.integers(1, 7), stride_h=st.integers(1, 2), @@ -4835,7 +4835,7 @@ class TestQuantizedConv(TestCase): height=st.integers(10, 16), width=st.integers(7, 14), output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), - groups=st.integers(1, 3), + groups=st.integers(1, 300), kernel_h=st.integers(1, 7), kernel_w=st.integers(1, 7), stride_h=st.integers(1, 2), @@ -4961,7 +4961,7 @@ class TestQuantizedConv(TestCase): height=st.integers(10, 16), width=st.integers(7, 14), output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), - groups=st.integers(1, 3), + groups=st.integers(1, 300), kernel_t=st.integers(1, 7), kernel_h=st.integers(1, 7), kernel_w=st.integers(1, 7), diff --git a/torch/ao/quantization/backend_config/x86.py b/torch/ao/quantization/backend_config/x86.py new file mode 100644 index 00000000000..ce92ed9bc42 --- /dev/null +++ b/torch/ao/quantization/backend_config/x86.py @@ -0,0 +1,111 @@ +import torch +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, +) +from .backend_config import BackendConfig, DTypeConfig + + +# =================== +# | DTYPE CONFIGS | +# =================== + +# X86 aligns with FBGEMM for now + +x86_weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +x86_default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +x86_default_op_fp16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float16, + weight_dtype=torch.float16, + bias_dtype=torch.float16, +) + +x86_default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +x86_default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + is_dynamic=True, +) + +x86_weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + +x86_weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, +) + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + +def get_x86_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch's native x86 backend. + """ + conv_dtype_configs = [x86_weighted_op_int8_dtype_config] + linear_dtype_configs = [ + x86_weighted_op_int8_dtype_config, + x86_default_dynamic_int8_dtype_config, + x86_default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [x86_weighted_op_int8_dtype_config] + default_op_dtype_configs = [x86_default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [x86_weighted_op_int8_dtype_config] + share_qparams_op_dtype_configs = [x86_default_op_quint8_dtype_config] + rnn_op_dtype_configs = [ + x86_default_dynamic_int8_dtype_config, + x86_default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + x86_weight_only_quint8_dtype_config, + x86_weight_only_quint4x2_dtype_config, + ] + return BackendConfig("x86") \ + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \ + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \ + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \ + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs)) + +__all__ = [ + "get_x86_backend_config", +] diff --git a/torch/ao/quantization/fx/_model_report/detector.py b/torch/ao/quantization/fx/_model_report/detector.py index 30fad8b45f5..239137aaaab 100644 --- a/torch/ao/quantization/fx/_model_report/detector.py +++ b/torch/ao/quantization/fx/_model_report/detector.py @@ -222,6 +222,7 @@ class PerChannelDetector(DetectorBase): "fbgemm": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]), "qnnpack": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]), "onednn": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]), + "x86": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]), } def __init__(self, backend: str = torch.backends.quantized.engine): diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index 0006f385e3d..fb7346a17af 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -190,7 +190,7 @@ def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[ quantize_op = torch.quantize_per_tensor_dynamic # TODO: get reduce range from observer # reduce_range = activation_post_process.reduce_range - reduce_range = torch.backends.quantized.engine == "fbgemm" + reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86") qparams = {"_dtype_": compute_dtype, "_reduce_range_": reduce_range} elif dtype == torch.float16: node_type = "call_method" diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index 17310fd3aec..e0a1a52f567 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -229,7 +229,7 @@ def get_default_qconfig(backend='fbgemm', version=0): Args: * `backend`: a string representing the target backend. Currently supports `fbgemm`, - `qnnpack` and `onednn`. + `qnnpack`, `onednn` and `x86`. Return: qconfig @@ -244,6 +244,9 @@ def get_default_qconfig(backend='fbgemm', version=0): elif backend == 'onednn': qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False), weight=default_per_channel_weight_observer) + elif backend == 'x86': + qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True), + weight=default_per_channel_weight_observer) else: qconfig = default_qconfig else: @@ -300,7 +303,7 @@ def get_default_qat_qconfig(backend='fbgemm', version=1): Args: * `backend`: a string representing the target backend. Currently supports `fbgemm`, - `qnnpack` and `onednn`. + `qnnpack`, `onednn` and `x86`. * `version`: version, for backwards compatibility. Can be `None` or `1`. Return: @@ -325,6 +328,12 @@ def get_default_qat_qconfig(backend='fbgemm', version=1): quant_min=0, quant_max=255), weight=default_per_channel_weight_fake_quant) + if backend == 'x86': + qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True), + weight=default_per_channel_weight_fake_quant) else: qconfig = default_qat_qconfig # Use the fused observe + fake_quant modules for doing QAT. @@ -346,6 +355,12 @@ def get_default_qat_qconfig(backend='fbgemm', version=1): quant_min=0, quant_max=255), weight=default_fused_per_channel_wt_fake_quant) + elif backend == 'x86': + qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True), + weight=default_fused_per_channel_wt_fake_quant) else: qconfig = default_qat_qconfig_v2 else: diff --git a/torch/ao/quantization/qconfig_mapping.py b/torch/ao/quantization/qconfig_mapping.py index fa1d6d057ae..ac863409001 100644 --- a/torch/ao/quantization/qconfig_mapping.py +++ b/torch/ao/quantization/qconfig_mapping.py @@ -69,7 +69,7 @@ def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int) -> QC # so we have to modify the weight observer to default_weight_observer or another # per tensor supported observer. # see https://github.com/pytorch/pytorch/issues/47535 - if backend == "fbgemm": + if backend in ("fbgemm", "x86"): qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight) else: qconfig_transpose = qconfig diff --git a/torch/backends/quantized/__init__.py b/torch/backends/quantized/__init__.py index 6f7d479e90c..2db2b672f1b 100644 --- a/torch/backends/quantized/__init__.py +++ b/torch/backends/quantized/__init__.py @@ -13,6 +13,8 @@ def _get_qengine_id(qengine: str) -> int: ret = 2 elif qengine == 'onednn': ret = 3 + elif qengine == 'x86': + ret = 4 else: ret = -1 raise RuntimeError("{} is not a valid value for quantized engine".format(qengine)) @@ -20,7 +22,7 @@ def _get_qengine_id(qengine: str) -> int: # This function should correspond to the enums present in c10/core/QEngine.h def _get_qengine_str(qengine: int) -> str: - all_engines = {0 : 'none', 1 : 'fbgemm', 2 : 'qnnpack', 3 : 'onednn'} + all_engines = {0 : 'none', 1 : 'fbgemm', 2 : 'qnnpack', 3 : 'onednn', 4 : 'x86'} return all_engines.get(qengine, '*undefined') class _QEngineProp(object): diff --git a/torch/testing/_internal/common_quantized.py b/torch/testing/_internal/common_quantized.py index 597fd774e32..0db312f2c20 100644 --- a/torch/testing/_internal/common_quantized.py +++ b/torch/testing/_internal/common_quantized.py @@ -178,6 +178,8 @@ def qengine_is_qnnpack(): return torch.backends.quantized.engine == 'qnnpack' def qengine_is_onednn(): return torch.backends.quantized.engine == 'onednn' +def qengine_is_x86(): + return torch.backends.quantized.engine == 'x86' # Helper function used to simulate per-channel fake-quant against any axis def _permute_to_axis_zero(X, axis):