Update weight tensor initialization in RMSNormalization (#166550)

Ensure a >1d tensor as weight for ORT compatibility.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166550
Approved by: https://github.com/titaiwangms
This commit is contained in:
Justin Chu 2025-10-31 14:29:27 +00:00 committed by PyTorch MergeBot
parent 5bcfdae71d
commit 160ab53dd5
3 changed files with 24 additions and 11 deletions

View File

@ -17,6 +17,8 @@ from torch.utils import _pytree as torch_pytree
class _WithExport: class _WithExport:
def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgram: def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgram:
if isinstance(model, torch.nn.Module):
model = model.eval()
onnx_program = torch.onnx.export( onnx_program = torch.onnx.export(
model, model,
args, args,
@ -751,7 +753,7 @@ class DynamoExporterNewOpsetsTest(common_utils.TestCase, _WithExport):
x = torch.randn(2, 5, 3) x = torch.randn(2, 5, 3)
onnx_program = self.export(RMSNormModel(), (x,), opset_version=23) onnx_program = self.export(RMSNormModel(), (x,), opset_version=23)
onnx_testing.assert_onnx_program(onnx_program, backend="reference") onnx_testing.assert_onnx_program(onnx_program)
# Test with multi-dimensional normalized_shape # Test with multi-dimensional normalized_shape
class RMSNormModel2D(torch.nn.Module): class RMSNormModel2D(torch.nn.Module):
@ -760,7 +762,7 @@ class DynamoExporterNewOpsetsTest(common_utils.TestCase, _WithExport):
x = torch.randn(2, 5, 7, 3) x = torch.randn(2, 5, 7, 3)
onnx_program = self.export(RMSNormModel2D(), (x,), opset_version=23) onnx_program = self.export(RMSNormModel2D(), (x,), opset_version=23)
onnx_testing.assert_onnx_program(onnx_program, backend="reference") onnx_testing.assert_onnx_program(onnx_program)
def test_rms_norm_with_weight(self): def test_rms_norm_with_weight(self):
"""Test RMS normalization with weight parameter.""" """Test RMS normalization with weight parameter."""
@ -790,7 +792,7 @@ class DynamoExporterNewOpsetsTest(common_utils.TestCase, _WithExport):
onnx_program = self.export(RMSNormWithEps(), (x,), opset_version=23) onnx_program = self.export(RMSNormWithEps(), (x,), opset_version=23)
onnx_testing.assert_onnx_program(onnx_program, backend="reference") onnx_testing.assert_onnx_program(onnx_program)
def test_enable_gqa_in_attention_23_with_dropout(self): def test_enable_gqa_in_attention_23_with_dropout(self):
class Model(torch.nn.Module): class Model(torch.nn.Module):

View File

@ -458,6 +458,18 @@ TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}), TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True), TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True),
TorchLibOpInfo("gelu_op20", nn_ops.aten_gelu_opset20, opset_introduced=20), TorchLibOpInfo("gelu_op20", nn_ops.aten_gelu_opset20, opset_introduced=20),
TorchLibOpInfo(
"nn.functional.group_norm", nn_ops.aten_group_norm, opset_introduced=21
).skip(
reason="ONNX Runtime does not support zero sized inputs for GroupNorm",
matcher=lambda sample: sample.input.numel() == 0,
),
TorchLibOpInfo(
"nn.functional.rms_norm", nn_ops.aten_rms_norm, opset_introduced=23
).skip(
reason="ONNX Runtime does not support <1d inputs or zero sized inputs for RMSNorm",
matcher=lambda sample: len(sample.input.shape) < 2 or sample.input.numel() == 0,
),
) )

View File

@ -6,7 +6,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional, TYPE_CHECKING from typing import Optional, Sequence, TYPE_CHECKING
from onnxscript.onnx_opset import ( # type: ignore[attr-defined] from onnxscript.onnx_opset import ( # type: ignore[attr-defined]
opset20 as op20, opset20 as op20,
@ -25,9 +25,6 @@ if TYPE_CHECKING:
aten = torch.ops.aten aten = torch.ops.aten
_INT64_MAX = 9223372036854775807
_INT64_MIN = -9223372036854775808
@onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20) @onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20)
def aten_gelu_opset20( def aten_gelu_opset20(
@ -51,9 +48,9 @@ def aten_group_norm(
c = op21.Shape(input, start=1, end=2) c = op21.Shape(input, start=1, end=2)
if weight is None: if weight is None:
weight = op21.ConstantOfShape(c, value=ir.tensor(1.0, dtype=input.dtype)) weight = op21.ConstantOfShape(c, value=ir.tensor([1.0], dtype=input.dtype))
if bias is None: if bias is None:
bias = op21.ConstantOfShape(c, value=ir.tensor(0.0, dtype=input.dtype)) bias = op21.ConstantOfShape(c, value=ir.tensor([0.0], dtype=input.dtype))
return op21.GroupNormalization( return op21.GroupNormalization(
input, weight, bias, epsilon=eps, num_groups=num_groups input, weight, bias, epsilon=eps, num_groups=num_groups
) )
@ -62,7 +59,7 @@ def aten_group_norm(
@onnx_impl(aten.rms_norm.default, trace_only=True, opset_introduced=23) @onnx_impl(aten.rms_norm.default, trace_only=True, opset_introduced=23)
def aten_rms_norm( def aten_rms_norm(
input: TFloat, input: TFloat,
normalized_shape: list[int], normalized_shape: Sequence[int],
weight: Optional[TFloat] = None, weight: Optional[TFloat] = None,
eps: Optional[float] = None, eps: Optional[float] = None,
) -> TFloat: ) -> TFloat:
@ -81,7 +78,9 @@ def aten_rms_norm(
# Create weight tensor if not provided # Create weight tensor if not provided
if weight is None: if weight is None:
weight = op23.Constant(value=ir.tensor(1.0, dtype=input.dtype)) weight = op23.ConstantOfShape(
op23.Shape(input), value=ir.tensor([1], dtype=input.dtype)
)
return op23.RMSNormalization(input, weight, axis=axis, epsilon=eps) return op23.RMSNormalization(input, weight, axis=axis, epsilon=eps)