Commit Graph

6 Commits

Author SHA1 Message Date
Justin Chu
c5701d0ab5 [ONNX] Create fake implementations for onnx ops; fix boolean mask in attention (#165780)
Previously we rely on the concreate implementation to generate fake implementation. This makes the fake implementation overly complicated and breaks in some cases when there are dynamic shapes.

This PR updates onnx op registration to instead take a dedicated fake implementation.

**Also fixed: When boolean mask is supplied to torch sdpa, it was previously taken the negation, which is incorrect.**

Fix https://github.com/pytorch/pytorch/issues/164909 Also taken changes from https://github.com/pytorch/pytorch/pull/156635

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165780
Approved by: https://github.com/titaiwangms
2025-10-29 04:51:49 +00:00
Yuanyuan Chen
315ffdc1e4 [4/N] Apply ruff UP035 rule to python code (#164206)
Follows #164104

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164206
Approved by: https://github.com/albanD
2025-10-01 19:05:53 +00:00
Justin Chu
fdf68fa5d7 [ONNX] Fix rotary_embedding_23 implementation (#162865)
The implementation of rotary_embedding_23 when input is 3D was incorrect.

## Tested

Locally with

```py
import onnx_ir as ir
import onnx
import torch
import os
import numpy as np

base_path = "/home/justinchu/dev/onnx/onnx/backend/test/data/node"
test_names = [
    "test_rotary_embedding",
    "test_rotary_embedding_3d_input",
    "test_rotary_embedding_interleaved",
    "test_rotary_embedding_no_position_ids",
    "test_rotary_embedding_no_position_ids_interleaved",
    "test_rotary_embedding_no_position_ids_rotary_dim",
    "test_rotary_embedding_with_interleaved_rotary_dim",
    "test_rotary_embedding_with_rotary_dim",
]
model_paths = [os.path.join(base_path, name) for name in test_names]

for path in model_paths:
    print(f"Checking {path} for issues...")

    model = onnx.load(os.path.join(path, "model.onnx"))
    input0 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_0.pb"))
    ).numpy()
    input1 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_1.pb"))
    ).numpy()
    input2 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_2.pb"))
    ).numpy()
    if os.path.exists(os.path.join(path, "test_data_set_0", "input_3.pb")):
        input3 = ir.from_proto(
            onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_3.pb"))
        ).numpy()
    else:
        input3 = None
    output0 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "output_0.pb"))
    ).numpy()

    m = ir.from_proto(model)

    node = m.graph[-1]
    print(node)
    assert node.op_type == "RotaryEmbedding"

    interleaved = node.attributes.get_int("interleaved", 0)
    num_heads = node.attributes.get_int("num_heads", 0)
    rotary_embedding_dim = node.attributes.get_int("rotary_embedding_dim", 0)

    torch_out = torch.onnx.ops.rotary_embedding(
        torch.tensor(input0),
        torch.tensor(input1),
        torch.tensor(input2),
        position_ids=torch.tensor(input3) if input3 is not None else None,
        interleaved=bool(interleaved),
        num_heads=num_heads,
        rotary_embedding_dim=rotary_embedding_dim,
    )
    torch_out = torch_out.detach().cpu().numpy()
    np.testing.assert_allclose(torch_out, output0)
```

Fix https://github.com/pytorch/pytorch/issues/162848

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162865
Approved by: https://github.com/kunal-vaishnavi, https://github.com/titaiwangms
2025-09-16 03:30:05 +00:00
Vinayak Pawar
9ad5e8edb1 Improve typing of ONNX decorators with ParamSpec (#162332)
## Summary
This PR improves typing in ONNX-related modules by replacing TypeVar bound to Callable[..., Any] with ParamSpec to preserve parameter types and avoid type erasure in decorator functions.

## Changes
- `torch/onnx/_internal/exporter/_flags.py`: Replace TCallable TypeVar with ParamSpec
- `torch/onnx/ops/_impl.py`: Replace _T TypeVar with ParamSpec for _onnx_op decorator
- `torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py`: Replace _T TypeVar with ParamSpec

## Motivation
The previous implementation used TypeVar bound to Callable which erased parameter type information to Any. ParamSpec preserves the exact parameter types and return types, providing better type safety and IDE support.

## Testing
- Verified all changes compile and import correctly
- Created comprehensive test suite to validate ParamSpec functionality
- No linting errors introduced
- Maintains backward compatibility

Fixes #142306
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162332
Approved by: https://github.com/Skylion007
2025-09-07 18:06:03 +00:00
Justin Chu
fbbab794ef [ONNX] Implement Attention-23 (#156431)
Implement Attention-23 using sdpa and flexattention.

- I used copilot for this.
- Also updated the conversion logic to remove trailing None inputs.

@gramalingam @kunal-vaishnavi @titaiwangms
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156431
Approved by: https://github.com/titaiwangms

Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-06-20 23:54:57 +00:00
Justin Chu
3e57de1251 [ONNX] Create support for rotary embeddings (#154745)
This PR registers the RotaryEmbedding op in the `torch.ops.onnx` name spaces and allows the exporter to recognize and export onnx operators.

## Design

ONNX operators of their respective opset version is implemented in torch/onnx/ops/_impl.py, and are registered in the torch.ops.onnx namespace following the following rule:

`OpType-version => torch.ops.onnx.OpType.opset{version}`

For example, `RotaryEmbedding-23` becomes `torch.ops.onnx.RotaryEmbedding.opset23`

This name is parsed by the exporter to create an onnx node in the graph without having to go through translation.

When users use the ops in the model, we provide more convenient, unversioned functions under `torch.onnx.ops` that will dispatch to the implementations based on user input (type and provided attributes). For example, users can directly call `torch.onnx.ops.rotary_embedding()` to use the op natively in their pytorch models. I chose snake case naming to make the functions more pythonic and aligned with other torch apis.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154745
Approved by: https://github.com/titaiwangms
2025-06-04 03:07:43 +00:00