Adding quantized::conv2d function for pytorch mobile in c10 (#26152)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26152

This change adds the support to call QNNPACK using the refactored API for Conv2d operators

Test Plan:
python test/test_quantized.py TestQNNPackOps.test_qconv_qnnpack

Imported from OSS

Differential Revision: D17459892

fbshipit-source-id: d20b3e8b81dd403541cb2b9164731448ca229695
This commit is contained in:
Supriya Rao 2019-09-18 16:46:59 -07:00 committed by Facebook Github Bot
parent 1f51051287
commit b23be95558
4 changed files with 410 additions and 39 deletions

View File

@ -3,6 +3,7 @@
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/cpp_custom_type_hack.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <cmath>
namespace at {
@ -64,8 +65,24 @@ SmallVector<int64_t, 4> convOutputShape(
template <bool ReluFused>
class QConv2dInt8 final : public c10::OperatorKernel {
public:
void conv_checks(
int64_t act_dims,
int64_t stride_dims,
int64_t padding_dims,
int64_t dilation_dims) {
TORCH_CHECK(
act_dims == 4,
"quantized::conv2d(): Expected activation tensor to have 4 dimensions.");
TORCH_CHECK(
stride_dims == 2, "quantized::conv2d(): Supports 2D convolution only");
TORCH_CHECK(
padding_dims == 2, "quantized::conv2d(): Supports 2D convolution only");
TORCH_CHECK(
dilation_dims == 2,
"quantized::conv2d(): Supports 2D convolution only");
}
#ifdef USE_FBGEMM
Tensor operator()(
at::Tensor fbgemm_conv(
Tensor act,
Tensor packed_weight,
torch::List<int64_t> stride,
@ -76,12 +93,9 @@ class QConv2dInt8 final : public c10::OperatorKernel {
int64_t output_zero_point) {
TORCH_CHECK(
fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
TORCH_CHECK(
act.ndimension() == 4,
"Activations are supposed to have 4 dimensions.");
TORCH_CHECK(stride.size() == 2, "2D convolution only");
TORCH_CHECK(padding.size() == 2, "2D convolution only");
TORCH_CHECK(dilation.size() == 2, "2D convolution only");
conv_checks(
act.ndimension(), stride.size(), padding.size(), dilation.size());
// inputs are in NHWC format
int N = act.size(0);
int H = act.size(1);
@ -106,14 +120,14 @@ class QConv2dInt8 final : public c10::OperatorKernel {
int stride_w = stride[1];
int kernel_h = kernel[0];
int kernel_w = kernel[1];
// clang-format off
TORCH_CHECK(C == (packB->inputChannels()),
"[QConv2D] Given groups=", groups, ", weight of size ",
K, ", ", kernel_h, ", ", kernel_w, ", ", packB->inputChannels(),
", expected input (NHWC) ", N, ", ", H, ", ", W, ", ", C,
" to have ", (packB->inputChannels() * groups),
" channels, but got ", C, " channels instead");
// clang-format on
fbgemm::conv_param_t<> conv_p(
N, // Batch size
C, // Number of input channels
@ -236,23 +250,147 @@ class QConv2dInt8 final : public c10::OperatorKernel {
return output;
}
#else // USE_FBGEMM
Tensor operator()(
Tensor /* activation */,
Tensor /* packed_weight */,
torch::List<int64_t> /* stride */,
torch::List<int64_t> /* padding */,
torch::List<int64_t> /* dilation */,
torch::List<int64_t> /* output padding */,
int64_t /* groups */,
double /* output scale */,
int64_t /* output_zero_point */) {
#endif
#ifdef USE_PYTORCH_QNNPACK
at::Tensor qnnpack_conv(
Tensor act,
Tensor packed_weight,
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> dilation,
int64_t groups,
double output_scale,
int64_t output_zero_point) {
conv_checks(
act.ndimension(), stride.size(), padding.size(), dilation.size());
PackedConvWeightsQnnp& pack_ptr =
cpp_custom_type_hack::cast<PackedConvWeightsQnnp>(packed_weight);
auto packB = pack_ptr.w.get();
auto& kernel = pack_ptr.kernel;
auto kernel_zp = pack_ptr.w_zp;
auto kernel_scale = pack_ptr.w_scale;
const uint32_t kernel_h = kernel[0];
const uint32_t kernel_w = kernel[1];
const auto out_ch = packB->getOutputChannels();
// inputs are in NHWC format
Tensor input_contig = act.contiguous();
int N = input_contig.size(0);
int H = input_contig.size(1);
int W = input_contig.size(2);
int in_ch = input_contig.size(3);
int K = out_ch; // output channels
uint32_t stride_h = stride[0];
uint32_t stride_w = stride[1];
uint32_t pad_t = padding[0];
uint32_t pad_l = padding[1];
uint32_t dilation_h = dilation[0];
uint32_t dilation_w = dilation[1];
auto output_min = ReluFused
? activationLimits(output_scale, output_zero_point, Activation::RELU)
.first
: std::numeric_limits<uint8_t>::min();
auto output_max = ReluFused
? activationLimits(output_scale, output_zero_point, Activation::RELU)
.second
: std::numeric_limits<uint8_t>::max();
qnnpack::conv_param_t conv_p(
{kernel_w, kernel_h},
{stride_w, stride_h},
{dilation_w, dilation_h},
{pad_t, pad_l, pad_t, pad_l},
groups,
in_ch,
out_ch,
kernel_zp,
kernel_scale,
output_min,
output_max);
auto outShape =
convOutputShape(N, H, W, K, kernel, stride, padding, dilation);
TORCH_CHECK(
false,
"This PyTorch installation was not built "
"with FBGEMM operators");
std::all_of(
outShape.begin(), outShape.end(), [](int64_t i) { return i > 0; }),
"quantized::conv2d (qnnpack): each dimension of output tensor should be greater "
"than 0")
TORCH_CHECK(
(outShape[3] == out_ch),
"quantized::conv2d (qnnpack): Number of filters must be equal to number of "
"output channels")
// Allocate output Tensor and a buffer for QNNPACK to use
Tensor output = at::_empty_affine_quantized(
outShape,
at::device(kCPU).dtype(kQUInt8),
output_scale,
output_zero_point);
const pytorch_qnnp_status runStatus = qnnpack::qnnpackConv(
conv_p,
packB->getPackedWeights(),
N,
H,
W,
input_contig.q_scale(),
input_contig.q_zero_point(),
(uint8_t*)input_contig.data_ptr<c10::quint8>(),
output.q_scale(),
output.q_zero_point(),
(uint8_t*)output.data_ptr<c10::quint8>(),
nullptr);
TORCH_INTERNAL_ASSERT(
runStatus == pytorch_qnnp_status_success,
"failed to run quantized::conv2d (qnnpack) operator");
return output;
}
#endif
Tensor operator()(
Tensor act,
Tensor packed_weight,
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> dilation,
int64_t groups,
double output_scale,
int64_t output_zero_point) {
auto& ctx = at::globalContext();
#ifdef USE_FBGEMM
if (ctx.preferredQuantizedEngine() == at::QEngine::FBGEMM) {
return fbgemm_conv(
act,
packed_weight,
stride,
padding,
dilation,
groups,
output_scale,
output_zero_point);
}
#endif
#ifdef USE_PYTORCH_QNNPACK
if (ctx.preferredQuantizedEngine() == at::QEngine::QNNPACK) {
return qnnpack_conv(
act,
packed_weight,
stride,
padding,
dilation,
groups,
output_scale,
output_zero_point);
}
#endif
TORCH_INTERNAL_ASSERT(
"Didn't find engine for operation quantized::conv ",
toString(ctx.preferredQuantizedEngine()));
return at::Tensor();
}
#endif // USE_FBGEMM
};
static auto registry =

View File

@ -2,6 +2,8 @@
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/cpp_custom_type_hack.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/quantized/Quantizer.h>
namespace caffe2 {
@ -9,6 +11,10 @@ namespace caffe2 {
// Required for cpp_custom_type_hack to work
CAFFE_KNOWN_TYPE(PackedConvWeight);
#endif
#ifdef USE_PYTORCH_QNNPACK
// Required for cpp_custom_type_hack to work
CAFFE_KNOWN_TYPE(PackedConvWeightsQnnp);
#endif // USE_PYTORCH_QNNPACK
} // namespace caffe2
namespace at {
@ -17,7 +23,7 @@ namespace {
class QConvPackWeightInt8 final : public c10::OperatorKernel {
public:
#ifdef USE_FBGEMM
Tensor operator()(
Tensor fbgemm_conv_prepack(
Tensor weight,
c10::optional<Tensor> bias,
torch::List<int64_t> stride,
@ -31,7 +37,7 @@ class QConvPackWeightInt8 final : public c10::OperatorKernel {
TORCH_CHECK(
padding.size() == 2,
"Specify top/left padding only. \
bottom/right padding assumed to be equal to top/left");
bottom/right padding assumed to be equal to top/left");
TORCH_CHECK(dilation.size() == 2, "2D convolution only");
// weights in KRS(C/G) format
int output_channels = weight.size(0);
@ -125,19 +131,117 @@ class QConvPackWeightInt8 final : public c10::OperatorKernel {
// point.
return cpp_custom_type_hack::create(std::move(ret_ptr), weight.options());
}
#else // USE_FBGEMM
Tensor operator()(
Tensor, /* weight */
c10::optional<Tensor>, /* bias */
torch::List<int64_t>, /* stride */
torch::List<int64_t>, /* padding */
torch::List<int64_t>, /* dilation */
int64_t /* groups */
) {
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
}
#endif // USE_FBGEMM
#ifdef USE_PYTORCH_QNNPACK
at::Tensor qnnpack_conv_prepack(
Tensor weight,
c10::optional<Tensor> bias_in,
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> dilation,
int64_t groups) {
TORCH_CHECK(
weight.ndimension() == 4,
"quantized::conv_prepack (qnnpack): Weights are expected to have 4 dimensions");
const auto qtype = weight.qscheme();
TORCH_CHECK(
weight.qscheme() == kPerTensorAffine,
"quantized::conv_prepack (qnnpack): only supports Per Tensor Quantization Scheme")
TORCH_CHECK(
stride.size() == 2,
"quantized::conv_prepack (qnnpack): 2D convolution only");
TORCH_CHECK(
padding.size() == 2,
"quantized::conv_prepack (qnnpack): Specify top/left padding only. \
bottom/right padding assumed to be equal to top/left");
TORCH_CHECK(
dilation.size() == 2,
" quantized::conv_prepack (qnnpack): 2D convolution only");
initQNNPACK();
// QNNPACK expects weights to be of the format {out_c, kH, kW, in_c/groups}
const size_t out_ch = weight.size(0);
const uint32_t kernel_h = weight.size(1);
const uint32_t kernel_w = weight.size(2);
const size_t in_ch = weight.size(3) * groups;
Tensor bias;
if (bias_in.has_value()) {
bias = bias_in.value();
} else {
bias = at::empty(out_ch, at::kFloat);
bias = at::quantize_linear(bias, 1.0, 0, kQInt32);
}
TORCH_CHECK(
!bias.defined() || (bias.ndimension() == 1 && bias.size(0) == out_ch),
"quantized::conv_prepack (qnnpack): expected bias to be 1-dimensional with ",
out_ch,
" elements",
", but got bias of size ",
bias.sizes(),
" instead");
uint32_t stride_h = stride[0];
uint32_t stride_w = stride[1];
uint32_t pad_t = padding[0];
uint32_t pad_l = padding[1];
uint32_t dilation_h = dilation[0];
uint32_t dilation_w = dilation[1];
qnnpack::conv_param_t conv_p(
{kernel_w, kernel_h},
{stride_w, stride_h},
{dilation_w, dilation_h},
{pad_t, pad_l, pad_t, pad_l},
groups,
in_ch,
out_ch,
weight.q_zero_point(),
weight.q_scale(),
std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
auto weight_contig = weight.contiguous();
auto bias_contig = bias.contiguous();
auto wt_ptr =
guts::make_unique<PackedConvWeightsQnnp>(PackedConvWeightsQnnp{
guts::make_unique<qnnpack::PrePackConvWeights>(
conv_p,
(uint8_t*)weight_contig.data_ptr<c10::quint8>(),
(int32_t*)bias_contig.data_ptr<c10::qint32>()),
{kernel_h, kernel_w},
weight.q_scale(),
weight.q_zero_point()});
return cpp_custom_type_hack::create(std::move(wt_ptr), weight.options());
}
#endif // USE_PYTORCH_QNNPACK
Tensor operator()(
Tensor weight,
c10::optional<Tensor> bias,
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> dilation,
int64_t groups) {
auto& ctx = at::globalContext();
#ifdef USE_FBGEMM
if (ctx.preferredQuantizedEngine() == at::QEngine::FBGEMM) {
return fbgemm_conv_prepack(
weight, bias, stride, padding, dilation, groups);
}
#endif
#ifdef USE_PYTORCH_QNNPACK
if (ctx.preferredQuantizedEngine() == at::QEngine::QNNPACK) {
return qnnpack_conv_prepack(
weight, bias, stride, padding, dilation, groups);
}
#endif
TORCH_INTERNAL_ASSERT(
"Didn't find engine for operation quantized::conv_prepack ",
toString(ctx.preferredQuantizedEngine()));
return at::Tensor();
}
};
static auto registry = c10::RegisterOperators().op(

View File

@ -15,6 +15,13 @@ struct PackedLinearWeightsQnnp {
int64_t w_zp;
};
struct PackedConvWeightsQnnp {
std::unique_ptr<qnnpack::PrePackConvWeights> w;
std::vector<int64_t> kernel;
double w_scale;
int64_t w_zp;
};
enum class Activation : uint8_t { NONE = 0, RELU = 1 };
#if defined(__ANDROID__) && !defined(__NDK_MAJOR__)

View File

@ -1012,7 +1012,6 @@ class TestQuantizedConv(unittest.TestCase):
use_relu,
use_channelwise
):
qconv = torch.ops.quantized.conv2d
if use_relu:
qconv = torch.ops.quantized.conv2d_relu
@ -1325,6 +1324,129 @@ class TestQNNPackOps(TestCase):
np.testing.assert_equal(
Y_q_ref2.int_repr().numpy(), Y_q.int_repr().numpy())
@given(batch_size=st.integers(1, 3),
input_channels_per_group=st.sampled_from([8, 16, 32]),
height=st.integers(10, 16),
width=st.integers(7, 14),
output_channels_per_group=st.sampled_from([8, 16, 32]),
groups=st.integers(1, 3),
kernel_h=st.integers(1, 7),
kernel_w=st.integers(1, 7),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation_h=st.integers(1, 1),
X_scale=st.floats(1.2, 1.6),
X_zp=st.integers(0, 4),
W_scale=st.floats(0.2, 1.6),
W_zp=st.integers(2, 5),
Y_scale=st.floats(4.2, 5.6),
Y_zp=st.integers(0, 4),
use_relu=st.booleans())
def test_qconv_qnnpack(
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
X_scale,
X_zp,
W_scale,
W_zp,
Y_scale,
Y_zp,
use_relu):
with enable_mobile_quantized_engine():
# C
input_channels = input_channels_per_group * groups
# K
output_channels = output_channels_per_group * groups
stride = [stride_h, stride_w]
padding = [pad_h, pad_w]
kernel = [kernel_h, kernel_w]
dilation = [dilation_h, dilation_h]
W_value_min = 0
W_value_max = 10
W_init = torch.from_numpy(
np.random.randint(
W_value_min,
W_value_max,
(output_channels, int(input_channels / groups), kernel_h, kernel_w)),
)
b_init = torch.from_numpy(np.random.randint(0, 10, (output_channels,)))
X_value_min = 0
X_value_max = 10
X_init = torch.from_numpy(np.random.randint(
X_value_min, X_value_max, (batch_size, input_channels, height, width)))
# Existing floating point conv operator
conv_op = torch.nn.Conv2d(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_h),
groups,
)
X = X_scale * (X_init - X_zp).to(dtype=torch.float)
W = W_scale * (W_init - W_zp).to(dtype=torch.float)
b = X_scale * W_scale * (b_init - 0).to(dtype=torch.float)
# assign weights
conv_op.weight = torch.nn.Parameter(W, requires_grad=False)
conv_op.bias = torch.nn.Parameter(b, requires_grad=False)
result_ref = conv_op(X)
X_NHWC = X.permute([0, 2, 3, 1]).contiguous()
W_RSCK = W.permute([0, 2, 3, 1]).contiguous()
X_q = torch.quantize_linear(X_NHWC, scale=X_scale, zero_point=X_zp, dtype=torch.quint8)
W_q = torch.quantize_linear(W_RSCK, scale=W_scale, zero_point=W_zp, dtype=torch.quint8)
b_q = torch.quantize_linear(b, scale=X_scale * W_scale, zero_point=0, dtype=torch.qint32)
W_pack = torch.ops.quantized.conv_prepack(W_q, b_q, stride, padding, dilation, groups)
qconv = torch.ops.quantized.conv2d
if use_relu:
qconv = torch.ops.quantized.conv2d_relu
Y_q = qconv(
X_q,
W_pack,
stride,
padding,
dilation,
groups,
Y_scale,
Y_zp
)
result_NHWK = result_ref.permute([0, 2, 3, 1])
if use_relu:
relu = torch.nn.ReLU()
result_NHWK = relu(result_NHWK)
result_ref_q = torch.quantize_linear(result_NHWK, scale=Y_scale, zero_point=Y_zp, dtype=torch.quint8)
np.testing.assert_array_almost_equal(result_ref_q.int_repr().numpy(), Y_q.int_repr().numpy(), decimal=0)
"""Tests the correctness of the quantized::qnnpack_add op."""
@given(A=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
qparams=hu.qparams(dtypes=torch.quint8,