Add onednn quant backend (#74137)

Summary:
Resolve the conflicts in https://github.com/pytorch/pytorch/pull/69820
jerryzh168 Please review. Thanks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74137

Reviewed By: samdow

Differential Revision: D34840477

Pulled By: jerryzh168

fbshipit-source-id: 8aa60981ff7be211a1609644f273b16d18efd425
(cherry picked from commit de76bb808b315e9a2e45d8c5f1c1233a47d669c4)
This commit is contained in:
Weiwen Xia 2022-03-14 18:23:08 -07:00 committed by PyTorch MergeBot
parent deae5950ba
commit 060f1b822a
19 changed files with 1007 additions and 37 deletions

View File

@ -236,6 +236,10 @@ const std::vector<at::QEngine>& Context::supportedQEngines() {
engines.push_back(at::kNoQEngine);
#endif // C10_MOBILE
#if AT_MKLDNN_ENABLED()
engines.push_back(at::kONEDNN);
#endif
#ifdef USE_FBGEMM
if (fbgemm::fbgemmSupportedCPU()) {
engines.push_back(at::kFBGEMM);

View File

@ -4,6 +4,7 @@
#include <ATen/core/List.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/onednn_utils.h>
#include <c10/util/irange.h>
#include <tuple>
@ -358,6 +359,20 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> deserialize_conv(
);
}
#endif // USE_PYTORCH_QNNPACK
#if AT_MKLDNN_ENABLED()
if (ctx.qEngine() == at::QEngine::ONEDNN) {
return PackedConvWeightsOnednn<kSpatialDim>::prepack(
weight.value(),
bias,
stride,
padding,
output_padding,
dilation,
groups,
transpose
);
}
#endif // AT_MKLDNN_ENABLED()
TORCH_CHECK(
false,
"Didn't find engine for when deserializing ConvPackedParams: ",

View File

@ -4,6 +4,7 @@
#include <ATen/native/quantized/cpu/embedding_packed_params.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/onednn_utils.h>
#include <ATen/native/TensorFactories.h>
#include <ATen/quantized/QTensorImpl.h>
#include <ATen/quantized/Quantizer.h>
@ -470,6 +471,16 @@ int register_linear_params() {
std::move(weight), std::move(bias));
}
#endif // USE_PYTORCH_QNNPACK
#if AT_MKLDNN_ENABLED()
if (at::globalContext().qEngine() == at::QEngine::ONEDNN) {
TORCH_CHECK(
weight.scalar_type() == at::kQInt8,
"ONEDNN only supports INT8 bit width currently. Got ",
c10::toString(weight.scalar_type()));
return PackedLinearWeightsOnednn::prepack(
std::move(weight), std::move(bias));
}
#endif // #if AT_MKLDNN_ENABLED()
TORCH_CHECK(false, "Unknown qengine");
})
.def("bias", [](const c10::intrusive_ptr<LinearPackedParamsBase>& self) {

View File

@ -0,0 +1,151 @@
#pragma once
#include <ATen/Config.h>
#if AT_MKLDNN_ENABLED()
#include <ATen/Tensor.h>
#include <ATen/native/quantized/packed_params.h>
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/native/mkldnn/Utils.h>
struct PackedLinearWeightsOnednn : public LinearPackedParamsBase {
PackedLinearWeightsOnednn(
std::unique_ptr<ideep::tensor> weight,
c10::optional<ideep::tensor> bias,
at::Tensor orig_weight,
c10::optional<at::Tensor> orig_bias)
: weight_(std::move(weight)),
bias_(std::move(bias)),
orig_weight_(std::move(orig_weight)),
orig_bias_(std::move(orig_bias)) {}
std::unique_ptr<ideep::tensor> weight_;
c10::optional<ideep::tensor> bias_;
at::Tensor orig_weight_;
c10::optional<at::Tensor> orig_bias_;
at::Tensor apply(
at::Tensor input,
double output_scale,
int64_t output_zero_point) override;
at::Tensor apply_relu(
at::Tensor input,
double output_scale,
int64_t output_zero_point) override;
at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override;
at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override;
std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
c10::optional<at::Tensor> bias() override {
return orig_bias_;
}
static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
at::Tensor weight,
c10::optional<at::Tensor> bias);
private:
template <bool ReluFused>
at::Tensor apply_impl(
at::Tensor input,
double output_scale,
int64_t output_zero_point);
template <bool ReluFused>
at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range=false);
};
template <int kSpatialDim = 2>
struct PackedConvWeightsOnednn : public ConvPackedParamsBase<kSpatialDim> {
PackedConvWeightsOnednn(
std::unique_ptr<ideep::tensor> weight,
c10::optional<ideep::tensor> bias,
at::Tensor orig_weight,
c10::optional<at::Tensor> orig_bias,
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups,
uint8_t transpose)
: weight_(std::move(weight)),
bias_(std::move(bias)),
orig_weight_(std::move(orig_weight)),
orig_bias_(std::move(orig_bias)),
stride_(std::move(stride)),
padding_(std::move(padding)),
output_padding_(std::move(output_padding)),
dilation_(std::move(dilation)),
groups_(groups),
transpose_(transpose) {}
std::unique_ptr<ideep::tensor> weight_;
c10::optional<ideep::tensor> bias_;
at::Tensor orig_weight_;
c10::optional<at::Tensor> orig_bias_;
torch::List<int64_t> stride_;
torch::List<int64_t> padding_;
torch::List<int64_t> output_padding_;
torch::List<int64_t> dilation_;
int64_t groups_;
uint8_t transpose_;
at::Tensor apply(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point) override;
at::Tensor apply_relu(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point) override;
at::Tensor apply_dynamic(
const at::Tensor& input,
bool reduce_range) override;
std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
at::Tensor weight,
c10::optional<at::Tensor> bias,
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups,
bool transpose);
torch::List<int64_t> stride() const override {
return stride_;
}
torch::List<int64_t> padding() const override {
return padding_;
}
torch::List<int64_t> output_padding() const override {
return output_padding_;
}
torch::List<int64_t> dilation() const override {
return dilation_;
}
int64_t groups() const override {
return groups_;
}
bool transpose() const override {
return (bool)transpose_;
}
private:
template <bool ReluFused>
at::Tensor apply_impl(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point);
};
#endif // #if AT_MKLDNN_ENABLED()

View File

@ -9,6 +9,8 @@
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/xnnpack_utils.h>
#include <ATen/native/quantized/cpu/onednn_utils.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/native/quantized/cpu/quant_utils.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <torch/library.h>
@ -1148,6 +1150,177 @@ template at::Tensor PackedConvWeightsQnnp<3>::apply_impl<false>(
#endif // USE_PYTORCH_QNNPACK
#if AT_MKLDNN_ENABLED()
template <int kSpatialDim>
at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point) {
return apply_impl<false>(input, output_scale, output_zero_point);
}
template <int kSpatialDim>
at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_relu(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point) {
return apply_impl<true>(input, output_scale, output_zero_point);
}
template <int kSpatialDim>
template <bool kReluFused>
at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_impl(
const at::Tensor& act,
double output_scale,
int64_t output_zero_point) {
std::string func_name = "quantized::conv";
if (transpose()) {
func_name += "_transpose";
}
func_name += std::to_string(kSpatialDim) + "d";
if (kReluFused) {
func_name += "_relu";
}
ConvDimChecks<kSpatialDim>(
act.ndimension(), stride().size(), padding().size(),
output_padding().size(), dilation().size(), func_name, transpose());
TORCH_CHECK(act.scalar_type() == c10::ScalarType::QUInt8,
func_name, " (ONEDNN): data type of input should be QUint8.");
// src
auto act_contig = act.contiguous(kSpatialDim == 2 ? c10::MemoryFormat::ChannelsLast : c10::MemoryFormat::ChannelsLast3d);
auto src_dims = act_contig.sizes().vec();
auto src_data_type = dnnl::memory::data_type::u8;
auto src_desc = ideep::tensor::desc(src_dims, src_data_type,
kSpatialDim == 2 ? ideep::format_tag::nhwc : ideep::format_tag::ndhwc);
ideep::tensor src;
src.init(src_desc, act_contig.data_ptr());
// weights & bias
ideep::tensor& weights = *(weight_.get());
bool with_bias = bias_.has_value();
const auto& kernel_size = weights.get_dims();
// dst
const std::vector<int64_t>& input_size = src.get_dims();
std::vector<int64_t> output_sizes;
if (transpose()) {
// Prepacked weight format: [o, i, ...]
const int N = act.size(0); // batch size
const int C = act.size(1); // input channels
const int M = weights.get_dim(0); // output channels
const int D = kSpatialDim == 2 ? 1 : act.size(2); // input depth
const int H = act.size(kSpatialDim); // input height
const int W = act.size(kSpatialDim + 1); // input width
const int KH = weights.get_dim(kSpatialDim); // kernel height
const int KW = weights.get_dim(kSpatialDim + 1); // kernel width
const int KD = kSpatialDim == 2 ? 1 : weights.get_dim(2); // kernel depth
TORCH_CHECK(C == groups() * weights.get_dim(1), // weight: [o, i, ...]
func_name, " (ONEDNN): input channel number should be ",
groups() * weights.get_dim(1), ", but got ", C);
auto output_shape = MakeDeConvOutputShape<kSpatialDim>(
N,
M,
kSpatialDim == 2 ? std::vector<int64_t>{H, W} : std::vector<int64_t>{D, H, W},
kSpatialDim == 2 ? std::vector<int64_t>{KH, KW} : std::vector<int64_t>{KD, KH, KW},
stride(),
padding(),
output_padding(),
dilation());
output_sizes = c10::IntArrayRef(output_shape).vec();
} else {
output_sizes = at::native::conv_output_size(input_size, kernel_size, padding().vec(), stride().vec(), dilation().vec());
}
ideep::dims dst_dims = ideep::dims({output_sizes.cbegin(), output_sizes.cend()});
at::Tensor output = at::_empty_affine_quantized(
dst_dims,
device(c10::kCPU)
.dtype(c10::kQUInt8)
.memory_format(kSpatialDim == 2 ?
c10::MemoryFormat::ChannelsLast :
c10::MemoryFormat::ChannelsLast3d),
output_scale,
output_zero_point,
c10::nullopt);
if (output.numel() == 0) {
return output;
}
ideep::tensor dst({dst_dims, ideep::tensor::data_type::u8, {output.strides().cbegin(), output.strides().cend()}},
output.data_ptr());
// Parameters
const ideep::dims& strides = stride().vec();
const ideep::dims& dilates = dilation().vec();
const ideep::dims& padding_l = padding().vec();
const ideep::dims& padding_r = padding().vec();
const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/act.q_scale()); // Scales of ONEDNN and PyTorch are reciprocal
const ideep::scale_t& weights_scales = weights.get_scale();
const ideep::scale_t& dst_scales = ideep::scale_t(weights_scales.size(), 1.0/output_scale); // Scales of ONEDNN and PyTorch are reciprocal
const ideep::zero_point_t src_zero_points = ideep::zero_point_t(1, act.q_zero_point());
const ideep::zero_point_t dst_zero_points = ideep::zero_point_t(1, output_zero_point);
ideep::attr_t op_attr = kReluFused ? ideep::attr_t::fuse_relu() : ideep::attr_t();
op_attr.set_zero_points(DNNL_ARG_SRC, ideep::utils::tensor_zp_mask(1), {DNNL_RUNTIME_S32_VAL}); // runtime src zero point
if (with_bias) {
// Bias might be modified outside (e.g. by quantization bias correction).
// If so, update the prepacked bias as well.
if (bias_.value().get_data_handle() != orig_bias_.value().data_ptr()) {
bias_.value().init(bias_.value().get_desc(), orig_bias_.value().data_ptr());
}
const auto& b = bias_.value();
if (transpose()) {
ideep::convolution_transpose_forward::compute_v2(
src, weights, b, dst_dims, dst,
strides, padding_l, padding_r, dilates,
groups(), src_scales, weights_scales, dst_scales, src_zero_points, dst_zero_points,
op_attr, dnnl::algorithm::deconvolution_direct, dnnl::prop_kind::forward_inference,
ideep::u8s8, ideep::engine::cpu_engine());
} else {
ideep::convolution_forward::compute_v2(
src, weights, b, dst_dims, dst,
strides, dilates, padding_l, padding_r, groups(),
src_scales, weights_scales, dst_scales, src_zero_points, dst_zero_points,
op_attr, dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference,
ideep::u8s8, ideep::engine::cpu_engine());
}
} else {
if (transpose()) {
ideep::convolution_transpose_forward::compute_v2(
src, weights, dst_dims, dst,
strides, padding_l, padding_r, dilates,
groups(), src_scales, weights_scales, dst_scales, src_zero_points, dst_zero_points,
op_attr, dnnl::algorithm::deconvolution_direct, dnnl::prop_kind::forward_inference,
ideep::u8s8, ideep::engine::cpu_engine());
} else {
ideep::convolution_forward::compute_v2(
src, weights, dst_dims, dst,
strides, dilates, padding_l, padding_r, groups(),
src_scales, weights_scales, dst_scales, src_zero_points, dst_zero_points,
op_attr, dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference,
ideep::u8s8, ideep::engine::cpu_engine());
}
}
return output;
}
template at::Tensor PackedConvWeightsOnednn<2>::apply(
const at::Tensor& act,
double output_scale,
int64_t output_zero_point);
template at::Tensor PackedConvWeightsOnednn<2>::apply_relu(
const at::Tensor& act,
double output_scale,
int64_t output_zero_point);
template at::Tensor PackedConvWeightsOnednn<3>::apply(
const at::Tensor& act,
double output_scale,
int64_t output_zero_point);
template at::Tensor PackedConvWeightsOnednn<3>::apply_relu(
const at::Tensor& act,
double output_scale,
int64_t output_zero_point);
#endif // #if AT_MKLDNN_ENABLED()
namespace at {
namespace native {
namespace {

View File

@ -8,6 +8,7 @@
#include <ATen/native/quantized/packed_params.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/onednn_utils.h>
#include <ATen/native/quantized/cpu/quant_utils.h>
#include <c10/util/irange.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
@ -118,6 +119,57 @@ template at::Tensor PackedConvWeightsQnnp<3>::apply_dynamic(
#endif // USE_PYTORCH_QNNPACK
#if AT_MKLDNN_ENABLED()
template <int kSpatialDim>
at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_dynamic(
const at::Tensor& input,
bool reduce_range) {
// Find min/max of input
float x_max = 0, x_min = 0;
if (input.numel() > 0) {
x_min = input.min().item<float>();
x_max = input.max().item<float>();
}
// Input tensor is quantized as 8-bit unsigned values
static constexpr int precision = 8;
static constexpr bool is_signed = false;
// Calculate scale and zero point for quantization of input tensor
auto q_params = quant_utils::ChooseQuantizationParams(
/*min=*/x_min,
/*max=*/x_max,
/*qmin=*/is_signed ? -(1 << (precision - 1)) : 0,
/*qmax=*/
is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1,
/*preserve_sparsity=*/false,
/*force_scale_power_of_two=*/false,
/*reduce_range=*/reduce_range);
// Quantize input
at::Tensor q_input = at::quantize_per_tensor(
input, q_params.scale, q_params.zero_point, c10::kQUInt8);
at::Tensor out =
apply_impl<false>(q_input, q_params.scale, q_params.zero_point);
// TODO: Modify ideep to allow fp32 input & output
// to avoid explicit `quantize - dequantize`
return at::dequantize(out);
}
template at::Tensor PackedConvWeightsOnednn<2>::apply_dynamic(
const at::Tensor& input,
bool reduce_range);
template at::Tensor PackedConvWeightsOnednn<3>::apply_dynamic(
const at::Tensor& input,
bool reduce_range);
#endif // AT_MKLDNN_ENABLED()
namespace at {
namespace native {
namespace {

View File

@ -6,6 +6,7 @@
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/onednn_utils.h>
#include <ATen/native/quantized/cpu/quant_utils.h>
#include <ATen/quantized/Quantizer.h>
#include <torch/library.h>
@ -314,6 +315,165 @@ c10::intrusive_ptr<ConvPackedParamsBase<2>> PackedConvWeightsQnnp<
bool transpose);
#endif // USE_PYTORCH_QNNPACK
#if AT_MKLDNN_ENABLED()
template <int kSpatialDim>
c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightsOnednn<
kSpatialDim>::
prepack(
at::Tensor weight,
c10::optional<at::Tensor> bias,
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups,
bool transpose) {
TORCH_CHECK(
weight.ndimension() == kSpatialDim + 2,
"Weights are expected to have ", kSpatialDim + 2, " dimensions");
TORCH_CHECK(
stride.size() == kSpatialDim,
"stride should contain ", kSpatialDim, " elements for ",
kSpatialDim, "D convolution.");
TORCH_CHECK(
padding.size() == kSpatialDim,
"Specify front/top/left padding only. "
"end/bottom/right padding assumed to be equal to front/top/left");
TORCH_CHECK(
!transpose || output_padding.size() == kSpatialDim,
"quantized::conv_prepack: Specify top/left output padding "
"only. bottom/right padding assumed to be equal to top/left");
TORCH_CHECK(
dilation.size() == kSpatialDim,
"dilation should contain ", kSpatialDim, " elements for ",
kSpatialDim, "D convolution.");
TORCH_CHECK(
!transpose || std::all_of(output_padding.begin(), output_padding.end(), [](int i) { return i==0; }),
"quantized::conv_prepack: ONEDNN only supports zero output_padding.");
// Weight
// Format: [OC IC//group KH KW] for conv; [IC OC//group KH KW] for deconv
auto dims = weight.sizes().vec();
auto strides = stride.vec();
auto padding_l = padding.vec();
auto padding_r = padding.vec();
auto dilates = dilation.vec();
auto op_attr = ideep::attr_t();
std::vector<int32_t> wgt_zero_points;
ideep::scale_t wgt_scales;
const int output_channels = transpose ? weight.size(1) * groups
: weight.size(0);
const auto qtype = weight.qscheme();
if (qtype == c10::kPerTensorAffine) {
TORCH_CHECK(
weight.q_zero_point()==0,
"quantized::qconv_prepack: ONEDNN only supports symmetric quantization of weight,"
" whose zero point must be 0.");
wgt_zero_points = std::vector<int32_t>(1, weight.q_zero_point());
wgt_scales = ideep::scale_t(1, 1.0/weight.q_scale()); // Scales of ONEDNN and PyTorch are reciprocal
} else if (qtype == c10::kPerChannelAffine) {
TORCH_CHECK(
!transpose,
"Per Channel Quantization is currently disabled for transposed conv");
wgt_zero_points.resize(output_channels);
wgt_scales.resize(output_channels);
for (int i = 0; i < output_channels; ++i) {
wgt_zero_points[i] = weight.q_per_channel_zero_points()[i].item<int32_t>();
TORCH_CHECK(
wgt_zero_points[i]==0,
"quantized::qconv_prepack: ONEDNN only supports symmetric quantization of weight,"
" whose zero point must be 0.");
wgt_scales[i] = 1.0f / weight.q_per_channel_scales()[i].item<float>(); // Scales of ONEDNN and PyTorch are reciprocal
}
} else {
TORCH_CHECK(false, "Unsupported qscheme: ", toString(qtype));
}
// Set runtime src zero point
auto src_zero_point = {DNNL_RUNTIME_S32_VAL};
op_attr.set_zero_points(DNNL_ARG_SRC,
ideep::utils::tensor_zp_mask(src_zero_point.size()),
src_zero_point);
at::Tensor weight_copy;
ideep::tensor::desc w_desc;
ideep::dims dims_iohw, dims_giohw;
ideep::tag w_tag = ideep::tag::any;
const bool with_groups = groups > 1;
if (transpose) {
w_desc = ideep::convolution_transpose_forward::expected_weights_desc(
dims, dnnl::memory::data_type::s8,
strides, padding_l, padding_r, dilates, groups,
dnnl::algorithm::deconvolution_direct, dnnl::prop_kind::forward_inference,
ideep::dims(), op_attr);
// convolution_transpose_forward::expected_weights_desc() gives format [i, o, ...],
// but ONEDNN requires [o, i, ...] for computation
dims_iohw = w_desc.get_dims();
dims_giohw = with_groups ? ideep::utils::group_dims(dims_iohw, groups) : dims_iohw;
std::vector<int64_t> perms(dims_giohw.size(), 0); // for permutation of weight
std::iota(perms.begin(), perms.end(), 0);
w_desc = w_desc.transpose(with_groups, with_groups + 1);
std::swap(perms[with_groups], perms[with_groups + 1]);
weight_copy = weight.reshape(dims_giohw).permute(c10::IntArrayRef(perms)).clone();
} else {
w_desc = ideep::convolution_forward::expected_weights_desc(
dims, dnnl::memory::data_type::s8,
strides, padding_l, padding_r, dilates, groups,
dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference,
dnnl::memory::data_type::u8, ideep::dims(), op_attr);
weight_copy = weight.clone();
}
if (with_groups) {
w_tag = kSpatialDim == 2 ? ideep::tag::goihw : ideep::tag::goidhw;
} else {
w_tag = kSpatialDim == 2 ? ideep::tag::oihw : ideep::tag::oidhw;
}
ideep::dims w_dims = with_groups ? ideep::utils::group_dims(w_desc.get_dims(), groups)
: w_desc.get_dims();
ideep::tensor wgt = ideep::tensor(
ideep::tensor::desc({w_dims, dnnl::memory::data_type::s8, w_tag}, groups),
weight_copy.data_ptr());
wgt.set_scale(wgt_scales); // Scales are needed for feed_from().
ideep::tensor exp_wgt;
exp_wgt.init(w_desc);
exp_wgt.set_scale(wgt_scales); // Also for feed_from()
exp_wgt.feed_from(wgt, transpose); // expect wgt to be in [OC IC KH KW] format
ideep::tensor * packed_weight_p = new ideep::tensor(exp_wgt);
packed_weight_p->set_scale(wgt_scales);
packed_weight_p->set_zero_point(wgt_zero_points);
std::unique_ptr<ideep::tensor> weight_ptr(packed_weight_p);
// Bias
c10::optional<ideep::tensor> onednn_bias{c10::nullopt};
if (bias.has_value()) {
at::Tensor bias_vec = bias.value();
TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)");
TORCH_CHECK(
bias_vec.size(0) == output_channels,
"bias should have K elements: " + std::to_string(output_channels));
auto bias_desc = ideep::tensor::desc(bias.value().sizes().vec(), dnnl::memory::data_type::f32);
ideep::tensor packed_bias;
packed_bias.init(bias_desc, bias.value().data_ptr());
onednn_bias = c10::optional<ideep::tensor>(packed_bias);
}
auto ret_ptr = c10::make_intrusive<PackedConvWeightsOnednn<kSpatialDim>>(
PackedConvWeightsOnednn<kSpatialDim>{
std::move(weight_ptr),
onednn_bias,
weight,
bias,
stride,
padding,
output_padding,
dilation,
groups,
transpose
});
return ret_ptr;
}
template struct PackedConvWeightsOnednn<2>;
template struct PackedConvWeightsOnednn<3>;
#endif // #if AT_MKLDNN_ENABLED()
namespace at {
namespace native {
namespace {
@ -377,6 +537,14 @@ class QConvPackWeightInt8 final {
}
#endif
#if AT_MKLDNN_ENABLED()
if (ctx.qEngine() == at::QEngine::ONEDNN) {
return PackedConvWeightsOnednn<kSpatialDim>::prepack(
weight, bias, stride, padding, output_padding, dilation, groups,
transpose);
}
#endif
TORCH_CHECK(
false,
"Didn't find engine for operation quantized::conv2d_prepack ",
@ -438,8 +606,6 @@ class QConv1dPackWeightInt8 final {
}
#endif
#ifdef USE_PYTORCH_QNNPACK
if (ctx.qEngine() == at::QEngine::QNNPACK) {
return PackedConvWeightsQnnp<2>::prepack(
@ -447,6 +613,15 @@ class QConv1dPackWeightInt8 final {
transpose);
}
#endif
#if AT_MKLDNN_ENABLED()
if (ctx.qEngine() == at::QEngine::ONEDNN) {
return PackedConvWeightsOnednn<2>::prepack(
weight, bias, stride, padding, output_padding, dilation, groups,
transpose);
}
#endif
TORCH_CHECK(
false,
"Didn't find engine for operation quantized::conv1d_prepack ",

View File

@ -5,6 +5,7 @@
#include <torch/library.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/onednn_utils.h>
#include <ATen/native/quantized/cpu/quant_utils.h>
#include <ATen/native/quantized/packed_params.h>
@ -120,6 +121,20 @@ template std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightsQnnp
3>::unpack();
#endif // USE_PYTORCH_QNNPACK
#if AT_MKLDNN_ENABLED()
template <int kSpatialDim>
std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightsOnednn<
kSpatialDim>::unpack() {
return std::tuple<at::Tensor, c10::optional<at::Tensor>>(
orig_weight_, orig_bias_);
}
template std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightsOnednn<
2>::unpack();
template std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightsOnednn<
3>::unpack();
#endif // #if AT_MKLDNN_ENABLED()
namespace at {
namespace native {
namespace {
@ -154,6 +169,12 @@ class QConvUnpackWeightsInt8 final {
}
#endif
#if AT_MKLDNN_ENABLED()
if (ctx.qEngine() == at::QEngine::ONEDNN) {
return packed_weight->unpack();
}
#endif
TORCH_CHECK(
false,
"Didn't find engine for operation quantized::conv2d_unpack ",
@ -185,6 +206,15 @@ class QConv1dUnpackWeightsInt8 final {
}
#endif
#if AT_MKLDNN_ENABLED()
if (ctx.qEngine() == at::QEngine::ONEDNN) {
std::tie(weight, bias) = packed_weight->unpack();
at::Tensor new_weight = weight.clone();
new_weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2);
return std::tuple<at::Tensor, c10::optional<at::Tensor>>(new_weight, bias);
}
#endif
TORCH_CHECK(
false,
"Didn't find engine for operation quantized::conv1d_unpack ",

View File

@ -5,6 +5,7 @@
#include <ATen/native/quantized/packed_params.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/xnnpack_utils.h>
#include <ATen/native/quantized/cpu/onednn_utils.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <torch/custom_class.h>
#include <torch/library.h>
@ -617,6 +618,81 @@ at::Tensor PackedLinearWeightsQnnp::apply_relu(
#endif // USE_PYTORCH_QNNPACK
#if AT_MKLDNN_ENABLED()
template <bool ReluFused>
at::Tensor PackedLinearWeightsOnednn::apply_impl(
at::Tensor input,
double output_scale,
int64_t output_zero_point) {
const int64_t dim = input.dim();
TORCH_CHECK(
dim != 0,
"qlinear (ONEDNN): input dim should be at least 1, but got 0");
TORCH_CHECK(input.scalar_type() == c10::ScalarType::QUInt8,
"qlinear (ONEDNN): data type of input should be QUint8.");
auto input_contig = input.expect_contiguous();
auto& w = *(weight_.get());
auto K = input.size(dim - 1), M = input.numel() / K, N = w.get_dim(1);
auto input_dims = {M, K};
auto input_data_type = dnnl::memory::data_type::u8;
auto input_desc = ideep::tensor::desc(input_dims, input_data_type);
ideep::attr_t op_attr = ReluFused ? ideep::attr_t::fuse_relu() : ideep::attr_t();
ideep::tensor x(input_desc, input_contig->data_ptr<c10::quint8>());
auto dst_dims = {M, N};
const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/input.q_scale());
const ideep::scale_t& weights_scales = w.get_scale();
const ideep::scale_t& dst_scales = ideep::scale_t(1, 1.0/output_scale); // Scales of ONEDNN and PyTorch are reciprocal
const ideep::zero_point_t& src_zero_point = ideep::zero_point_t(1, input.q_zero_point());
const ideep::zero_point_t& dst_zero_point = ideep::zero_point_t(1, output_zero_point);
// Compute: Use ideep::matmul_forward to support asymmetric quantization
// Allocate output Tensor
at::Tensor output = at::_empty_affine_quantized(
dst_dims,
at::device(c10::kCPU).dtype(c10::kQUInt8),
output_scale,
output_zero_point);
if (output.numel() == 0) {
return output;
}
ideep::tensor y({dst_dims, ideep::tensor::data_type::u8, {output.strides().cbegin(), output.strides().cend()}},
output.data_ptr());
if (bias_.has_value()) {
// Bias might be modified outside (e.g. by quantization bias correction).
// If so, update the prepacked bias as well.
if (bias_.value().get_data_handle() != orig_bias_.value().data_ptr()) {
bias_.value().init(bias_.value().get_desc(), orig_bias_.value().data_ptr());
}
const auto& b = bias_.value();
ideep::matmul_forward::compute_v2(x, w, b, y, 1.0f, 1.0f, src_scales, weights_scales, dst_scales,
src_zero_point, dst_zero_point, op_attr);
} else {
ideep::matmul_forward::compute_v2(x, w, y, 1.0f, 1.0f, src_scales, weights_scales, dst_scales,
src_zero_point, dst_zero_point, op_attr);
}
auto out_sizes = input.sizes().vec();
out_sizes.back() = N;
if (output.sizes().vec() == out_sizes)
return output;
return output.reshape(out_sizes);
}
at::Tensor PackedLinearWeightsOnednn::apply(
at::Tensor input,
double output_scale,
int64_t output_zero_point) {
return apply_impl<false>(std::move(input), output_scale, output_zero_point);
}
at::Tensor PackedLinearWeightsOnednn::apply_relu(
at::Tensor input,
double output_scale,
int64_t output_zero_point) {
return apply_impl<true>(std::move(input), output_scale, output_zero_point);
}
#endif // #if AT_MKLDNN_ENABLED()
namespace at {
namespace native {
namespace {

View File

@ -4,6 +4,7 @@
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/packed_params.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/onednn_utils.h>
#include <ATen/native/quantized/cpu/quant_utils.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <torch/library.h>
@ -463,6 +464,99 @@ void PackedLinearWeightFp16::set_bias(c10::optional<at::Tensor> bias) {
#endif // USE_FBGEMM
#if AT_MKLDNN_ENABLED()
template <bool ReluFused>
at::Tensor PackedLinearWeightsOnednn::apply_dynamic_impl(
at::Tensor input,
bool reduce_range) {
// Dynamic: fp32 * int8 -> fp32
using at::Tensor;
TORCH_CHECK(
input.dim() >= 2,
"The dimension of input tensor should be larger than or equal to 2");
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Float,
"qlinear_dynamic (ONEDNN): data type of input should be float.");
// Input -> uint8
auto input_contig = input.contiguous();
const int64_t dim = input.dim();
auto input_reshaped =
dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)});
auto input_dims = input_reshaped.sizes().vec();
auto input_data_type = dnnl::memory::data_type::f32;
auto input_desc = ideep::tensor::desc(input_dims, input_data_type);
ideep::attr_t op_attr = ReluFused ? ideep::attr_t::fuse_relu() : ideep::attr_t();
ideep::tensor x;
x.init(input_desc, input_contig.data_ptr());
// Find quantization parameters
float x_max = 0, x_min = 0;
if (input.numel() > 0) {
x_min = input_contig.min().item<float>();
x_max = input_contig.max().item<float>();
}
const int precision = 8;
auto q_params = quant_utils::ChooseQuantizationParams(
/*min=*/x_min,
/*max=*/x_max,
/*qmin=*/0,
/*qmax=*/(1 << precision) - 1,
/*preserve_sparsity=*/false,
/*force_scale_power_of_two=*/false,
/*reduce_range=*/reduce_range);
const std::vector<int32_t>& src_zero_point = std::vector<int32_t>(1, q_params.zero_point);
// weights, dst
auto w = *(weight_.get());
auto dst_dims = {x.get_dim(0), w.get_dim(1)};
const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/q_params.scale);
const ideep::scale_t& weights_scales = w.get_scale();
// Compute -> f32
// Use ideep::matmul_forward instead of ideep::inner_product_forward,
// since the latter does not support asymmetric quantization
// Allocate output Tensor
at::Tensor output = at::empty(dst_dims, input.options().dtype(at::kFloat));
if (output.numel() == 0) return output;
ideep::tensor y({dst_dims, ideep::tensor::data_type::f32,
{output.strides().cbegin(), output.strides().cend()}},
output.data_ptr());
if (bias_.has_value()) {
// Bias might be modified outside (e.g. by quantization bias correction).
// If so, update the prepacked bias as well.
if (bias_.value().get_data_handle() != orig_bias_.value().data_ptr()) {
bias_.value().init(bias_.value().get_desc(), orig_bias_.value().data_ptr());
}
const ideep::tensor b = bias_.value();
ideep::matmul_forward::compute_v2(x, w, b, y, 1.0f, 1.0f,
src_scales, weights_scales, ideep::scale_t(),
src_zero_point, ideep::zero_point_t(), op_attr);
} else {
ideep::matmul_forward::compute_v2(x, w, y, 1.0f, 1.0f,
src_scales, weights_scales, ideep::scale_t(),
src_zero_point, ideep::zero_point_t(), op_attr);
}
auto out_sizes = input.sizes().vec();
out_sizes.back() = w.get_dim(1);
if (output.sizes().vec() == out_sizes)
return output;
return output.reshape(out_sizes);
}
at::Tensor PackedLinearWeightsOnednn::apply_dynamic(
at::Tensor input,
bool reduce_range) {
return apply_dynamic_impl</*ReluFused=*/false>(
std::move(input), reduce_range);
}
at::Tensor PackedLinearWeightsOnednn::apply_dynamic_relu(
at::Tensor input,
bool reduce_range) {
return apply_dynamic_impl</*ReluFused=*/true>(
std::move(input), reduce_range);
}
#endif // #if AT_MKLDNN_ENABLED()
namespace at {
namespace native {
namespace {

View File

@ -4,6 +4,7 @@
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/packed_params.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/onednn_utils.h>
#include <ATen/native/quantized/cpu/quant_utils.h>
#include <ATen/quantized/Quantizer.h>
#include <torch/custom_class.h>
@ -194,6 +195,80 @@ c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightFp16::prepack(
}
#endif // USE_FBGEMM
#if AT_MKLDNN_ENABLED()
c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightsOnednn::prepack(
at::Tensor weight,
c10::optional<at::Tensor> bias) {
TORCH_CHECK(
weight.dim() == 2,
"The weight tensor for quantized::linear_prepack (onednn) should"
" be 2-dimensional.");
// Weight
std::vector<int64_t> dims = weight.sizes().vec();
auto N = weight.size(0);
std::vector<int32_t> wgt_zero_points;
ideep::scale_t wgt_scales;
const auto qtype = weight.qscheme();
if (qtype == c10::kPerTensorAffine) {
TORCH_CHECK(
weight.q_zero_point() == 0,
"quantized::linear_prepack: ONEDNN only supports symmetric quantization of weight,"
" whose zero point must be 0, but got ", weight.q_zero_point());
wgt_zero_points = std::vector<int32_t>(1, weight.q_zero_point());
wgt_scales = ideep::scale_t(1, 1.0/weight.q_scale()); // Scales of ONEDNN and PyTorch are reciprocal
} else if (qtype == c10::kPerChannelAffine) {
wgt_zero_points.resize(N);
wgt_scales.resize(N);
for (int i = 0; i < N; ++i) {
wgt_zero_points[i] = weight.q_per_channel_zero_points()[i].item<int32_t>();
TORCH_CHECK(
wgt_zero_points[i] == 0,
"quantized::linear_prepack: ONEDNN only supports symmetric quantization of weight,"
" whose zero point must be 0, but got ", wgt_zero_points[i], ", at index ", i);
wgt_scales[i] = 1.0f / weight.q_per_channel_scales()[i].item<float>(); // Scales of ONEDNN and PyTorch are reciprocal
}
} else {
TORCH_CHECK(false, "Unsupported qscheme: ", toString(qtype));
}
// Prepack weight
auto weight_copy = weight.clone();
ideep::tensor wgt = ideep::tensor({dims, dnnl::memory::data_type::s8}, weight_copy.data_ptr());
wgt.transpose_(0, 1); // ONEDNN requires transposed weight
auto w_desc = ideep::matmul_forward::expected_weights_desc(wgt.get_dims(), dnnl::memory::data_type::s8,
dnnl::memory::data_type::u8);
ideep::tensor exp_wgt(w_desc);
exp_wgt.feed_from(wgt);
ideep::tensor * packed_weight_p = new ideep::tensor(exp_wgt);
packed_weight_p->set_scale(wgt_scales);
packed_weight_p->set_zero_point(wgt_zero_points);
std::unique_ptr<ideep::tensor> weight_ptr(packed_weight_p);
// Bias
c10::optional<ideep::tensor> onednn_bias{c10::nullopt};
if (bias.has_value()) {
auto& b = bias.value();
auto bias_size = b.sizes().vec();
bias_size.insert(bias_size.begin(), 1);
TORCH_CHECK(
bias_size[1] == weight_ptr->get_dim(1),
"bias should have N elements: ",
std::to_string(weight_ptr->get_dim(1)),
", but got ", bias_size[1]);
auto bias_desc = ideep::tensor::desc(bias_size, dnnl::memory::data_type::f32);
ideep::tensor packed_bias;
packed_bias.init(bias_desc, b.data_ptr());
onednn_bias = c10::optional<ideep::tensor>(packed_bias);
}
auto ret_ptr = c10::make_intrusive<PackedLinearWeightsOnednn>(
PackedLinearWeightsOnednn{
std::move(weight_ptr),
onednn_bias,
weight,
bias});
return ret_ptr;
}
#endif // #if AT_MKLDNN_ENABLED()
namespace at {
namespace native {
@ -224,6 +299,11 @@ class QLinearPackWeightInt8 final {
std::move(weight), std::move(bias));
}
#endif
#if AT_MKLDNN_ENABLED()
if (ctx.qEngine() == at::QEngine::ONEDNN) {
return PackedLinearWeightsOnednn::prepack(std::move(weight), std::move(bias));
}
#endif // #if AT_MKLDNN_ENABLED()
TORCH_CHECK(
false,
"Didn't find engine for operation quantized::linear_prepack ",
@ -254,6 +334,14 @@ class QLinearPackWeightFp16 final {
"not supported by QNNPACK");
}
#endif // USE_PYTORCH_QNNPACK
#if AT_MKLDNN_ENABLED()
if (ctx.qEngine() == at::QEngine::ONEDNN) {
TORCH_CHECK(
false,
"quantized::linear_prepack_fp16 is currently "
"not supported by ONEDNN");
}
#endif // #if AT_MKLDNN_ENABLED()
TORCH_CHECK(
false,
"Didn't find engine for operation quantized::linear_prepack_fp16 ",
@ -287,6 +375,16 @@ class QLinearPackWeightInt8Legacy final {
return cpp_custom_type_hack::create(std::move(wrapped), options);
}
#endif // USE_PYTORCH_QNNPACK
#if AT_MKLDNN_ENABLED()
if (ctx.qEngine() == at::QEngine::ONEDNN) {
auto prepacked =
PackedLinearWeightsOnednn::prepack(std::move(weight), std::move(bias));
auto wrapped =
std::make_unique<c10::intrusive_ptr<LinearPackedParamsBase>>(
std::move(prepacked));
return cpp_custom_type_hack::create(std::move(wrapped), options);
}
#endif // #if AT_MKLDNN_ENABLED()
TORCH_CHECK(
false,
"Didn't find engine for operation quantized::linear_prepack ",
@ -317,6 +415,14 @@ class QLinearPackWeightFp16Legacy final {
"not supported by QNNPACK");
}
#endif // USE_PYTORCH_QNNPACK
#if AT_MKLDNN_ENABLED()
if (ctx.qEngine() == at::QEngine::ONEDNN) {
TORCH_CHECK(
false,
"quantized::linear_prepack_fp16 is currently "
"not supported by ONEDNN");
}
#endif // #if AT_MKLDNN_ENABLED()
TORCH_CHECK(
false,
"Didn't find engine for operation quantized::linear_prepack_fp16 ",

View File

@ -3,6 +3,7 @@
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/packed_params.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/onednn_utils.h>
#include <torch/custom_class.h>
#include <torch/library.h>
@ -74,6 +75,13 @@ std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedLinearWeightFp16::
}
#endif // USE_FBGEMM
#if AT_MKLDNN_ENABLED()
std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedLinearWeightsOnednn::unpack() {
return std::tuple<at::Tensor, c10::optional<at::Tensor>>(
orig_weight_, orig_bias_);
}
#endif // #if AT_MKLDNN_ENABLED()
namespace at {
namespace native {
namespace {

View File

@ -15,11 +15,13 @@ enum class QEngine : uint8_t {
NoQEngine = 0,
FBGEMM = 1,
QNNPACK = 2,
ONEDNN = 3,
};
constexpr auto kNoQEngine = QEngine::NoQEngine;
constexpr auto kFBGEMM = QEngine::FBGEMM;
constexpr auto kQNNPACK = QEngine::QNNPACK;
constexpr auto kONEDNN = QEngine::ONEDNN;
inline std::string toString(QEngine qengine) {
switch (qengine) {
@ -29,6 +31,8 @@ inline std::string toString(QEngine qengine) {
return "FBGEMM";
case kQNNPACK:
return "QNNPACK";
case kONEDNN:
return "ONEDNN";
default:
TORCH_CHECK(
false, "Unrecognized Quantized Engine: ", static_cast<int>(qengine));

View File

@ -22,6 +22,7 @@ from torch.testing._internal.common_quantized import (
override_qengines,
qengine_is_qnnpack,
qengine_is_fbgemm,
qengine_is_onednn,
)
# TODO: Once more test files are created, move the contents to a ao folder.
@ -48,6 +49,9 @@ class TestQuantizedSparseKernels(TestCase):
# to other higher priority works.
if qengine_is_qnnpack() and not (row_block_size == 1 and col_block_size == 4):
return
# ONEDNN does not support this yet
if qengine_is_onednn():
return
dense_prepack = torch.ops.quantized.linear_prepack
dense_qlinear = torch.ops.quantized.linear
@ -215,6 +219,10 @@ class TestQuantizedSparseLayers(TestCase):
Y_hat = sqmodel(X_fp32)
self.assertEqual(Y_ref, Y_hat)
# ONEDNN does not support this yet
elif qengine_is_onednn():
return
row_block_size, col_block_size = sqmodel.linear._packed_params._weight_bias()[2:]
assert row_block_size == 1 and col_block_size == 4

View File

@ -27,6 +27,7 @@ from torch.testing._internal.common_quantized import (
override_quantized_engine,
override_qengines,
qengine_is_qnnpack,
qengine_is_onednn,
)
from hypothesis import assume, given
from hypothesis import strategies as st
@ -99,7 +100,9 @@ class TestStaticQuantizedModule(QuantizationTestCase):
zero_points=zero_point_tensor,
axis=0, dtype=torch.qint8)
else:
W_q = torch.quantize_per_tensor(W, 0.1, 4, torch.qint8)
# ONEDNN only supports symmetric quantization of weight
W_zp = 0 if qengine_is_onednn() else 4
W_q = torch.quantize_per_tensor(W, 0.1, W_zp, torch.qint8)
X = torch.rand(batch_size, in_features).float()
X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8)
@ -434,7 +437,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
X_scale = 1.3
X_zero_point = 2
W_scale = [0.5]
W_zero_point = [3]
W_zero_point = [0] if qengine_is_onednn() else [3]
Y_scale = 5.0
Y_zero_point = 4
if torch.backends.quantized.engine == 'qnnpack':
@ -501,7 +504,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
X_scale = 1.3
X_zero_point = 2
W_scale = [0.5]
W_zero_point = [3]
W_zero_point = [0] if qengine_is_onednn() else [3]
Y_scale = 5.0
Y_zero_point = 4
# use_fused -> quantized class
@ -570,7 +573,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
X_scale = 1.3
X_zero_point = 2
W_scale = [0.5]
W_zero_point = [3]
W_zero_point = [0] if qengine_is_onednn() else [3]
Y_scale = 5.0
Y_zero_point = 4
# use_fused -> quantized class
@ -1200,7 +1203,8 @@ class TestDynamicQuantizedModule(QuantizationTestCase):
def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_default_observer):
"""test API functionality for nn.quantized.dynamic.Linear"""
W = torch.rand(out_features, in_features).float()
W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8)
qscheme = torch.per_tensor_symmetric if qengine_is_onednn() else torch.per_tensor_affine
W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8, qscheme=qscheme)
W_q = torch.quantize_per_tensor(W, W_scale, W_zp, torch.qint8)
X = torch.rand(batch_size, in_features).float()
B = torch.rand(out_features).float() if use_bias else None
@ -1311,8 +1315,8 @@ class TestDynamicQuantizedModule(QuantizationTestCase):
bias_keys.append(key_name1)
bias_keys.append(key_name2)
if not (dtype == torch.float16 and torch.backends.quantized.engine == "qnnpack"):
# fp16 dynamic quant is not supported for qnnpack
if not (dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn")):
# fp16 dynamic quant is not supported for qnnpack or onednn
x = torch.randn(seq_len, batch, input_size)
h = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size)
c = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size)
@ -1362,8 +1366,8 @@ class TestDynamicQuantizedModule(QuantizationTestCase):
# instantiated for all engines and dtypes
for dtype in [torch.qint8, torch.float16]:
if dtype == torch.float16 and torch.backends.quantized.engine == "qnnpack":
# fp16 dynamic quant is not supported for qnnpack
if dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn"):
# fp16 dynamic quant is not supported for qnnpack or onednn
continue
# Test default instantiation
seq_len = 4
@ -1435,8 +1439,8 @@ class TestDynamicQuantizedModule(QuantizationTestCase):
'RNNReLU': torch.ops.quantized.quantized_rnn_relu_cell_dynamic}
for rnn_type in cell_dict.keys():
if not (dtype == torch.float16 and torch.backends.quantized.engine == "qnnpack"):
# fp16 dynamic quant is not supported for qnnpack
if not (dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn")):
# fp16 dynamic quant is not supported for qnnpack or onednn
kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'bias': bias, 'dtype': dtype}
if rnn_type == 'RNNReLU':
kwargs['nonlinearity'] = "relu"

View File

@ -26,7 +26,10 @@ from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MAC
from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK
from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \
override_quantized_engine, supported_qengines, override_qengines, _snr
from torch.testing._internal.common_quantized import qengine_is_qnnpack
from torch.testing._internal.common_quantized import (
qengine_is_qnnpack,
qengine_is_onednn,
)
from torch.ao.quantization import PerChannelMinMaxObserver
from torch.testing._internal.common_cuda import TEST_CUDNN
import torch.backends.xnnpack
@ -2658,7 +2661,7 @@ class TestQuantizedOps(TestCase):
]
q_data = []
reduce_range = (qengine == 'fbgemm')
reduce_range = (qengine in ('fbgemm', 'onednn'))
for idx, x in enumerate(fp_data):
scale, zero_point = _calculate_dynamic_qparams(
x, dtype=dtype, reduce_range=reduce_range)
@ -2679,7 +2682,13 @@ class TestQuantizedOps(TestCase):
mha.eval()
# Prepare
mha.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
if qengine_is_onednn():
# `reduce_range` is False by default for ONEDNN backend
# but the test fails on earlier CPUs without VNNI.
# So we use a default qconfig with `reduce_range=True` here
mha.qconfig = torch.ao.quantization.get_default_qconfig()
else:
mha.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
mha_prepared = torch.ao.quantization.prepare(
mha, prepare_custom_config_dict=custom_module_config)
@ -2772,7 +2781,7 @@ class TestDynamicQuantizedOps(TestCase):
(b_value_max - b_value_min) + b_value_min
).astype(np.int32) if use_bias else None
if torch.backends.quantized.engine == 'fbgemm':
if torch.backends.quantized.engine in ('fbgemm', 'onednn'):
avoid_vpmaddubsw_overflow_linear(
batch_size,
input_channels,
@ -3009,8 +3018,8 @@ class TestDynamicQuantizedOps(TestCase):
for rnn_type in ['LSTM', 'GRU']:
for dtype in [torch.qint8, torch.float16]:
# Fp16 quantization is not supported for qnnpack
if torch.backends.quantized.engine == 'qnnpack' and dtype == torch.float16:
# Fp16 quantization is not supported for qnnpack or onednn
if torch.backends.quantized.engine in ('qnnpack', 'onednn') and dtype == torch.float16:
continue
if torch.backends.quantized.engine == 'qnnpack':
@ -3143,8 +3152,8 @@ class TestDynamicQuantizedOps(TestCase):
for rnn_type in ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU']:
for dtype in [torch.qint8, torch.float16]:
# Fp16 quantization is not supported for qnnpack
if torch.backends.quantized.engine == 'qnnpack' and dtype == torch.float16:
# Fp16 quantization is not supported for qnnpack or onednn
if torch.backends.quantized.engine in ('qnnpack', 'onednn') and dtype == torch.float16:
continue
if torch.backends.quantized.engine == 'qnnpack':
@ -3299,7 +3308,8 @@ class TestQuantizedLinear(TestCase):
for dtype in dtypes:
# No support for channelwise in xnnpack (int8)
if dtype == torch.qint8 and use_channelwise:
# ONEDNN does not support qint8
if dtype == torch.qint8 and (use_channelwise or qengine_is_onednn()):
return
nptype = np_dtype[dtype]
@ -3322,7 +3332,8 @@ class TestQuantizedLinear(TestCase):
W_scales = np.random.rand(output_channels)
# xnnpack forces W_zp to 0 when using symmetric quantization
if dtype == torch.qint8:
# ONEDNN only supports symmetric quantization of weight
if dtype == torch.qint8 or qengine_is_onednn():
W_zps = np.zeros(output_channels).astype(np.int)
else:
W_zps = np.round(np.random.rand(output_channels) * 100 - 50).astype(np.int)
@ -3342,7 +3353,7 @@ class TestQuantizedLinear(TestCase):
np.random.rand(output_channels) *
(b_value_max - b_value_min) + b_value_min
).astype(np.int32) if use_bias else None
if torch.backends.quantized.engine == 'fbgemm':
if torch.backends.quantized.engine in ('fbgemm', 'onednn'):
avoid_vpmaddubsw_overflow_linear(
batch_size,
input_channels,
@ -3429,6 +3440,13 @@ class TestQuantizedLinear(TestCase):
qlinear_prepack = torch.ops.quantized.linear_prepack
qlinear_unpack = torch.ops.quantized.linear_unpack
# ONEDNN only supports symmetric quantization of weight
if qengine_is_onednn():
if use_channelwise:
W_zps = torch.zeros(output_channels).to(torch.int64)
else:
W_zp = 0
W = torch.from_numpy(W)
if use_channelwise:
W_q = torch.quantize_per_channel(
@ -3892,6 +3910,10 @@ class TestQuantizedConv(TestCase):
if channelwise and transposed:
# currently transposed conv and per-channel per quantization does not work
return
# ONEDNN only supports symmetric quantization of weight and zero output padding
if qengine_is_onednn():
W_zero_point = 0
o_pads = len(o_pads) * [0] if o_pads is not None else None
if channelwise:
if transposed:
output_channels = W.shape[1] # IC OC/G
@ -4030,6 +4052,9 @@ class TestQuantizedConv(TestCase):
weight_dtype=torch.qint8,
output_dtype=torch.quint8,
):
# ONEDNN only supports symmetric quantization of weight
if qengine_is_onednn() and W_zero_point is not None:
W_zero_point = len(W_zero_point) * [0]
(X, W), (X_q, W_q), bias_float = self._make_qconv_tensors(
batch_size, input_channels_per_group, input_feature_map_shape,
output_channels_per_group, groups, kernels,
@ -4512,6 +4537,9 @@ class TestQuantizedConv(TestCase):
use_bias):
if qengine_is_qnnpack() and (IS_PPC or TEST_WITH_UBSAN):
return # QNNPACK doesn't support these
# ONEDNN does not support output paddings
if qengine_is_onednn() and (o_pad_h, o_pad_w) != (0, 0):
return
assume(o_pad_h < stride_h and o_pad_h < dilation)
assume(o_pad_w < stride_w and o_pad_w < dilation)
@ -4641,6 +4669,9 @@ class TestQuantizedConv(TestCase):
use_bias):
if qengine_is_qnnpack():
return # QNNPACK doesn't support this
# ONEDNN doesn't support output paddings
if qengine_is_onednn() and (o_pad_t, o_pad_h, o_pad_w) != (0, 0, 0):
return
assume(o_pad_t < stride_t or o_pad_t < dilation)
assume(o_pad_h < stride_h or o_pad_h < dilation)
assume(o_pad_w < stride_w or o_pad_w < dilation)

View File

@ -184,8 +184,8 @@ def get_default_qconfig(backend='fbgemm', version=0):
Returns the default PTQ qconfig for the specified backend.
Args:
* `backend`: a string representing the target backend. Currently supports `fbgemm`
and `qnnpack`.
* `backend`: a string representing the target backend. Currently supports `fbgemm`,
`qnnpack` and `onednn`.
Return:
qconfig
@ -197,6 +197,9 @@ def get_default_qconfig(backend='fbgemm', version=0):
elif backend == 'qnnpack':
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
weight=default_weight_observer)
elif backend == 'onednn':
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
weight=default_per_channel_weight_observer)
else:
qconfig = default_qconfig
else:
@ -216,8 +219,8 @@ def get_default_qat_qconfig(backend='fbgemm', version=1):
Returns the default QAT qconfig for the specified backend.
Args:
* `backend`: a string representing the target backend. Currently supports `fbgemm`
and `qnnpack`.
* `backend`: a string representing the target backend. Currently supports `fbgemm`,
`qnnpack` and `onednn`.
* `version`: version, for backwards compatibility. Can be `None` or `1`.
Return:
@ -237,6 +240,11 @@ def get_default_qat_qconfig(backend='fbgemm', version=1):
quant_max=255,
reduce_range=False),
weight=default_weight_fake_quant)
elif backend == 'onednn':
qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
quant_min=0,
quant_max=255),
weight=default_per_channel_weight_fake_quant)
else:
qconfig = default_qat_qconfig
# Use the fused observe + fake_quant modules for doing QAT.
@ -253,6 +261,11 @@ def get_default_qat_qconfig(backend='fbgemm', version=1):
quant_max=255,
reduce_range=False),
weight=default_fused_wt_fake_quant)
elif backend == 'onednn':
qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
quant_min=0,
quant_max=255),
weight=default_fused_per_channel_wt_fake_quant)
else:
qconfig = default_qat_qconfig_v2
else:

View File

@ -11,6 +11,8 @@ def _get_qengine_id(qengine: str) -> int:
ret = 1
elif qengine == 'qnnpack':
ret = 2
elif qengine == 'onednn':
ret = 3
else:
ret = -1
raise RuntimeError("{} is not a valid value for quantized engine".format(qengine))
@ -18,7 +20,7 @@ def _get_qengine_id(qengine: str) -> int:
# This function should correspond to the enums present in c10/core/QEngine.h
def _get_qengine_str(qengine: int) -> str:
all_engines = {0 : 'none', 1 : 'fbgemm', 2 : 'qnnpack'}
all_engines = {0 : 'none', 1 : 'fbgemm', 2 : 'qnnpack', 3 : 'onednn'}
return all_engines.get(qengine, '*undefined')
class _QEngineProp(object):

View File

@ -46,9 +46,12 @@ def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8):
qx = np.clip(qx, qmin, qmax).astype(qtype)
return qx
def _calculate_dynamic_qparams(X, dtype, reduce_range=False):
def _calculate_dynamic_qparams(X, dtype, reduce_range=False, qscheme=torch.per_tensor_affine):
"""Calculate the dynamic quantization parameters (scale, zero_point)
according to the min and max element of the tensor"""
assert qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric)
if qscheme == torch.per_tensor_symmetric:
assert dtype == torch.qint8
if isinstance(X, torch.Tensor):
X = X.numpy()
if dtype == torch.qint8:
@ -63,17 +66,25 @@ def _calculate_dynamic_qparams(X, dtype, reduce_range=False):
qmin, qmax = 0, 255
min_val = X.min()
max_val = X.max()
is_symmetric = (qscheme == torch.per_tensor_symmetric)
if min_val == max_val:
scale = 1.0
zero_point = 0
else:
max_val = max(max_val, 0.0)
min_val = min(min_val, 0.0)
scale = (max_val - min_val) / (qmax - qmin)
scale = max(scale, np.finfo(np.float32).eps)
zero_point = qmin - round(min_val / scale)
zero_point = max(qmin, zero_point)
zero_point = min(qmax, zero_point)
if is_symmetric:
max_val = max(max_val, -min_val)
min_val = -max_val
scale = (max_val - min_val) / (qmax - qmin)
scale = max(scale, np.finfo(np.float32).eps)
zero_point = 0
else:
max_val = max(max_val, 0.0)
min_val = min(min_val, 0.0)
scale = (max_val - min_val) / (qmax - qmin)
scale = max(scale, np.finfo(np.float32).eps)
zero_point = qmin - round(min_val / scale)
zero_point = max(qmin, zero_point)
zero_point = min(qmax, zero_point)
return [float(scale), int(zero_point)]
def _calculate_dynamic_per_channel_qparams(X, dtype):
@ -165,6 +176,8 @@ def qengine_is_fbgemm():
return torch.backends.quantized.engine == 'fbgemm'
def qengine_is_qnnpack():
return torch.backends.quantized.engine == 'qnnpack'
def qengine_is_onednn():
return torch.backends.quantized.engine == 'onednn'
# Helper function used to simulate per-channel fake-quant against any axis
def _permute_to_axis_zero(X, axis):