mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add RMSNorm module (#121364)
Similar todbeed9724b/torchmultimodal/modules/layers/normalizations.py (L51)**The implementation here is not optimized and we welcome pull requests to improve this** - Use `normalized_shape` instead of singular integer `dim` to be aligned with the `nn.LayerNorm` implementation - Remove the [upcast to float and downcast ](dbeed9724b/torchmultimodal/modules/layers/normalizations.py (L73)) Pull Request resolved: https://github.com/pytorch/pytorch/pull/121364 Approved by: https://github.com/albanD
This commit is contained in:
parent
b693fff5d7
commit
a7306de0dc
|
|
@ -226,6 +226,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
|
|||
m.impl("reshape", native::reshape_symint);
|
||||
OP_DECOMPOSE(resolve_conj);
|
||||
OP_DECOMPOSE(resolve_neg);
|
||||
OP_DECOMPOSE(rms_norm);
|
||||
OP_DECOMPOSE(row_stack);
|
||||
OP_DECOMPOSE(rrelu);
|
||||
OP_DECOMPOSE(rrelu_);
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
#include <ATen/native/layer_norm.h>
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/native/cpu/mixed_data_type.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
|
@ -18,6 +19,9 @@
|
|||
#include <ATen/ops/native_layer_norm.h>
|
||||
#include <ATen/ops/native_layer_norm_backward_native.h>
|
||||
#include <ATen/ops/native_layer_norm_native.h>
|
||||
#include <ATen/ops/pow.h>
|
||||
#include <ATen/ops/rsqrt.h>
|
||||
#include <ATen/ops/rms_norm.h>
|
||||
#include <ATen/ops/zeros_like_native.h>
|
||||
#endif
|
||||
|
||||
|
|
@ -258,4 +262,49 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
|
|||
rstd = rstd.view(stat_shape);
|
||||
return std::make_tuple(out, mean, rstd);
|
||||
}
|
||||
|
||||
Tensor rms_norm(
|
||||
const Tensor& input,
|
||||
IntArrayRef normalized_shape,
|
||||
const c10::optional<Tensor>& weight_opt /* optional */,
|
||||
c10::optional<double> eps) {
|
||||
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
||||
const Tensor& weight = *weight_maybe_owned;
|
||||
auto bias_opt = at::optional<Tensor>();
|
||||
const Tensor& bias = *at::borrow_from_optional_tensor(bias_opt);
|
||||
(void) _check_layer_norm_inputs(input, normalized_shape, weight, bias);
|
||||
|
||||
std::vector<int64_t> dims_to_reduce;
|
||||
for (const auto i : c10::irange(normalized_shape.size())) {
|
||||
dims_to_reduce.push_back(input.dim() - i - 1);
|
||||
}
|
||||
IntArrayRef dims_to_reduce_ref = IntArrayRef(dims_to_reduce);
|
||||
|
||||
auto result = AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
input.scalar_type(),
|
||||
"rms_norm",
|
||||
[&] {
|
||||
scalar_t eps_val;
|
||||
if (!eps.has_value()) {
|
||||
eps_val = std::numeric_limits<at::scalar_value_type<scalar_t>::type>::epsilon();
|
||||
} else {
|
||||
eps_val = eps.value();
|
||||
}
|
||||
|
||||
auto result = input.mul(at::rsqrt(at::pow(input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val)));
|
||||
|
||||
if (weight_opt.has_value()) {
|
||||
result = result.mul(weight_opt.value());
|
||||
}
|
||||
|
||||
return result;
|
||||
});
|
||||
|
||||
return result;
|
||||
|
||||
}
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -71,6 +71,12 @@ void layer_norm_cpu_out(
|
|||
int64_t M,
|
||||
int64_t N);
|
||||
|
||||
Tensor rms_norm(
|
||||
const Tensor& input,
|
||||
IntArrayRef normalized_shape,
|
||||
const c10::optional<Tensor>& weight_opt /* optional */,
|
||||
c10::optional<double> eps);
|
||||
|
||||
using forward_fn = void (*)(
|
||||
const Tensor& /* X */,
|
||||
const Tensor& /* gamma */,
|
||||
|
|
|
|||
|
|
@ -3268,6 +3268,8 @@
|
|||
autogen: native_layer_norm_backward.out
|
||||
tags: core
|
||||
|
||||
- func: rms_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor
|
||||
|
||||
- func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
|
||||
variants: function, method
|
||||
dispatch:
|
||||
|
|
|
|||
|
|
@ -108,6 +108,7 @@ Non-linear activation functions
|
|||
instance_norm
|
||||
layer_norm
|
||||
local_response_norm
|
||||
rms_norm
|
||||
normalize
|
||||
|
||||
.. _Link 1: https://arxiv.org/abs/1611.00712
|
||||
|
|
|
|||
|
|
@ -207,6 +207,7 @@ Normalization Layers
|
|||
nn.LazyInstanceNorm3d
|
||||
nn.LayerNorm
|
||||
nn.LocalResponseNorm
|
||||
nn.RMSNorm
|
||||
|
||||
Recurrent Layers
|
||||
----------------
|
||||
|
|
@ -527,6 +528,18 @@ Lazy Modules Initialization
|
|||
|
||||
nn.modules.lazy.LazyModuleMixin
|
||||
|
||||
Aliases
|
||||
_______
|
||||
|
||||
The following are aliases to their counterparts in ``torch.nn``:
|
||||
|
||||
.. currentmodule:: torch
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
nn.modules.normalization.RMSNorm
|
||||
|
||||
.. This module needs to be documented. Adding here in the meantime
|
||||
.. for tracking purposes
|
||||
|
|
|
|||
|
|
@ -338,6 +338,7 @@ inductor_override_kwargs = {
|
|||
("nn.functional.cosine_similarity", "cuda", f16): {"reference_in_float": True},
|
||||
("nn.functional.instance_norm", "cuda", f16): {"reference_in_float": True},
|
||||
("nn.functional.local_response_norm", "cuda", f16): {"reference_in_float": True},
|
||||
("nn.functional.rms_norm", "cuda", f16): {"reference_in_float": True},
|
||||
("nn.functional.soft_margin_loss", "cuda", f16): {"reference_in_float": True},
|
||||
("nn.functional.softmin", "cuda", f16): {"atol": 1e-4, "rtol": 0.01},
|
||||
("nn.functional.softsign", "cuda", f16): {"reference_in_float": True},
|
||||
|
|
|
|||
|
|
@ -4287,6 +4287,7 @@ class TestFunctionalTracing(JitTestCase):
|
|||
"fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
|
||||
"fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
|
||||
"layer_norm": ARG_TYPE_MISMATCH,
|
||||
"rms_norm": ARG_TYPE_MISMATCH,
|
||||
"lp_pool1d": ARG_TYPE_MISMATCH,
|
||||
|
||||
"affine_grid": CONTROL_FLOW,
|
||||
|
|
|
|||
|
|
@ -1990,6 +1990,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
|||
"torch.resolve_conj",
|
||||
"torch.resolve_neg",
|
||||
"torch.result_type",
|
||||
"torch.rms_norm",
|
||||
"torch.rnn_relu_cell",
|
||||
"torch.rnn_relu",
|
||||
"torch.rnn_tanh_cell",
|
||||
|
|
|
|||
|
|
@ -2574,6 +2574,21 @@ def layer_norm(
|
|||
)
|
||||
return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
|
||||
|
||||
def rms_norm(
|
||||
input: Tensor,
|
||||
normalized_shape: List[int],
|
||||
weight: Optional[Tensor] = None,
|
||||
eps: Optional[float] = None,
|
||||
) -> Tensor:
|
||||
r"""Apply Root Mean Square Layer Normalization.
|
||||
|
||||
See :class:`~torch.nn.RMSNorm` for details.
|
||||
"""
|
||||
if has_torch_function_variadic(input, weight):
|
||||
return handle_torch_function(
|
||||
rms_norm, (input, weight), input, normalized_shape, weight=weight, eps=eps
|
||||
)
|
||||
return torch.rms_norm(input, normalized_shape, weight, eps)
|
||||
|
||||
def group_norm(
|
||||
input: Tensor, num_groups: int, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: float = 1e-5
|
||||
|
|
|
|||
|
|
@ -299,6 +299,12 @@ def layer_norm(
|
|||
bias: Optional[Tensor] = ...,
|
||||
eps: float = ...,
|
||||
) -> Tensor: ...
|
||||
def rms_norm(
|
||||
input: Tensor,
|
||||
normalized_shape: Sequence[int],
|
||||
weight: Optional[Tensor] = ...,
|
||||
eps: Optional[float] = ...,
|
||||
) -> Tensor: ...
|
||||
def group_norm(
|
||||
input: Tensor,
|
||||
num_groups: int,
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm, \
|
|||
LazyBatchNorm1d, LazyBatchNorm2d, LazyBatchNorm3d
|
||||
from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d, \
|
||||
LazyInstanceNorm1d, LazyInstanceNorm2d, LazyInstanceNorm3d
|
||||
from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm
|
||||
from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm, RMSNorm
|
||||
from .dropout import Dropout, Dropout1d, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout
|
||||
from .padding import ReflectionPad1d, ReflectionPad2d, ReflectionPad3d, ReplicationPad1d, ReplicationPad2d, \
|
||||
ReplicationPad3d, ZeroPad1d, ZeroPad2d, ZeroPad3d, ConstantPad1d, ConstantPad2d, ConstantPad3d, \
|
||||
|
|
@ -49,7 +49,7 @@ __all__ = [
|
|||
'ParameterList', 'ParameterDict', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
|
||||
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d",
|
||||
'LPPool1d', 'LPPool2d', 'LPPool3d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d',
|
||||
'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'SyncBatchNorm',
|
||||
'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'RMSNorm', 'SyncBatchNorm',
|
||||
'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
|
||||
'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
|
||||
'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell',
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ from .. import functional as F
|
|||
from .. import init
|
||||
|
||||
from torch import Tensor, Size
|
||||
from typing import Union, List, Tuple
|
||||
from typing import Union, List, Optional, Tuple
|
||||
|
||||
__all__ = ['LocalResponseNorm', 'CrossMapLRN2d', 'LayerNorm', 'GroupNorm']
|
||||
__all__ = ['LocalResponseNorm', 'CrossMapLRN2d', 'LayerNorm', 'GroupNorm', 'RMSNorm']
|
||||
|
||||
class LocalResponseNorm(Module):
|
||||
r"""Applies local response normalization over an input signal.
|
||||
|
|
@ -292,6 +292,88 @@ class GroupNorm(Module):
|
|||
'affine={affine}'.format(**self.__dict__)
|
||||
|
||||
|
||||
class RMSNorm(Module):
|
||||
r"""Applies Root Mean Square Layer Normalization over a mini-batch of inputs.
|
||||
|
||||
This layer implements the operation as described in
|
||||
the paper `Root Mean Square Layer Normalization <https://arxiv.org/pdf/1910.07467.pdf>`__
|
||||
|
||||
.. math::
|
||||
y = \frac{x}{\sqrt{\mathrm{RMS}[x] + \epsilon}} * \gamma
|
||||
|
||||
The root mean squared norm is taken over the last ``D`` dimensions, where ``D``
|
||||
is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
|
||||
is ``(3, 5)`` (a 2-dimensional shape), the rms norm is computed over
|
||||
the last 2 dimensions of the input.
|
||||
|
||||
Args:
|
||||
normalized_shape (int or list or torch.Size): input shape from an expected input
|
||||
of size
|
||||
|
||||
.. math::
|
||||
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
|
||||
\times \ldots \times \text{normalized\_shape}[-1]]
|
||||
|
||||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
eps: a value added to the denominator for numerical stability. Default: :func:`torch.finfo(x.dtype).eps`
|
||||
elementwise_affine: a boolean value that when set to ``True``, this module
|
||||
has learnable per-element affine parameters initialized to ones (for weights)
|
||||
and zeros (for biases). Default: ``True``.
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, *)`
|
||||
- Output: :math:`(N, *)` (same shape as input)
|
||||
|
||||
Examples::
|
||||
|
||||
>>> rms_norm = nn.RMSNorm([2, 3])
|
||||
>>> input = torch.randn(2, 2, 3)
|
||||
>>> rms_norm(input)
|
||||
|
||||
"""
|
||||
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
|
||||
normalized_shape: Tuple[int, ...]
|
||||
eps: Optional[float]
|
||||
elementwise_affine: bool
|
||||
|
||||
def __init__(self, normalized_shape: _shape_t, eps: Optional[float] = None, elementwise_affine: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
# mypy error: incompatible types in assignment
|
||||
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
||||
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
"""
|
||||
Resets parameters based on their initialization used in __init__.
|
||||
"""
|
||||
if self.elementwise_affine:
|
||||
init.ones_(self.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Runs forward pass.
|
||||
"""
|
||||
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""
|
||||
Extra information about the module.
|
||||
"""
|
||||
return '{normalized_shape}, eps={eps}, ' \
|
||||
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
|
||||
|
||||
|
||||
# TODO: ContrastiveNorm2d
|
||||
# TODO: DivisiveNorm2d
|
||||
# TODO: SubtractiveNorm2d
|
||||
|
|
|
|||
|
|
@ -910,6 +910,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.nn.functional.prelu: lambda input, weight: -1,
|
||||
torch.nn.functional.relu: lambda input, inplace=False: -1,
|
||||
torch.nn.functional.relu6: lambda input, inplace=False: -1,
|
||||
torch.nn.functional.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
|
||||
torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1,
|
||||
torch.nn.functional.selu: lambda input, inplace=False: -1,
|
||||
torch.nn.functional.silu: lambda input, inplace=False: -1,
|
||||
|
|
@ -1008,6 +1009,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.renorm: lambda input, p, dim, maxnorm, out=None: -1,
|
||||
torch.repeat_interleave: lambda input, dim=None: -1,
|
||||
torch.reshape: lambda input, shape: -1,
|
||||
torch.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
|
||||
torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
|
||||
torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
|
||||
torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ from torch.testing._internal.common_cuda import (
|
|||
from torch.testing._internal.common_utils import (
|
||||
make_fullrank_matrices_with_distinct_singular_values,
|
||||
TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY,
|
||||
torch_to_numpy_dtype_dict, TEST_WITH_ASAN,
|
||||
torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN,
|
||||
GRADCHECK_NONDET_TOL, freeze_rng_state, slowTest, TEST_WITH_SLOW,
|
||||
TEST_WITH_TORCHINDUCTOR
|
||||
)
|
||||
|
|
@ -4465,6 +4465,29 @@ def sample_inputs_native_layer_norm(opinfo, device, dtype, requires_grad, **kwar
|
|||
args=(normalized_shape, None, None, eps),
|
||||
)
|
||||
|
||||
def sample_inputs_rms_norm(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
# Ordered as input shape, normalized_shape and a kwarg dict for eps
|
||||
cases: Tuple[Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment]
|
||||
((1, 2, 3), (1, 2, 3), {'eps': 0.5}),
|
||||
((2, 2, 3), (2, 3), {'eps': -0.5}),
|
||||
((1,), (1,), {}),
|
||||
((1, 2), (2,), {}),
|
||||
((0, 1), (1,), {}),
|
||||
)
|
||||
|
||||
for input_shape, normalized_shape, kwargs in cases:
|
||||
# Shape of weight and bias should be the same as normalized_shape
|
||||
weight = make_arg(normalized_shape)
|
||||
yield SampleInput(
|
||||
make_arg(input_shape),
|
||||
args=(normalized_shape, weight),
|
||||
kwargs=kwargs
|
||||
)
|
||||
# Without any optional args
|
||||
yield SampleInput(make_arg((1, 2)), args=((2,),))
|
||||
|
||||
def error_inputs_group_norm(opinfo, device, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
|
||||
|
||||
|
|
@ -4509,6 +4532,31 @@ def error_inputs_native_layer_norm(opinfo, device, **kwargs):
|
|||
)
|
||||
yield ErrorInput(s4, error_regex=err_msg4)
|
||||
|
||||
def error_inputs_rms_norm(opinfo, device, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
|
||||
input_shape = (1, 2, 3)
|
||||
|
||||
err_msg1 = "Expected normalized_shape to be at least 1-dimensional"
|
||||
s1 = SampleInput(
|
||||
make_arg(input_shape), args=(tuple(), None, 1e-5)
|
||||
)
|
||||
yield ErrorInput(s1, error_regex=err_msg1)
|
||||
|
||||
normalized_shape = (1, 2, 3)
|
||||
weight = make_arg((1, 2))
|
||||
err_msg2 = "Expected weight to be of same shape as normalized_shape"
|
||||
s2 = SampleInput(
|
||||
make_arg(input_shape), args=(normalized_shape, weight, 1e-5)
|
||||
)
|
||||
yield ErrorInput(s2, error_regex=err_msg2)
|
||||
|
||||
|
||||
err_msg4 = "Given normalized_shape="
|
||||
s4 = SampleInput(
|
||||
make_arg((2, 2, 3)), args=((2, 2), None, 1e-5)
|
||||
)
|
||||
yield ErrorInput(s4, error_regex=err_msg4)
|
||||
|
||||
|
||||
def sample_inputs_local_response_norm(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
|
@ -10012,6 +10060,18 @@ def reference_native_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], w
|
|||
return Y.reshape(*inp.shape), mean.reshape(stat_shape), (1.0 / np.sqrt(var + eps)).reshape(stat_shape)
|
||||
|
||||
|
||||
def reference_rms_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, eps=None):
|
||||
if eps is None:
|
||||
eps = torch.finfo(numpy_to_torch_dtype(inp.dtype)).eps
|
||||
feature_size = np.prod(normalized_shape)
|
||||
inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload]
|
||||
rms = np.sqrt((inp_view**2).mean(axis=-1, keepdims=True) + eps)
|
||||
Y = inp_view / rms
|
||||
if weight is not None:
|
||||
Y = Y * weight.reshape(-1)
|
||||
return Y.reshape(*inp.shape)
|
||||
|
||||
|
||||
def reference_group_norm(inp: np.ndarray, num_groups: int, weight=None, bias=None, eps=1e-5):
|
||||
inp_view = inp
|
||||
if np.prod(inp.shape) != 0:
|
||||
|
|
@ -13656,6 +13716,16 @@ op_db: List[OpInfo] = [
|
|||
],
|
||||
sample_inputs_func=sample_inputs_layer_norm,
|
||||
supports_expanded_weight=True,),
|
||||
OpInfo('nn.functional.rms_norm',
|
||||
aten_name='rms_norm',
|
||||
aliases=('rms_norm',),
|
||||
ref=reference_rms_norm,
|
||||
dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
sample_inputs_func=sample_inputs_rms_norm,
|
||||
error_inputs_func=error_inputs_rms_norm,),
|
||||
OpInfo('nn.functional.local_response_norm',
|
||||
dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
|
|
|
|||
|
|
@ -1928,6 +1928,55 @@ def module_inputs_torch_nn_LayerNorm(module_info, device, dtype, requires_grad,
|
|||
desc='3d_elementwise_affine_no_bias'),
|
||||
]
|
||||
|
||||
def module_inputs_torch_nn_RMSNorm(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
def rms_norm_reference_fn(m, p, i):
|
||||
eps = m.eps
|
||||
if eps is None:
|
||||
eps = torch.finfo(i.dtype).eps
|
||||
ndim = i.ndim
|
||||
normalized_shape = m.normalized_shape
|
||||
weight = m.weight
|
||||
dims = [ndim - i - 1 for i in range(len(normalized_shape))]
|
||||
result = i * torch.rsqrt(i.pow(2).mean(dim=dims, keepdim=True) + m.eps)
|
||||
if weight is not None:
|
||||
result *= weight
|
||||
return result
|
||||
|
||||
return [
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput([5], 1e-3),
|
||||
forward_input=FunctionInput(make_input((4, 5, 5))),
|
||||
desc='1d_elementwise_affine',
|
||||
reference_fn=rms_norm_reference_fn),
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput([5], 1e-3),
|
||||
forward_input=FunctionInput(make_input((128, 5, 5))),
|
||||
desc='1d_elementwise_affine_large_batch',
|
||||
reference_fn=rms_norm_reference_fn),
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput([5], 1e-3, False),
|
||||
forward_input=FunctionInput(make_input((4, 5, 5))),
|
||||
desc='1d_no_elementwise_affine',
|
||||
reference_fn=rms_norm_reference_fn),
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput([2, 2, 5], 1e-3),
|
||||
forward_input=FunctionInput(make_input((4, 2, 2, 5))),
|
||||
desc='3d_elementwise_affine',
|
||||
reference_fn=rms_norm_reference_fn),
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput([2, 2, 5], 1e-3, False),
|
||||
forward_input=FunctionInput(make_input((4, 2, 2, 5))),
|
||||
desc='3d_no_elementwise_affine',
|
||||
reference_fn=rms_norm_reference_fn),
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput([5], 1e-3),
|
||||
forward_input=FunctionInput(make_input((0, 5))),
|
||||
desc='1d_empty_elementwise_affine',
|
||||
reference_fn=rms_norm_reference_fn),
|
||||
]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_LocalResponseNorm(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
|
@ -4071,6 +4120,9 @@ module_db: List[ModuleInfo] = [
|
|||
# No channels_last support for LayerNorm currently.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
|
||||
),
|
||||
ModuleInfo(torch.nn.RMSNorm,
|
||||
module_inputs_func=module_inputs_torch_nn_RMSNorm,
|
||||
),
|
||||
# TransformerEncoder takes the same inputs as TransformerEncoderLayer
|
||||
ModuleInfo(torch.nn.TransformerEncoder,
|
||||
train_and_eval_differ=True,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user