mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23858 Pull Request resolved: https://github.com/pytorch/pytorch/pull/23718 Changes: - Enable tests for quantization test files in `run_tests.py` - Remove `__future__` imports from `torch/nn/qat/modules/__init__.py`, since `unicode_literals` messes up imports on python2 because the elements in `__all__` will be Unicode and not string - Skip PostTrainingQuantTests if the build doesn't have FBGEMM (only a small subset of targets in tests) or if testing under UBSAN (the suppression file doesn't seem to work) Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D16639467 Pulled By: jamesr66a fbshipit-source-id: 532766797c216976dd7e07d751f768ff8e0fc207
120 lines
5.2 KiB
Python
120 lines
5.2 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import unittest
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.quantized.functional as qF
|
|
|
|
from hypothesis import assume, given
|
|
from hypothesis import strategies as st
|
|
import hypothesis_utils as hu
|
|
|
|
from common_quantized import _conv_output_shape
|
|
from common_utils import TestCase, run_tests, TEST_WITH_UBSAN
|
|
|
|
@unittest.skipIf(TEST_WITH_UBSAN or not torch.fbgemm_is_cpu_supported(),
|
|
'Quantization requires FBGEMM. FBGEMM does not play'
|
|
' well with UBSAN at the moment, so we skip the test if'
|
|
' we are in a UBSAN environment.')
|
|
class QuantizedConvTest(TestCase):
|
|
@given(X=hu.tensor_conv2d(min_batch=1, max_batch=3,
|
|
min_in_channels=1, max_in_channels=7,
|
|
min_out_channels=1, max_out_channels=7,
|
|
H_range=(6, 12), W_range=(6, 12),
|
|
kH_range=(3, 5), kW_range=(3, 5),
|
|
max_groups=4,
|
|
qparams=[hu.qparams(dtypes=torch.quint8,
|
|
zero_point_min=0,
|
|
zero_point_max=0),
|
|
hu.qparams(dtypes=torch.qint8,
|
|
zero_point_min=0,
|
|
zero_point_max=0),
|
|
hu.qparams(dtypes=torch.qint32,
|
|
zero_point_min=0,
|
|
zero_point_max=0)]),
|
|
padH=st.integers(1, 3), padW=st.integers(1, 3),
|
|
sH=st.integers(1, 3), sW=st.integers(1, 3),
|
|
dH=st.integers(1, 2), dW=st.integers(1, 2))
|
|
def test_conv_api(self, X, padH, padW, sH, sW, dH, dW):
|
|
"""Tests the correctness of the conv functional.
|
|
|
|
The correctness is defined by the behavior being similar to the
|
|
`quantized._ops` implementation.
|
|
"""
|
|
# Random inputs
|
|
# X, (scale, zero_point, torch_type) = X
|
|
(inputs, filters, bias, groups) = X
|
|
inputs, (inputs_scale, inputs_zero_point, inputs_qtype) = inputs
|
|
filters, (filters_scale, filters_zero_point, filters_qtype) = filters
|
|
bias, (bias_scale, bias_zero_point, bias_qtype) = bias
|
|
|
|
scale, zero_point = inputs_scale, inputs_zero_point
|
|
torch_type = inputs_qtype
|
|
|
|
iC, oC = inputs.shape[1], filters.shape[0]
|
|
|
|
iH, iW = inputs.shape[2:]
|
|
kH, kW = filters.shape[2:]
|
|
assume(kH // 2 >= padH)
|
|
assume(kW // 2 >= padW)
|
|
oH = _conv_output_shape(iH, kH, padH, sH, dH)
|
|
assume(oH > 0)
|
|
oW = _conv_output_shape(iW, kW, padW, sW, dW)
|
|
assume(oW > 0)
|
|
|
|
inputs = torch.from_numpy(inputs).to(torch.float)
|
|
filters = torch.from_numpy(filters).to(torch.float)
|
|
bias = torch.from_numpy(bias).to(torch.float)
|
|
|
|
kernel_size = (kH, kW)
|
|
stride = (sH, sW)
|
|
i_padding = (padH, padW)
|
|
dilation = (dH, dW)
|
|
|
|
# Quantized inputs
|
|
q_inputs = torch.quantize_linear(inputs, inputs_scale,
|
|
inputs_zero_point, inputs_qtype)
|
|
q_filters = torch.quantize_linear(filters, filters_scale,
|
|
filters_zero_point, filters_qtype)
|
|
q_bias = torch.quantize_linear(bias, bias_scale, bias_zero_point,
|
|
bias_qtype)
|
|
|
|
# Reference op
|
|
ref_op = torch.ops.quantized.fbgemm_conv2d
|
|
|
|
# Results check
|
|
try:
|
|
q_filters_ref = torch.ops.quantized.fbgemm_conv_prepack(q_filters.permute([0, 2, 3, 1]),
|
|
stride,
|
|
i_padding,
|
|
dilation,
|
|
groups)
|
|
ref_result = ref_op(q_inputs.permute([0, 2, 3, 1]), q_filters_ref,
|
|
q_bias, stride,
|
|
i_padding, dilation,
|
|
groups, scale, zero_point).permute([0, 3, 1, 2])
|
|
except RuntimeError as e:
|
|
e_msg = str(e).split("\n")[0].split("(")[0].strip()
|
|
np.testing.assert_raises_regex(
|
|
type(e), e_msg, qF.conv2d,
|
|
q_inputs, q_filters, bias=q_bias,
|
|
scale=scale, zero_point=zero_point,
|
|
stride=stride, padding=i_padding, dilation=dilation,
|
|
groups=groups, dtype=torch_type)
|
|
else:
|
|
q_result = qF.conv2d(q_inputs,
|
|
q_filters,
|
|
bias=q_bias, scale=scale,
|
|
zero_point=zero_point,
|
|
stride=stride, padding=i_padding,
|
|
dilation=dilation, groups=groups,
|
|
dtype=torch_type)
|
|
|
|
self.assertEqual(ref_result, q_result)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|