[Quant] Add fused ConvAddReLU2d module for onednn backend (#91154)

**Summary**
Post op fusion can reduce data movement overhead and improve inference performance. This PR adds fused ConvAddReLU2d module for onednn backend, which will be used for int8 inference with onednn backend. Cannot call this module with other quantization backends otherwise an error is thrown.

**Test plan**
```
python -m pytest test_quantization.py -k test_conv2d_add_relu
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91154
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
leslie-fang-intel 2023-01-31 11:09:35 +08:00 committed by PyTorch MergeBot
parent ef4118e435
commit e77f28a03d
7 changed files with 128 additions and 5 deletions

View File

@ -278,7 +278,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
example_input = [X, ]
example_input_q = [X_q, ]
if post_op == "add":
if post_op in ["add", "add_relu"]:
X2, X2_q = _make_conv_add_extra_input_tensor(X2_scale, X2_zero_point, conv_module[0](X).size())
example_input = [X, X2]
example_input_q = [X_q, X2_q]
@ -290,7 +290,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
qconv_module.scale = Y_scale
qconv_module.zero_point = Y_zero_point
raw_conv_module = conv_module[0] if post_op in ["relu", "add"] else conv_module
raw_conv_module = conv_module[0] if post_op in ["relu", "add", "add_relu"] else conv_module
raw_conv_module.weight.data = W
if use_bias:
raw_conv_module.bias.data = b
@ -356,7 +356,6 @@ class TestStaticQuantizedModule(QuantizationTestCase):
self.assertEqual(qconv_module.scale, loaded_qconv_module.scale)
self.assertEqual(qconv_module.zero_point,
loaded_qconv_module.zero_point)
Y_loaded = loaded_qconv_module(*example_input_q)
np.testing.assert_array_almost_equal(
Y_exp.int_repr().numpy(), Y_loaded.int_repr().numpy(), decimal=0)
@ -396,6 +395,12 @@ class TestStaticQuantizedModule(QuantizationTestCase):
qconv_module, [example_input_q],
check_save_load=True)
if post_op in ["add_relu"]:
# **TODO Leslie** Remove this part when enabling the lowering in next PR.
# workaround in this PR to return from here, since the below lowering part enabled in next PR
# We will enable below check in next PR
return
class _FusedModule_two_input_args(torch.nn.intrinsic._FusedModule):
# Help Module for ConvAdd2d since torch.nn.intrinsic._FusedModule only support one input arg
def forward(self, x1, x2):
@ -829,6 +834,65 @@ class TestStaticQuantizedModule(QuantizationTestCase):
pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point,
Y_scale, Y_zero_point, use_bias, "add", use_channelwise, X2_scale, X2_zero_point)
@skipIfNoONEDNN
def test_conv2d_add_relu(self):
"""test API functionality for nn.intrinsic.quantized.ConvAdd2d"""
with override_quantized_engine('onednn'):
options = itertools.product(
["zeros", "reflect"], # pad_mode
[True, False], # use_bias
[True, False], # use_channelwise
)
batch_size = 2
in_channels_per_group = 2
H = 8
W = 8
out_channels_per_group = 2
groups = 3
kernel_h = 3
kernel_w = 3
stride_h = 2
stride_w = 2
pad_h = 1
pad_w = 1
dilation = 1
# Tests the correctness of the conv2d module.
in_channels = in_channels_per_group * groups
out_channels = out_channels_per_group * groups
input_feature_map_size = (H, W)
kernel_size = (kernel_h, kernel_w)
stride = (stride_h, stride_w)
padding = (pad_h, pad_w)
dilation = (dilation, dilation)
X_scale = 1.3
X_zero_point = 2
X2_scale = 1.2
X2_zero_point = 1
W_scale = [0.5]
W_zero_point = [0] if qengine_is_onednn() else [3]
Y_scale = 5.0
Y_zero_point = 4
qconv_cls = nniq.ConvAddReLU2d
module_name = "QuantizedConvAddReLU2d"
for pad_mode, use_bias, use_channelwise in options:
qconv_module = qconv_cls(
in_channels, out_channels, kernel_size, stride, padding,
dilation, groups, use_bias, padding_mode=pad_mode
)
conv_module = nn.Conv2d(
in_channels, out_channels, kernel_size, stride, padding,
dilation, groups, use_bias, padding_mode=pad_mode)
conv_module = torch.ao.nn.intrinsic.ConvAddReLU2d(conv_module, torch.add, nn.ReLU())
conv_module = conv_module.float()
self._test_conv_api_impl(
module_name, qconv_module, conv_module, batch_size,
in_channels_per_group, input_feature_map_size,
out_channels_per_group, groups, kernel_size, stride, padding,
pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point,
Y_scale, Y_zero_point, use_bias, "add_relu", use_channelwise, X2_scale, X2_zero_point)
def test_pool_api(self):
"""Tests the correctness of the pool module.
The correctness is defined against the functional implementation.

View File

@ -22,6 +22,7 @@ __all__ = [
'LinearLeakyReLU',
'LinearTanh',
'ConvAdd2d',
'ConvAddReLU2d',
]
# We are exposing all subpackages to the end-user.

View File

@ -15,6 +15,7 @@ from .fused import LinearBn1d
from .fused import LinearLeakyReLU
from .fused import LinearTanh
from .fused import ConvAdd2d
from .fused import ConvAddReLU2d
__all__ = [
'ConvBn1d',
@ -33,4 +34,5 @@ __all__ = [
'LinearLeakyReLU',
'LinearTanh',
'ConvAdd2d',
'ConvAddReLU2d',
]

View File

@ -4,7 +4,7 @@ from torch.nn.utils.parametrize import type_before_parametrizations
__all__ = ['ConvReLU1d', 'ConvReLU2d', 'ConvReLU3d', 'LinearReLU', 'ConvBn1d', 'ConvBn2d',
'ConvBnReLU1d', 'ConvBnReLU2d', 'ConvBn3d', 'ConvBnReLU3d', 'BNReLU2d', 'BNReLU3d',
'LinearBn1d', 'LinearLeakyReLU', 'LinearTanh', 'ConvAdd2d']
'LinearBn1d', 'LinearLeakyReLU', 'LinearTanh', 'ConvAdd2d', 'ConvAddReLU2d']
# Used for identifying intrinsic modules used in quantization
class _FusedModule(torch.nn.Sequential):
@ -155,3 +155,14 @@ class ConvAdd2d(_FusedModule):
def forward(self, x1, x2):
return self.add(self[0](x1), x2)
class ConvAddReLU2d(_FusedModule):
r"""This is a sequential container which calls the Conv2d, add, Relu.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, add, relu):
super().__init__(conv)
self.add = add
self.relu = relu
def forward(self, x1, x2):
return self.relu(self.add(self[0](x1), x2))

View File

@ -10,4 +10,5 @@ __all__ = [
'LinearLeakyReLU',
'LinearTanh',
'ConvAdd2d',
'ConvAddReLU2d',
]

View File

@ -1,7 +1,7 @@
from .linear_relu import LinearReLU, LinearLeakyReLU, LinearTanh
from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d
from .bn_relu import BNReLU2d, BNReLU3d
from .conv_add import ConvAdd2d
from .conv_add import ConvAdd2d, ConvAddReLU2d
__all__ = [
'LinearReLU',
@ -13,4 +13,5 @@ __all__ = [
'LinearLeakyReLU',
'LinearTanh',
'ConvAdd2d',
'ConvAddReLU2d',
]

View File

@ -48,3 +48,46 @@ class ConvAdd2d(nnq.Conv2d):
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
class ConvAddReLU2d(nnq.Conv2d):
r"""
A ConvAddReLU2d module is a fused module of Conv2d, Add and Relu
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
Attributes:
Same as torch.ao.nn.quantized.Conv2d
"""
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAddReLU2d # type: ignore[assignment]
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
padding_mode='zeros', device=None, dtype=None):
super().__init__(
in_channels, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias,
padding_mode=padding_mode, device=device, dtype=dtype)
def forward(self, input, extra_input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
if self.padding_mode != 'zeros':
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(input, _reversed_padding_repeated_twice,
mode=self.padding_mode)
return torch.ops.quantized.conv2d_add_relu(
input, extra_input, self._packed_params, self.scale, self.zero_point)
def _get_name(self):
return 'QuantizedConvAddReLU2d'
@classmethod
def from_float(cls, mod):
return super().from_float(mod)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)