[quant] Add quantized Sigmoid module (#45883)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45883

Test Plan:
python test/test_quantization.py TestStaticQuantizedModule.test_sigmoid

Imported from OSS

Reviewed By: z-a-f

Differential Revision: D24129116

fbshipit-source-id: aa960549509c60374012f35b1f5be39e90418099
This commit is contained in:
Jerry Zhang 2020-10-07 10:24:57 -07:00 committed by Facebook GitHub Bot
parent 30bf799f9c
commit 83d2c9a232
3 changed files with 26 additions and 1 deletions

View File

@ -716,6 +716,9 @@ class TestStaticQuantizedModule(QuantizationTestCase):
def test_leaky_relu(self):
self._test_activation_module_impl("LeakyReLU", nn.LeakyReLU, nnq.LeakyReLU, {"negative_slope": 0.2})
def test_sigmoid(self):
self._test_activation_module_impl("Sigmoid", nn.Sigmoid, nnq.Sigmoid, {})
@given(
num_embeddings=st.integers(10, 50),
embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),

View File

@ -2,7 +2,7 @@
import torch
from torch.nn.modules.pooling import MaxPool2d
from .activation import ReLU, ReLU6, Hardswish, ELU, LeakyReLU
from .activation import ReLU, ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid
from .batchnorm import BatchNorm2d, BatchNorm3d
from .normalization import LayerNorm, GroupNorm, InstanceNorm1d, \
InstanceNorm2d, InstanceNorm3d
@ -100,6 +100,7 @@ __all__ = [
'Hardswish',
'ELU',
'LeakyReLU',
'Sigmoid',
'LayerNorm',
'GroupNorm',
'InstanceNorm1d',

View File

@ -149,3 +149,24 @@ class LeakyReLU(torch.nn.LeakyReLU):
def from_float(cls, mod):
scale, zero_point = mod.activation_post_process.calculate_qparams()
return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
class Sigmoid(torch.nn.Sigmoid):
r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`.
Args:
scale: quantization scale of the output tensor
zero_point: quantization zero point of the output tensor
"""
def __init__(self, output_scale: float, output_zero_point: int):
super().__init__()
self.output_scale = output_scale
self.output_zero_point = output_zero_point
def forward(self, input):
return torch.ops.quantized.sigmoid(input, self.output_scale, self.output_zero_point)
@classmethod
def from_float(cls, mod):
output_scale, output_zero_point = mod.activation_post_process.calculate_qparams()
return cls(float(output_scale), int(output_zero_point))