pytorch/test/test_quantized_conv.py
Daya Khudia f510409281 Enable FBGEMM tests under UBSAN as well (#23570)
Summary:
Enabling tests under UBSAN as well
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23570

Test Plan:
buck test mode/dev caffe2/test:quantized
```
Running 29 tests
Started new test run: https://our.intern.facebook.com/intern/testinfra/testrun/3940649677415136
      ✓ caffe2/test:quantized - test_qtensor (test_quantized_tensor.TestQuantizedTensor) 0.536 1/29 (passed)
      ✓ caffe2/test:quantized - test_qtensor_per_channel_affine (test_quantized_tensor.TestQuantizedTensor) 0.453 2/29 (passed)
      ✓ caffe2/test:quantized - test_qtensor_reshape (test_quantized_tensor.TestQuantizedTensor) 0.302 3/29 (passed)
      ✓ caffe2/test:quantized - test_qadd_relu_same_qparams (test_quantized.TestQuantizedOps) 0.332 4/29 (passed)
      ✓ caffe2/test:quantized - test_qtensor_view (test_quantized_tensor.TestQuantizedTensor) 0.351 5/29 (passed)
      ✓ caffe2/test:quantized - test_qadd_relu_different_qparams (test_quantized.TestQuantizedOps) 0.348 6/29 (passed)
      ✓ caffe2/test:quantized - test_qtensor_dequantize_linear (test_quantized_tensor.TestQuantizedTensor) 0.338 7/29 (passed)
      ✓ caffe2/test:quantized - test_qtensor_copy (test_quantized_tensor.TestQuantizedTensor) 0.267 8/29 (passed)
      ✓ caffe2/test:quantized - test_qtensor_clone (test_quantized_tensor.TestQuantizedTensor) 0.330 9/29 (passed)
      ✓ caffe2/test:quantized - test_qrelu (test_quantized.TestQuantizedOps) 1.774 10/29 (passed)
      ✓ caffe2/test:quantized - test_pool_api (test_nn_quantized.ModuleAPITest) 0.418 11/29 (passed)
      ✓ caffe2/test:quantized - test_qtensor_load_save (test_quantized_tensor.TestQuantizedTensor) 0.724 12/29 (passed)
      ✓ caffe2/test:quantized - test_relu_api (test_nn_quantized.FunctionalAPITest) 1.013 13/29 (passed)
      ✓ caffe2/test:quantized - test_qtensor_quant_dequant (test_quantized_tensor.TestQuantizedTensor) 1.055 14/29 (passed)
      ✓ caffe2/test:quantized - test_qtensor_permute (test_quantized_tensor.TestQuantizedTensor) 0.696 15/29 (passed)
      ✓ caffe2/test:quantized - test_qtensor_dtypes (test_quantized_tensor.TestQuantizedTensor) 0.841 16/29 (passed)
      ✓ caffe2/test:quantized - test_quant_dequant_api (test_nn_quantized.ModuleAPITest) 0.616 17/29 (passed)
      ✓ caffe2/test:quantized - test_qtensor_creation (test_quantized_tensor.TestQuantizedTensor) 0.698 18/29 (passed)
      ✓ caffe2/test:quantized - test_qconv (test_quantized.TestQuantizedConv) 4.743 19/29 (passed)
      ✓ caffe2/test:quantized - test_cat (test_quantized.TestQuantizedOps) 6.992 20/29 (passed)
      ✓ caffe2/test:quantized - test_linear_api (test_nn_quantized.ModuleAPITest) 8.970 21/29 (passed)
      ✓ caffe2/test:quantized - test_conv_api (test_quantized_conv.QuantizedConvTest) 9.403 22/29 (passed)
      ↷ caffe2/test:quantized - test_qnnpack_linear (test_quantized.TestQNNPackOps) 0.000 23/29 (skipped)
Test output:
> Skipped: QNNPACK does not play well with UBSAN at the moment, so we skip the test if we are in a UBSAN environment.
> test_qnnpack_linear (test_quantized.TestQNNPackOps) ... skipped 'QNNPACK does not play well with UBSAN at the moment, so we skip the test if we are in a UBSAN environment.'
>
> ----------------------------------------------------------------------
> Ran 1 test in 0.000s
>
> OK (skipped=1)
      ↷ caffe2/test:quantized - test_qnnpack_relu (test_quantized.TestQNNPackOps) 0.000 24/29 (skipped)
Test output:
> Skipped: QNNPACK does not play well with UBSAN at the moment, so we skip the test if we are in a UBSAN environment.
> test_qnnpack_relu (test_quantized.TestQNNPackOps) ... skipped 'QNNPACK does not play well with UBSAN at the moment, so we skip the test if we are in a UBSAN environment.'
>
> ----------------------------------------------------------------------
> Ran 1 test in 0.000s
>
> OK (skipped=1)
      ✓ caffe2/test:quantized - test_max_pool2d (test_quantized.TestQuantizedOps) 8.453 25/29 (passed)
      ✓ caffe2/test:quantized - test_qlinear_unpack (test_quantized.TestQuantizedLinear) 0.664 26/29 (passed)
      ✓ caffe2/test:quantized - test_qconv_unpack (test_quantized.TestQuantizedConv) 2.965 27/29 (passed)
      ✓ caffe2/test:quantized - test_qlinear (test_quantized.TestQuantizedLinear) 1.915 28/29 (passed)
      ✓ caffe2/test:quantized - test_conv_api (test_nn_quantized.ModuleAPITest) 60.804 29/29 (passed)
      ✓ caffe2/test:quantized - main 0.000 (passed)
Finished test run: https://our.intern.facebook.com/intern/testinfra/testrun/3940649677415136
Summary (total time 68.66s):
  PASS: 28
  FAIL: 0
  SKIP: 2
    caffe2/test:quantized - test_qnnpack_linear (test_quantized.TestQNNPackOps)
    caffe2/test:quantized - test_qnnpack_relu (test_quantized.TestQNNPackOps)
  FATAL: 0
  TIMEOUT: 0
  OMIT: 0
```

Reviewed By: jianyuh

Differential Revision: D16569166

Pulled By: dskhudia

fbshipit-source-id: 53522b4162eb1ebb35b408a1503d9664305c85b0
2019-08-12 17:59:22 -07:00

121 lines
5.1 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import unittest
import numpy as np
import torch
import torch.nn.quantized.functional as qF
from hypothesis import assume, given
from hypothesis import strategies as st
import hypothesis_utils as hu
from common_quantized import _conv_output_shape
from common_utils import TestCase, run_tests
@unittest.skipIf(
not torch.fbgemm_is_cpu_supported(),
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.",
)
class QuantizedConvTest(TestCase):
@given(X=hu.tensor_conv2d(min_batch=1, max_batch=3,
min_in_channels=1, max_in_channels=7,
min_out_channels=1, max_out_channels=7,
H_range=(6, 12), W_range=(6, 12),
kH_range=(3, 5), kW_range=(3, 5),
max_groups=4,
qparams=[hu.qparams(dtypes=torch.quint8,
zero_point_min=0,
zero_point_max=0),
hu.qparams(dtypes=torch.qint8,
zero_point_min=0,
zero_point_max=0),
hu.qparams(dtypes=torch.qint32,
zero_point_min=0,
zero_point_max=0)]),
padH=st.integers(1, 3), padW=st.integers(1, 3),
sH=st.integers(1, 3), sW=st.integers(1, 3),
dH=st.integers(1, 2), dW=st.integers(1, 2))
def test_conv_api(self, X, padH, padW, sH, sW, dH, dW):
"""Tests the correctness of the conv functional.
The correctness is defined by the behavior being similar to the
`quantized._ops` implementation.
"""
# Random inputs
# X, (scale, zero_point, torch_type) = X
(inputs, filters, bias, groups) = X
inputs, (inputs_scale, inputs_zero_point, inputs_qtype) = inputs
filters, (filters_scale, filters_zero_point, filters_qtype) = filters
bias, (bias_scale, bias_zero_point, bias_qtype) = bias
scale, zero_point = inputs_scale, inputs_zero_point
torch_type = inputs_qtype
iC, oC = inputs.shape[1], filters.shape[0]
iH, iW = inputs.shape[2:]
kH, kW = filters.shape[2:]
assume(kH // 2 >= padH)
assume(kW // 2 >= padW)
oH = _conv_output_shape(iH, kH, padH, sH, dH)
assume(oH > 0)
oW = _conv_output_shape(iW, kW, padW, sW, dW)
assume(oW > 0)
inputs = torch.from_numpy(inputs).to(torch.float)
filters = torch.from_numpy(filters).to(torch.float)
bias = torch.from_numpy(bias).to(torch.float)
kernel_size = (kH, kW)
stride = (sH, sW)
i_padding = (padH, padW)
dilation = (dH, dW)
# Quantized inputs
q_inputs = torch.quantize_linear(inputs, inputs_scale,
inputs_zero_point, inputs_qtype)
q_filters = torch.quantize_linear(filters, filters_scale,
filters_zero_point, filters_qtype)
q_bias = torch.quantize_linear(bias, bias_scale, bias_zero_point,
bias_qtype)
# Reference op
ref_op = torch.ops.quantized.fbgemm_conv2d
# Results check
try:
q_filters_ref = torch.ops.quantized.fbgemm_conv_prepack(q_filters.permute([0, 2, 3, 1]),
stride,
i_padding,
dilation,
groups)
ref_result = ref_op(q_inputs.permute([0, 2, 3, 1]), q_filters_ref,
q_bias, stride,
i_padding, dilation,
groups, scale, zero_point).permute([0, 3, 1, 2])
except RuntimeError as e:
e_msg = str(e).split("\n")[0].split("(")[0].strip()
np.testing.assert_raises_regex(
type(e), e_msg, qF.conv2d,
q_inputs, q_filters, bias=q_bias,
scale=scale, zero_point=zero_point,
stride=stride, padding=i_padding, dilation=dilation,
groups=groups, dtype=torch_type)
else:
q_result = qF.conv2d(q_inputs,
q_filters,
bias=q_bias, scale=scale,
zero_point=zero_point,
stride=stride, padding=i_padding,
dilation=dilation, groups=groups,
dtype=torch_type)
self.assertEqual(ref_result, q_result)
if __name__ == "__main__":
run_tests()