Add RMSNorm module (#121364)

Similar to dbeed9724b/torchmultimodal/modules/layers/normalizations.py (L51)

**The implementation here is not optimized and we welcome pull requests to improve this**

- Use `normalized_shape` instead of singular integer `dim` to be aligned with the `nn.LayerNorm` implementation
- Remove the [upcast to float and downcast
](dbeed9724b/torchmultimodal/modules/layers/normalizations.py (L73))

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121364
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki 2024-03-25 12:58:08 -07:00 committed by PyTorch MergeBot
parent b693fff5d7
commit a7306de0dc
16 changed files with 307 additions and 5 deletions

View File

@ -226,6 +226,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
m.impl("reshape", native::reshape_symint);
OP_DECOMPOSE(resolve_conj);
OP_DECOMPOSE(resolve_neg);
OP_DECOMPOSE(rms_norm);
OP_DECOMPOSE(row_stack);
OP_DECOMPOSE(rrelu);
OP_DECOMPOSE(rrelu_);

View File

@ -2,6 +2,7 @@
#include <ATen/native/layer_norm.h>
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/native/cpu/mixed_data_type.h>
#include <c10/util/irange.h>
@ -18,6 +19,9 @@
#include <ATen/ops/native_layer_norm.h>
#include <ATen/ops/native_layer_norm_backward_native.h>
#include <ATen/ops/native_layer_norm_native.h>
#include <ATen/ops/pow.h>
#include <ATen/ops/rsqrt.h>
#include <ATen/ops/rms_norm.h>
#include <ATen/ops/zeros_like_native.h>
#endif
@ -258,4 +262,49 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
rstd = rstd.view(stat_shape);
return std::make_tuple(out, mean, rstd);
}
Tensor rms_norm(
const Tensor& input,
IntArrayRef normalized_shape,
const c10::optional<Tensor>& weight_opt /* optional */,
c10::optional<double> eps) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
auto bias_opt = at::optional<Tensor>();
const Tensor& bias = *at::borrow_from_optional_tensor(bias_opt);
(void) _check_layer_norm_inputs(input, normalized_shape, weight, bias);
std::vector<int64_t> dims_to_reduce;
for (const auto i : c10::irange(normalized_shape.size())) {
dims_to_reduce.push_back(input.dim() - i - 1);
}
IntArrayRef dims_to_reduce_ref = IntArrayRef(dims_to_reduce);
auto result = AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"rms_norm",
[&] {
scalar_t eps_val;
if (!eps.has_value()) {
eps_val = std::numeric_limits<at::scalar_value_type<scalar_t>::type>::epsilon();
} else {
eps_val = eps.value();
}
auto result = input.mul(at::rsqrt(at::pow(input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val)));
if (weight_opt.has_value()) {
result = result.mul(weight_opt.value());
}
return result;
});
return result;
}
} // namespace at::native

View File

@ -71,6 +71,12 @@ void layer_norm_cpu_out(
int64_t M,
int64_t N);
Tensor rms_norm(
const Tensor& input,
IntArrayRef normalized_shape,
const c10::optional<Tensor>& weight_opt /* optional */,
c10::optional<double> eps);
using forward_fn = void (*)(
const Tensor& /* X */,
const Tensor& /* gamma */,

View File

@ -3268,6 +3268,8 @@
autogen: native_layer_norm_backward.out
tags: core
- func: rms_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor
- func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
variants: function, method
dispatch:

View File

@ -108,6 +108,7 @@ Non-linear activation functions
instance_norm
layer_norm
local_response_norm
rms_norm
normalize
.. _Link 1: https://arxiv.org/abs/1611.00712

View File

@ -207,6 +207,7 @@ Normalization Layers
nn.LazyInstanceNorm3d
nn.LayerNorm
nn.LocalResponseNorm
nn.RMSNorm
Recurrent Layers
----------------
@ -527,6 +528,18 @@ Lazy Modules Initialization
nn.modules.lazy.LazyModuleMixin
Aliases
_______
The following are aliases to their counterparts in ``torch.nn``:
.. currentmodule:: torch
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
nn.modules.normalization.RMSNorm
.. This module needs to be documented. Adding here in the meantime
.. for tracking purposes

View File

@ -338,6 +338,7 @@ inductor_override_kwargs = {
("nn.functional.cosine_similarity", "cuda", f16): {"reference_in_float": True},
("nn.functional.instance_norm", "cuda", f16): {"reference_in_float": True},
("nn.functional.local_response_norm", "cuda", f16): {"reference_in_float": True},
("nn.functional.rms_norm", "cuda", f16): {"reference_in_float": True},
("nn.functional.soft_margin_loss", "cuda", f16): {"reference_in_float": True},
("nn.functional.softmin", "cuda", f16): {"atol": 1e-4, "rtol": 0.01},
("nn.functional.softsign", "cuda", f16): {"reference_in_float": True},

View File

@ -4287,6 +4287,7 @@ class TestFunctionalTracing(JitTestCase):
"fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
"fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
"layer_norm": ARG_TYPE_MISMATCH,
"rms_norm": ARG_TYPE_MISMATCH,
"lp_pool1d": ARG_TYPE_MISMATCH,
"affine_grid": CONTROL_FLOW,

View File

@ -1990,6 +1990,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch.resolve_conj",
"torch.resolve_neg",
"torch.result_type",
"torch.rms_norm",
"torch.rnn_relu_cell",
"torch.rnn_relu",
"torch.rnn_tanh_cell",

View File

@ -2574,6 +2574,21 @@ def layer_norm(
)
return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
def rms_norm(
input: Tensor,
normalized_shape: List[int],
weight: Optional[Tensor] = None,
eps: Optional[float] = None,
) -> Tensor:
r"""Apply Root Mean Square Layer Normalization.
See :class:`~torch.nn.RMSNorm` for details.
"""
if has_torch_function_variadic(input, weight):
return handle_torch_function(
rms_norm, (input, weight), input, normalized_shape, weight=weight, eps=eps
)
return torch.rms_norm(input, normalized_shape, weight, eps)
def group_norm(
input: Tensor, num_groups: int, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: float = 1e-5

View File

@ -299,6 +299,12 @@ def layer_norm(
bias: Optional[Tensor] = ...,
eps: float = ...,
) -> Tensor: ...
def rms_norm(
input: Tensor,
normalized_shape: Sequence[int],
weight: Optional[Tensor] = ...,
eps: Optional[float] = ...,
) -> Tensor: ...
def group_norm(
input: Tensor,
num_groups: int,

View File

@ -19,7 +19,7 @@ from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm, \
LazyBatchNorm1d, LazyBatchNorm2d, LazyBatchNorm3d
from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d, \
LazyInstanceNorm1d, LazyInstanceNorm2d, LazyInstanceNorm3d
from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm
from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm, RMSNorm
from .dropout import Dropout, Dropout1d, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout
from .padding import ReflectionPad1d, ReflectionPad2d, ReflectionPad3d, ReplicationPad1d, ReplicationPad2d, \
ReplicationPad3d, ZeroPad1d, ZeroPad2d, ZeroPad3d, ConstantPad1d, ConstantPad2d, ConstantPad3d, \
@ -49,7 +49,7 @@ __all__ = [
'ParameterList', 'ParameterDict', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d",
'LPPool1d', 'LPPool2d', 'LPPool3d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d',
'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'SyncBatchNorm',
'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'RMSNorm', 'SyncBatchNorm',
'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell',

View File

@ -7,9 +7,9 @@ from .. import functional as F
from .. import init
from torch import Tensor, Size
from typing import Union, List, Tuple
from typing import Union, List, Optional, Tuple
__all__ = ['LocalResponseNorm', 'CrossMapLRN2d', 'LayerNorm', 'GroupNorm']
__all__ = ['LocalResponseNorm', 'CrossMapLRN2d', 'LayerNorm', 'GroupNorm', 'RMSNorm']
class LocalResponseNorm(Module):
r"""Applies local response normalization over an input signal.
@ -292,6 +292,88 @@ class GroupNorm(Module):
'affine={affine}'.format(**self.__dict__)
class RMSNorm(Module):
r"""Applies Root Mean Square Layer Normalization over a mini-batch of inputs.
This layer implements the operation as described in
the paper `Root Mean Square Layer Normalization <https://arxiv.org/pdf/1910.07467.pdf>`__
.. math::
y = \frac{x}{\sqrt{\mathrm{RMS}[x] + \epsilon}} * \gamma
The root mean squared norm is taken over the last ``D`` dimensions, where ``D``
is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
is ``(3, 5)`` (a 2-dimensional shape), the rms norm is computed over
the last 2 dimensions of the input.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
\times \ldots \times \text{normalized\_shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: :func:`torch.finfo(x.dtype).eps`
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> rms_norm = nn.RMSNorm([2, 3])
>>> input = torch.randn(2, 2, 3)
>>> rms_norm(input)
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
normalized_shape: Tuple[int, ...]
eps: Optional[float]
elementwise_affine: bool
def __init__(self, normalized_shape: _shape_t, eps: Optional[float] = None, elementwise_affine: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter('weight', None)
self.reset_parameters()
def reset_parameters(self) -> None:
"""
Resets parameters based on their initialization used in __init__.
"""
if self.elementwise_affine:
init.ones_(self.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Runs forward pass.
"""
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
def extra_repr(self) -> str:
"""
Extra information about the module.
"""
return '{normalized_shape}, eps={eps}, ' \
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
# TODO: ContrastiveNorm2d
# TODO: DivisiveNorm2d
# TODO: SubtractiveNorm2d

View File

@ -910,6 +910,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nn.functional.prelu: lambda input, weight: -1,
torch.nn.functional.relu: lambda input, inplace=False: -1,
torch.nn.functional.relu6: lambda input, inplace=False: -1,
torch.nn.functional.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1,
torch.nn.functional.selu: lambda input, inplace=False: -1,
torch.nn.functional.silu: lambda input, inplace=False: -1,
@ -1008,6 +1009,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.renorm: lambda input, p, dim, maxnorm, out=None: -1,
torch.repeat_interleave: lambda input, dim=None: -1,
torch.reshape: lambda input, shape: -1,
torch.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,

View File

@ -35,7 +35,7 @@ from torch.testing._internal.common_cuda import (
from torch.testing._internal.common_utils import (
make_fullrank_matrices_with_distinct_singular_values,
TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY,
torch_to_numpy_dtype_dict, TEST_WITH_ASAN,
torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN,
GRADCHECK_NONDET_TOL, freeze_rng_state, slowTest, TEST_WITH_SLOW,
TEST_WITH_TORCHINDUCTOR
)
@ -4465,6 +4465,29 @@ def sample_inputs_native_layer_norm(opinfo, device, dtype, requires_grad, **kwar
args=(normalized_shape, None, None, eps),
)
def sample_inputs_rms_norm(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
# Ordered as input shape, normalized_shape and a kwarg dict for eps
cases: Tuple[Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment]
((1, 2, 3), (1, 2, 3), {'eps': 0.5}),
((2, 2, 3), (2, 3), {'eps': -0.5}),
((1,), (1,), {}),
((1, 2), (2,), {}),
((0, 1), (1,), {}),
)
for input_shape, normalized_shape, kwargs in cases:
# Shape of weight and bias should be the same as normalized_shape
weight = make_arg(normalized_shape)
yield SampleInput(
make_arg(input_shape),
args=(normalized_shape, weight),
kwargs=kwargs
)
# Without any optional args
yield SampleInput(make_arg((1, 2)), args=((2,),))
def error_inputs_group_norm(opinfo, device, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
@ -4509,6 +4532,31 @@ def error_inputs_native_layer_norm(opinfo, device, **kwargs):
)
yield ErrorInput(s4, error_regex=err_msg4)
def error_inputs_rms_norm(opinfo, device, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
input_shape = (1, 2, 3)
err_msg1 = "Expected normalized_shape to be at least 1-dimensional"
s1 = SampleInput(
make_arg(input_shape), args=(tuple(), None, 1e-5)
)
yield ErrorInput(s1, error_regex=err_msg1)
normalized_shape = (1, 2, 3)
weight = make_arg((1, 2))
err_msg2 = "Expected weight to be of same shape as normalized_shape"
s2 = SampleInput(
make_arg(input_shape), args=(normalized_shape, weight, 1e-5)
)
yield ErrorInput(s2, error_regex=err_msg2)
err_msg4 = "Given normalized_shape="
s4 = SampleInput(
make_arg((2, 2, 3)), args=((2, 2), None, 1e-5)
)
yield ErrorInput(s4, error_regex=err_msg4)
def sample_inputs_local_response_norm(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@ -10012,6 +10060,18 @@ def reference_native_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], w
return Y.reshape(*inp.shape), mean.reshape(stat_shape), (1.0 / np.sqrt(var + eps)).reshape(stat_shape)
def reference_rms_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, eps=None):
if eps is None:
eps = torch.finfo(numpy_to_torch_dtype(inp.dtype)).eps
feature_size = np.prod(normalized_shape)
inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload]
rms = np.sqrt((inp_view**2).mean(axis=-1, keepdims=True) + eps)
Y = inp_view / rms
if weight is not None:
Y = Y * weight.reshape(-1)
return Y.reshape(*inp.shape)
def reference_group_norm(inp: np.ndarray, num_groups: int, weight=None, bias=None, eps=1e-5):
inp_view = inp
if np.prod(inp.shape) != 0:
@ -13656,6 +13716,16 @@ op_db: List[OpInfo] = [
],
sample_inputs_func=sample_inputs_layer_norm,
supports_expanded_weight=True,),
OpInfo('nn.functional.rms_norm',
aten_name='rms_norm',
aliases=('rms_norm',),
ref=reference_rms_norm,
dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_rms_norm,
error_inputs_func=error_inputs_rms_norm,),
OpInfo('nn.functional.local_response_norm',
dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),

View File

@ -1928,6 +1928,55 @@ def module_inputs_torch_nn_LayerNorm(module_info, device, dtype, requires_grad,
desc='3d_elementwise_affine_no_bias'),
]
def module_inputs_torch_nn_RMSNorm(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
def rms_norm_reference_fn(m, p, i):
eps = m.eps
if eps is None:
eps = torch.finfo(i.dtype).eps
ndim = i.ndim
normalized_shape = m.normalized_shape
weight = m.weight
dims = [ndim - i - 1 for i in range(len(normalized_shape))]
result = i * torch.rsqrt(i.pow(2).mean(dim=dims, keepdim=True) + m.eps)
if weight is not None:
result *= weight
return result
return [
ModuleInput(
constructor_input=FunctionInput([5], 1e-3),
forward_input=FunctionInput(make_input((4, 5, 5))),
desc='1d_elementwise_affine',
reference_fn=rms_norm_reference_fn),
ModuleInput(
constructor_input=FunctionInput([5], 1e-3),
forward_input=FunctionInput(make_input((128, 5, 5))),
desc='1d_elementwise_affine_large_batch',
reference_fn=rms_norm_reference_fn),
ModuleInput(
constructor_input=FunctionInput([5], 1e-3, False),
forward_input=FunctionInput(make_input((4, 5, 5))),
desc='1d_no_elementwise_affine',
reference_fn=rms_norm_reference_fn),
ModuleInput(
constructor_input=FunctionInput([2, 2, 5], 1e-3),
forward_input=FunctionInput(make_input((4, 2, 2, 5))),
desc='3d_elementwise_affine',
reference_fn=rms_norm_reference_fn),
ModuleInput(
constructor_input=FunctionInput([2, 2, 5], 1e-3, False),
forward_input=FunctionInput(make_input((4, 2, 2, 5))),
desc='3d_no_elementwise_affine',
reference_fn=rms_norm_reference_fn),
ModuleInput(
constructor_input=FunctionInput([5], 1e-3),
forward_input=FunctionInput(make_input((0, 5))),
desc='1d_empty_elementwise_affine',
reference_fn=rms_norm_reference_fn),
]
def module_inputs_torch_nn_LocalResponseNorm(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@ -4071,6 +4120,9 @@ module_db: List[ModuleInfo] = [
# No channels_last support for LayerNorm currently.
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
),
ModuleInfo(torch.nn.RMSNorm,
module_inputs_func=module_inputs_torch_nn_RMSNorm,
),
# TransformerEncoder takes the same inputs as TransformerEncoderLayer
ModuleInfo(torch.nn.TransformerEncoder,
train_and_eval_differ=True,