mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add onednn quant backend (#74137)
Summary: Resolve the conflicts in https://github.com/pytorch/pytorch/pull/69820 jerryzh168 Please review. Thanks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/74137 Reviewed By: samdow Differential Revision: D34840477 Pulled By: jerryzh168 fbshipit-source-id: 8aa60981ff7be211a1609644f273b16d18efd425 (cherry picked from commit de76bb808b315e9a2e45d8c5f1c1233a47d669c4)
This commit is contained in:
parent
deae5950ba
commit
060f1b822a
|
|
@ -236,6 +236,10 @@ const std::vector<at::QEngine>& Context::supportedQEngines() {
|
||||||
engines.push_back(at::kNoQEngine);
|
engines.push_back(at::kNoQEngine);
|
||||||
#endif // C10_MOBILE
|
#endif // C10_MOBILE
|
||||||
|
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
engines.push_back(at::kONEDNN);
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef USE_FBGEMM
|
#ifdef USE_FBGEMM
|
||||||
if (fbgemm::fbgemmSupportedCPU()) {
|
if (fbgemm::fbgemmSupportedCPU()) {
|
||||||
engines.push_back(at::kFBGEMM);
|
engines.push_back(at::kFBGEMM);
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
#include <ATen/core/List.h>
|
#include <ATen/core/List.h>
|
||||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||||
|
#include <ATen/native/quantized/cpu/onednn_utils.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
@ -358,6 +359,20 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> deserialize_conv(
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
#endif // USE_PYTORCH_QNNPACK
|
#endif // USE_PYTORCH_QNNPACK
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
if (ctx.qEngine() == at::QEngine::ONEDNN) {
|
||||||
|
return PackedConvWeightsOnednn<kSpatialDim>::prepack(
|
||||||
|
weight.value(),
|
||||||
|
bias,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
output_padding,
|
||||||
|
dilation,
|
||||||
|
groups,
|
||||||
|
transpose
|
||||||
|
);
|
||||||
|
}
|
||||||
|
#endif // AT_MKLDNN_ENABLED()
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
"Didn't find engine for when deserializing ConvPackedParams: ",
|
"Didn't find engine for when deserializing ConvPackedParams: ",
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
#include <ATen/native/quantized/cpu/embedding_packed_params.h>
|
#include <ATen/native/quantized/cpu/embedding_packed_params.h>
|
||||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||||
|
#include <ATen/native/quantized/cpu/onednn_utils.h>
|
||||||
#include <ATen/native/TensorFactories.h>
|
#include <ATen/native/TensorFactories.h>
|
||||||
#include <ATen/quantized/QTensorImpl.h>
|
#include <ATen/quantized/QTensorImpl.h>
|
||||||
#include <ATen/quantized/Quantizer.h>
|
#include <ATen/quantized/Quantizer.h>
|
||||||
|
|
@ -470,6 +471,16 @@ int register_linear_params() {
|
||||||
std::move(weight), std::move(bias));
|
std::move(weight), std::move(bias));
|
||||||
}
|
}
|
||||||
#endif // USE_PYTORCH_QNNPACK
|
#endif // USE_PYTORCH_QNNPACK
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
if (at::globalContext().qEngine() == at::QEngine::ONEDNN) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
weight.scalar_type() == at::kQInt8,
|
||||||
|
"ONEDNN only supports INT8 bit width currently. Got ",
|
||||||
|
c10::toString(weight.scalar_type()));
|
||||||
|
return PackedLinearWeightsOnednn::prepack(
|
||||||
|
std::move(weight), std::move(bias));
|
||||||
|
}
|
||||||
|
#endif // #if AT_MKLDNN_ENABLED()
|
||||||
TORCH_CHECK(false, "Unknown qengine");
|
TORCH_CHECK(false, "Unknown qengine");
|
||||||
})
|
})
|
||||||
.def("bias", [](const c10::intrusive_ptr<LinearPackedParamsBase>& self) {
|
.def("bias", [](const c10::intrusive_ptr<LinearPackedParamsBase>& self) {
|
||||||
|
|
|
||||||
151
aten/src/ATen/native/quantized/cpu/onednn_utils.h
Normal file
151
aten/src/ATen/native/quantized/cpu/onednn_utils.h
Normal file
|
|
@ -0,0 +1,151 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/Config.h>
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
#include <ATen/Tensor.h>
|
||||||
|
#include <ATen/native/quantized/packed_params.h>
|
||||||
|
#include <ATen/native/mkldnn/MKLDNNCommon.h>
|
||||||
|
#include <ATen/native/mkldnn/Utils.h>
|
||||||
|
|
||||||
|
struct PackedLinearWeightsOnednn : public LinearPackedParamsBase {
|
||||||
|
PackedLinearWeightsOnednn(
|
||||||
|
std::unique_ptr<ideep::tensor> weight,
|
||||||
|
c10::optional<ideep::tensor> bias,
|
||||||
|
at::Tensor orig_weight,
|
||||||
|
c10::optional<at::Tensor> orig_bias)
|
||||||
|
: weight_(std::move(weight)),
|
||||||
|
bias_(std::move(bias)),
|
||||||
|
orig_weight_(std::move(orig_weight)),
|
||||||
|
orig_bias_(std::move(orig_bias)) {}
|
||||||
|
std::unique_ptr<ideep::tensor> weight_;
|
||||||
|
c10::optional<ideep::tensor> bias_;
|
||||||
|
at::Tensor orig_weight_;
|
||||||
|
c10::optional<at::Tensor> orig_bias_;
|
||||||
|
|
||||||
|
at::Tensor apply(
|
||||||
|
at::Tensor input,
|
||||||
|
double output_scale,
|
||||||
|
int64_t output_zero_point) override;
|
||||||
|
at::Tensor apply_relu(
|
||||||
|
at::Tensor input,
|
||||||
|
double output_scale,
|
||||||
|
int64_t output_zero_point) override;
|
||||||
|
|
||||||
|
at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override;
|
||||||
|
at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override;
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
|
||||||
|
|
||||||
|
c10::optional<at::Tensor> bias() override {
|
||||||
|
return orig_bias_;
|
||||||
|
}
|
||||||
|
|
||||||
|
static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
|
||||||
|
at::Tensor weight,
|
||||||
|
c10::optional<at::Tensor> bias);
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <bool ReluFused>
|
||||||
|
at::Tensor apply_impl(
|
||||||
|
at::Tensor input,
|
||||||
|
double output_scale,
|
||||||
|
int64_t output_zero_point);
|
||||||
|
|
||||||
|
template <bool ReluFused>
|
||||||
|
at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range=false);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int kSpatialDim = 2>
|
||||||
|
struct PackedConvWeightsOnednn : public ConvPackedParamsBase<kSpatialDim> {
|
||||||
|
PackedConvWeightsOnednn(
|
||||||
|
std::unique_ptr<ideep::tensor> weight,
|
||||||
|
c10::optional<ideep::tensor> bias,
|
||||||
|
at::Tensor orig_weight,
|
||||||
|
c10::optional<at::Tensor> orig_bias,
|
||||||
|
torch::List<int64_t> stride,
|
||||||
|
torch::List<int64_t> padding,
|
||||||
|
torch::List<int64_t> output_padding,
|
||||||
|
torch::List<int64_t> dilation,
|
||||||
|
int64_t groups,
|
||||||
|
uint8_t transpose)
|
||||||
|
: weight_(std::move(weight)),
|
||||||
|
bias_(std::move(bias)),
|
||||||
|
orig_weight_(std::move(orig_weight)),
|
||||||
|
orig_bias_(std::move(orig_bias)),
|
||||||
|
stride_(std::move(stride)),
|
||||||
|
padding_(std::move(padding)),
|
||||||
|
output_padding_(std::move(output_padding)),
|
||||||
|
dilation_(std::move(dilation)),
|
||||||
|
groups_(groups),
|
||||||
|
transpose_(transpose) {}
|
||||||
|
|
||||||
|
std::unique_ptr<ideep::tensor> weight_;
|
||||||
|
c10::optional<ideep::tensor> bias_;
|
||||||
|
at::Tensor orig_weight_;
|
||||||
|
c10::optional<at::Tensor> orig_bias_;
|
||||||
|
torch::List<int64_t> stride_;
|
||||||
|
torch::List<int64_t> padding_;
|
||||||
|
torch::List<int64_t> output_padding_;
|
||||||
|
torch::List<int64_t> dilation_;
|
||||||
|
int64_t groups_;
|
||||||
|
uint8_t transpose_;
|
||||||
|
|
||||||
|
at::Tensor apply(
|
||||||
|
const at::Tensor& input,
|
||||||
|
double output_scale,
|
||||||
|
int64_t output_zero_point) override;
|
||||||
|
|
||||||
|
at::Tensor apply_relu(
|
||||||
|
const at::Tensor& input,
|
||||||
|
double output_scale,
|
||||||
|
int64_t output_zero_point) override;
|
||||||
|
|
||||||
|
at::Tensor apply_dynamic(
|
||||||
|
const at::Tensor& input,
|
||||||
|
bool reduce_range) override;
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
|
||||||
|
|
||||||
|
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
|
||||||
|
at::Tensor weight,
|
||||||
|
c10::optional<at::Tensor> bias,
|
||||||
|
torch::List<int64_t> stride,
|
||||||
|
torch::List<int64_t> padding,
|
||||||
|
torch::List<int64_t> output_padding,
|
||||||
|
torch::List<int64_t> dilation,
|
||||||
|
int64_t groups,
|
||||||
|
bool transpose);
|
||||||
|
|
||||||
|
torch::List<int64_t> stride() const override {
|
||||||
|
return stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::List<int64_t> padding() const override {
|
||||||
|
return padding_;
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::List<int64_t> output_padding() const override {
|
||||||
|
return output_padding_;
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::List<int64_t> dilation() const override {
|
||||||
|
return dilation_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t groups() const override {
|
||||||
|
return groups_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool transpose() const override {
|
||||||
|
return (bool)transpose_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <bool ReluFused>
|
||||||
|
at::Tensor apply_impl(
|
||||||
|
const at::Tensor& input,
|
||||||
|
double output_scale,
|
||||||
|
int64_t output_zero_point);
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // #if AT_MKLDNN_ENABLED()
|
||||||
|
|
@ -9,6 +9,8 @@
|
||||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/xnnpack_utils.h>
|
#include <ATen/native/quantized/cpu/xnnpack_utils.h>
|
||||||
|
#include <ATen/native/quantized/cpu/onednn_utils.h>
|
||||||
|
#include <ATen/native/ConvUtils.h>
|
||||||
#include <ATen/native/quantized/cpu/quant_utils.h>
|
#include <ATen/native/quantized/cpu/quant_utils.h>
|
||||||
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
|
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
|
|
@ -1148,6 +1150,177 @@ template at::Tensor PackedConvWeightsQnnp<3>::apply_impl<false>(
|
||||||
|
|
||||||
#endif // USE_PYTORCH_QNNPACK
|
#endif // USE_PYTORCH_QNNPACK
|
||||||
|
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
template <int kSpatialDim>
|
||||||
|
at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply(
|
||||||
|
const at::Tensor& input,
|
||||||
|
double output_scale,
|
||||||
|
int64_t output_zero_point) {
|
||||||
|
return apply_impl<false>(input, output_scale, output_zero_point);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int kSpatialDim>
|
||||||
|
at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_relu(
|
||||||
|
const at::Tensor& input,
|
||||||
|
double output_scale,
|
||||||
|
int64_t output_zero_point) {
|
||||||
|
return apply_impl<true>(input, output_scale, output_zero_point);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int kSpatialDim>
|
||||||
|
template <bool kReluFused>
|
||||||
|
at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_impl(
|
||||||
|
const at::Tensor& act,
|
||||||
|
double output_scale,
|
||||||
|
int64_t output_zero_point) {
|
||||||
|
std::string func_name = "quantized::conv";
|
||||||
|
if (transpose()) {
|
||||||
|
func_name += "_transpose";
|
||||||
|
}
|
||||||
|
func_name += std::to_string(kSpatialDim) + "d";
|
||||||
|
if (kReluFused) {
|
||||||
|
func_name += "_relu";
|
||||||
|
}
|
||||||
|
ConvDimChecks<kSpatialDim>(
|
||||||
|
act.ndimension(), stride().size(), padding().size(),
|
||||||
|
output_padding().size(), dilation().size(), func_name, transpose());
|
||||||
|
TORCH_CHECK(act.scalar_type() == c10::ScalarType::QUInt8,
|
||||||
|
func_name, " (ONEDNN): data type of input should be QUint8.");
|
||||||
|
|
||||||
|
// src
|
||||||
|
auto act_contig = act.contiguous(kSpatialDim == 2 ? c10::MemoryFormat::ChannelsLast : c10::MemoryFormat::ChannelsLast3d);
|
||||||
|
auto src_dims = act_contig.sizes().vec();
|
||||||
|
auto src_data_type = dnnl::memory::data_type::u8;
|
||||||
|
auto src_desc = ideep::tensor::desc(src_dims, src_data_type,
|
||||||
|
kSpatialDim == 2 ? ideep::format_tag::nhwc : ideep::format_tag::ndhwc);
|
||||||
|
ideep::tensor src;
|
||||||
|
src.init(src_desc, act_contig.data_ptr());
|
||||||
|
// weights & bias
|
||||||
|
ideep::tensor& weights = *(weight_.get());
|
||||||
|
bool with_bias = bias_.has_value();
|
||||||
|
const auto& kernel_size = weights.get_dims();
|
||||||
|
// dst
|
||||||
|
const std::vector<int64_t>& input_size = src.get_dims();
|
||||||
|
std::vector<int64_t> output_sizes;
|
||||||
|
if (transpose()) {
|
||||||
|
// Prepacked weight format: [o, i, ...]
|
||||||
|
const int N = act.size(0); // batch size
|
||||||
|
const int C = act.size(1); // input channels
|
||||||
|
const int M = weights.get_dim(0); // output channels
|
||||||
|
const int D = kSpatialDim == 2 ? 1 : act.size(2); // input depth
|
||||||
|
const int H = act.size(kSpatialDim); // input height
|
||||||
|
const int W = act.size(kSpatialDim + 1); // input width
|
||||||
|
const int KH = weights.get_dim(kSpatialDim); // kernel height
|
||||||
|
const int KW = weights.get_dim(kSpatialDim + 1); // kernel width
|
||||||
|
const int KD = kSpatialDim == 2 ? 1 : weights.get_dim(2); // kernel depth
|
||||||
|
TORCH_CHECK(C == groups() * weights.get_dim(1), // weight: [o, i, ...]
|
||||||
|
func_name, " (ONEDNN): input channel number should be ",
|
||||||
|
groups() * weights.get_dim(1), ", but got ", C);
|
||||||
|
auto output_shape = MakeDeConvOutputShape<kSpatialDim>(
|
||||||
|
N,
|
||||||
|
M,
|
||||||
|
kSpatialDim == 2 ? std::vector<int64_t>{H, W} : std::vector<int64_t>{D, H, W},
|
||||||
|
kSpatialDim == 2 ? std::vector<int64_t>{KH, KW} : std::vector<int64_t>{KD, KH, KW},
|
||||||
|
stride(),
|
||||||
|
padding(),
|
||||||
|
output_padding(),
|
||||||
|
dilation());
|
||||||
|
output_sizes = c10::IntArrayRef(output_shape).vec();
|
||||||
|
} else {
|
||||||
|
output_sizes = at::native::conv_output_size(input_size, kernel_size, padding().vec(), stride().vec(), dilation().vec());
|
||||||
|
}
|
||||||
|
ideep::dims dst_dims = ideep::dims({output_sizes.cbegin(), output_sizes.cend()});
|
||||||
|
at::Tensor output = at::_empty_affine_quantized(
|
||||||
|
dst_dims,
|
||||||
|
device(c10::kCPU)
|
||||||
|
.dtype(c10::kQUInt8)
|
||||||
|
.memory_format(kSpatialDim == 2 ?
|
||||||
|
c10::MemoryFormat::ChannelsLast :
|
||||||
|
c10::MemoryFormat::ChannelsLast3d),
|
||||||
|
output_scale,
|
||||||
|
output_zero_point,
|
||||||
|
c10::nullopt);
|
||||||
|
if (output.numel() == 0) {
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
ideep::tensor dst({dst_dims, ideep::tensor::data_type::u8, {output.strides().cbegin(), output.strides().cend()}},
|
||||||
|
output.data_ptr());
|
||||||
|
// Parameters
|
||||||
|
const ideep::dims& strides = stride().vec();
|
||||||
|
const ideep::dims& dilates = dilation().vec();
|
||||||
|
const ideep::dims& padding_l = padding().vec();
|
||||||
|
const ideep::dims& padding_r = padding().vec();
|
||||||
|
const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/act.q_scale()); // Scales of ONEDNN and PyTorch are reciprocal
|
||||||
|
const ideep::scale_t& weights_scales = weights.get_scale();
|
||||||
|
const ideep::scale_t& dst_scales = ideep::scale_t(weights_scales.size(), 1.0/output_scale); // Scales of ONEDNN and PyTorch are reciprocal
|
||||||
|
const ideep::zero_point_t src_zero_points = ideep::zero_point_t(1, act.q_zero_point());
|
||||||
|
const ideep::zero_point_t dst_zero_points = ideep::zero_point_t(1, output_zero_point);
|
||||||
|
ideep::attr_t op_attr = kReluFused ? ideep::attr_t::fuse_relu() : ideep::attr_t();
|
||||||
|
op_attr.set_zero_points(DNNL_ARG_SRC, ideep::utils::tensor_zp_mask(1), {DNNL_RUNTIME_S32_VAL}); // runtime src zero point
|
||||||
|
if (with_bias) {
|
||||||
|
// Bias might be modified outside (e.g. by quantization bias correction).
|
||||||
|
// If so, update the prepacked bias as well.
|
||||||
|
if (bias_.value().get_data_handle() != orig_bias_.value().data_ptr()) {
|
||||||
|
bias_.value().init(bias_.value().get_desc(), orig_bias_.value().data_ptr());
|
||||||
|
}
|
||||||
|
const auto& b = bias_.value();
|
||||||
|
if (transpose()) {
|
||||||
|
ideep::convolution_transpose_forward::compute_v2(
|
||||||
|
src, weights, b, dst_dims, dst,
|
||||||
|
strides, padding_l, padding_r, dilates,
|
||||||
|
groups(), src_scales, weights_scales, dst_scales, src_zero_points, dst_zero_points,
|
||||||
|
op_attr, dnnl::algorithm::deconvolution_direct, dnnl::prop_kind::forward_inference,
|
||||||
|
ideep::u8s8, ideep::engine::cpu_engine());
|
||||||
|
} else {
|
||||||
|
ideep::convolution_forward::compute_v2(
|
||||||
|
src, weights, b, dst_dims, dst,
|
||||||
|
strides, dilates, padding_l, padding_r, groups(),
|
||||||
|
src_scales, weights_scales, dst_scales, src_zero_points, dst_zero_points,
|
||||||
|
op_attr, dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference,
|
||||||
|
ideep::u8s8, ideep::engine::cpu_engine());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (transpose()) {
|
||||||
|
ideep::convolution_transpose_forward::compute_v2(
|
||||||
|
src, weights, dst_dims, dst,
|
||||||
|
strides, padding_l, padding_r, dilates,
|
||||||
|
groups(), src_scales, weights_scales, dst_scales, src_zero_points, dst_zero_points,
|
||||||
|
op_attr, dnnl::algorithm::deconvolution_direct, dnnl::prop_kind::forward_inference,
|
||||||
|
ideep::u8s8, ideep::engine::cpu_engine());
|
||||||
|
} else {
|
||||||
|
ideep::convolution_forward::compute_v2(
|
||||||
|
src, weights, dst_dims, dst,
|
||||||
|
strides, dilates, padding_l, padding_r, groups(),
|
||||||
|
src_scales, weights_scales, dst_scales, src_zero_points, dst_zero_points,
|
||||||
|
op_attr, dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference,
|
||||||
|
ideep::u8s8, ideep::engine::cpu_engine());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
template at::Tensor PackedConvWeightsOnednn<2>::apply(
|
||||||
|
const at::Tensor& act,
|
||||||
|
double output_scale,
|
||||||
|
int64_t output_zero_point);
|
||||||
|
|
||||||
|
template at::Tensor PackedConvWeightsOnednn<2>::apply_relu(
|
||||||
|
const at::Tensor& act,
|
||||||
|
double output_scale,
|
||||||
|
int64_t output_zero_point);
|
||||||
|
|
||||||
|
template at::Tensor PackedConvWeightsOnednn<3>::apply(
|
||||||
|
const at::Tensor& act,
|
||||||
|
double output_scale,
|
||||||
|
int64_t output_zero_point);
|
||||||
|
|
||||||
|
template at::Tensor PackedConvWeightsOnednn<3>::apply_relu(
|
||||||
|
const at::Tensor& act,
|
||||||
|
double output_scale,
|
||||||
|
int64_t output_zero_point);
|
||||||
|
|
||||||
|
#endif // #if AT_MKLDNN_ENABLED()
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace native {
|
namespace native {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
#include <ATen/native/quantized/packed_params.h>
|
#include <ATen/native/quantized/packed_params.h>
|
||||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||||
|
#include <ATen/native/quantized/cpu/onednn_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/quant_utils.h>
|
#include <ATen/native/quantized/cpu/quant_utils.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
|
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
|
||||||
|
|
@ -118,6 +119,57 @@ template at::Tensor PackedConvWeightsQnnp<3>::apply_dynamic(
|
||||||
|
|
||||||
#endif // USE_PYTORCH_QNNPACK
|
#endif // USE_PYTORCH_QNNPACK
|
||||||
|
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
|
||||||
|
template <int kSpatialDim>
|
||||||
|
at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_dynamic(
|
||||||
|
const at::Tensor& input,
|
||||||
|
bool reduce_range) {
|
||||||
|
|
||||||
|
// Find min/max of input
|
||||||
|
float x_max = 0, x_min = 0;
|
||||||
|
if (input.numel() > 0) {
|
||||||
|
x_min = input.min().item<float>();
|
||||||
|
x_max = input.max().item<float>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Input tensor is quantized as 8-bit unsigned values
|
||||||
|
static constexpr int precision = 8;
|
||||||
|
static constexpr bool is_signed = false;
|
||||||
|
|
||||||
|
// Calculate scale and zero point for quantization of input tensor
|
||||||
|
auto q_params = quant_utils::ChooseQuantizationParams(
|
||||||
|
/*min=*/x_min,
|
||||||
|
/*max=*/x_max,
|
||||||
|
/*qmin=*/is_signed ? -(1 << (precision - 1)) : 0,
|
||||||
|
/*qmax=*/
|
||||||
|
is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1,
|
||||||
|
/*preserve_sparsity=*/false,
|
||||||
|
/*force_scale_power_of_two=*/false,
|
||||||
|
/*reduce_range=*/reduce_range);
|
||||||
|
|
||||||
|
// Quantize input
|
||||||
|
at::Tensor q_input = at::quantize_per_tensor(
|
||||||
|
input, q_params.scale, q_params.zero_point, c10::kQUInt8);
|
||||||
|
|
||||||
|
at::Tensor out =
|
||||||
|
apply_impl<false>(q_input, q_params.scale, q_params.zero_point);
|
||||||
|
|
||||||
|
// TODO: Modify ideep to allow fp32 input & output
|
||||||
|
// to avoid explicit `quantize - dequantize`
|
||||||
|
return at::dequantize(out);
|
||||||
|
}
|
||||||
|
|
||||||
|
template at::Tensor PackedConvWeightsOnednn<2>::apply_dynamic(
|
||||||
|
const at::Tensor& input,
|
||||||
|
bool reduce_range);
|
||||||
|
|
||||||
|
template at::Tensor PackedConvWeightsOnednn<3>::apply_dynamic(
|
||||||
|
const at::Tensor& input,
|
||||||
|
bool reduce_range);
|
||||||
|
|
||||||
|
#endif // AT_MKLDNN_ENABLED()
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace native {
|
namespace native {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
||||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||||
|
#include <ATen/native/quantized/cpu/onednn_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/quant_utils.h>
|
#include <ATen/native/quantized/cpu/quant_utils.h>
|
||||||
#include <ATen/quantized/Quantizer.h>
|
#include <ATen/quantized/Quantizer.h>
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
|
|
@ -314,6 +315,165 @@ c10::intrusive_ptr<ConvPackedParamsBase<2>> PackedConvWeightsQnnp<
|
||||||
bool transpose);
|
bool transpose);
|
||||||
#endif // USE_PYTORCH_QNNPACK
|
#endif // USE_PYTORCH_QNNPACK
|
||||||
|
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
template <int kSpatialDim>
|
||||||
|
c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightsOnednn<
|
||||||
|
kSpatialDim>::
|
||||||
|
prepack(
|
||||||
|
at::Tensor weight,
|
||||||
|
c10::optional<at::Tensor> bias,
|
||||||
|
torch::List<int64_t> stride,
|
||||||
|
torch::List<int64_t> padding,
|
||||||
|
torch::List<int64_t> output_padding,
|
||||||
|
torch::List<int64_t> dilation,
|
||||||
|
int64_t groups,
|
||||||
|
bool transpose) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
weight.ndimension() == kSpatialDim + 2,
|
||||||
|
"Weights are expected to have ", kSpatialDim + 2, " dimensions");
|
||||||
|
TORCH_CHECK(
|
||||||
|
stride.size() == kSpatialDim,
|
||||||
|
"stride should contain ", kSpatialDim, " elements for ",
|
||||||
|
kSpatialDim, "D convolution.");
|
||||||
|
TORCH_CHECK(
|
||||||
|
padding.size() == kSpatialDim,
|
||||||
|
"Specify front/top/left padding only. "
|
||||||
|
"end/bottom/right padding assumed to be equal to front/top/left");
|
||||||
|
TORCH_CHECK(
|
||||||
|
!transpose || output_padding.size() == kSpatialDim,
|
||||||
|
"quantized::conv_prepack: Specify top/left output padding "
|
||||||
|
"only. bottom/right padding assumed to be equal to top/left");
|
||||||
|
TORCH_CHECK(
|
||||||
|
dilation.size() == kSpatialDim,
|
||||||
|
"dilation should contain ", kSpatialDim, " elements for ",
|
||||||
|
kSpatialDim, "D convolution.");
|
||||||
|
TORCH_CHECK(
|
||||||
|
!transpose || std::all_of(output_padding.begin(), output_padding.end(), [](int i) { return i==0; }),
|
||||||
|
"quantized::conv_prepack: ONEDNN only supports zero output_padding.");
|
||||||
|
|
||||||
|
// Weight
|
||||||
|
// Format: [OC IC//group KH KW] for conv; [IC OC//group KH KW] for deconv
|
||||||
|
auto dims = weight.sizes().vec();
|
||||||
|
auto strides = stride.vec();
|
||||||
|
auto padding_l = padding.vec();
|
||||||
|
auto padding_r = padding.vec();
|
||||||
|
auto dilates = dilation.vec();
|
||||||
|
auto op_attr = ideep::attr_t();
|
||||||
|
std::vector<int32_t> wgt_zero_points;
|
||||||
|
ideep::scale_t wgt_scales;
|
||||||
|
const int output_channels = transpose ? weight.size(1) * groups
|
||||||
|
: weight.size(0);
|
||||||
|
const auto qtype = weight.qscheme();
|
||||||
|
if (qtype == c10::kPerTensorAffine) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
weight.q_zero_point()==0,
|
||||||
|
"quantized::qconv_prepack: ONEDNN only supports symmetric quantization of weight,"
|
||||||
|
" whose zero point must be 0.");
|
||||||
|
wgt_zero_points = std::vector<int32_t>(1, weight.q_zero_point());
|
||||||
|
wgt_scales = ideep::scale_t(1, 1.0/weight.q_scale()); // Scales of ONEDNN and PyTorch are reciprocal
|
||||||
|
} else if (qtype == c10::kPerChannelAffine) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
!transpose,
|
||||||
|
"Per Channel Quantization is currently disabled for transposed conv");
|
||||||
|
wgt_zero_points.resize(output_channels);
|
||||||
|
wgt_scales.resize(output_channels);
|
||||||
|
for (int i = 0; i < output_channels; ++i) {
|
||||||
|
wgt_zero_points[i] = weight.q_per_channel_zero_points()[i].item<int32_t>();
|
||||||
|
TORCH_CHECK(
|
||||||
|
wgt_zero_points[i]==0,
|
||||||
|
"quantized::qconv_prepack: ONEDNN only supports symmetric quantization of weight,"
|
||||||
|
" whose zero point must be 0.");
|
||||||
|
wgt_scales[i] = 1.0f / weight.q_per_channel_scales()[i].item<float>(); // Scales of ONEDNN and PyTorch are reciprocal
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported qscheme: ", toString(qtype));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set runtime src zero point
|
||||||
|
auto src_zero_point = {DNNL_RUNTIME_S32_VAL};
|
||||||
|
op_attr.set_zero_points(DNNL_ARG_SRC,
|
||||||
|
ideep::utils::tensor_zp_mask(src_zero_point.size()),
|
||||||
|
src_zero_point);
|
||||||
|
at::Tensor weight_copy;
|
||||||
|
ideep::tensor::desc w_desc;
|
||||||
|
ideep::dims dims_iohw, dims_giohw;
|
||||||
|
ideep::tag w_tag = ideep::tag::any;
|
||||||
|
const bool with_groups = groups > 1;
|
||||||
|
if (transpose) {
|
||||||
|
w_desc = ideep::convolution_transpose_forward::expected_weights_desc(
|
||||||
|
dims, dnnl::memory::data_type::s8,
|
||||||
|
strides, padding_l, padding_r, dilates, groups,
|
||||||
|
dnnl::algorithm::deconvolution_direct, dnnl::prop_kind::forward_inference,
|
||||||
|
ideep::dims(), op_attr);
|
||||||
|
// convolution_transpose_forward::expected_weights_desc() gives format [i, o, ...],
|
||||||
|
// but ONEDNN requires [o, i, ...] for computation
|
||||||
|
dims_iohw = w_desc.get_dims();
|
||||||
|
dims_giohw = with_groups ? ideep::utils::group_dims(dims_iohw, groups) : dims_iohw;
|
||||||
|
std::vector<int64_t> perms(dims_giohw.size(), 0); // for permutation of weight
|
||||||
|
std::iota(perms.begin(), perms.end(), 0);
|
||||||
|
w_desc = w_desc.transpose(with_groups, with_groups + 1);
|
||||||
|
std::swap(perms[with_groups], perms[with_groups + 1]);
|
||||||
|
weight_copy = weight.reshape(dims_giohw).permute(c10::IntArrayRef(perms)).clone();
|
||||||
|
} else {
|
||||||
|
w_desc = ideep::convolution_forward::expected_weights_desc(
|
||||||
|
dims, dnnl::memory::data_type::s8,
|
||||||
|
strides, padding_l, padding_r, dilates, groups,
|
||||||
|
dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference,
|
||||||
|
dnnl::memory::data_type::u8, ideep::dims(), op_attr);
|
||||||
|
weight_copy = weight.clone();
|
||||||
|
}
|
||||||
|
if (with_groups) {
|
||||||
|
w_tag = kSpatialDim == 2 ? ideep::tag::goihw : ideep::tag::goidhw;
|
||||||
|
} else {
|
||||||
|
w_tag = kSpatialDim == 2 ? ideep::tag::oihw : ideep::tag::oidhw;
|
||||||
|
}
|
||||||
|
ideep::dims w_dims = with_groups ? ideep::utils::group_dims(w_desc.get_dims(), groups)
|
||||||
|
: w_desc.get_dims();
|
||||||
|
ideep::tensor wgt = ideep::tensor(
|
||||||
|
ideep::tensor::desc({w_dims, dnnl::memory::data_type::s8, w_tag}, groups),
|
||||||
|
weight_copy.data_ptr());
|
||||||
|
wgt.set_scale(wgt_scales); // Scales are needed for feed_from().
|
||||||
|
ideep::tensor exp_wgt;
|
||||||
|
exp_wgt.init(w_desc);
|
||||||
|
exp_wgt.set_scale(wgt_scales); // Also for feed_from()
|
||||||
|
exp_wgt.feed_from(wgt, transpose); // expect wgt to be in [OC IC KH KW] format
|
||||||
|
ideep::tensor * packed_weight_p = new ideep::tensor(exp_wgt);
|
||||||
|
packed_weight_p->set_scale(wgt_scales);
|
||||||
|
packed_weight_p->set_zero_point(wgt_zero_points);
|
||||||
|
std::unique_ptr<ideep::tensor> weight_ptr(packed_weight_p);
|
||||||
|
// Bias
|
||||||
|
c10::optional<ideep::tensor> onednn_bias{c10::nullopt};
|
||||||
|
if (bias.has_value()) {
|
||||||
|
at::Tensor bias_vec = bias.value();
|
||||||
|
TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)");
|
||||||
|
TORCH_CHECK(
|
||||||
|
bias_vec.size(0) == output_channels,
|
||||||
|
"bias should have K elements: " + std::to_string(output_channels));
|
||||||
|
auto bias_desc = ideep::tensor::desc(bias.value().sizes().vec(), dnnl::memory::data_type::f32);
|
||||||
|
ideep::tensor packed_bias;
|
||||||
|
packed_bias.init(bias_desc, bias.value().data_ptr());
|
||||||
|
onednn_bias = c10::optional<ideep::tensor>(packed_bias);
|
||||||
|
}
|
||||||
|
auto ret_ptr = c10::make_intrusive<PackedConvWeightsOnednn<kSpatialDim>>(
|
||||||
|
PackedConvWeightsOnednn<kSpatialDim>{
|
||||||
|
std::move(weight_ptr),
|
||||||
|
onednn_bias,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
output_padding,
|
||||||
|
dilation,
|
||||||
|
groups,
|
||||||
|
transpose
|
||||||
|
});
|
||||||
|
return ret_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
template struct PackedConvWeightsOnednn<2>;
|
||||||
|
template struct PackedConvWeightsOnednn<3>;
|
||||||
|
#endif // #if AT_MKLDNN_ENABLED()
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace native {
|
namespace native {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
@ -377,6 +537,14 @@ class QConvPackWeightInt8 final {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
if (ctx.qEngine() == at::QEngine::ONEDNN) {
|
||||||
|
return PackedConvWeightsOnednn<kSpatialDim>::prepack(
|
||||||
|
weight, bias, stride, padding, output_padding, dilation, groups,
|
||||||
|
transpose);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
"Didn't find engine for operation quantized::conv2d_prepack ",
|
"Didn't find engine for operation quantized::conv2d_prepack ",
|
||||||
|
|
@ -438,8 +606,6 @@ class QConv1dPackWeightInt8 final {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#ifdef USE_PYTORCH_QNNPACK
|
#ifdef USE_PYTORCH_QNNPACK
|
||||||
if (ctx.qEngine() == at::QEngine::QNNPACK) {
|
if (ctx.qEngine() == at::QEngine::QNNPACK) {
|
||||||
return PackedConvWeightsQnnp<2>::prepack(
|
return PackedConvWeightsQnnp<2>::prepack(
|
||||||
|
|
@ -447,6 +613,15 @@ class QConv1dPackWeightInt8 final {
|
||||||
transpose);
|
transpose);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
if (ctx.qEngine() == at::QEngine::ONEDNN) {
|
||||||
|
return PackedConvWeightsOnednn<2>::prepack(
|
||||||
|
weight, bias, stride, padding, output_padding, dilation, groups,
|
||||||
|
transpose);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
"Didn't find engine for operation quantized::conv1d_prepack ",
|
"Didn't find engine for operation quantized::conv1d_prepack ",
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||||
|
#include <ATen/native/quantized/cpu/onednn_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/quant_utils.h>
|
#include <ATen/native/quantized/cpu/quant_utils.h>
|
||||||
#include <ATen/native/quantized/packed_params.h>
|
#include <ATen/native/quantized/packed_params.h>
|
||||||
|
|
||||||
|
|
@ -120,6 +121,20 @@ template std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightsQnnp
|
||||||
3>::unpack();
|
3>::unpack();
|
||||||
#endif // USE_PYTORCH_QNNPACK
|
#endif // USE_PYTORCH_QNNPACK
|
||||||
|
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
template <int kSpatialDim>
|
||||||
|
std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightsOnednn<
|
||||||
|
kSpatialDim>::unpack() {
|
||||||
|
return std::tuple<at::Tensor, c10::optional<at::Tensor>>(
|
||||||
|
orig_weight_, orig_bias_);
|
||||||
|
}
|
||||||
|
|
||||||
|
template std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightsOnednn<
|
||||||
|
2>::unpack();
|
||||||
|
template std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightsOnednn<
|
||||||
|
3>::unpack();
|
||||||
|
#endif // #if AT_MKLDNN_ENABLED()
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace native {
|
namespace native {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
@ -154,6 +169,12 @@ class QConvUnpackWeightsInt8 final {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
if (ctx.qEngine() == at::QEngine::ONEDNN) {
|
||||||
|
return packed_weight->unpack();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
"Didn't find engine for operation quantized::conv2d_unpack ",
|
"Didn't find engine for operation quantized::conv2d_unpack ",
|
||||||
|
|
@ -185,6 +206,15 @@ class QConv1dUnpackWeightsInt8 final {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
if (ctx.qEngine() == at::QEngine::ONEDNN) {
|
||||||
|
std::tie(weight, bias) = packed_weight->unpack();
|
||||||
|
at::Tensor new_weight = weight.clone();
|
||||||
|
new_weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2);
|
||||||
|
return std::tuple<at::Tensor, c10::optional<at::Tensor>>(new_weight, bias);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
"Didn't find engine for operation quantized::conv1d_unpack ",
|
"Didn't find engine for operation quantized::conv1d_unpack ",
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
#include <ATen/native/quantized/packed_params.h>
|
#include <ATen/native/quantized/packed_params.h>
|
||||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/xnnpack_utils.h>
|
#include <ATen/native/quantized/cpu/xnnpack_utils.h>
|
||||||
|
#include <ATen/native/quantized/cpu/onednn_utils.h>
|
||||||
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
|
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
|
||||||
#include <torch/custom_class.h>
|
#include <torch/custom_class.h>
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
|
|
@ -617,6 +618,81 @@ at::Tensor PackedLinearWeightsQnnp::apply_relu(
|
||||||
|
|
||||||
#endif // USE_PYTORCH_QNNPACK
|
#endif // USE_PYTORCH_QNNPACK
|
||||||
|
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
template <bool ReluFused>
|
||||||
|
at::Tensor PackedLinearWeightsOnednn::apply_impl(
|
||||||
|
at::Tensor input,
|
||||||
|
double output_scale,
|
||||||
|
int64_t output_zero_point) {
|
||||||
|
const int64_t dim = input.dim();
|
||||||
|
TORCH_CHECK(
|
||||||
|
dim != 0,
|
||||||
|
"qlinear (ONEDNN): input dim should be at least 1, but got 0");
|
||||||
|
TORCH_CHECK(input.scalar_type() == c10::ScalarType::QUInt8,
|
||||||
|
"qlinear (ONEDNN): data type of input should be QUint8.");
|
||||||
|
|
||||||
|
auto input_contig = input.expect_contiguous();
|
||||||
|
auto& w = *(weight_.get());
|
||||||
|
auto K = input.size(dim - 1), M = input.numel() / K, N = w.get_dim(1);
|
||||||
|
auto input_dims = {M, K};
|
||||||
|
auto input_data_type = dnnl::memory::data_type::u8;
|
||||||
|
auto input_desc = ideep::tensor::desc(input_dims, input_data_type);
|
||||||
|
ideep::attr_t op_attr = ReluFused ? ideep::attr_t::fuse_relu() : ideep::attr_t();
|
||||||
|
ideep::tensor x(input_desc, input_contig->data_ptr<c10::quint8>());
|
||||||
|
auto dst_dims = {M, N};
|
||||||
|
const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/input.q_scale());
|
||||||
|
const ideep::scale_t& weights_scales = w.get_scale();
|
||||||
|
const ideep::scale_t& dst_scales = ideep::scale_t(1, 1.0/output_scale); // Scales of ONEDNN and PyTorch are reciprocal
|
||||||
|
const ideep::zero_point_t& src_zero_point = ideep::zero_point_t(1, input.q_zero_point());
|
||||||
|
const ideep::zero_point_t& dst_zero_point = ideep::zero_point_t(1, output_zero_point);
|
||||||
|
// Compute: Use ideep::matmul_forward to support asymmetric quantization
|
||||||
|
// Allocate output Tensor
|
||||||
|
at::Tensor output = at::_empty_affine_quantized(
|
||||||
|
dst_dims,
|
||||||
|
at::device(c10::kCPU).dtype(c10::kQUInt8),
|
||||||
|
output_scale,
|
||||||
|
output_zero_point);
|
||||||
|
if (output.numel() == 0) {
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
ideep::tensor y({dst_dims, ideep::tensor::data_type::u8, {output.strides().cbegin(), output.strides().cend()}},
|
||||||
|
output.data_ptr());
|
||||||
|
if (bias_.has_value()) {
|
||||||
|
// Bias might be modified outside (e.g. by quantization bias correction).
|
||||||
|
// If so, update the prepacked bias as well.
|
||||||
|
if (bias_.value().get_data_handle() != orig_bias_.value().data_ptr()) {
|
||||||
|
bias_.value().init(bias_.value().get_desc(), orig_bias_.value().data_ptr());
|
||||||
|
}
|
||||||
|
const auto& b = bias_.value();
|
||||||
|
ideep::matmul_forward::compute_v2(x, w, b, y, 1.0f, 1.0f, src_scales, weights_scales, dst_scales,
|
||||||
|
src_zero_point, dst_zero_point, op_attr);
|
||||||
|
} else {
|
||||||
|
ideep::matmul_forward::compute_v2(x, w, y, 1.0f, 1.0f, src_scales, weights_scales, dst_scales,
|
||||||
|
src_zero_point, dst_zero_point, op_attr);
|
||||||
|
}
|
||||||
|
auto out_sizes = input.sizes().vec();
|
||||||
|
out_sizes.back() = N;
|
||||||
|
if (output.sizes().vec() == out_sizes)
|
||||||
|
return output;
|
||||||
|
return output.reshape(out_sizes);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor PackedLinearWeightsOnednn::apply(
|
||||||
|
at::Tensor input,
|
||||||
|
double output_scale,
|
||||||
|
int64_t output_zero_point) {
|
||||||
|
return apply_impl<false>(std::move(input), output_scale, output_zero_point);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor PackedLinearWeightsOnednn::apply_relu(
|
||||||
|
at::Tensor input,
|
||||||
|
double output_scale,
|
||||||
|
int64_t output_zero_point) {
|
||||||
|
return apply_impl<true>(std::move(input), output_scale, output_zero_point);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // #if AT_MKLDNN_ENABLED()
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace native {
|
namespace native {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||||
#include <ATen/native/quantized/packed_params.h>
|
#include <ATen/native/quantized/packed_params.h>
|
||||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||||
|
#include <ATen/native/quantized/cpu/onednn_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/quant_utils.h>
|
#include <ATen/native/quantized/cpu/quant_utils.h>
|
||||||
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
|
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
|
|
@ -463,6 +464,99 @@ void PackedLinearWeightFp16::set_bias(c10::optional<at::Tensor> bias) {
|
||||||
|
|
||||||
#endif // USE_FBGEMM
|
#endif // USE_FBGEMM
|
||||||
|
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
template <bool ReluFused>
|
||||||
|
at::Tensor PackedLinearWeightsOnednn::apply_dynamic_impl(
|
||||||
|
at::Tensor input,
|
||||||
|
bool reduce_range) {
|
||||||
|
// Dynamic: fp32 * int8 -> fp32
|
||||||
|
using at::Tensor;
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
input.dim() >= 2,
|
||||||
|
"The dimension of input tensor should be larger than or equal to 2");
|
||||||
|
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Float,
|
||||||
|
"qlinear_dynamic (ONEDNN): data type of input should be float.");
|
||||||
|
|
||||||
|
// Input -> uint8
|
||||||
|
auto input_contig = input.contiguous();
|
||||||
|
const int64_t dim = input.dim();
|
||||||
|
auto input_reshaped =
|
||||||
|
dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)});
|
||||||
|
auto input_dims = input_reshaped.sizes().vec();
|
||||||
|
auto input_data_type = dnnl::memory::data_type::f32;
|
||||||
|
auto input_desc = ideep::tensor::desc(input_dims, input_data_type);
|
||||||
|
ideep::attr_t op_attr = ReluFused ? ideep::attr_t::fuse_relu() : ideep::attr_t();
|
||||||
|
ideep::tensor x;
|
||||||
|
x.init(input_desc, input_contig.data_ptr());
|
||||||
|
// Find quantization parameters
|
||||||
|
float x_max = 0, x_min = 0;
|
||||||
|
if (input.numel() > 0) {
|
||||||
|
x_min = input_contig.min().item<float>();
|
||||||
|
x_max = input_contig.max().item<float>();
|
||||||
|
}
|
||||||
|
const int precision = 8;
|
||||||
|
auto q_params = quant_utils::ChooseQuantizationParams(
|
||||||
|
/*min=*/x_min,
|
||||||
|
/*max=*/x_max,
|
||||||
|
/*qmin=*/0,
|
||||||
|
/*qmax=*/(1 << precision) - 1,
|
||||||
|
/*preserve_sparsity=*/false,
|
||||||
|
/*force_scale_power_of_two=*/false,
|
||||||
|
/*reduce_range=*/reduce_range);
|
||||||
|
const std::vector<int32_t>& src_zero_point = std::vector<int32_t>(1, q_params.zero_point);
|
||||||
|
// weights, dst
|
||||||
|
auto w = *(weight_.get());
|
||||||
|
auto dst_dims = {x.get_dim(0), w.get_dim(1)};
|
||||||
|
const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/q_params.scale);
|
||||||
|
const ideep::scale_t& weights_scales = w.get_scale();
|
||||||
|
// Compute -> f32
|
||||||
|
// Use ideep::matmul_forward instead of ideep::inner_product_forward,
|
||||||
|
// since the latter does not support asymmetric quantization
|
||||||
|
// Allocate output Tensor
|
||||||
|
at::Tensor output = at::empty(dst_dims, input.options().dtype(at::kFloat));
|
||||||
|
if (output.numel() == 0) return output;
|
||||||
|
ideep::tensor y({dst_dims, ideep::tensor::data_type::f32,
|
||||||
|
{output.strides().cbegin(), output.strides().cend()}},
|
||||||
|
output.data_ptr());
|
||||||
|
if (bias_.has_value()) {
|
||||||
|
// Bias might be modified outside (e.g. by quantization bias correction).
|
||||||
|
// If so, update the prepacked bias as well.
|
||||||
|
if (bias_.value().get_data_handle() != orig_bias_.value().data_ptr()) {
|
||||||
|
bias_.value().init(bias_.value().get_desc(), orig_bias_.value().data_ptr());
|
||||||
|
}
|
||||||
|
const ideep::tensor b = bias_.value();
|
||||||
|
ideep::matmul_forward::compute_v2(x, w, b, y, 1.0f, 1.0f,
|
||||||
|
src_scales, weights_scales, ideep::scale_t(),
|
||||||
|
src_zero_point, ideep::zero_point_t(), op_attr);
|
||||||
|
} else {
|
||||||
|
ideep::matmul_forward::compute_v2(x, w, y, 1.0f, 1.0f,
|
||||||
|
src_scales, weights_scales, ideep::scale_t(),
|
||||||
|
src_zero_point, ideep::zero_point_t(), op_attr);
|
||||||
|
}
|
||||||
|
auto out_sizes = input.sizes().vec();
|
||||||
|
out_sizes.back() = w.get_dim(1);
|
||||||
|
if (output.sizes().vec() == out_sizes)
|
||||||
|
return output;
|
||||||
|
return output.reshape(out_sizes);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor PackedLinearWeightsOnednn::apply_dynamic(
|
||||||
|
at::Tensor input,
|
||||||
|
bool reduce_range) {
|
||||||
|
return apply_dynamic_impl</*ReluFused=*/false>(
|
||||||
|
std::move(input), reduce_range);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor PackedLinearWeightsOnednn::apply_dynamic_relu(
|
||||||
|
at::Tensor input,
|
||||||
|
bool reduce_range) {
|
||||||
|
return apply_dynamic_impl</*ReluFused=*/true>(
|
||||||
|
std::move(input), reduce_range);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // #if AT_MKLDNN_ENABLED()
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace native {
|
namespace native {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
||||||
#include <ATen/native/quantized/packed_params.h>
|
#include <ATen/native/quantized/packed_params.h>
|
||||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||||
|
#include <ATen/native/quantized/cpu/onednn_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/quant_utils.h>
|
#include <ATen/native/quantized/cpu/quant_utils.h>
|
||||||
#include <ATen/quantized/Quantizer.h>
|
#include <ATen/quantized/Quantizer.h>
|
||||||
#include <torch/custom_class.h>
|
#include <torch/custom_class.h>
|
||||||
|
|
@ -194,6 +195,80 @@ c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightFp16::prepack(
|
||||||
}
|
}
|
||||||
#endif // USE_FBGEMM
|
#endif // USE_FBGEMM
|
||||||
|
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightsOnednn::prepack(
|
||||||
|
at::Tensor weight,
|
||||||
|
c10::optional<at::Tensor> bias) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
weight.dim() == 2,
|
||||||
|
"The weight tensor for quantized::linear_prepack (onednn) should"
|
||||||
|
" be 2-dimensional.");
|
||||||
|
// Weight
|
||||||
|
std::vector<int64_t> dims = weight.sizes().vec();
|
||||||
|
auto N = weight.size(0);
|
||||||
|
std::vector<int32_t> wgt_zero_points;
|
||||||
|
ideep::scale_t wgt_scales;
|
||||||
|
const auto qtype = weight.qscheme();
|
||||||
|
if (qtype == c10::kPerTensorAffine) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
weight.q_zero_point() == 0,
|
||||||
|
"quantized::linear_prepack: ONEDNN only supports symmetric quantization of weight,"
|
||||||
|
" whose zero point must be 0, but got ", weight.q_zero_point());
|
||||||
|
wgt_zero_points = std::vector<int32_t>(1, weight.q_zero_point());
|
||||||
|
wgt_scales = ideep::scale_t(1, 1.0/weight.q_scale()); // Scales of ONEDNN and PyTorch are reciprocal
|
||||||
|
} else if (qtype == c10::kPerChannelAffine) {
|
||||||
|
wgt_zero_points.resize(N);
|
||||||
|
wgt_scales.resize(N);
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
wgt_zero_points[i] = weight.q_per_channel_zero_points()[i].item<int32_t>();
|
||||||
|
TORCH_CHECK(
|
||||||
|
wgt_zero_points[i] == 0,
|
||||||
|
"quantized::linear_prepack: ONEDNN only supports symmetric quantization of weight,"
|
||||||
|
" whose zero point must be 0, but got ", wgt_zero_points[i], ", at index ", i);
|
||||||
|
wgt_scales[i] = 1.0f / weight.q_per_channel_scales()[i].item<float>(); // Scales of ONEDNN and PyTorch are reciprocal
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported qscheme: ", toString(qtype));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepack weight
|
||||||
|
auto weight_copy = weight.clone();
|
||||||
|
ideep::tensor wgt = ideep::tensor({dims, dnnl::memory::data_type::s8}, weight_copy.data_ptr());
|
||||||
|
wgt.transpose_(0, 1); // ONEDNN requires transposed weight
|
||||||
|
auto w_desc = ideep::matmul_forward::expected_weights_desc(wgt.get_dims(), dnnl::memory::data_type::s8,
|
||||||
|
dnnl::memory::data_type::u8);
|
||||||
|
ideep::tensor exp_wgt(w_desc);
|
||||||
|
exp_wgt.feed_from(wgt);
|
||||||
|
ideep::tensor * packed_weight_p = new ideep::tensor(exp_wgt);
|
||||||
|
packed_weight_p->set_scale(wgt_scales);
|
||||||
|
packed_weight_p->set_zero_point(wgt_zero_points);
|
||||||
|
std::unique_ptr<ideep::tensor> weight_ptr(packed_weight_p);
|
||||||
|
// Bias
|
||||||
|
c10::optional<ideep::tensor> onednn_bias{c10::nullopt};
|
||||||
|
if (bias.has_value()) {
|
||||||
|
auto& b = bias.value();
|
||||||
|
auto bias_size = b.sizes().vec();
|
||||||
|
bias_size.insert(bias_size.begin(), 1);
|
||||||
|
TORCH_CHECK(
|
||||||
|
bias_size[1] == weight_ptr->get_dim(1),
|
||||||
|
"bias should have N elements: ",
|
||||||
|
std::to_string(weight_ptr->get_dim(1)),
|
||||||
|
", but got ", bias_size[1]);
|
||||||
|
auto bias_desc = ideep::tensor::desc(bias_size, dnnl::memory::data_type::f32);
|
||||||
|
ideep::tensor packed_bias;
|
||||||
|
packed_bias.init(bias_desc, b.data_ptr());
|
||||||
|
onednn_bias = c10::optional<ideep::tensor>(packed_bias);
|
||||||
|
}
|
||||||
|
auto ret_ptr = c10::make_intrusive<PackedLinearWeightsOnednn>(
|
||||||
|
PackedLinearWeightsOnednn{
|
||||||
|
std::move(weight_ptr),
|
||||||
|
onednn_bias,
|
||||||
|
weight,
|
||||||
|
bias});
|
||||||
|
return ret_ptr;
|
||||||
|
}
|
||||||
|
#endif // #if AT_MKLDNN_ENABLED()
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace native {
|
namespace native {
|
||||||
|
|
||||||
|
|
@ -224,6 +299,11 @@ class QLinearPackWeightInt8 final {
|
||||||
std::move(weight), std::move(bias));
|
std::move(weight), std::move(bias));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
if (ctx.qEngine() == at::QEngine::ONEDNN) {
|
||||||
|
return PackedLinearWeightsOnednn::prepack(std::move(weight), std::move(bias));
|
||||||
|
}
|
||||||
|
#endif // #if AT_MKLDNN_ENABLED()
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
"Didn't find engine for operation quantized::linear_prepack ",
|
"Didn't find engine for operation quantized::linear_prepack ",
|
||||||
|
|
@ -254,6 +334,14 @@ class QLinearPackWeightFp16 final {
|
||||||
"not supported by QNNPACK");
|
"not supported by QNNPACK");
|
||||||
}
|
}
|
||||||
#endif // USE_PYTORCH_QNNPACK
|
#endif // USE_PYTORCH_QNNPACK
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
if (ctx.qEngine() == at::QEngine::ONEDNN) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
"quantized::linear_prepack_fp16 is currently "
|
||||||
|
"not supported by ONEDNN");
|
||||||
|
}
|
||||||
|
#endif // #if AT_MKLDNN_ENABLED()
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
"Didn't find engine for operation quantized::linear_prepack_fp16 ",
|
"Didn't find engine for operation quantized::linear_prepack_fp16 ",
|
||||||
|
|
@ -287,6 +375,16 @@ class QLinearPackWeightInt8Legacy final {
|
||||||
return cpp_custom_type_hack::create(std::move(wrapped), options);
|
return cpp_custom_type_hack::create(std::move(wrapped), options);
|
||||||
}
|
}
|
||||||
#endif // USE_PYTORCH_QNNPACK
|
#endif // USE_PYTORCH_QNNPACK
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
if (ctx.qEngine() == at::QEngine::ONEDNN) {
|
||||||
|
auto prepacked =
|
||||||
|
PackedLinearWeightsOnednn::prepack(std::move(weight), std::move(bias));
|
||||||
|
auto wrapped =
|
||||||
|
std::make_unique<c10::intrusive_ptr<LinearPackedParamsBase>>(
|
||||||
|
std::move(prepacked));
|
||||||
|
return cpp_custom_type_hack::create(std::move(wrapped), options);
|
||||||
|
}
|
||||||
|
#endif // #if AT_MKLDNN_ENABLED()
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
"Didn't find engine for operation quantized::linear_prepack ",
|
"Didn't find engine for operation quantized::linear_prepack ",
|
||||||
|
|
@ -317,6 +415,14 @@ class QLinearPackWeightFp16Legacy final {
|
||||||
"not supported by QNNPACK");
|
"not supported by QNNPACK");
|
||||||
}
|
}
|
||||||
#endif // USE_PYTORCH_QNNPACK
|
#endif // USE_PYTORCH_QNNPACK
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
if (ctx.qEngine() == at::QEngine::ONEDNN) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
"quantized::linear_prepack_fp16 is currently "
|
||||||
|
"not supported by ONEDNN");
|
||||||
|
}
|
||||||
|
#endif // #if AT_MKLDNN_ENABLED()
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
"Didn't find engine for operation quantized::linear_prepack_fp16 ",
|
"Didn't find engine for operation quantized::linear_prepack_fp16 ",
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||||
#include <ATen/native/quantized/packed_params.h>
|
#include <ATen/native/quantized/packed_params.h>
|
||||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||||
|
#include <ATen/native/quantized/cpu/onednn_utils.h>
|
||||||
#include <torch/custom_class.h>
|
#include <torch/custom_class.h>
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
|
|
||||||
|
|
@ -74,6 +75,13 @@ std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedLinearWeightFp16::
|
||||||
}
|
}
|
||||||
#endif // USE_FBGEMM
|
#endif // USE_FBGEMM
|
||||||
|
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedLinearWeightsOnednn::unpack() {
|
||||||
|
return std::tuple<at::Tensor, c10::optional<at::Tensor>>(
|
||||||
|
orig_weight_, orig_bias_);
|
||||||
|
}
|
||||||
|
#endif // #if AT_MKLDNN_ENABLED()
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace native {
|
namespace native {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,13 @@ enum class QEngine : uint8_t {
|
||||||
NoQEngine = 0,
|
NoQEngine = 0,
|
||||||
FBGEMM = 1,
|
FBGEMM = 1,
|
||||||
QNNPACK = 2,
|
QNNPACK = 2,
|
||||||
|
ONEDNN = 3,
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr auto kNoQEngine = QEngine::NoQEngine;
|
constexpr auto kNoQEngine = QEngine::NoQEngine;
|
||||||
constexpr auto kFBGEMM = QEngine::FBGEMM;
|
constexpr auto kFBGEMM = QEngine::FBGEMM;
|
||||||
constexpr auto kQNNPACK = QEngine::QNNPACK;
|
constexpr auto kQNNPACK = QEngine::QNNPACK;
|
||||||
|
constexpr auto kONEDNN = QEngine::ONEDNN;
|
||||||
|
|
||||||
inline std::string toString(QEngine qengine) {
|
inline std::string toString(QEngine qengine) {
|
||||||
switch (qengine) {
|
switch (qengine) {
|
||||||
|
|
@ -29,6 +31,8 @@ inline std::string toString(QEngine qengine) {
|
||||||
return "FBGEMM";
|
return "FBGEMM";
|
||||||
case kQNNPACK:
|
case kQNNPACK:
|
||||||
return "QNNPACK";
|
return "QNNPACK";
|
||||||
|
case kONEDNN:
|
||||||
|
return "ONEDNN";
|
||||||
default:
|
default:
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false, "Unrecognized Quantized Engine: ", static_cast<int>(qengine));
|
false, "Unrecognized Quantized Engine: ", static_cast<int>(qengine));
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from torch.testing._internal.common_quantized import (
|
||||||
override_qengines,
|
override_qengines,
|
||||||
qengine_is_qnnpack,
|
qengine_is_qnnpack,
|
||||||
qengine_is_fbgemm,
|
qengine_is_fbgemm,
|
||||||
|
qengine_is_onednn,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Once more test files are created, move the contents to a ao folder.
|
# TODO: Once more test files are created, move the contents to a ao folder.
|
||||||
|
|
@ -48,6 +49,9 @@ class TestQuantizedSparseKernels(TestCase):
|
||||||
# to other higher priority works.
|
# to other higher priority works.
|
||||||
if qengine_is_qnnpack() and not (row_block_size == 1 and col_block_size == 4):
|
if qengine_is_qnnpack() and not (row_block_size == 1 and col_block_size == 4):
|
||||||
return
|
return
|
||||||
|
# ONEDNN does not support this yet
|
||||||
|
if qengine_is_onednn():
|
||||||
|
return
|
||||||
|
|
||||||
dense_prepack = torch.ops.quantized.linear_prepack
|
dense_prepack = torch.ops.quantized.linear_prepack
|
||||||
dense_qlinear = torch.ops.quantized.linear
|
dense_qlinear = torch.ops.quantized.linear
|
||||||
|
|
@ -215,6 +219,10 @@ class TestQuantizedSparseLayers(TestCase):
|
||||||
Y_hat = sqmodel(X_fp32)
|
Y_hat = sqmodel(X_fp32)
|
||||||
self.assertEqual(Y_ref, Y_hat)
|
self.assertEqual(Y_ref, Y_hat)
|
||||||
|
|
||||||
|
# ONEDNN does not support this yet
|
||||||
|
elif qengine_is_onednn():
|
||||||
|
return
|
||||||
|
|
||||||
row_block_size, col_block_size = sqmodel.linear._packed_params._weight_bias()[2:]
|
row_block_size, col_block_size = sqmodel.linear._packed_params._weight_bias()[2:]
|
||||||
assert row_block_size == 1 and col_block_size == 4
|
assert row_block_size == 1 and col_block_size == 4
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ from torch.testing._internal.common_quantized import (
|
||||||
override_quantized_engine,
|
override_quantized_engine,
|
||||||
override_qengines,
|
override_qengines,
|
||||||
qengine_is_qnnpack,
|
qengine_is_qnnpack,
|
||||||
|
qengine_is_onednn,
|
||||||
)
|
)
|
||||||
from hypothesis import assume, given
|
from hypothesis import assume, given
|
||||||
from hypothesis import strategies as st
|
from hypothesis import strategies as st
|
||||||
|
|
@ -99,7 +100,9 @@ class TestStaticQuantizedModule(QuantizationTestCase):
|
||||||
zero_points=zero_point_tensor,
|
zero_points=zero_point_tensor,
|
||||||
axis=0, dtype=torch.qint8)
|
axis=0, dtype=torch.qint8)
|
||||||
else:
|
else:
|
||||||
W_q = torch.quantize_per_tensor(W, 0.1, 4, torch.qint8)
|
# ONEDNN only supports symmetric quantization of weight
|
||||||
|
W_zp = 0 if qengine_is_onednn() else 4
|
||||||
|
W_q = torch.quantize_per_tensor(W, 0.1, W_zp, torch.qint8)
|
||||||
|
|
||||||
X = torch.rand(batch_size, in_features).float()
|
X = torch.rand(batch_size, in_features).float()
|
||||||
X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8)
|
X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8)
|
||||||
|
|
@ -434,7 +437,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
|
||||||
X_scale = 1.3
|
X_scale = 1.3
|
||||||
X_zero_point = 2
|
X_zero_point = 2
|
||||||
W_scale = [0.5]
|
W_scale = [0.5]
|
||||||
W_zero_point = [3]
|
W_zero_point = [0] if qengine_is_onednn() else [3]
|
||||||
Y_scale = 5.0
|
Y_scale = 5.0
|
||||||
Y_zero_point = 4
|
Y_zero_point = 4
|
||||||
if torch.backends.quantized.engine == 'qnnpack':
|
if torch.backends.quantized.engine == 'qnnpack':
|
||||||
|
|
@ -501,7 +504,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
|
||||||
X_scale = 1.3
|
X_scale = 1.3
|
||||||
X_zero_point = 2
|
X_zero_point = 2
|
||||||
W_scale = [0.5]
|
W_scale = [0.5]
|
||||||
W_zero_point = [3]
|
W_zero_point = [0] if qengine_is_onednn() else [3]
|
||||||
Y_scale = 5.0
|
Y_scale = 5.0
|
||||||
Y_zero_point = 4
|
Y_zero_point = 4
|
||||||
# use_fused -> quantized class
|
# use_fused -> quantized class
|
||||||
|
|
@ -570,7 +573,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
|
||||||
X_scale = 1.3
|
X_scale = 1.3
|
||||||
X_zero_point = 2
|
X_zero_point = 2
|
||||||
W_scale = [0.5]
|
W_scale = [0.5]
|
||||||
W_zero_point = [3]
|
W_zero_point = [0] if qengine_is_onednn() else [3]
|
||||||
Y_scale = 5.0
|
Y_scale = 5.0
|
||||||
Y_zero_point = 4
|
Y_zero_point = 4
|
||||||
# use_fused -> quantized class
|
# use_fused -> quantized class
|
||||||
|
|
@ -1200,7 +1203,8 @@ class TestDynamicQuantizedModule(QuantizationTestCase):
|
||||||
def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_default_observer):
|
def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_default_observer):
|
||||||
"""test API functionality for nn.quantized.dynamic.Linear"""
|
"""test API functionality for nn.quantized.dynamic.Linear"""
|
||||||
W = torch.rand(out_features, in_features).float()
|
W = torch.rand(out_features, in_features).float()
|
||||||
W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8)
|
qscheme = torch.per_tensor_symmetric if qengine_is_onednn() else torch.per_tensor_affine
|
||||||
|
W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8, qscheme=qscheme)
|
||||||
W_q = torch.quantize_per_tensor(W, W_scale, W_zp, torch.qint8)
|
W_q = torch.quantize_per_tensor(W, W_scale, W_zp, torch.qint8)
|
||||||
X = torch.rand(batch_size, in_features).float()
|
X = torch.rand(batch_size, in_features).float()
|
||||||
B = torch.rand(out_features).float() if use_bias else None
|
B = torch.rand(out_features).float() if use_bias else None
|
||||||
|
|
@ -1311,8 +1315,8 @@ class TestDynamicQuantizedModule(QuantizationTestCase):
|
||||||
bias_keys.append(key_name1)
|
bias_keys.append(key_name1)
|
||||||
bias_keys.append(key_name2)
|
bias_keys.append(key_name2)
|
||||||
|
|
||||||
if not (dtype == torch.float16 and torch.backends.quantized.engine == "qnnpack"):
|
if not (dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn")):
|
||||||
# fp16 dynamic quant is not supported for qnnpack
|
# fp16 dynamic quant is not supported for qnnpack or onednn
|
||||||
x = torch.randn(seq_len, batch, input_size)
|
x = torch.randn(seq_len, batch, input_size)
|
||||||
h = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size)
|
h = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size)
|
||||||
c = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size)
|
c = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size)
|
||||||
|
|
@ -1362,8 +1366,8 @@ class TestDynamicQuantizedModule(QuantizationTestCase):
|
||||||
# instantiated for all engines and dtypes
|
# instantiated for all engines and dtypes
|
||||||
|
|
||||||
for dtype in [torch.qint8, torch.float16]:
|
for dtype in [torch.qint8, torch.float16]:
|
||||||
if dtype == torch.float16 and torch.backends.quantized.engine == "qnnpack":
|
if dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn"):
|
||||||
# fp16 dynamic quant is not supported for qnnpack
|
# fp16 dynamic quant is not supported for qnnpack or onednn
|
||||||
continue
|
continue
|
||||||
# Test default instantiation
|
# Test default instantiation
|
||||||
seq_len = 4
|
seq_len = 4
|
||||||
|
|
@ -1435,8 +1439,8 @@ class TestDynamicQuantizedModule(QuantizationTestCase):
|
||||||
'RNNReLU': torch.ops.quantized.quantized_rnn_relu_cell_dynamic}
|
'RNNReLU': torch.ops.quantized.quantized_rnn_relu_cell_dynamic}
|
||||||
|
|
||||||
for rnn_type in cell_dict.keys():
|
for rnn_type in cell_dict.keys():
|
||||||
if not (dtype == torch.float16 and torch.backends.quantized.engine == "qnnpack"):
|
if not (dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn")):
|
||||||
# fp16 dynamic quant is not supported for qnnpack
|
# fp16 dynamic quant is not supported for qnnpack or onednn
|
||||||
kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'bias': bias, 'dtype': dtype}
|
kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'bias': bias, 'dtype': dtype}
|
||||||
if rnn_type == 'RNNReLU':
|
if rnn_type == 'RNNReLU':
|
||||||
kwargs['nonlinearity'] = "relu"
|
kwargs['nonlinearity'] = "relu"
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,10 @@ from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MAC
|
||||||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK
|
from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK
|
||||||
from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \
|
from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \
|
||||||
override_quantized_engine, supported_qengines, override_qengines, _snr
|
override_quantized_engine, supported_qengines, override_qengines, _snr
|
||||||
from torch.testing._internal.common_quantized import qengine_is_qnnpack
|
from torch.testing._internal.common_quantized import (
|
||||||
|
qengine_is_qnnpack,
|
||||||
|
qengine_is_onednn,
|
||||||
|
)
|
||||||
from torch.ao.quantization import PerChannelMinMaxObserver
|
from torch.ao.quantization import PerChannelMinMaxObserver
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDNN
|
from torch.testing._internal.common_cuda import TEST_CUDNN
|
||||||
import torch.backends.xnnpack
|
import torch.backends.xnnpack
|
||||||
|
|
@ -2658,7 +2661,7 @@ class TestQuantizedOps(TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
q_data = []
|
q_data = []
|
||||||
reduce_range = (qengine == 'fbgemm')
|
reduce_range = (qengine in ('fbgemm', 'onednn'))
|
||||||
for idx, x in enumerate(fp_data):
|
for idx, x in enumerate(fp_data):
|
||||||
scale, zero_point = _calculate_dynamic_qparams(
|
scale, zero_point = _calculate_dynamic_qparams(
|
||||||
x, dtype=dtype, reduce_range=reduce_range)
|
x, dtype=dtype, reduce_range=reduce_range)
|
||||||
|
|
@ -2679,7 +2682,13 @@ class TestQuantizedOps(TestCase):
|
||||||
mha.eval()
|
mha.eval()
|
||||||
|
|
||||||
# Prepare
|
# Prepare
|
||||||
mha.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
|
if qengine_is_onednn():
|
||||||
|
# `reduce_range` is False by default for ONEDNN backend
|
||||||
|
# but the test fails on earlier CPUs without VNNI.
|
||||||
|
# So we use a default qconfig with `reduce_range=True` here
|
||||||
|
mha.qconfig = torch.ao.quantization.get_default_qconfig()
|
||||||
|
else:
|
||||||
|
mha.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
|
||||||
mha_prepared = torch.ao.quantization.prepare(
|
mha_prepared = torch.ao.quantization.prepare(
|
||||||
mha, prepare_custom_config_dict=custom_module_config)
|
mha, prepare_custom_config_dict=custom_module_config)
|
||||||
|
|
||||||
|
|
@ -2772,7 +2781,7 @@ class TestDynamicQuantizedOps(TestCase):
|
||||||
(b_value_max - b_value_min) + b_value_min
|
(b_value_max - b_value_min) + b_value_min
|
||||||
).astype(np.int32) if use_bias else None
|
).astype(np.int32) if use_bias else None
|
||||||
|
|
||||||
if torch.backends.quantized.engine == 'fbgemm':
|
if torch.backends.quantized.engine in ('fbgemm', 'onednn'):
|
||||||
avoid_vpmaddubsw_overflow_linear(
|
avoid_vpmaddubsw_overflow_linear(
|
||||||
batch_size,
|
batch_size,
|
||||||
input_channels,
|
input_channels,
|
||||||
|
|
@ -3009,8 +3018,8 @@ class TestDynamicQuantizedOps(TestCase):
|
||||||
|
|
||||||
for rnn_type in ['LSTM', 'GRU']:
|
for rnn_type in ['LSTM', 'GRU']:
|
||||||
for dtype in [torch.qint8, torch.float16]:
|
for dtype in [torch.qint8, torch.float16]:
|
||||||
# Fp16 quantization is not supported for qnnpack
|
# Fp16 quantization is not supported for qnnpack or onednn
|
||||||
if torch.backends.quantized.engine == 'qnnpack' and dtype == torch.float16:
|
if torch.backends.quantized.engine in ('qnnpack', 'onednn') and dtype == torch.float16:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if torch.backends.quantized.engine == 'qnnpack':
|
if torch.backends.quantized.engine == 'qnnpack':
|
||||||
|
|
@ -3143,8 +3152,8 @@ class TestDynamicQuantizedOps(TestCase):
|
||||||
|
|
||||||
for rnn_type in ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU']:
|
for rnn_type in ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU']:
|
||||||
for dtype in [torch.qint8, torch.float16]:
|
for dtype in [torch.qint8, torch.float16]:
|
||||||
# Fp16 quantization is not supported for qnnpack
|
# Fp16 quantization is not supported for qnnpack or onednn
|
||||||
if torch.backends.quantized.engine == 'qnnpack' and dtype == torch.float16:
|
if torch.backends.quantized.engine in ('qnnpack', 'onednn') and dtype == torch.float16:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if torch.backends.quantized.engine == 'qnnpack':
|
if torch.backends.quantized.engine == 'qnnpack':
|
||||||
|
|
@ -3299,7 +3308,8 @@ class TestQuantizedLinear(TestCase):
|
||||||
|
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
# No support for channelwise in xnnpack (int8)
|
# No support for channelwise in xnnpack (int8)
|
||||||
if dtype == torch.qint8 and use_channelwise:
|
# ONEDNN does not support qint8
|
||||||
|
if dtype == torch.qint8 and (use_channelwise or qengine_is_onednn()):
|
||||||
return
|
return
|
||||||
|
|
||||||
nptype = np_dtype[dtype]
|
nptype = np_dtype[dtype]
|
||||||
|
|
@ -3322,7 +3332,8 @@ class TestQuantizedLinear(TestCase):
|
||||||
|
|
||||||
W_scales = np.random.rand(output_channels)
|
W_scales = np.random.rand(output_channels)
|
||||||
# xnnpack forces W_zp to 0 when using symmetric quantization
|
# xnnpack forces W_zp to 0 when using symmetric quantization
|
||||||
if dtype == torch.qint8:
|
# ONEDNN only supports symmetric quantization of weight
|
||||||
|
if dtype == torch.qint8 or qengine_is_onednn():
|
||||||
W_zps = np.zeros(output_channels).astype(np.int)
|
W_zps = np.zeros(output_channels).astype(np.int)
|
||||||
else:
|
else:
|
||||||
W_zps = np.round(np.random.rand(output_channels) * 100 - 50).astype(np.int)
|
W_zps = np.round(np.random.rand(output_channels) * 100 - 50).astype(np.int)
|
||||||
|
|
@ -3342,7 +3353,7 @@ class TestQuantizedLinear(TestCase):
|
||||||
np.random.rand(output_channels) *
|
np.random.rand(output_channels) *
|
||||||
(b_value_max - b_value_min) + b_value_min
|
(b_value_max - b_value_min) + b_value_min
|
||||||
).astype(np.int32) if use_bias else None
|
).astype(np.int32) if use_bias else None
|
||||||
if torch.backends.quantized.engine == 'fbgemm':
|
if torch.backends.quantized.engine in ('fbgemm', 'onednn'):
|
||||||
avoid_vpmaddubsw_overflow_linear(
|
avoid_vpmaddubsw_overflow_linear(
|
||||||
batch_size,
|
batch_size,
|
||||||
input_channels,
|
input_channels,
|
||||||
|
|
@ -3429,6 +3440,13 @@ class TestQuantizedLinear(TestCase):
|
||||||
qlinear_prepack = torch.ops.quantized.linear_prepack
|
qlinear_prepack = torch.ops.quantized.linear_prepack
|
||||||
qlinear_unpack = torch.ops.quantized.linear_unpack
|
qlinear_unpack = torch.ops.quantized.linear_unpack
|
||||||
|
|
||||||
|
# ONEDNN only supports symmetric quantization of weight
|
||||||
|
if qengine_is_onednn():
|
||||||
|
if use_channelwise:
|
||||||
|
W_zps = torch.zeros(output_channels).to(torch.int64)
|
||||||
|
else:
|
||||||
|
W_zp = 0
|
||||||
|
|
||||||
W = torch.from_numpy(W)
|
W = torch.from_numpy(W)
|
||||||
if use_channelwise:
|
if use_channelwise:
|
||||||
W_q = torch.quantize_per_channel(
|
W_q = torch.quantize_per_channel(
|
||||||
|
|
@ -3892,6 +3910,10 @@ class TestQuantizedConv(TestCase):
|
||||||
if channelwise and transposed:
|
if channelwise and transposed:
|
||||||
# currently transposed conv and per-channel per quantization does not work
|
# currently transposed conv and per-channel per quantization does not work
|
||||||
return
|
return
|
||||||
|
# ONEDNN only supports symmetric quantization of weight and zero output padding
|
||||||
|
if qengine_is_onednn():
|
||||||
|
W_zero_point = 0
|
||||||
|
o_pads = len(o_pads) * [0] if o_pads is not None else None
|
||||||
if channelwise:
|
if channelwise:
|
||||||
if transposed:
|
if transposed:
|
||||||
output_channels = W.shape[1] # IC OC/G
|
output_channels = W.shape[1] # IC OC/G
|
||||||
|
|
@ -4030,6 +4052,9 @@ class TestQuantizedConv(TestCase):
|
||||||
weight_dtype=torch.qint8,
|
weight_dtype=torch.qint8,
|
||||||
output_dtype=torch.quint8,
|
output_dtype=torch.quint8,
|
||||||
):
|
):
|
||||||
|
# ONEDNN only supports symmetric quantization of weight
|
||||||
|
if qengine_is_onednn() and W_zero_point is not None:
|
||||||
|
W_zero_point = len(W_zero_point) * [0]
|
||||||
(X, W), (X_q, W_q), bias_float = self._make_qconv_tensors(
|
(X, W), (X_q, W_q), bias_float = self._make_qconv_tensors(
|
||||||
batch_size, input_channels_per_group, input_feature_map_shape,
|
batch_size, input_channels_per_group, input_feature_map_shape,
|
||||||
output_channels_per_group, groups, kernels,
|
output_channels_per_group, groups, kernels,
|
||||||
|
|
@ -4512,6 +4537,9 @@ class TestQuantizedConv(TestCase):
|
||||||
use_bias):
|
use_bias):
|
||||||
if qengine_is_qnnpack() and (IS_PPC or TEST_WITH_UBSAN):
|
if qengine_is_qnnpack() and (IS_PPC or TEST_WITH_UBSAN):
|
||||||
return # QNNPACK doesn't support these
|
return # QNNPACK doesn't support these
|
||||||
|
# ONEDNN does not support output paddings
|
||||||
|
if qengine_is_onednn() and (o_pad_h, o_pad_w) != (0, 0):
|
||||||
|
return
|
||||||
assume(o_pad_h < stride_h and o_pad_h < dilation)
|
assume(o_pad_h < stride_h and o_pad_h < dilation)
|
||||||
assume(o_pad_w < stride_w and o_pad_w < dilation)
|
assume(o_pad_w < stride_w and o_pad_w < dilation)
|
||||||
|
|
||||||
|
|
@ -4641,6 +4669,9 @@ class TestQuantizedConv(TestCase):
|
||||||
use_bias):
|
use_bias):
|
||||||
if qengine_is_qnnpack():
|
if qengine_is_qnnpack():
|
||||||
return # QNNPACK doesn't support this
|
return # QNNPACK doesn't support this
|
||||||
|
# ONEDNN doesn't support output paddings
|
||||||
|
if qengine_is_onednn() and (o_pad_t, o_pad_h, o_pad_w) != (0, 0, 0):
|
||||||
|
return
|
||||||
assume(o_pad_t < stride_t or o_pad_t < dilation)
|
assume(o_pad_t < stride_t or o_pad_t < dilation)
|
||||||
assume(o_pad_h < stride_h or o_pad_h < dilation)
|
assume(o_pad_h < stride_h or o_pad_h < dilation)
|
||||||
assume(o_pad_w < stride_w or o_pad_w < dilation)
|
assume(o_pad_w < stride_w or o_pad_w < dilation)
|
||||||
|
|
|
||||||
|
|
@ -184,8 +184,8 @@ def get_default_qconfig(backend='fbgemm', version=0):
|
||||||
Returns the default PTQ qconfig for the specified backend.
|
Returns the default PTQ qconfig for the specified backend.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
* `backend`: a string representing the target backend. Currently supports `fbgemm`
|
* `backend`: a string representing the target backend. Currently supports `fbgemm`,
|
||||||
and `qnnpack`.
|
`qnnpack` and `onednn`.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
qconfig
|
qconfig
|
||||||
|
|
@ -197,6 +197,9 @@ def get_default_qconfig(backend='fbgemm', version=0):
|
||||||
elif backend == 'qnnpack':
|
elif backend == 'qnnpack':
|
||||||
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
|
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
|
||||||
weight=default_weight_observer)
|
weight=default_weight_observer)
|
||||||
|
elif backend == 'onednn':
|
||||||
|
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
|
||||||
|
weight=default_per_channel_weight_observer)
|
||||||
else:
|
else:
|
||||||
qconfig = default_qconfig
|
qconfig = default_qconfig
|
||||||
else:
|
else:
|
||||||
|
|
@ -216,8 +219,8 @@ def get_default_qat_qconfig(backend='fbgemm', version=1):
|
||||||
Returns the default QAT qconfig for the specified backend.
|
Returns the default QAT qconfig for the specified backend.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
* `backend`: a string representing the target backend. Currently supports `fbgemm`
|
* `backend`: a string representing the target backend. Currently supports `fbgemm`,
|
||||||
and `qnnpack`.
|
`qnnpack` and `onednn`.
|
||||||
* `version`: version, for backwards compatibility. Can be `None` or `1`.
|
* `version`: version, for backwards compatibility. Can be `None` or `1`.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
|
|
@ -237,6 +240,11 @@ def get_default_qat_qconfig(backend='fbgemm', version=1):
|
||||||
quant_max=255,
|
quant_max=255,
|
||||||
reduce_range=False),
|
reduce_range=False),
|
||||||
weight=default_weight_fake_quant)
|
weight=default_weight_fake_quant)
|
||||||
|
elif backend == 'onednn':
|
||||||
|
qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
||||||
|
quant_min=0,
|
||||||
|
quant_max=255),
|
||||||
|
weight=default_per_channel_weight_fake_quant)
|
||||||
else:
|
else:
|
||||||
qconfig = default_qat_qconfig
|
qconfig = default_qat_qconfig
|
||||||
# Use the fused observe + fake_quant modules for doing QAT.
|
# Use the fused observe + fake_quant modules for doing QAT.
|
||||||
|
|
@ -253,6 +261,11 @@ def get_default_qat_qconfig(backend='fbgemm', version=1):
|
||||||
quant_max=255,
|
quant_max=255,
|
||||||
reduce_range=False),
|
reduce_range=False),
|
||||||
weight=default_fused_wt_fake_quant)
|
weight=default_fused_wt_fake_quant)
|
||||||
|
elif backend == 'onednn':
|
||||||
|
qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
||||||
|
quant_min=0,
|
||||||
|
quant_max=255),
|
||||||
|
weight=default_fused_per_channel_wt_fake_quant)
|
||||||
else:
|
else:
|
||||||
qconfig = default_qat_qconfig_v2
|
qconfig = default_qat_qconfig_v2
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,8 @@ def _get_qengine_id(qengine: str) -> int:
|
||||||
ret = 1
|
ret = 1
|
||||||
elif qengine == 'qnnpack':
|
elif qengine == 'qnnpack':
|
||||||
ret = 2
|
ret = 2
|
||||||
|
elif qengine == 'onednn':
|
||||||
|
ret = 3
|
||||||
else:
|
else:
|
||||||
ret = -1
|
ret = -1
|
||||||
raise RuntimeError("{} is not a valid value for quantized engine".format(qengine))
|
raise RuntimeError("{} is not a valid value for quantized engine".format(qengine))
|
||||||
|
|
@ -18,7 +20,7 @@ def _get_qengine_id(qengine: str) -> int:
|
||||||
|
|
||||||
# This function should correspond to the enums present in c10/core/QEngine.h
|
# This function should correspond to the enums present in c10/core/QEngine.h
|
||||||
def _get_qengine_str(qengine: int) -> str:
|
def _get_qengine_str(qengine: int) -> str:
|
||||||
all_engines = {0 : 'none', 1 : 'fbgemm', 2 : 'qnnpack'}
|
all_engines = {0 : 'none', 1 : 'fbgemm', 2 : 'qnnpack', 3 : 'onednn'}
|
||||||
return all_engines.get(qengine, '*undefined')
|
return all_engines.get(qengine, '*undefined')
|
||||||
|
|
||||||
class _QEngineProp(object):
|
class _QEngineProp(object):
|
||||||
|
|
|
||||||
|
|
@ -46,9 +46,12 @@ def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8):
|
||||||
qx = np.clip(qx, qmin, qmax).astype(qtype)
|
qx = np.clip(qx, qmin, qmax).astype(qtype)
|
||||||
return qx
|
return qx
|
||||||
|
|
||||||
def _calculate_dynamic_qparams(X, dtype, reduce_range=False):
|
def _calculate_dynamic_qparams(X, dtype, reduce_range=False, qscheme=torch.per_tensor_affine):
|
||||||
"""Calculate the dynamic quantization parameters (scale, zero_point)
|
"""Calculate the dynamic quantization parameters (scale, zero_point)
|
||||||
according to the min and max element of the tensor"""
|
according to the min and max element of the tensor"""
|
||||||
|
assert qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric)
|
||||||
|
if qscheme == torch.per_tensor_symmetric:
|
||||||
|
assert dtype == torch.qint8
|
||||||
if isinstance(X, torch.Tensor):
|
if isinstance(X, torch.Tensor):
|
||||||
X = X.numpy()
|
X = X.numpy()
|
||||||
if dtype == torch.qint8:
|
if dtype == torch.qint8:
|
||||||
|
|
@ -63,17 +66,25 @@ def _calculate_dynamic_qparams(X, dtype, reduce_range=False):
|
||||||
qmin, qmax = 0, 255
|
qmin, qmax = 0, 255
|
||||||
min_val = X.min()
|
min_val = X.min()
|
||||||
max_val = X.max()
|
max_val = X.max()
|
||||||
|
is_symmetric = (qscheme == torch.per_tensor_symmetric)
|
||||||
if min_val == max_val:
|
if min_val == max_val:
|
||||||
scale = 1.0
|
scale = 1.0
|
||||||
zero_point = 0
|
zero_point = 0
|
||||||
else:
|
else:
|
||||||
max_val = max(max_val, 0.0)
|
if is_symmetric:
|
||||||
min_val = min(min_val, 0.0)
|
max_val = max(max_val, -min_val)
|
||||||
scale = (max_val - min_val) / (qmax - qmin)
|
min_val = -max_val
|
||||||
scale = max(scale, np.finfo(np.float32).eps)
|
scale = (max_val - min_val) / (qmax - qmin)
|
||||||
zero_point = qmin - round(min_val / scale)
|
scale = max(scale, np.finfo(np.float32).eps)
|
||||||
zero_point = max(qmin, zero_point)
|
zero_point = 0
|
||||||
zero_point = min(qmax, zero_point)
|
else:
|
||||||
|
max_val = max(max_val, 0.0)
|
||||||
|
min_val = min(min_val, 0.0)
|
||||||
|
scale = (max_val - min_val) / (qmax - qmin)
|
||||||
|
scale = max(scale, np.finfo(np.float32).eps)
|
||||||
|
zero_point = qmin - round(min_val / scale)
|
||||||
|
zero_point = max(qmin, zero_point)
|
||||||
|
zero_point = min(qmax, zero_point)
|
||||||
return [float(scale), int(zero_point)]
|
return [float(scale), int(zero_point)]
|
||||||
|
|
||||||
def _calculate_dynamic_per_channel_qparams(X, dtype):
|
def _calculate_dynamic_per_channel_qparams(X, dtype):
|
||||||
|
|
@ -165,6 +176,8 @@ def qengine_is_fbgemm():
|
||||||
return torch.backends.quantized.engine == 'fbgemm'
|
return torch.backends.quantized.engine == 'fbgemm'
|
||||||
def qengine_is_qnnpack():
|
def qengine_is_qnnpack():
|
||||||
return torch.backends.quantized.engine == 'qnnpack'
|
return torch.backends.quantized.engine == 'qnnpack'
|
||||||
|
def qengine_is_onednn():
|
||||||
|
return torch.backends.quantized.engine == 'onednn'
|
||||||
|
|
||||||
# Helper function used to simulate per-channel fake-quant against any axis
|
# Helper function used to simulate per-channel fake-quant against any axis
|
||||||
def _permute_to_axis_zero(X, axis):
|
def _permute_to_axis_zero(X, axis):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user