mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Update weight tensor initialization in RMSNormalization (#166550)
Ensure a >1d tensor as weight for ORT compatibility. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166550 Approved by: https://github.com/titaiwangms
This commit is contained in:
parent
5bcfdae71d
commit
160ab53dd5
|
|
@ -17,6 +17,8 @@ from torch.utils import _pytree as torch_pytree
|
|||
|
||||
class _WithExport:
|
||||
def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgram:
|
||||
if isinstance(model, torch.nn.Module):
|
||||
model = model.eval()
|
||||
onnx_program = torch.onnx.export(
|
||||
model,
|
||||
args,
|
||||
|
|
@ -751,7 +753,7 @@ class DynamoExporterNewOpsetsTest(common_utils.TestCase, _WithExport):
|
|||
|
||||
x = torch.randn(2, 5, 3)
|
||||
onnx_program = self.export(RMSNormModel(), (x,), opset_version=23)
|
||||
onnx_testing.assert_onnx_program(onnx_program, backend="reference")
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
# Test with multi-dimensional normalized_shape
|
||||
class RMSNormModel2D(torch.nn.Module):
|
||||
|
|
@ -760,7 +762,7 @@ class DynamoExporterNewOpsetsTest(common_utils.TestCase, _WithExport):
|
|||
|
||||
x = torch.randn(2, 5, 7, 3)
|
||||
onnx_program = self.export(RMSNormModel2D(), (x,), opset_version=23)
|
||||
onnx_testing.assert_onnx_program(onnx_program, backend="reference")
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
def test_rms_norm_with_weight(self):
|
||||
"""Test RMS normalization with weight parameter."""
|
||||
|
|
@ -790,7 +792,7 @@ class DynamoExporterNewOpsetsTest(common_utils.TestCase, _WithExport):
|
|||
|
||||
onnx_program = self.export(RMSNormWithEps(), (x,), opset_version=23)
|
||||
|
||||
onnx_testing.assert_onnx_program(onnx_program, backend="reference")
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
def test_enable_gqa_in_attention_23_with_dropout(self):
|
||||
class Model(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -458,6 +458,18 @@ TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (
|
|||
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
|
||||
TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True),
|
||||
TorchLibOpInfo("gelu_op20", nn_ops.aten_gelu_opset20, opset_introduced=20),
|
||||
TorchLibOpInfo(
|
||||
"nn.functional.group_norm", nn_ops.aten_group_norm, opset_introduced=21
|
||||
).skip(
|
||||
reason="ONNX Runtime does not support zero sized inputs for GroupNorm",
|
||||
matcher=lambda sample: sample.input.numel() == 0,
|
||||
),
|
||||
TorchLibOpInfo(
|
||||
"nn.functional.rms_norm", nn_ops.aten_rms_norm, opset_introduced=23
|
||||
).skip(
|
||||
reason="ONNX Runtime does not support <1d inputs or zero sized inputs for RMSNorm",
|
||||
matcher=lambda sample: len(sample.input.shape) < 2 or sample.input.numel() == 0,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import Optional, Sequence, TYPE_CHECKING
|
||||
|
||||
from onnxscript.onnx_opset import ( # type: ignore[attr-defined]
|
||||
opset20 as op20,
|
||||
|
|
@ -25,9 +25,6 @@ if TYPE_CHECKING:
|
|||
|
||||
aten = torch.ops.aten
|
||||
|
||||
_INT64_MAX = 9223372036854775807
|
||||
_INT64_MIN = -9223372036854775808
|
||||
|
||||
|
||||
@onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20)
|
||||
def aten_gelu_opset20(
|
||||
|
|
@ -51,9 +48,9 @@ def aten_group_norm(
|
|||
|
||||
c = op21.Shape(input, start=1, end=2)
|
||||
if weight is None:
|
||||
weight = op21.ConstantOfShape(c, value=ir.tensor(1.0, dtype=input.dtype))
|
||||
weight = op21.ConstantOfShape(c, value=ir.tensor([1.0], dtype=input.dtype))
|
||||
if bias is None:
|
||||
bias = op21.ConstantOfShape(c, value=ir.tensor(0.0, dtype=input.dtype))
|
||||
bias = op21.ConstantOfShape(c, value=ir.tensor([0.0], dtype=input.dtype))
|
||||
return op21.GroupNormalization(
|
||||
input, weight, bias, epsilon=eps, num_groups=num_groups
|
||||
)
|
||||
|
|
@ -62,7 +59,7 @@ def aten_group_norm(
|
|||
@onnx_impl(aten.rms_norm.default, trace_only=True, opset_introduced=23)
|
||||
def aten_rms_norm(
|
||||
input: TFloat,
|
||||
normalized_shape: list[int],
|
||||
normalized_shape: Sequence[int],
|
||||
weight: Optional[TFloat] = None,
|
||||
eps: Optional[float] = None,
|
||||
) -> TFloat:
|
||||
|
|
@ -81,7 +78,9 @@ def aten_rms_norm(
|
|||
|
||||
# Create weight tensor if not provided
|
||||
if weight is None:
|
||||
weight = op23.Constant(value=ir.tensor(1.0, dtype=input.dtype))
|
||||
weight = op23.ConstantOfShape(
|
||||
op23.Shape(input), value=ir.tensor([1], dtype=input.dtype)
|
||||
)
|
||||
|
||||
return op23.RMSNormalization(input, weight, axis=axis, epsilon=eps)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user