mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
quantized layer norm: add to static quant (#36690)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36690 Adds the static quantization hook for LayerNorm Test Plan: ``` python test/quantization/test_quantized_nn_mods.py ModuleAPITest.test_layer_norm python test/quantization/test_quantization.py EagerModePostTrainingQuantTest.test_normalization ``` Imported from OSS Differential Revision: D21055401 fbshipit-source-id: 188329f35359576d50ed0db5fb675ce68c28bf7d
This commit is contained in:
parent
24aac32171
commit
2c558dba3d
|
|
@ -35,7 +35,7 @@ from torch.testing._internal.common_quantization import QuantizationTestCase, \
|
|||
prepare_dynamic, convert_dynamic, SingleLayerLinearDynamicModel, \
|
||||
TwoLayerLinearModel, NestedModel, ResNetBase, LSTMDynamicModel, \
|
||||
ModelWithNoQconfigPropagation, ModelForFusionWithBias, \
|
||||
ActivationsTestModel, ActivationsQATTestModel
|
||||
ActivationsTestModel, ActivationsQATTestModel, NormalizationTestModel
|
||||
|
||||
from torch.testing._internal.common_quantization import AnnotatedTwoLayerLinearModel, AnnotatedNestedModel, \
|
||||
AnnotatedSubNestedModel, AnnotatedCustomConfigNestedModel
|
||||
|
|
@ -315,6 +315,29 @@ class EagerModePostTrainingQuantTest(QuantizationTestCase):
|
|||
|
||||
checkQuantized(model)
|
||||
|
||||
def test_normalization(self):
|
||||
r"""
|
||||
Test quantization of normalization layers
|
||||
"""
|
||||
model = NormalizationTestModel()
|
||||
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
|
||||
prepare(model, inplace=True)
|
||||
self.checkObservers(model)
|
||||
test_only_eval_fn(model, self.calib_data)
|
||||
model = convert(model)
|
||||
|
||||
def checkQuantized(model):
|
||||
self.checkNoPrepModules(model.layer_norm)
|
||||
self.assertEqual(type(model.layer_norm), nnq.LayerNorm)
|
||||
test_only_eval_fn(model, self.calib_data)
|
||||
self.checkScriptable(model, self.calib_data)
|
||||
|
||||
checkQuantized(model)
|
||||
|
||||
model_oneline = quantize(
|
||||
NormalizationTestModel(), test_only_eval_fn, self.calib_data)
|
||||
checkQuantized(model)
|
||||
|
||||
@given(qengine=st.sampled_from(("qnnpack", "fbgemm")))
|
||||
def test_save_load_state_dict(self, qengine):
|
||||
r"""Test PTQ flow of creating a model and quantizing it and saving the quantized state_dict
|
||||
|
|
|
|||
|
|
@ -850,5 +850,36 @@ class ModuleAPITest(QuantizationTestCase):
|
|||
self.assertEqual(quant_ref.int_repr().numpy(), qy.int_repr().numpy(),
|
||||
message="BatchNorm3d module API failed")
|
||||
|
||||
def test_layer_norm(self):
|
||||
"""Tests the correctness of the layernorm module.
|
||||
The correctness is defined against the functional implementation.
|
||||
"""
|
||||
x_scale = 10.0 / 256
|
||||
x_zero_point = 0
|
||||
y_scale = 5.0 / 256
|
||||
y_zero_point = 127
|
||||
|
||||
dims = (1, 4, 8)
|
||||
|
||||
X = (torch.randn(dims, dtype=torch.float) - 0.5) * 10
|
||||
qX = torch.quantize_per_tensor(X, x_scale, x_zero_point, dtype=torch.quint8)
|
||||
dqX = qX.dequantize()
|
||||
|
||||
float_mod = torch.nn.LayerNorm(dqX.size()[1:])
|
||||
float_mod.weight = torch.nn.Parameter(torch.rand(*dims[1:]))
|
||||
float_mod.bias = torch.nn.Parameter(torch.rand(*dims[1:]))
|
||||
|
||||
dqY_ref = float_mod(dqX)
|
||||
qY_ref = torch.quantize_per_tensor(
|
||||
dqY_ref, y_scale, y_zero_point, dtype=torch.quint8)
|
||||
|
||||
quant_mod = nnq.LayerNorm(
|
||||
qX.size()[1:], float_mod.weight, float_mod.bias, y_scale, y_zero_point)
|
||||
qY = quant_mod(qX)
|
||||
|
||||
self.assertEqual(qY_ref.int_repr().numpy(), qY.int_repr().numpy(),
|
||||
message="LayerNorm module API failed, qY_ref\n{} vs qY\n{}"
|
||||
.format(qY_ref, qY))
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from torch.nn.modules.pooling import MaxPool2d
|
|||
|
||||
from .activation import ReLU, ReLU6, Hardswish
|
||||
from .batchnorm import BatchNorm2d, BatchNorm3d
|
||||
from .normalization import LayerNorm
|
||||
from .conv import Conv1d, Conv2d, Conv3d
|
||||
from .linear import Linear
|
||||
|
||||
|
|
@ -89,6 +90,7 @@ __all__ = [
|
|||
'ReLU',
|
||||
'ReLU6',
|
||||
'Hardswish',
|
||||
'LayerNorm',
|
||||
# Wrapper modules
|
||||
'FloatFunctional',
|
||||
'QFunctional',
|
||||
|
|
|
|||
91
torch/nn/quantized/modules/normalization.py
Normal file
91
torch/nn/quantized/modules/normalization.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import torch
|
||||
import torch.nn.quantized.functional
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
r"""Applies Layer Normalization over a mini-batch of inputs as described in
|
||||
the paper `Layer Normalization`_ .
|
||||
|
||||
.. math::
|
||||
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
||||
|
||||
The mean and standard-deviation are calculated separately over the last
|
||||
certain number dimensions which have to be of the shape specified by
|
||||
:attr:`normalized_shape`.
|
||||
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
|
||||
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
|
||||
|
||||
.. note::
|
||||
Unlike Batch Normalization and Instance Normalization, which applies
|
||||
scalar scale and bias for each entire channel/plane with the
|
||||
:attr:`affine` option, Layer Normalization applies per-element scale and
|
||||
bias with :attr:`elementwise_affine`.
|
||||
|
||||
This layer uses statistics computed from input data in both training and
|
||||
evaluation modes.
|
||||
|
||||
Args:
|
||||
normalized_shape (int or list or torch.Size): input shape from an expected input
|
||||
of size
|
||||
|
||||
.. math::
|
||||
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
|
||||
\times \ldots \times \text{normalized\_shape}[-1]]
|
||||
|
||||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
eps: a value added to the denominator for numerical stability. Default: 1e-5
|
||||
elementwise_affine: a boolean value that when set to ``True``, this module
|
||||
has learnable per-element affine parameters initialized to ones (for weights)
|
||||
and zeros (for biases). Default: ``True``.
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, *)`
|
||||
- Output: :math:`(N, *)` (same shape as input)
|
||||
|
||||
Examples::
|
||||
|
||||
>>> input = torch.randn(20, 5, 10, 10)
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = nn.LayerNorm(input.size()[1:])
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
|
||||
>>> # Normalize over last two dimensions
|
||||
>>> m = nn.LayerNorm([10, 10])
|
||||
>>> # Normalize over last dimension of size 10
|
||||
>>> m = nn.LayerNorm(10)
|
||||
>>> # Activating the module
|
||||
>>> output = m(input)
|
||||
|
||||
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape, weight, bias, scale, zero_point, eps=1e-5,
|
||||
elementwise_affine=True):
|
||||
super(LayerNorm, self).__init__(
|
||||
normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
self.scale = scale
|
||||
self.zero_point = zero_point
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.quantized.layer_norm(
|
||||
input, self.normalized_shape, weight=self.weight, bias=self.bias,
|
||||
eps=self.eps, output_scale=self.scale, output_zero_point=self.zero_point)
|
||||
|
||||
def _get_name(self):
|
||||
return 'QuantizedLayerNorm'
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod):
|
||||
activation_post_process = mod.activation_post_process
|
||||
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
||||
new_mod = cls(
|
||||
mod.normalized_shape, mod.weight, mod.bias, float(scale),
|
||||
int(zero_point), mod.eps, mod.elementwise_affine)
|
||||
return new_mod
|
||||
|
|
@ -21,6 +21,7 @@ DEFAULT_MODULE_MAPPING = {
|
|||
nn.Conv3d: nnq.Conv3d,
|
||||
nn.BatchNorm2d: nnq.BatchNorm2d,
|
||||
nn.BatchNorm3d: nnq.BatchNorm3d,
|
||||
nn.LayerNorm: nnq.LayerNorm,
|
||||
QuantStub: nnq.Quantize,
|
||||
DeQuantStub: nnq.DeQuantize,
|
||||
# Wrapper Modules:
|
||||
|
|
|
|||
|
|
@ -311,6 +311,19 @@ class LinearReluModel(torch.nn.Module):
|
|||
x = self.relu(self.fc(x))
|
||||
return x
|
||||
|
||||
class NormalizationTestModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(NormalizationTestModel, self).__init__()
|
||||
self.quant = torch.quantization.QuantStub()
|
||||
self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
|
||||
self.layer_norm = torch.nn.LayerNorm((8))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.quant(x)
|
||||
x = self.fc1(x)
|
||||
x = self.layer_norm(x)
|
||||
return x
|
||||
|
||||
class NestedModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(NestedModel, self).__init__()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user