diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index f8c0d431536..3e064d6c39d 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -226,6 +226,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { m.impl("reshape", native::reshape_symint); OP_DECOMPOSE(resolve_conj); OP_DECOMPOSE(resolve_neg); + OP_DECOMPOSE(rms_norm); OP_DECOMPOSE(row_stack); OP_DECOMPOSE(rrelu); OP_DECOMPOSE(rrelu_); diff --git a/aten/src/ATen/native/layer_norm.cpp b/aten/src/ATen/native/layer_norm.cpp index 88d53da856d..27a701dd2eb 100644 --- a/aten/src/ATen/native/layer_norm.cpp +++ b/aten/src/ATen/native/layer_norm.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -18,6 +19,9 @@ #include #include #include +#include +#include +#include #include #endif @@ -258,4 +262,49 @@ std::tuple math_native_layer_norm( rstd = rstd.view(stat_shape); return std::make_tuple(out, mean, rstd); } + +Tensor rms_norm( + const Tensor& input, + IntArrayRef normalized_shape, + const c10::optional& weight_opt /* optional */, + c10::optional eps) { + + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + auto bias_opt = at::optional(); + const Tensor& bias = *at::borrow_from_optional_tensor(bias_opt); + (void) _check_layer_norm_inputs(input, normalized_shape, weight, bias); + + std::vector dims_to_reduce; + for (const auto i : c10::irange(normalized_shape.size())) { + dims_to_reduce.push_back(input.dim() - i - 1); + } + IntArrayRef dims_to_reduce_ref = IntArrayRef(dims_to_reduce); + + auto result = AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "rms_norm", + [&] { + scalar_t eps_val; + if (!eps.has_value()) { + eps_val = std::numeric_limits::type>::epsilon(); + } else { + eps_val = eps.value(); + } + + auto result = input.mul(at::rsqrt(at::pow(input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val))); + + if (weight_opt.has_value()) { + result = result.mul(weight_opt.value()); + } + + return result; + }); + + return result; + +} } // namespace at::native diff --git a/aten/src/ATen/native/layer_norm.h b/aten/src/ATen/native/layer_norm.h index 13fb1e4783d..38e63569586 100644 --- a/aten/src/ATen/native/layer_norm.h +++ b/aten/src/ATen/native/layer_norm.h @@ -71,6 +71,12 @@ void layer_norm_cpu_out( int64_t M, int64_t N); +Tensor rms_norm( + const Tensor& input, + IntArrayRef normalized_shape, + const c10::optional& weight_opt /* optional */, + c10::optional eps); + using forward_fn = void (*)( const Tensor& /* X */, const Tensor& /* gamma */, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 454faaec7aa..89a2d8eb827 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3268,6 +3268,8 @@ autogen: native_layer_norm_backward.out tags: core +- func: rms_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor + - func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor variants: function, method dispatch: diff --git a/docs/source/nn.functional.rst b/docs/source/nn.functional.rst index a2241cb4764..9d2ea0eef5e 100644 --- a/docs/source/nn.functional.rst +++ b/docs/source/nn.functional.rst @@ -108,6 +108,7 @@ Non-linear activation functions instance_norm layer_norm local_response_norm + rms_norm normalize .. _Link 1: https://arxiv.org/abs/1611.00712 diff --git a/docs/source/nn.rst b/docs/source/nn.rst index e14cc0ac443..18c5b3850c0 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -207,6 +207,7 @@ Normalization Layers nn.LazyInstanceNorm3d nn.LayerNorm nn.LocalResponseNorm + nn.RMSNorm Recurrent Layers ---------------- @@ -527,6 +528,18 @@ Lazy Modules Initialization nn.modules.lazy.LazyModuleMixin +Aliases +_______ + +The following are aliases to their counterparts in ``torch.nn``: + +.. currentmodule:: torch +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + nn.modules.normalization.RMSNorm .. This module needs to be documented. Adding here in the meantime .. for tracking purposes diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 3a2e742efc3..39aaa8a12ff 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -338,6 +338,7 @@ inductor_override_kwargs = { ("nn.functional.cosine_similarity", "cuda", f16): {"reference_in_float": True}, ("nn.functional.instance_norm", "cuda", f16): {"reference_in_float": True}, ("nn.functional.local_response_norm", "cuda", f16): {"reference_in_float": True}, + ("nn.functional.rms_norm", "cuda", f16): {"reference_in_float": True}, ("nn.functional.soft_margin_loss", "cuda", f16): {"reference_in_float": True}, ("nn.functional.softmin", "cuda", f16): {"atol": 1e-4, "rtol": 0.01}, ("nn.functional.softsign", "cuda", f16): {"reference_in_float": True}, diff --git a/test/test_fx.py b/test/test_fx.py index 961759e76ba..18f842d6704 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -4287,6 +4287,7 @@ class TestFunctionalTracing(JitTestCase): "fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH, "fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH, "layer_norm": ARG_TYPE_MISMATCH, + "rms_norm": ARG_TYPE_MISMATCH, "lp_pool1d": ARG_TYPE_MISMATCH, "affine_grid": CONTROL_FLOW, diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index f786b9ae28b..0c62e2e95e6 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1990,6 +1990,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys( "torch.resolve_conj", "torch.resolve_neg", "torch.result_type", + "torch.rms_norm", "torch.rnn_relu_cell", "torch.rnn_relu", "torch.rnn_tanh_cell", diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 09fa7b8f37a..bb68b397325 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2574,6 +2574,21 @@ def layer_norm( ) return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled) +def rms_norm( + input: Tensor, + normalized_shape: List[int], + weight: Optional[Tensor] = None, + eps: Optional[float] = None, +) -> Tensor: + r"""Apply Root Mean Square Layer Normalization. + + See :class:`~torch.nn.RMSNorm` for details. + """ + if has_torch_function_variadic(input, weight): + return handle_torch_function( + rms_norm, (input, weight), input, normalized_shape, weight=weight, eps=eps + ) + return torch.rms_norm(input, normalized_shape, weight, eps) def group_norm( input: Tensor, num_groups: int, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: float = 1e-5 diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index b3878c215ea..5bb847a0a72 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -299,6 +299,12 @@ def layer_norm( bias: Optional[Tensor] = ..., eps: float = ..., ) -> Tensor: ... +def rms_norm( + input: Tensor, + normalized_shape: Sequence[int], + weight: Optional[Tensor] = ..., + eps: Optional[float] = ..., +) -> Tensor: ... def group_norm( input: Tensor, num_groups: int, diff --git a/torch/nn/modules/__init__.py b/torch/nn/modules/__init__.py index 67916b3ae75..403d0d547e2 100644 --- a/torch/nn/modules/__init__.py +++ b/torch/nn/modules/__init__.py @@ -19,7 +19,7 @@ from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm, \ LazyBatchNorm1d, LazyBatchNorm2d, LazyBatchNorm3d from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d, \ LazyInstanceNorm1d, LazyInstanceNorm2d, LazyInstanceNorm3d -from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm +from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm, RMSNorm from .dropout import Dropout, Dropout1d, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout from .padding import ReflectionPad1d, ReflectionPad2d, ReflectionPad3d, ReplicationPad1d, ReplicationPad2d, \ ReplicationPad3d, ZeroPad1d, ZeroPad2d, ZeroPad3d, ConstantPad1d, ConstantPad2d, ConstantPad3d, \ @@ -49,7 +49,7 @@ __all__ = [ 'ParameterList', 'ParameterDict', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d", 'LPPool1d', 'LPPool2d', 'LPPool3d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', - 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'SyncBatchNorm', + 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'RMSNorm', 'SyncBatchNorm', 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout', 'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d', 'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index 6502ec2a471..97c9c307c5d 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -7,9 +7,9 @@ from .. import functional as F from .. import init from torch import Tensor, Size -from typing import Union, List, Tuple +from typing import Union, List, Optional, Tuple -__all__ = ['LocalResponseNorm', 'CrossMapLRN2d', 'LayerNorm', 'GroupNorm'] +__all__ = ['LocalResponseNorm', 'CrossMapLRN2d', 'LayerNorm', 'GroupNorm', 'RMSNorm'] class LocalResponseNorm(Module): r"""Applies local response normalization over an input signal. @@ -292,6 +292,88 @@ class GroupNorm(Module): 'affine={affine}'.format(**self.__dict__) +class RMSNorm(Module): + r"""Applies Root Mean Square Layer Normalization over a mini-batch of inputs. + + This layer implements the operation as described in + the paper `Root Mean Square Layer Normalization `__ + + .. math:: + y = \frac{x}{\sqrt{\mathrm{RMS}[x] + \epsilon}} * \gamma + + The root mean squared norm is taken over the last ``D`` dimensions, where ``D`` + is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape` + is ``(3, 5)`` (a 2-dimensional shape), the rms norm is computed over + the last 2 dimensions of the input. + + Args: + normalized_shape (int or list or torch.Size): input shape from an expected input + of size + + .. math:: + [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] + \times \ldots \times \text{normalized\_shape}[-1]] + + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability. Default: :func:`torch.finfo(x.dtype).eps` + elementwise_affine: a boolean value that when set to ``True``, this module + has learnable per-element affine parameters initialized to ones (for weights) + and zeros (for biases). Default: ``True``. + + Shape: + - Input: :math:`(N, *)` + - Output: :math:`(N, *)` (same shape as input) + + Examples:: + + >>> rms_norm = nn.RMSNorm([2, 3]) + >>> input = torch.randn(2, 2, 3) + >>> rms_norm(input) + + """ + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + normalized_shape: Tuple[int, ...] + eps: Optional[float] + elementwise_affine: bool + + def __init__(self, normalized_shape: _shape_t, eps: Optional[float] = None, elementwise_affine: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + else: + self.register_parameter('weight', None) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + if self.elementwise_affine: + init.ones_(self.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Runs forward pass. + """ + return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) + + def extra_repr(self) -> str: + """ + Extra information about the module. + """ + return '{normalized_shape}, eps={eps}, ' \ + 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) + + # TODO: ContrastiveNorm2d # TODO: DivisiveNorm2d # TODO: SubtractiveNorm2d diff --git a/torch/overrides.py b/torch/overrides.py index 8076802e48a..a7a783cd39e 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -910,6 +910,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.nn.functional.prelu: lambda input, weight: -1, torch.nn.functional.relu: lambda input, inplace=False: -1, torch.nn.functional.relu6: lambda input, inplace=False: -1, + torch.nn.functional.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1, torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1, torch.nn.functional.selu: lambda input, inplace=False: -1, torch.nn.functional.silu: lambda input, inplace=False: -1, @@ -1008,6 +1009,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.renorm: lambda input, p, dim, maxnorm, out=None: -1, torch.repeat_interleave: lambda input, dim=None: -1, torch.reshape: lambda input, shape: -1, + torch.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1, torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index e47e4c620c9..7196577f8b0 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -35,7 +35,7 @@ from torch.testing._internal.common_cuda import ( from torch.testing._internal.common_utils import ( make_fullrank_matrices_with_distinct_singular_values, TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY, - torch_to_numpy_dtype_dict, TEST_WITH_ASAN, + torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN, GRADCHECK_NONDET_TOL, freeze_rng_state, slowTest, TEST_WITH_SLOW, TEST_WITH_TORCHINDUCTOR ) @@ -4465,6 +4465,29 @@ def sample_inputs_native_layer_norm(opinfo, device, dtype, requires_grad, **kwar args=(normalized_shape, None, None, eps), ) +def sample_inputs_rms_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, normalized_shape and a kwarg dict for eps + cases: Tuple[Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment] + ((1, 2, 3), (1, 2, 3), {'eps': 0.5}), + ((2, 2, 3), (2, 3), {'eps': -0.5}), + ((1,), (1,), {}), + ((1, 2), (2,), {}), + ((0, 1), (1,), {}), + ) + + for input_shape, normalized_shape, kwargs in cases: + # Shape of weight and bias should be the same as normalized_shape + weight = make_arg(normalized_shape) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, weight), + kwargs=kwargs + ) + # Without any optional args + yield SampleInput(make_arg((1, 2)), args=((2,),)) + def error_inputs_group_norm(opinfo, device, **kwargs): make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) @@ -4509,6 +4532,31 @@ def error_inputs_native_layer_norm(opinfo, device, **kwargs): ) yield ErrorInput(s4, error_regex=err_msg4) +def error_inputs_rms_norm(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) + input_shape = (1, 2, 3) + + err_msg1 = "Expected normalized_shape to be at least 1-dimensional" + s1 = SampleInput( + make_arg(input_shape), args=(tuple(), None, 1e-5) + ) + yield ErrorInput(s1, error_regex=err_msg1) + + normalized_shape = (1, 2, 3) + weight = make_arg((1, 2)) + err_msg2 = "Expected weight to be of same shape as normalized_shape" + s2 = SampleInput( + make_arg(input_shape), args=(normalized_shape, weight, 1e-5) + ) + yield ErrorInput(s2, error_regex=err_msg2) + + + err_msg4 = "Given normalized_shape=" + s4 = SampleInput( + make_arg((2, 2, 3)), args=((2, 2), None, 1e-5) + ) + yield ErrorInput(s4, error_regex=err_msg4) + def sample_inputs_local_response_norm(opinfo, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -10012,6 +10060,18 @@ def reference_native_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], w return Y.reshape(*inp.shape), mean.reshape(stat_shape), (1.0 / np.sqrt(var + eps)).reshape(stat_shape) +def reference_rms_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, eps=None): + if eps is None: + eps = torch.finfo(numpy_to_torch_dtype(inp.dtype)).eps + feature_size = np.prod(normalized_shape) + inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload] + rms = np.sqrt((inp_view**2).mean(axis=-1, keepdims=True) + eps) + Y = inp_view / rms + if weight is not None: + Y = Y * weight.reshape(-1) + return Y.reshape(*inp.shape) + + def reference_group_norm(inp: np.ndarray, num_groups: int, weight=None, bias=None, eps=1e-5): inp_view = inp if np.prod(inp.shape) != 0: @@ -13656,6 +13716,16 @@ op_db: List[OpInfo] = [ ], sample_inputs_func=sample_inputs_layer_norm, supports_expanded_weight=True,), + OpInfo('nn.functional.rms_norm', + aten_name='rms_norm', + aliases=('rms_norm',), + ref=reference_rms_norm, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_rms_norm, + error_inputs_func=error_inputs_rms_norm,), OpInfo('nn.functional.local_response_norm', dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index a5b756c3736..97a0c97066b 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -1928,6 +1928,55 @@ def module_inputs_torch_nn_LayerNorm(module_info, device, dtype, requires_grad, desc='3d_elementwise_affine_no_bias'), ] +def module_inputs_torch_nn_RMSNorm(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def rms_norm_reference_fn(m, p, i): + eps = m.eps + if eps is None: + eps = torch.finfo(i.dtype).eps + ndim = i.ndim + normalized_shape = m.normalized_shape + weight = m.weight + dims = [ndim - i - 1 for i in range(len(normalized_shape))] + result = i * torch.rsqrt(i.pow(2).mean(dim=dims, keepdim=True) + m.eps) + if weight is not None: + result *= weight + return result + + return [ + ModuleInput( + constructor_input=FunctionInput([5], 1e-3), + forward_input=FunctionInput(make_input((4, 5, 5))), + desc='1d_elementwise_affine', + reference_fn=rms_norm_reference_fn), + ModuleInput( + constructor_input=FunctionInput([5], 1e-3), + forward_input=FunctionInput(make_input((128, 5, 5))), + desc='1d_elementwise_affine_large_batch', + reference_fn=rms_norm_reference_fn), + ModuleInput( + constructor_input=FunctionInput([5], 1e-3, False), + forward_input=FunctionInput(make_input((4, 5, 5))), + desc='1d_no_elementwise_affine', + reference_fn=rms_norm_reference_fn), + ModuleInput( + constructor_input=FunctionInput([2, 2, 5], 1e-3), + forward_input=FunctionInput(make_input((4, 2, 2, 5))), + desc='3d_elementwise_affine', + reference_fn=rms_norm_reference_fn), + ModuleInput( + constructor_input=FunctionInput([2, 2, 5], 1e-3, False), + forward_input=FunctionInput(make_input((4, 2, 2, 5))), + desc='3d_no_elementwise_affine', + reference_fn=rms_norm_reference_fn), + ModuleInput( + constructor_input=FunctionInput([5], 1e-3), + forward_input=FunctionInput(make_input((0, 5))), + desc='1d_empty_elementwise_affine', + reference_fn=rms_norm_reference_fn), + ] + def module_inputs_torch_nn_LocalResponseNorm(module_info, device, dtype, requires_grad, training, **kwargs): make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -4071,6 +4120,9 @@ module_db: List[ModuleInfo] = [ # No channels_last support for LayerNorm currently. DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) ), + ModuleInfo(torch.nn.RMSNorm, + module_inputs_func=module_inputs_torch_nn_RMSNorm, + ), # TransformerEncoder takes the same inputs as TransformerEncoderLayer ModuleInfo(torch.nn.TransformerEncoder, train_and_eval_differ=True,