pytorch/test/onnx
Justin Chu fdf68fa5d7 [ONNX] Fix rotary_embedding_23 implementation (#162865)
The implementation of rotary_embedding_23 when input is 3D was incorrect.

## Tested

Locally with

```py
import onnx_ir as ir
import onnx
import torch
import os
import numpy as np

base_path = "/home/justinchu/dev/onnx/onnx/backend/test/data/node"
test_names = [
    "test_rotary_embedding",
    "test_rotary_embedding_3d_input",
    "test_rotary_embedding_interleaved",
    "test_rotary_embedding_no_position_ids",
    "test_rotary_embedding_no_position_ids_interleaved",
    "test_rotary_embedding_no_position_ids_rotary_dim",
    "test_rotary_embedding_with_interleaved_rotary_dim",
    "test_rotary_embedding_with_rotary_dim",
]
model_paths = [os.path.join(base_path, name) for name in test_names]

for path in model_paths:
    print(f"Checking {path} for issues...")

    model = onnx.load(os.path.join(path, "model.onnx"))
    input0 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_0.pb"))
    ).numpy()
    input1 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_1.pb"))
    ).numpy()
    input2 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_2.pb"))
    ).numpy()
    if os.path.exists(os.path.join(path, "test_data_set_0", "input_3.pb")):
        input3 = ir.from_proto(
            onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_3.pb"))
        ).numpy()
    else:
        input3 = None
    output0 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "output_0.pb"))
    ).numpy()

    m = ir.from_proto(model)

    node = m.graph[-1]
    print(node)
    assert node.op_type == "RotaryEmbedding"

    interleaved = node.attributes.get_int("interleaved", 0)
    num_heads = node.attributes.get_int("num_heads", 0)
    rotary_embedding_dim = node.attributes.get_int("rotary_embedding_dim", 0)

    torch_out = torch.onnx.ops.rotary_embedding(
        torch.tensor(input0),
        torch.tensor(input1),
        torch.tensor(input2),
        position_ids=torch.tensor(input3) if input3 is not None else None,
        interleaved=bool(interleaved),
        num_heads=num_heads,
        rotary_embedding_dim=rotary_embedding_dim,
    )
    torch_out = torch_out.detach().cpu().numpy()
    np.testing.assert_allclose(torch_out, output0)
```

Fix https://github.com/pytorch/pytorch/issues/162848

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162865
Approved by: https://github.com/kunal-vaishnavi, https://github.com/titaiwangms
2025-09-16 03:30:05 +00:00
..
assets
expect
exporter [ONNX] Support enable_gqa when dropout is non-zero (#162771) 2025-09-12 04:00:57 +00:00
internal [ONNX] Refactor torchscript based exporter (#161323) 2025-09-02 16:10:30 +00:00
model_defs PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
ops [ONNX] Fix rotary_embedding_23 implementation (#162865) 2025-09-16 03:30:05 +00:00
torchlib [BE]: ruff PLC0207 - use maxsplit kwarg (#160107) 2025-08-08 03:14:59 +00:00
autograd_helper.py
onnx_test_common.py [ONNX] Refactor torchscript based exporter (#161323) 2025-09-02 16:10:30 +00:00
pytorch_test_common.py [ONNX] Clean up the diagnostics module (#149864) 2025-03-26 05:58:32 +00:00
test_autograd_funs.py [ONNX] Refactor torchscript based exporter (#161323) 2025-09-02 16:10:30 +00:00
test_custom_ops.py
test_fx_type_promotion.py
test_lazy_import.py
test_models_onnxruntime.py PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
test_models_quantized_onnxruntime.py Add __main__ guards to tests (#154716) 2025-06-04 14:38:13 +00:00
test_models.py
test_onnx_opset.py [ONNX] Default to dynamo export (#159646) 2025-09-02 22:45:55 +00:00
test_onnxscript_no_runtime.py [ONNX] Default to dynamo export (#159646) 2025-09-02 22:45:55 +00:00
test_onnxscript_runtime.py [ONNX] Refactor torchscript based exporter (#161323) 2025-09-02 16:10:30 +00:00
test_op_consistency.py PEP585 update - test (#145176) 2025-01-22 04:48:28 +00:00
test_pytorch_jit_onnx.py [ONNX] Refactor torchscript based exporter (#161323) 2025-09-02 16:10:30 +00:00
test_pytorch_onnx_onnxruntime_cuda.py
test_pytorch_onnx_onnxruntime.py [ONNX] Default to dynamo export (#159646) 2025-09-02 22:45:55 +00:00
test_pytorch_onnx_shape_inference.py [ONNX] Default to dynamo export (#159646) 2025-09-02 22:45:55 +00:00
test_symbolic_helper.py [ONNX] Refactor torchscript based exporter (#161323) 2025-09-02 16:10:30 +00:00
test_utility_funs.py [ONNX] Default to dynamo export (#159646) 2025-09-02 22:45:55 +00:00
verify.py Update ruff linter for PEP585 (#147540) 2025-02-22 04:45:17 +00:00