mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Adding quantized::conv2d function for pytorch mobile in c10 (#26152)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26152 This change adds the support to call QNNPACK using the refactored API for Conv2d operators Test Plan: python test/test_quantized.py TestQNNPackOps.test_qconv_qnnpack Imported from OSS Differential Revision: D17459892 fbshipit-source-id: d20b3e8b81dd403541cb2b9164731448ca229695
This commit is contained in:
parent
1f51051287
commit
b23be95558
|
|
@ -3,6 +3,7 @@
|
|||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/cpp_custom_type_hack.h>
|
||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||
#include <cmath>
|
||||
|
||||
namespace at {
|
||||
|
|
@ -64,8 +65,24 @@ SmallVector<int64_t, 4> convOutputShape(
|
|||
template <bool ReluFused>
|
||||
class QConv2dInt8 final : public c10::OperatorKernel {
|
||||
public:
|
||||
void conv_checks(
|
||||
int64_t act_dims,
|
||||
int64_t stride_dims,
|
||||
int64_t padding_dims,
|
||||
int64_t dilation_dims) {
|
||||
TORCH_CHECK(
|
||||
act_dims == 4,
|
||||
"quantized::conv2d(): Expected activation tensor to have 4 dimensions.");
|
||||
TORCH_CHECK(
|
||||
stride_dims == 2, "quantized::conv2d(): Supports 2D convolution only");
|
||||
TORCH_CHECK(
|
||||
padding_dims == 2, "quantized::conv2d(): Supports 2D convolution only");
|
||||
TORCH_CHECK(
|
||||
dilation_dims == 2,
|
||||
"quantized::conv2d(): Supports 2D convolution only");
|
||||
}
|
||||
#ifdef USE_FBGEMM
|
||||
Tensor operator()(
|
||||
at::Tensor fbgemm_conv(
|
||||
Tensor act,
|
||||
Tensor packed_weight,
|
||||
torch::List<int64_t> stride,
|
||||
|
|
@ -76,12 +93,9 @@ class QConv2dInt8 final : public c10::OperatorKernel {
|
|||
int64_t output_zero_point) {
|
||||
TORCH_CHECK(
|
||||
fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
|
||||
TORCH_CHECK(
|
||||
act.ndimension() == 4,
|
||||
"Activations are supposed to have 4 dimensions.");
|
||||
TORCH_CHECK(stride.size() == 2, "2D convolution only");
|
||||
TORCH_CHECK(padding.size() == 2, "2D convolution only");
|
||||
TORCH_CHECK(dilation.size() == 2, "2D convolution only");
|
||||
conv_checks(
|
||||
act.ndimension(), stride.size(), padding.size(), dilation.size());
|
||||
|
||||
// inputs are in NHWC format
|
||||
int N = act.size(0);
|
||||
int H = act.size(1);
|
||||
|
|
@ -106,14 +120,14 @@ class QConv2dInt8 final : public c10::OperatorKernel {
|
|||
int stride_w = stride[1];
|
||||
int kernel_h = kernel[0];
|
||||
int kernel_w = kernel[1];
|
||||
|
||||
// clang-format off
|
||||
TORCH_CHECK(C == (packB->inputChannels()),
|
||||
"[QConv2D] Given groups=", groups, ", weight of size ",
|
||||
K, ", ", kernel_h, ", ", kernel_w, ", ", packB->inputChannels(),
|
||||
", expected input (NHWC) ", N, ", ", H, ", ", W, ", ", C,
|
||||
" to have ", (packB->inputChannels() * groups),
|
||||
" channels, but got ", C, " channels instead");
|
||||
|
||||
// clang-format on
|
||||
fbgemm::conv_param_t<> conv_p(
|
||||
N, // Batch size
|
||||
C, // Number of input channels
|
||||
|
|
@ -236,23 +250,147 @@ class QConv2dInt8 final : public c10::OperatorKernel {
|
|||
|
||||
return output;
|
||||
}
|
||||
#else // USE_FBGEMM
|
||||
Tensor operator()(
|
||||
Tensor /* activation */,
|
||||
Tensor /* packed_weight */,
|
||||
torch::List<int64_t> /* stride */,
|
||||
torch::List<int64_t> /* padding */,
|
||||
torch::List<int64_t> /* dilation */,
|
||||
torch::List<int64_t> /* output padding */,
|
||||
int64_t /* groups */,
|
||||
double /* output scale */,
|
||||
int64_t /* output_zero_point */) {
|
||||
#endif
|
||||
#ifdef USE_PYTORCH_QNNPACK
|
||||
at::Tensor qnnpack_conv(
|
||||
Tensor act,
|
||||
Tensor packed_weight,
|
||||
torch::List<int64_t> stride,
|
||||
torch::List<int64_t> padding,
|
||||
torch::List<int64_t> dilation,
|
||||
int64_t groups,
|
||||
double output_scale,
|
||||
int64_t output_zero_point) {
|
||||
conv_checks(
|
||||
act.ndimension(), stride.size(), padding.size(), dilation.size());
|
||||
|
||||
PackedConvWeightsQnnp& pack_ptr =
|
||||
cpp_custom_type_hack::cast<PackedConvWeightsQnnp>(packed_weight);
|
||||
auto packB = pack_ptr.w.get();
|
||||
auto& kernel = pack_ptr.kernel;
|
||||
auto kernel_zp = pack_ptr.w_zp;
|
||||
auto kernel_scale = pack_ptr.w_scale;
|
||||
|
||||
const uint32_t kernel_h = kernel[0];
|
||||
const uint32_t kernel_w = kernel[1];
|
||||
const auto out_ch = packB->getOutputChannels();
|
||||
// inputs are in NHWC format
|
||||
Tensor input_contig = act.contiguous();
|
||||
int N = input_contig.size(0);
|
||||
int H = input_contig.size(1);
|
||||
int W = input_contig.size(2);
|
||||
int in_ch = input_contig.size(3);
|
||||
int K = out_ch; // output channels
|
||||
|
||||
uint32_t stride_h = stride[0];
|
||||
uint32_t stride_w = stride[1];
|
||||
uint32_t pad_t = padding[0];
|
||||
uint32_t pad_l = padding[1];
|
||||
uint32_t dilation_h = dilation[0];
|
||||
uint32_t dilation_w = dilation[1];
|
||||
|
||||
auto output_min = ReluFused
|
||||
? activationLimits(output_scale, output_zero_point, Activation::RELU)
|
||||
.first
|
||||
: std::numeric_limits<uint8_t>::min();
|
||||
auto output_max = ReluFused
|
||||
? activationLimits(output_scale, output_zero_point, Activation::RELU)
|
||||
.second
|
||||
: std::numeric_limits<uint8_t>::max();
|
||||
qnnpack::conv_param_t conv_p(
|
||||
{kernel_w, kernel_h},
|
||||
{stride_w, stride_h},
|
||||
{dilation_w, dilation_h},
|
||||
{pad_t, pad_l, pad_t, pad_l},
|
||||
groups,
|
||||
in_ch,
|
||||
out_ch,
|
||||
kernel_zp,
|
||||
kernel_scale,
|
||||
output_min,
|
||||
output_max);
|
||||
|
||||
auto outShape =
|
||||
convOutputShape(N, H, W, K, kernel, stride, padding, dilation);
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"This PyTorch installation was not built "
|
||||
"with FBGEMM operators");
|
||||
std::all_of(
|
||||
outShape.begin(), outShape.end(), [](int64_t i) { return i > 0; }),
|
||||
"quantized::conv2d (qnnpack): each dimension of output tensor should be greater "
|
||||
"than 0")
|
||||
TORCH_CHECK(
|
||||
(outShape[3] == out_ch),
|
||||
"quantized::conv2d (qnnpack): Number of filters must be equal to number of "
|
||||
"output channels")
|
||||
|
||||
// Allocate output Tensor and a buffer for QNNPACK to use
|
||||
Tensor output = at::_empty_affine_quantized(
|
||||
outShape,
|
||||
at::device(kCPU).dtype(kQUInt8),
|
||||
output_scale,
|
||||
output_zero_point);
|
||||
|
||||
const pytorch_qnnp_status runStatus = qnnpack::qnnpackConv(
|
||||
conv_p,
|
||||
packB->getPackedWeights(),
|
||||
N,
|
||||
H,
|
||||
W,
|
||||
input_contig.q_scale(),
|
||||
input_contig.q_zero_point(),
|
||||
(uint8_t*)input_contig.data_ptr<c10::quint8>(),
|
||||
output.q_scale(),
|
||||
output.q_zero_point(),
|
||||
(uint8_t*)output.data_ptr<c10::quint8>(),
|
||||
nullptr);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
runStatus == pytorch_qnnp_status_success,
|
||||
"failed to run quantized::conv2d (qnnpack) operator");
|
||||
|
||||
return output;
|
||||
}
|
||||
#endif
|
||||
Tensor operator()(
|
||||
Tensor act,
|
||||
Tensor packed_weight,
|
||||
torch::List<int64_t> stride,
|
||||
torch::List<int64_t> padding,
|
||||
torch::List<int64_t> dilation,
|
||||
int64_t groups,
|
||||
double output_scale,
|
||||
int64_t output_zero_point) {
|
||||
auto& ctx = at::globalContext();
|
||||
#ifdef USE_FBGEMM
|
||||
if (ctx.preferredQuantizedEngine() == at::QEngine::FBGEMM) {
|
||||
return fbgemm_conv(
|
||||
act,
|
||||
packed_weight,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
output_scale,
|
||||
output_zero_point);
|
||||
}
|
||||
#endif
|
||||
#ifdef USE_PYTORCH_QNNPACK
|
||||
if (ctx.preferredQuantizedEngine() == at::QEngine::QNNPACK) {
|
||||
return qnnpack_conv(
|
||||
act,
|
||||
packed_weight,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
output_scale,
|
||||
output_zero_point);
|
||||
}
|
||||
#endif
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
"Didn't find engine for operation quantized::conv ",
|
||||
toString(ctx.preferredQuantizedEngine()));
|
||||
return at::Tensor();
|
||||
}
|
||||
#endif // USE_FBGEMM
|
||||
};
|
||||
|
||||
static auto registry =
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/cpp_custom_type_hack.h>
|
||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||
#include <ATen/quantized/Quantizer.h>
|
||||
|
||||
namespace caffe2 {
|
||||
|
|
@ -9,6 +11,10 @@ namespace caffe2 {
|
|||
// Required for cpp_custom_type_hack to work
|
||||
CAFFE_KNOWN_TYPE(PackedConvWeight);
|
||||
#endif
|
||||
#ifdef USE_PYTORCH_QNNPACK
|
||||
// Required for cpp_custom_type_hack to work
|
||||
CAFFE_KNOWN_TYPE(PackedConvWeightsQnnp);
|
||||
#endif // USE_PYTORCH_QNNPACK
|
||||
} // namespace caffe2
|
||||
|
||||
namespace at {
|
||||
|
|
@ -17,7 +23,7 @@ namespace {
|
|||
class QConvPackWeightInt8 final : public c10::OperatorKernel {
|
||||
public:
|
||||
#ifdef USE_FBGEMM
|
||||
Tensor operator()(
|
||||
Tensor fbgemm_conv_prepack(
|
||||
Tensor weight,
|
||||
c10::optional<Tensor> bias,
|
||||
torch::List<int64_t> stride,
|
||||
|
|
@ -31,7 +37,7 @@ class QConvPackWeightInt8 final : public c10::OperatorKernel {
|
|||
TORCH_CHECK(
|
||||
padding.size() == 2,
|
||||
"Specify top/left padding only. \
|
||||
bottom/right padding assumed to be equal to top/left");
|
||||
bottom/right padding assumed to be equal to top/left");
|
||||
TORCH_CHECK(dilation.size() == 2, "2D convolution only");
|
||||
// weights in KRS(C/G) format
|
||||
int output_channels = weight.size(0);
|
||||
|
|
@ -125,19 +131,117 @@ class QConvPackWeightInt8 final : public c10::OperatorKernel {
|
|||
// point.
|
||||
return cpp_custom_type_hack::create(std::move(ret_ptr), weight.options());
|
||||
}
|
||||
#else // USE_FBGEMM
|
||||
Tensor operator()(
|
||||
Tensor, /* weight */
|
||||
c10::optional<Tensor>, /* bias */
|
||||
torch::List<int64_t>, /* stride */
|
||||
torch::List<int64_t>, /* padding */
|
||||
torch::List<int64_t>, /* dilation */
|
||||
int64_t /* groups */
|
||||
) {
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
}
|
||||
#endif // USE_FBGEMM
|
||||
#ifdef USE_PYTORCH_QNNPACK
|
||||
at::Tensor qnnpack_conv_prepack(
|
||||
Tensor weight,
|
||||
c10::optional<Tensor> bias_in,
|
||||
torch::List<int64_t> stride,
|
||||
torch::List<int64_t> padding,
|
||||
torch::List<int64_t> dilation,
|
||||
int64_t groups) {
|
||||
TORCH_CHECK(
|
||||
weight.ndimension() == 4,
|
||||
"quantized::conv_prepack (qnnpack): Weights are expected to have 4 dimensions");
|
||||
const auto qtype = weight.qscheme();
|
||||
TORCH_CHECK(
|
||||
weight.qscheme() == kPerTensorAffine,
|
||||
"quantized::conv_prepack (qnnpack): only supports Per Tensor Quantization Scheme")
|
||||
TORCH_CHECK(
|
||||
stride.size() == 2,
|
||||
"quantized::conv_prepack (qnnpack): 2D convolution only");
|
||||
TORCH_CHECK(
|
||||
padding.size() == 2,
|
||||
"quantized::conv_prepack (qnnpack): Specify top/left padding only. \
|
||||
bottom/right padding assumed to be equal to top/left");
|
||||
TORCH_CHECK(
|
||||
dilation.size() == 2,
|
||||
" quantized::conv_prepack (qnnpack): 2D convolution only");
|
||||
|
||||
initQNNPACK();
|
||||
|
||||
// QNNPACK expects weights to be of the format {out_c, kH, kW, in_c/groups}
|
||||
const size_t out_ch = weight.size(0);
|
||||
const uint32_t kernel_h = weight.size(1);
|
||||
const uint32_t kernel_w = weight.size(2);
|
||||
const size_t in_ch = weight.size(3) * groups;
|
||||
|
||||
Tensor bias;
|
||||
if (bias_in.has_value()) {
|
||||
bias = bias_in.value();
|
||||
} else {
|
||||
bias = at::empty(out_ch, at::kFloat);
|
||||
bias = at::quantize_linear(bias, 1.0, 0, kQInt32);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!bias.defined() || (bias.ndimension() == 1 && bias.size(0) == out_ch),
|
||||
"quantized::conv_prepack (qnnpack): expected bias to be 1-dimensional with ",
|
||||
out_ch,
|
||||
" elements",
|
||||
", but got bias of size ",
|
||||
bias.sizes(),
|
||||
" instead");
|
||||
|
||||
uint32_t stride_h = stride[0];
|
||||
uint32_t stride_w = stride[1];
|
||||
uint32_t pad_t = padding[0];
|
||||
uint32_t pad_l = padding[1];
|
||||
uint32_t dilation_h = dilation[0];
|
||||
uint32_t dilation_w = dilation[1];
|
||||
|
||||
qnnpack::conv_param_t conv_p(
|
||||
{kernel_w, kernel_h},
|
||||
{stride_w, stride_h},
|
||||
{dilation_w, dilation_h},
|
||||
{pad_t, pad_l, pad_t, pad_l},
|
||||
groups,
|
||||
in_ch,
|
||||
out_ch,
|
||||
weight.q_zero_point(),
|
||||
weight.q_scale(),
|
||||
std::numeric_limits<uint8_t>::min(),
|
||||
std::numeric_limits<uint8_t>::max());
|
||||
|
||||
auto weight_contig = weight.contiguous();
|
||||
auto bias_contig = bias.contiguous();
|
||||
auto wt_ptr =
|
||||
guts::make_unique<PackedConvWeightsQnnp>(PackedConvWeightsQnnp{
|
||||
guts::make_unique<qnnpack::PrePackConvWeights>(
|
||||
conv_p,
|
||||
(uint8_t*)weight_contig.data_ptr<c10::quint8>(),
|
||||
(int32_t*)bias_contig.data_ptr<c10::qint32>()),
|
||||
{kernel_h, kernel_w},
|
||||
weight.q_scale(),
|
||||
weight.q_zero_point()});
|
||||
|
||||
return cpp_custom_type_hack::create(std::move(wt_ptr), weight.options());
|
||||
}
|
||||
#endif // USE_PYTORCH_QNNPACK
|
||||
Tensor operator()(
|
||||
Tensor weight,
|
||||
c10::optional<Tensor> bias,
|
||||
torch::List<int64_t> stride,
|
||||
torch::List<int64_t> padding,
|
||||
torch::List<int64_t> dilation,
|
||||
int64_t groups) {
|
||||
auto& ctx = at::globalContext();
|
||||
#ifdef USE_FBGEMM
|
||||
if (ctx.preferredQuantizedEngine() == at::QEngine::FBGEMM) {
|
||||
return fbgemm_conv_prepack(
|
||||
weight, bias, stride, padding, dilation, groups);
|
||||
}
|
||||
#endif
|
||||
#ifdef USE_PYTORCH_QNNPACK
|
||||
if (ctx.preferredQuantizedEngine() == at::QEngine::QNNPACK) {
|
||||
return qnnpack_conv_prepack(
|
||||
weight, bias, stride, padding, dilation, groups);
|
||||
}
|
||||
#endif
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
"Didn't find engine for operation quantized::conv_prepack ",
|
||||
toString(ctx.preferredQuantizedEngine()));
|
||||
return at::Tensor();
|
||||
}
|
||||
};
|
||||
|
||||
static auto registry = c10::RegisterOperators().op(
|
||||
|
|
|
|||
|
|
@ -15,6 +15,13 @@ struct PackedLinearWeightsQnnp {
|
|||
int64_t w_zp;
|
||||
};
|
||||
|
||||
struct PackedConvWeightsQnnp {
|
||||
std::unique_ptr<qnnpack::PrePackConvWeights> w;
|
||||
std::vector<int64_t> kernel;
|
||||
double w_scale;
|
||||
int64_t w_zp;
|
||||
};
|
||||
|
||||
enum class Activation : uint8_t { NONE = 0, RELU = 1 };
|
||||
|
||||
#if defined(__ANDROID__) && !defined(__NDK_MAJOR__)
|
||||
|
|
|
|||
|
|
@ -1012,7 +1012,6 @@ class TestQuantizedConv(unittest.TestCase):
|
|||
use_relu,
|
||||
use_channelwise
|
||||
):
|
||||
|
||||
qconv = torch.ops.quantized.conv2d
|
||||
if use_relu:
|
||||
qconv = torch.ops.quantized.conv2d_relu
|
||||
|
|
@ -1325,6 +1324,129 @@ class TestQNNPackOps(TestCase):
|
|||
np.testing.assert_equal(
|
||||
Y_q_ref2.int_repr().numpy(), Y_q.int_repr().numpy())
|
||||
|
||||
@given(batch_size=st.integers(1, 3),
|
||||
input_channels_per_group=st.sampled_from([8, 16, 32]),
|
||||
height=st.integers(10, 16),
|
||||
width=st.integers(7, 14),
|
||||
output_channels_per_group=st.sampled_from([8, 16, 32]),
|
||||
groups=st.integers(1, 3),
|
||||
kernel_h=st.integers(1, 7),
|
||||
kernel_w=st.integers(1, 7),
|
||||
stride_h=st.integers(1, 2),
|
||||
stride_w=st.integers(1, 2),
|
||||
pad_h=st.integers(0, 2),
|
||||
pad_w=st.integers(0, 2),
|
||||
dilation_h=st.integers(1, 1),
|
||||
X_scale=st.floats(1.2, 1.6),
|
||||
X_zp=st.integers(0, 4),
|
||||
W_scale=st.floats(0.2, 1.6),
|
||||
W_zp=st.integers(2, 5),
|
||||
Y_scale=st.floats(4.2, 5.6),
|
||||
Y_zp=st.integers(0, 4),
|
||||
use_relu=st.booleans())
|
||||
def test_qconv_qnnpack(
|
||||
self,
|
||||
batch_size,
|
||||
input_channels_per_group,
|
||||
height,
|
||||
width,
|
||||
output_channels_per_group,
|
||||
groups,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
stride_h,
|
||||
stride_w,
|
||||
pad_h,
|
||||
pad_w,
|
||||
dilation_h,
|
||||
X_scale,
|
||||
X_zp,
|
||||
W_scale,
|
||||
W_zp,
|
||||
Y_scale,
|
||||
Y_zp,
|
||||
use_relu):
|
||||
with enable_mobile_quantized_engine():
|
||||
# C
|
||||
input_channels = input_channels_per_group * groups
|
||||
# K
|
||||
output_channels = output_channels_per_group * groups
|
||||
|
||||
stride = [stride_h, stride_w]
|
||||
padding = [pad_h, pad_w]
|
||||
kernel = [kernel_h, kernel_w]
|
||||
dilation = [dilation_h, dilation_h]
|
||||
|
||||
W_value_min = 0
|
||||
W_value_max = 10
|
||||
W_init = torch.from_numpy(
|
||||
np.random.randint(
|
||||
W_value_min,
|
||||
W_value_max,
|
||||
(output_channels, int(input_channels / groups), kernel_h, kernel_w)),
|
||||
)
|
||||
b_init = torch.from_numpy(np.random.randint(0, 10, (output_channels,)))
|
||||
|
||||
X_value_min = 0
|
||||
X_value_max = 10
|
||||
X_init = torch.from_numpy(np.random.randint(
|
||||
X_value_min, X_value_max, (batch_size, input_channels, height, width)))
|
||||
|
||||
# Existing floating point conv operator
|
||||
conv_op = torch.nn.Conv2d(
|
||||
input_channels,
|
||||
output_channels,
|
||||
(kernel_h, kernel_w),
|
||||
(stride_h, stride_w),
|
||||
(pad_h, pad_w),
|
||||
(dilation_h, dilation_h),
|
||||
groups,
|
||||
)
|
||||
|
||||
X = X_scale * (X_init - X_zp).to(dtype=torch.float)
|
||||
|
||||
W = W_scale * (W_init - W_zp).to(dtype=torch.float)
|
||||
|
||||
b = X_scale * W_scale * (b_init - 0).to(dtype=torch.float)
|
||||
|
||||
# assign weights
|
||||
conv_op.weight = torch.nn.Parameter(W, requires_grad=False)
|
||||
conv_op.bias = torch.nn.Parameter(b, requires_grad=False)
|
||||
|
||||
result_ref = conv_op(X)
|
||||
|
||||
X_NHWC = X.permute([0, 2, 3, 1]).contiguous()
|
||||
W_RSCK = W.permute([0, 2, 3, 1]).contiguous()
|
||||
|
||||
X_q = torch.quantize_linear(X_NHWC, scale=X_scale, zero_point=X_zp, dtype=torch.quint8)
|
||||
W_q = torch.quantize_linear(W_RSCK, scale=W_scale, zero_point=W_zp, dtype=torch.quint8)
|
||||
b_q = torch.quantize_linear(b, scale=X_scale * W_scale, zero_point=0, dtype=torch.qint32)
|
||||
|
||||
W_pack = torch.ops.quantized.conv_prepack(W_q, b_q, stride, padding, dilation, groups)
|
||||
qconv = torch.ops.quantized.conv2d
|
||||
if use_relu:
|
||||
qconv = torch.ops.quantized.conv2d_relu
|
||||
|
||||
Y_q = qconv(
|
||||
X_q,
|
||||
W_pack,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
Y_scale,
|
||||
Y_zp
|
||||
)
|
||||
|
||||
result_NHWK = result_ref.permute([0, 2, 3, 1])
|
||||
if use_relu:
|
||||
relu = torch.nn.ReLU()
|
||||
result_NHWK = relu(result_NHWK)
|
||||
|
||||
result_ref_q = torch.quantize_linear(result_NHWK, scale=Y_scale, zero_point=Y_zp, dtype=torch.quint8)
|
||||
|
||||
np.testing.assert_array_almost_equal(result_ref_q.int_repr().numpy(), Y_q.int_repr().numpy(), decimal=0)
|
||||
|
||||
"""Tests the correctness of the quantized::qnnpack_add op."""
|
||||
@given(A=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
|
||||
qparams=hu.qparams(dtypes=torch.quint8,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user