mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Cleaning up onnx module imports to prepare for updating `__init__`.
- Simplify importing the `_C` and `_C._onnx` name spaces
- Remove alias of the symbolic_helper module in imports
- Remove any module level function imports. Import modules instead
- Alias `symbilic_opsetx` as `opsetx`
- Fix some docstrings
Requires:
- https://github.com/pytorch/pytorch/pull/77448
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77423
Approved by: https://github.com/BowenBao
50 lines
1.4 KiB
Python
50 lines
1.4 KiB
Python
"""This file exports ONNX ops for opset 16.
|
|
|
|
Note [ONNX Operators that are added/updated in opset 16]
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
|
|
New operators:
|
|
GridSample https://github.com/onnx/onnx/pull/3557
|
|
|
|
Updated operators:
|
|
Identity
|
|
If
|
|
LeakyRelu
|
|
Loop
|
|
PRelu
|
|
RoiAlign
|
|
Scan
|
|
ScatterElemenets
|
|
ScatterND
|
|
Where
|
|
GreaterOrEqual
|
|
LessOrEqual
|
|
SequenceMap
|
|
"""
|
|
|
|
# EDITING THIS FILE? READ THIS FIRST!
|
|
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
|
|
|
from torch.nn.functional import (
|
|
GRID_SAMPLE_INTERPOLATION_MODES,
|
|
GRID_SAMPLE_PADDING_MODES,
|
|
)
|
|
from torch.onnx import symbolic_helper
|
|
|
|
|
|
# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
|
|
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
|
|
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
|
|
def grid_sampler(g, input, grid, mode_enum, padding_mode_enum, align_corners):
|
|
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
|
|
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg]
|
|
return g.op(
|
|
"GridSample",
|
|
input,
|
|
grid,
|
|
align_corners_i=int(align_corners),
|
|
mode_s=mode_s,
|
|
padding_mode_s=padding_mode_s,
|
|
)
|