mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant] Add quantized Sigmoid module (#45883)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45883 Test Plan: python test/test_quantization.py TestStaticQuantizedModule.test_sigmoid Imported from OSS Reviewed By: z-a-f Differential Revision: D24129116 fbshipit-source-id: aa960549509c60374012f35b1f5be39e90418099
This commit is contained in:
parent
30bf799f9c
commit
83d2c9a232
|
|
@ -716,6 +716,9 @@ class TestStaticQuantizedModule(QuantizationTestCase):
|
|||
def test_leaky_relu(self):
|
||||
self._test_activation_module_impl("LeakyReLU", nn.LeakyReLU, nnq.LeakyReLU, {"negative_slope": 0.2})
|
||||
|
||||
def test_sigmoid(self):
|
||||
self._test_activation_module_impl("Sigmoid", nn.Sigmoid, nnq.Sigmoid, {})
|
||||
|
||||
@given(
|
||||
num_embeddings=st.integers(10, 50),
|
||||
embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import torch
|
||||
from torch.nn.modules.pooling import MaxPool2d
|
||||
|
||||
from .activation import ReLU, ReLU6, Hardswish, ELU, LeakyReLU
|
||||
from .activation import ReLU, ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid
|
||||
from .batchnorm import BatchNorm2d, BatchNorm3d
|
||||
from .normalization import LayerNorm, GroupNorm, InstanceNorm1d, \
|
||||
InstanceNorm2d, InstanceNorm3d
|
||||
|
|
@ -100,6 +100,7 @@ __all__ = [
|
|||
'Hardswish',
|
||||
'ELU',
|
||||
'LeakyReLU',
|
||||
'Sigmoid',
|
||||
'LayerNorm',
|
||||
'GroupNorm',
|
||||
'InstanceNorm1d',
|
||||
|
|
|
|||
|
|
@ -149,3 +149,24 @@ class LeakyReLU(torch.nn.LeakyReLU):
|
|||
def from_float(cls, mod):
|
||||
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
||||
return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
|
||||
|
||||
class Sigmoid(torch.nn.Sigmoid):
|
||||
r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`.
|
||||
|
||||
Args:
|
||||
scale: quantization scale of the output tensor
|
||||
zero_point: quantization zero point of the output tensor
|
||||
"""
|
||||
|
||||
def __init__(self, output_scale: float, output_zero_point: int):
|
||||
super().__init__()
|
||||
self.output_scale = output_scale
|
||||
self.output_zero_point = output_zero_point
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.quantized.sigmoid(input, self.output_scale, self.output_zero_point)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod):
|
||||
output_scale, output_zero_point = mod.activation_post_process.calculate_qparams()
|
||||
return cls(float(output_scale), int(output_zero_point))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user