quantized layer norm: add to static quant (#36690)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36690

Adds the static quantization hook for LayerNorm

Test Plan:
```
python test/quantization/test_quantized_nn_mods.py ModuleAPITest.test_layer_norm
python test/quantization/test_quantization.py EagerModePostTrainingQuantTest.test_normalization
```

Imported from OSS

Differential Revision: D21055401

fbshipit-source-id: 188329f35359576d50ed0db5fb675ce68c28bf7d
This commit is contained in:
Vasiliy Kuznetsov 2020-04-16 18:13:30 -07:00 committed by Facebook GitHub Bot
parent 24aac32171
commit 2c558dba3d
6 changed files with 162 additions and 1 deletions

View File

@ -35,7 +35,7 @@ from torch.testing._internal.common_quantization import QuantizationTestCase, \
prepare_dynamic, convert_dynamic, SingleLayerLinearDynamicModel, \
TwoLayerLinearModel, NestedModel, ResNetBase, LSTMDynamicModel, \
ModelWithNoQconfigPropagation, ModelForFusionWithBias, \
ActivationsTestModel, ActivationsQATTestModel
ActivationsTestModel, ActivationsQATTestModel, NormalizationTestModel
from torch.testing._internal.common_quantization import AnnotatedTwoLayerLinearModel, AnnotatedNestedModel, \
AnnotatedSubNestedModel, AnnotatedCustomConfigNestedModel
@ -315,6 +315,29 @@ class EagerModePostTrainingQuantTest(QuantizationTestCase):
checkQuantized(model)
def test_normalization(self):
r"""
Test quantization of normalization layers
"""
model = NormalizationTestModel()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
prepare(model, inplace=True)
self.checkObservers(model)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model.layer_norm)
self.assertEqual(type(model.layer_norm), nnq.LayerNorm)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
model_oneline = quantize(
NormalizationTestModel(), test_only_eval_fn, self.calib_data)
checkQuantized(model)
@given(qengine=st.sampled_from(("qnnpack", "fbgemm")))
def test_save_load_state_dict(self, qengine):
r"""Test PTQ flow of creating a model and quantizing it and saving the quantized state_dict

View File

@ -850,5 +850,36 @@ class ModuleAPITest(QuantizationTestCase):
self.assertEqual(quant_ref.int_repr().numpy(), qy.int_repr().numpy(),
message="BatchNorm3d module API failed")
def test_layer_norm(self):
"""Tests the correctness of the layernorm module.
The correctness is defined against the functional implementation.
"""
x_scale = 10.0 / 256
x_zero_point = 0
y_scale = 5.0 / 256
y_zero_point = 127
dims = (1, 4, 8)
X = (torch.randn(dims, dtype=torch.float) - 0.5) * 10
qX = torch.quantize_per_tensor(X, x_scale, x_zero_point, dtype=torch.quint8)
dqX = qX.dequantize()
float_mod = torch.nn.LayerNorm(dqX.size()[1:])
float_mod.weight = torch.nn.Parameter(torch.rand(*dims[1:]))
float_mod.bias = torch.nn.Parameter(torch.rand(*dims[1:]))
dqY_ref = float_mod(dqX)
qY_ref = torch.quantize_per_tensor(
dqY_ref, y_scale, y_zero_point, dtype=torch.quint8)
quant_mod = nnq.LayerNorm(
qX.size()[1:], float_mod.weight, float_mod.bias, y_scale, y_zero_point)
qY = quant_mod(qX)
self.assertEqual(qY_ref.int_repr().numpy(), qY.int_repr().numpy(),
message="LayerNorm module API failed, qY_ref\n{} vs qY\n{}"
.format(qY_ref, qY))
if __name__ == '__main__':
run_tests()

View File

@ -4,6 +4,7 @@ from torch.nn.modules.pooling import MaxPool2d
from .activation import ReLU, ReLU6, Hardswish
from .batchnorm import BatchNorm2d, BatchNorm3d
from .normalization import LayerNorm
from .conv import Conv1d, Conv2d, Conv3d
from .linear import Linear
@ -89,6 +90,7 @@ __all__ = [
'ReLU',
'ReLU6',
'Hardswish',
'LayerNorm',
# Wrapper modules
'FloatFunctional',
'QFunctional',

View File

@ -0,0 +1,91 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import torch
import torch.nn.quantized.functional
class LayerNorm(torch.nn.LayerNorm):
r"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ .
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
\times \ldots \times \text{normalized\_shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = nn.LayerNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = nn.LayerNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = nn.LayerNorm(10)
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
def __init__(self, normalized_shape, weight, bias, scale, zero_point, eps=1e-5,
elementwise_affine=True):
super(LayerNorm, self).__init__(
normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
self.weight = weight
self.bias = bias
self.scale = scale
self.zero_point = zero_point
def forward(self, input):
return torch.ops.quantized.layer_norm(
input, self.normalized_shape, weight=self.weight, bias=self.bias,
eps=self.eps, output_scale=self.scale, output_zero_point=self.zero_point)
def _get_name(self):
return 'QuantizedLayerNorm'
@classmethod
def from_float(cls, mod):
activation_post_process = mod.activation_post_process
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.normalized_shape, mod.weight, mod.bias, float(scale),
int(zero_point), mod.eps, mod.elementwise_affine)
return new_mod

View File

@ -21,6 +21,7 @@ DEFAULT_MODULE_MAPPING = {
nn.Conv3d: nnq.Conv3d,
nn.BatchNorm2d: nnq.BatchNorm2d,
nn.BatchNorm3d: nnq.BatchNorm3d,
nn.LayerNorm: nnq.LayerNorm,
QuantStub: nnq.Quantize,
DeQuantStub: nnq.DeQuantize,
# Wrapper Modules:

View File

@ -311,6 +311,19 @@ class LinearReluModel(torch.nn.Module):
x = self.relu(self.fc(x))
return x
class NormalizationTestModel(torch.nn.Module):
def __init__(self):
super(NormalizationTestModel, self).__init__()
self.quant = torch.quantization.QuantStub()
self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
self.layer_norm = torch.nn.LayerNorm((8))
def forward(self, x):
x = self.quant(x)
x = self.fc1(x)
x = self.layer_norm(x)
return x
class NestedModel(torch.nn.Module):
def __init__(self):
super(NestedModel, self).__init__()