mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Quant] Add fused ConvAddReLU2d module for onednn backend (#91154)
**Summary** Post op fusion can reduce data movement overhead and improve inference performance. This PR adds fused ConvAddReLU2d module for onednn backend, which will be used for int8 inference with onednn backend. Cannot call this module with other quantization backends otherwise an error is thrown. **Test plan** ``` python -m pytest test_quantization.py -k test_conv2d_add_relu ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/91154 Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
parent
ef4118e435
commit
e77f28a03d
|
|
@ -278,7 +278,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
|
|||
example_input = [X, ]
|
||||
example_input_q = [X_q, ]
|
||||
|
||||
if post_op == "add":
|
||||
if post_op in ["add", "add_relu"]:
|
||||
X2, X2_q = _make_conv_add_extra_input_tensor(X2_scale, X2_zero_point, conv_module[0](X).size())
|
||||
example_input = [X, X2]
|
||||
example_input_q = [X_q, X2_q]
|
||||
|
|
@ -290,7 +290,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
|
|||
qconv_module.scale = Y_scale
|
||||
qconv_module.zero_point = Y_zero_point
|
||||
|
||||
raw_conv_module = conv_module[0] if post_op in ["relu", "add"] else conv_module
|
||||
raw_conv_module = conv_module[0] if post_op in ["relu", "add", "add_relu"] else conv_module
|
||||
raw_conv_module.weight.data = W
|
||||
if use_bias:
|
||||
raw_conv_module.bias.data = b
|
||||
|
|
@ -356,7 +356,6 @@ class TestStaticQuantizedModule(QuantizationTestCase):
|
|||
self.assertEqual(qconv_module.scale, loaded_qconv_module.scale)
|
||||
self.assertEqual(qconv_module.zero_point,
|
||||
loaded_qconv_module.zero_point)
|
||||
|
||||
Y_loaded = loaded_qconv_module(*example_input_q)
|
||||
np.testing.assert_array_almost_equal(
|
||||
Y_exp.int_repr().numpy(), Y_loaded.int_repr().numpy(), decimal=0)
|
||||
|
|
@ -396,6 +395,12 @@ class TestStaticQuantizedModule(QuantizationTestCase):
|
|||
qconv_module, [example_input_q],
|
||||
check_save_load=True)
|
||||
|
||||
if post_op in ["add_relu"]:
|
||||
# **TODO Leslie** Remove this part when enabling the lowering in next PR.
|
||||
# workaround in this PR to return from here, since the below lowering part enabled in next PR
|
||||
# We will enable below check in next PR
|
||||
return
|
||||
|
||||
class _FusedModule_two_input_args(torch.nn.intrinsic._FusedModule):
|
||||
# Help Module for ConvAdd2d since torch.nn.intrinsic._FusedModule only support one input arg
|
||||
def forward(self, x1, x2):
|
||||
|
|
@ -829,6 +834,65 @@ class TestStaticQuantizedModule(QuantizationTestCase):
|
|||
pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point,
|
||||
Y_scale, Y_zero_point, use_bias, "add", use_channelwise, X2_scale, X2_zero_point)
|
||||
|
||||
@skipIfNoONEDNN
|
||||
def test_conv2d_add_relu(self):
|
||||
"""test API functionality for nn.intrinsic.quantized.ConvAdd2d"""
|
||||
with override_quantized_engine('onednn'):
|
||||
options = itertools.product(
|
||||
["zeros", "reflect"], # pad_mode
|
||||
[True, False], # use_bias
|
||||
[True, False], # use_channelwise
|
||||
)
|
||||
batch_size = 2
|
||||
in_channels_per_group = 2
|
||||
H = 8
|
||||
W = 8
|
||||
out_channels_per_group = 2
|
||||
groups = 3
|
||||
kernel_h = 3
|
||||
kernel_w = 3
|
||||
stride_h = 2
|
||||
stride_w = 2
|
||||
pad_h = 1
|
||||
pad_w = 1
|
||||
dilation = 1
|
||||
# Tests the correctness of the conv2d module.
|
||||
in_channels = in_channels_per_group * groups
|
||||
out_channels = out_channels_per_group * groups
|
||||
input_feature_map_size = (H, W)
|
||||
kernel_size = (kernel_h, kernel_w)
|
||||
stride = (stride_h, stride_w)
|
||||
padding = (pad_h, pad_w)
|
||||
dilation = (dilation, dilation)
|
||||
X_scale = 1.3
|
||||
X_zero_point = 2
|
||||
X2_scale = 1.2
|
||||
X2_zero_point = 1
|
||||
W_scale = [0.5]
|
||||
W_zero_point = [0] if qengine_is_onednn() else [3]
|
||||
Y_scale = 5.0
|
||||
Y_zero_point = 4
|
||||
qconv_cls = nniq.ConvAddReLU2d
|
||||
module_name = "QuantizedConvAddReLU2d"
|
||||
for pad_mode, use_bias, use_channelwise in options:
|
||||
qconv_module = qconv_cls(
|
||||
in_channels, out_channels, kernel_size, stride, padding,
|
||||
dilation, groups, use_bias, padding_mode=pad_mode
|
||||
)
|
||||
|
||||
conv_module = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size, stride, padding,
|
||||
dilation, groups, use_bias, padding_mode=pad_mode)
|
||||
conv_module = torch.ao.nn.intrinsic.ConvAddReLU2d(conv_module, torch.add, nn.ReLU())
|
||||
conv_module = conv_module.float()
|
||||
|
||||
self._test_conv_api_impl(
|
||||
module_name, qconv_module, conv_module, batch_size,
|
||||
in_channels_per_group, input_feature_map_size,
|
||||
out_channels_per_group, groups, kernel_size, stride, padding,
|
||||
pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point,
|
||||
Y_scale, Y_zero_point, use_bias, "add_relu", use_channelwise, X2_scale, X2_zero_point)
|
||||
|
||||
def test_pool_api(self):
|
||||
"""Tests the correctness of the pool module.
|
||||
The correctness is defined against the functional implementation.
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ __all__ = [
|
|||
'LinearLeakyReLU',
|
||||
'LinearTanh',
|
||||
'ConvAdd2d',
|
||||
'ConvAddReLU2d',
|
||||
]
|
||||
|
||||
# We are exposing all subpackages to the end-user.
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from .fused import LinearBn1d
|
|||
from .fused import LinearLeakyReLU
|
||||
from .fused import LinearTanh
|
||||
from .fused import ConvAdd2d
|
||||
from .fused import ConvAddReLU2d
|
||||
|
||||
__all__ = [
|
||||
'ConvBn1d',
|
||||
|
|
@ -33,4 +34,5 @@ __all__ = [
|
|||
'LinearLeakyReLU',
|
||||
'LinearTanh',
|
||||
'ConvAdd2d',
|
||||
'ConvAddReLU2d',
|
||||
]
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from torch.nn.utils.parametrize import type_before_parametrizations
|
|||
|
||||
__all__ = ['ConvReLU1d', 'ConvReLU2d', 'ConvReLU3d', 'LinearReLU', 'ConvBn1d', 'ConvBn2d',
|
||||
'ConvBnReLU1d', 'ConvBnReLU2d', 'ConvBn3d', 'ConvBnReLU3d', 'BNReLU2d', 'BNReLU3d',
|
||||
'LinearBn1d', 'LinearLeakyReLU', 'LinearTanh', 'ConvAdd2d']
|
||||
'LinearBn1d', 'LinearLeakyReLU', 'LinearTanh', 'ConvAdd2d', 'ConvAddReLU2d']
|
||||
|
||||
# Used for identifying intrinsic modules used in quantization
|
||||
class _FusedModule(torch.nn.Sequential):
|
||||
|
|
@ -155,3 +155,14 @@ class ConvAdd2d(_FusedModule):
|
|||
|
||||
def forward(self, x1, x2):
|
||||
return self.add(self[0](x1), x2)
|
||||
|
||||
class ConvAddReLU2d(_FusedModule):
|
||||
r"""This is a sequential container which calls the Conv2d, add, Relu.
|
||||
During quantization this will be replaced with the corresponding fused module."""
|
||||
def __init__(self, conv, add, relu):
|
||||
super().__init__(conv)
|
||||
self.add = add
|
||||
self.relu = relu
|
||||
|
||||
def forward(self, x1, x2):
|
||||
return self.relu(self.add(self[0](x1), x2))
|
||||
|
|
|
|||
|
|
@ -10,4 +10,5 @@ __all__ = [
|
|||
'LinearLeakyReLU',
|
||||
'LinearTanh',
|
||||
'ConvAdd2d',
|
||||
'ConvAddReLU2d',
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from .linear_relu import LinearReLU, LinearLeakyReLU, LinearTanh
|
||||
from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d
|
||||
from .bn_relu import BNReLU2d, BNReLU3d
|
||||
from .conv_add import ConvAdd2d
|
||||
from .conv_add import ConvAdd2d, ConvAddReLU2d
|
||||
|
||||
__all__ = [
|
||||
'LinearReLU',
|
||||
|
|
@ -13,4 +13,5 @@ __all__ = [
|
|||
'LinearLeakyReLU',
|
||||
'LinearTanh',
|
||||
'ConvAdd2d',
|
||||
'ConvAddReLU2d',
|
||||
]
|
||||
|
|
|
|||
|
|
@ -48,3 +48,46 @@ class ConvAdd2d(nnq.Conv2d):
|
|||
@classmethod
|
||||
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
||||
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
|
||||
|
||||
class ConvAddReLU2d(nnq.Conv2d):
|
||||
r"""
|
||||
A ConvAddReLU2d module is a fused module of Conv2d, Add and Relu
|
||||
|
||||
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
|
||||
|
||||
Attributes:
|
||||
Same as torch.ao.nn.quantized.Conv2d
|
||||
|
||||
"""
|
||||
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAddReLU2d # type: ignore[assignment]
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True,
|
||||
padding_mode='zeros', device=None, dtype=None):
|
||||
super().__init__(
|
||||
in_channels, out_channels, kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=groups, bias=bias,
|
||||
padding_mode=padding_mode, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, input, extra_input):
|
||||
# Temporarily using len(shape) instead of ndim due to JIT issue
|
||||
# https://github.com/pytorch/pytorch/issues/23890
|
||||
if len(input.shape) != 4:
|
||||
raise ValueError("Input shape must be `(N, C, H, W)`!")
|
||||
if self.padding_mode != 'zeros':
|
||||
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
|
||||
input = F.pad(input, _reversed_padding_repeated_twice,
|
||||
mode=self.padding_mode)
|
||||
return torch.ops.quantized.conv2d_add_relu(
|
||||
input, extra_input, self._packed_params, self.scale, self.zero_point)
|
||||
|
||||
def _get_name(self):
|
||||
return 'QuantizedConvAddReLU2d'
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod):
|
||||
return super().from_float(mod)
|
||||
|
||||
@classmethod
|
||||
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
||||
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user