Update weight tensor initialization in RMSNormalization (#166550)

Ensure a >1d tensor as weight for ORT compatibility.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166550
Approved by: https://github.com/titaiwangms
This commit is contained in:
Justin Chu 2025-10-31 14:29:27 +00:00 committed by PyTorch MergeBot
parent 5bcfdae71d
commit 160ab53dd5
3 changed files with 24 additions and 11 deletions

View File

@ -17,6 +17,8 @@ from torch.utils import _pytree as torch_pytree
class _WithExport:
def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgram:
if isinstance(model, torch.nn.Module):
model = model.eval()
onnx_program = torch.onnx.export(
model,
args,
@ -751,7 +753,7 @@ class DynamoExporterNewOpsetsTest(common_utils.TestCase, _WithExport):
x = torch.randn(2, 5, 3)
onnx_program = self.export(RMSNormModel(), (x,), opset_version=23)
onnx_testing.assert_onnx_program(onnx_program, backend="reference")
onnx_testing.assert_onnx_program(onnx_program)
# Test with multi-dimensional normalized_shape
class RMSNormModel2D(torch.nn.Module):
@ -760,7 +762,7 @@ class DynamoExporterNewOpsetsTest(common_utils.TestCase, _WithExport):
x = torch.randn(2, 5, 7, 3)
onnx_program = self.export(RMSNormModel2D(), (x,), opset_version=23)
onnx_testing.assert_onnx_program(onnx_program, backend="reference")
onnx_testing.assert_onnx_program(onnx_program)
def test_rms_norm_with_weight(self):
"""Test RMS normalization with weight parameter."""
@ -790,7 +792,7 @@ class DynamoExporterNewOpsetsTest(common_utils.TestCase, _WithExport):
onnx_program = self.export(RMSNormWithEps(), (x,), opset_version=23)
onnx_testing.assert_onnx_program(onnx_program, backend="reference")
onnx_testing.assert_onnx_program(onnx_program)
def test_enable_gqa_in_attention_23_with_dropout(self):
class Model(torch.nn.Module):

View File

@ -458,6 +458,18 @@ TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True),
TorchLibOpInfo("gelu_op20", nn_ops.aten_gelu_opset20, opset_introduced=20),
TorchLibOpInfo(
"nn.functional.group_norm", nn_ops.aten_group_norm, opset_introduced=21
).skip(
reason="ONNX Runtime does not support zero sized inputs for GroupNorm",
matcher=lambda sample: sample.input.numel() == 0,
),
TorchLibOpInfo(
"nn.functional.rms_norm", nn_ops.aten_rms_norm, opset_introduced=23
).skip(
reason="ONNX Runtime does not support <1d inputs or zero sized inputs for RMSNorm",
matcher=lambda sample: len(sample.input.shape) < 2 or sample.input.numel() == 0,
),
)

View File

@ -6,7 +6,7 @@
from __future__ import annotations
from typing import Optional, TYPE_CHECKING
from typing import Optional, Sequence, TYPE_CHECKING
from onnxscript.onnx_opset import ( # type: ignore[attr-defined]
opset20 as op20,
@ -25,9 +25,6 @@ if TYPE_CHECKING:
aten = torch.ops.aten
_INT64_MAX = 9223372036854775807
_INT64_MIN = -9223372036854775808
@onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20)
def aten_gelu_opset20(
@ -51,9 +48,9 @@ def aten_group_norm(
c = op21.Shape(input, start=1, end=2)
if weight is None:
weight = op21.ConstantOfShape(c, value=ir.tensor(1.0, dtype=input.dtype))
weight = op21.ConstantOfShape(c, value=ir.tensor([1.0], dtype=input.dtype))
if bias is None:
bias = op21.ConstantOfShape(c, value=ir.tensor(0.0, dtype=input.dtype))
bias = op21.ConstantOfShape(c, value=ir.tensor([0.0], dtype=input.dtype))
return op21.GroupNormalization(
input, weight, bias, epsilon=eps, num_groups=num_groups
)
@ -62,7 +59,7 @@ def aten_group_norm(
@onnx_impl(aten.rms_norm.default, trace_only=True, opset_introduced=23)
def aten_rms_norm(
input: TFloat,
normalized_shape: list[int],
normalized_shape: Sequence[int],
weight: Optional[TFloat] = None,
eps: Optional[float] = None,
) -> TFloat:
@ -81,7 +78,9 @@ def aten_rms_norm(
# Create weight tensor if not provided
if weight is None:
weight = op23.Constant(value=ir.tensor(1.0, dtype=input.dtype))
weight = op23.ConstantOfShape(
op23.Shape(input), value=ir.tensor([1], dtype=input.dtype)
)
return op23.RMSNormalization(input, weight, axis=axis, epsilon=eps)