mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[torch.onnx] support torch.nn.functional.grid_sample
summary - Adds `F.grid_sample` support - Adds a test case Fixes #27212 Pull Request resolved: https://github.com/pytorch/pytorch/pull/76159 Approved by: https://github.com/justinchuby, https://github.com/BowenBao
This commit is contained in:
parent
e14f5336c0
commit
0ae3aa648e
|
|
@ -81,7 +81,7 @@ fi
|
|||
|
||||
if [[ "${SHARD_NUMBER}" == "2" ]]; then
|
||||
# Update the loop for new opsets
|
||||
for i in $(seq 10 15); do
|
||||
for i in $(seq 10 16); do
|
||||
pytest "${args[@]}" \
|
||||
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset$i"
|
||||
done
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from torch.nn import Module
|
|||
import onnx
|
||||
|
||||
import io
|
||||
import itertools
|
||||
|
||||
from torch.onnx.symbolic_helper import _export_onnx_opset_version
|
||||
from torch.onnx import producer_name, producer_version
|
||||
|
|
@ -369,6 +370,30 @@ class TestONNXOpset(TestCase):
|
|||
x = torch.randn(20, 16, 50)
|
||||
check_onnx_opsets_operator(MyDynamicModel(), x, ops, opset_versions=[9, 10])
|
||||
|
||||
def test_grid_sample(self):
|
||||
n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4
|
||||
ops = {16: [{"op_name": "GridSample"}]}
|
||||
|
||||
class MyModule(Module):
|
||||
def forward(self, x, grid, mode, padding_mode, align_corers):
|
||||
return torch.nn.functional.grid_sample(x, grid, mode, padding_mode, align_corners)
|
||||
|
||||
for mode, padding_mode, align_corners in itertools.product(
|
||||
("bilinear", "nearest", "bicubic"),
|
||||
("zeros", "border", "reflection"),
|
||||
(True, False),
|
||||
):
|
||||
|
||||
args = (
|
||||
torch.randn(n, c, h_in, w_in), # x
|
||||
torch.randn(n, h_out, w_out, 2), # grid,
|
||||
mode,
|
||||
padding_mode,
|
||||
align_corners,
|
||||
)
|
||||
check_onnx_opsets_operator(MyModule(), args, ops, opset_versions=[16], training=torch.onnx.TrainingMode.TRAINING)
|
||||
check_onnx_opsets_operator(MyModule(), args, ops, opset_versions=[16], training=torch.onnx.TrainingMode.EVAL)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -9183,6 +9183,7 @@ class _TestONNXRuntime:
|
|||
dynamic_axes={"size": [0, 1]},
|
||||
test_with_inputs=[(boxes, size), (boxes, size_2)])
|
||||
|
||||
@skipIfUnsupportedMaxOpsetVersion(15) # TODO: Opset 16 RoiAlign result mismatch
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_roi_align(self):
|
||||
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
|
||||
|
|
@ -9190,6 +9191,7 @@ class _TestONNXRuntime:
|
|||
model = ops.RoIAlign((5, 5), 1., 2)
|
||||
self.run_test(model, (x, single_roi))
|
||||
|
||||
@skipIfUnsupportedMaxOpsetVersion(15) # TODO: Opset 16 RoiAlign result mismatch
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_roi_align_aligned(self):
|
||||
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
|
||||
|
|
@ -9295,6 +9297,7 @@ class _TestONNXRuntime:
|
|||
test_with_inputs=[(images, features), (images2, test_features)],
|
||||
dict_check=False)
|
||||
|
||||
@skipIfUnsupportedMaxOpsetVersion(15) # TODO: Opset 16 RoiAlign result mismatch
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
@disableScriptTest()
|
||||
def test_multi_scale_roi_align(self):
|
||||
|
|
@ -10986,6 +10989,33 @@ class _TestONNXRuntime:
|
|||
self.run_test(Module(False), x, rtol=1e-3, atol=1e-6)
|
||||
self.run_test(Module(True), x, rtol=1e-3, atol=1e-6)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(16)
|
||||
def test_grid_sample(self):
|
||||
n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4
|
||||
|
||||
class GridSampleModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, mode, padding_mode, align_corners) -> None:
|
||||
super().__init__()
|
||||
self.mode, self.padding_mode, self.align_corners = mode, padding_mode, align_corners
|
||||
|
||||
def forward(self, input, grid):
|
||||
return torch.nn.functional.grid_sample(input, grid, self.mode, self.padding_mode, self.align_corners)
|
||||
|
||||
for mode, padding_mode, align_corners in itertools.product(
|
||||
("bilinear", "nearest", "bicubic"),
|
||||
("zeros", "border", "reflection"),
|
||||
(True, False),
|
||||
):
|
||||
atol_rtol = {}
|
||||
if (mode, padding_mode) == ("bicubic", "border"):
|
||||
if align_corners:
|
||||
atol_rtol.update({"atol": 0.3, "rtol": 0.4})
|
||||
else:
|
||||
atol_rtol.update({"atol": 0.02, "rtol": 0.02})
|
||||
input, grid = torch.randn(n, c, h_in, w_in), torch.randn(n, h_out, w_out, 2)
|
||||
self.run_test(GridSampleModule(mode, padding_mode, align_corners), (input, grid), **atol_rtol)
|
||||
|
||||
|
||||
def make_test(name, base, layer, bidirectional, initial_state,
|
||||
variable_length, dropout, script_test_min_opset_version,
|
||||
|
|
@ -11123,6 +11153,8 @@ TestONNXRuntime_opset14 = MakeTestCase(14, keep_initializers_as_inputs=False)
|
|||
|
||||
TestONNXRuntime_opset15 = MakeTestCase(15, keep_initializers_as_inputs=False)
|
||||
|
||||
TestONNXRuntime_opset16 = MakeTestCase(16, keep_initializers_as_inputs=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ static const int OPSET_VERSION_12 = 12;
|
|||
static const int OPSET_VERSION_13 = 13;
|
||||
static const int OPSET_VERSION_14 = 14;
|
||||
static const int OPSET_VERSION_15 = 15;
|
||||
static const int OPSET_VERSION_16 = 16;
|
||||
|
||||
using ValueToParamPairMap = std::map<Value*, std::pair<std::string, IValue>>;
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ namespace onnx = ::ONNX_NAMESPACE;
|
|||
const static int kInvalidOpsetVersion = -1;
|
||||
// Based on OP_SET_ID_VERSION_MAP in
|
||||
// https://github.com/onnx/onnx/blob/master/onnx/helper.py.
|
||||
constexpr static std::array<int64_t, 16> kOpsetVersionToIRVersion = {
|
||||
constexpr static std::array<int64_t, 17> kOpsetVersionToIRVersion = {
|
||||
kInvalidOpsetVersion,
|
||||
3,
|
||||
kInvalidOpsetVersion,
|
||||
|
|
@ -75,6 +75,7 @@ constexpr static std::array<int64_t, 16> kOpsetVersionToIRVersion = {
|
|||
7,
|
||||
7,
|
||||
7,
|
||||
8,
|
||||
8};
|
||||
|
||||
std::string getNodeStackTraceString(const Node* n) {
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
|
|||
|
||||
opset_version (int, default 13): The version of the
|
||||
`default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
|
||||
to target. Must be >= 7 and <= 15.
|
||||
to target. Must be >= 7 and <= 16.
|
||||
do_constant_folding (bool, default True): Apply the constant-folding optimization.
|
||||
Constant-folding will replace some of the ops that have all constant inputs
|
||||
with pre-computed constant nodes.
|
||||
|
|
|
|||
|
|
@ -1034,7 +1034,7 @@ def args_have_same_dtype(args):
|
|||
return has_same_dtype
|
||||
|
||||
_default_onnx_opset_version = 13
|
||||
_onnx_main_opset = 15
|
||||
_onnx_main_opset = 16
|
||||
_onnx_stable_opsets = list(range(7, _onnx_main_opset))
|
||||
_export_onnx_opset_version = _default_onnx_opset_version
|
||||
_constant_folding_opset_versions = list(range(9, _onnx_main_opset + 1))
|
||||
|
|
|
|||
46
torch/onnx/symbolic_opset16.py
Normal file
46
torch/onnx/symbolic_opset16.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
|
||||
# This file exports ONNX ops for opset 16
|
||||
|
||||
# Note [ONNX Operators that are added/updated in opset 16]
|
||||
#
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
|
||||
# New operators:
|
||||
# GridSample https://github.com/onnx/onnx/pull/3557
|
||||
#
|
||||
# Updated operators:
|
||||
# Identity
|
||||
# If
|
||||
# LeakyRelu
|
||||
# Loop
|
||||
# PRelu
|
||||
# RoiAlign
|
||||
# Scan
|
||||
# ScatterElemenets
|
||||
# ScatterND
|
||||
# Where
|
||||
# GreaterOrEqual
|
||||
# LessOrEqual
|
||||
# SequenceMap
|
||||
|
||||
from torch.onnx.symbolic_helper import parse_args
|
||||
|
||||
from torch.nn.functional import GRID_SAMPLE_INTERPOLATION_MODES, GRID_SAMPLE_PADDING_MODES
|
||||
|
||||
|
||||
# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
|
||||
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
|
||||
@parse_args("v", "v", "i", "i", "b")
|
||||
def grid_sampler(g, input, grid, mode_enum, padding_mode_enum, align_corners):
|
||||
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
|
||||
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg]
|
||||
return g.op(
|
||||
"GridSample",
|
||||
input,
|
||||
grid,
|
||||
align_corners_i=int(align_corners),
|
||||
mode_s=mode_s,
|
||||
padding_mode_s=padding_mode_s,
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user