[torch.onnx] support torch.nn.functional.grid_sample

summary

- Adds `F.grid_sample` support
- Adds a test case

Fixes #27212
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76159
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
This commit is contained in:
Masaki Kozuki 2022-05-02 22:07:56 +00:00 committed by PyTorch MergeBot
parent e14f5336c0
commit 0ae3aa648e
8 changed files with 109 additions and 4 deletions

View File

@ -81,7 +81,7 @@ fi
if [[ "${SHARD_NUMBER}" == "2" ]]; then
# Update the loop for new opsets
for i in $(seq 10 15); do
for i in $(seq 10 16); do
pytest "${args[@]}" \
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset$i"
done

View File

@ -9,6 +9,7 @@ from torch.nn import Module
import onnx
import io
import itertools
from torch.onnx.symbolic_helper import _export_onnx_opset_version
from torch.onnx import producer_name, producer_version
@ -369,6 +370,30 @@ class TestONNXOpset(TestCase):
x = torch.randn(20, 16, 50)
check_onnx_opsets_operator(MyDynamicModel(), x, ops, opset_versions=[9, 10])
def test_grid_sample(self):
n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4
ops = {16: [{"op_name": "GridSample"}]}
class MyModule(Module):
def forward(self, x, grid, mode, padding_mode, align_corers):
return torch.nn.functional.grid_sample(x, grid, mode, padding_mode, align_corners)
for mode, padding_mode, align_corners in itertools.product(
("bilinear", "nearest", "bicubic"),
("zeros", "border", "reflection"),
(True, False),
):
args = (
torch.randn(n, c, h_in, w_in), # x
torch.randn(n, h_out, w_out, 2), # grid,
mode,
padding_mode,
align_corners,
)
check_onnx_opsets_operator(MyModule(), args, ops, opset_versions=[16], training=torch.onnx.TrainingMode.TRAINING)
check_onnx_opsets_operator(MyModule(), args, ops, opset_versions=[16], training=torch.onnx.TrainingMode.EVAL)
if __name__ == "__main__":
run_tests()

View File

@ -9183,6 +9183,7 @@ class _TestONNXRuntime:
dynamic_axes={"size": [0, 1]},
test_with_inputs=[(boxes, size), (boxes, size_2)])
@skipIfUnsupportedMaxOpsetVersion(15) # TODO: Opset 16 RoiAlign result mismatch
@skipIfUnsupportedMinOpsetVersion(11)
def test_roi_align(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
@ -9190,6 +9191,7 @@ class _TestONNXRuntime:
model = ops.RoIAlign((5, 5), 1., 2)
self.run_test(model, (x, single_roi))
@skipIfUnsupportedMaxOpsetVersion(15) # TODO: Opset 16 RoiAlign result mismatch
@skipIfUnsupportedMinOpsetVersion(11)
def test_roi_align_aligned(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
@ -9295,6 +9297,7 @@ class _TestONNXRuntime:
test_with_inputs=[(images, features), (images2, test_features)],
dict_check=False)
@skipIfUnsupportedMaxOpsetVersion(15) # TODO: Opset 16 RoiAlign result mismatch
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
def test_multi_scale_roi_align(self):
@ -10986,6 +10989,33 @@ class _TestONNXRuntime:
self.run_test(Module(False), x, rtol=1e-3, atol=1e-6)
self.run_test(Module(True), x, rtol=1e-3, atol=1e-6)
@skipIfUnsupportedMinOpsetVersion(16)
def test_grid_sample(self):
n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4
class GridSampleModule(torch.nn.Module):
def __init__(self, mode, padding_mode, align_corners) -> None:
super().__init__()
self.mode, self.padding_mode, self.align_corners = mode, padding_mode, align_corners
def forward(self, input, grid):
return torch.nn.functional.grid_sample(input, grid, self.mode, self.padding_mode, self.align_corners)
for mode, padding_mode, align_corners in itertools.product(
("bilinear", "nearest", "bicubic"),
("zeros", "border", "reflection"),
(True, False),
):
atol_rtol = {}
if (mode, padding_mode) == ("bicubic", "border"):
if align_corners:
atol_rtol.update({"atol": 0.3, "rtol": 0.4})
else:
atol_rtol.update({"atol": 0.02, "rtol": 0.02})
input, grid = torch.randn(n, c, h_in, w_in), torch.randn(n, h_out, w_out, 2)
self.run_test(GridSampleModule(mode, padding_mode, align_corners), (input, grid), **atol_rtol)
def make_test(name, base, layer, bidirectional, initial_state,
variable_length, dropout, script_test_min_opset_version,
@ -11123,6 +11153,8 @@ TestONNXRuntime_opset14 = MakeTestCase(14, keep_initializers_as_inputs=False)
TestONNXRuntime_opset15 = MakeTestCase(15, keep_initializers_as_inputs=False)
TestONNXRuntime_opset16 = MakeTestCase(16, keep_initializers_as_inputs=False)
if __name__ == "__main__":
unittest.main()

View File

@ -16,6 +16,7 @@ static const int OPSET_VERSION_12 = 12;
static const int OPSET_VERSION_13 = 13;
static const int OPSET_VERSION_14 = 14;
static const int OPSET_VERSION_15 = 15;
static const int OPSET_VERSION_16 = 16;
using ValueToParamPairMap = std::map<Value*, std::pair<std::string, IValue>>;

View File

@ -59,7 +59,7 @@ namespace onnx = ::ONNX_NAMESPACE;
const static int kInvalidOpsetVersion = -1;
// Based on OP_SET_ID_VERSION_MAP in
// https://github.com/onnx/onnx/blob/master/onnx/helper.py.
constexpr static std::array<int64_t, 16> kOpsetVersionToIRVersion = {
constexpr static std::array<int64_t, 17> kOpsetVersionToIRVersion = {
kInvalidOpsetVersion,
3,
kInvalidOpsetVersion,
@ -75,6 +75,7 @@ constexpr static std::array<int64_t, 16> kOpsetVersionToIRVersion = {
7,
7,
7,
8,
8};
std::string getNodeStackTraceString(const Node* n) {

View File

@ -197,7 +197,7 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
opset_version (int, default 13): The version of the
`default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
to target. Must be >= 7 and <= 15.
to target. Must be >= 7 and <= 16.
do_constant_folding (bool, default True): Apply the constant-folding optimization.
Constant-folding will replace some of the ops that have all constant inputs
with pre-computed constant nodes.

View File

@ -1034,7 +1034,7 @@ def args_have_same_dtype(args):
return has_same_dtype
_default_onnx_opset_version = 13
_onnx_main_opset = 15
_onnx_main_opset = 16
_onnx_stable_opsets = list(range(7, _onnx_main_opset))
_export_onnx_opset_version = _default_onnx_opset_version
_constant_folding_opset_versions = list(range(9, _onnx_main_opset + 1))

View File

@ -0,0 +1,46 @@
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
# This file exports ONNX ops for opset 16
# Note [ONNX Operators that are added/updated in opset 16]
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
# New operators:
# GridSample https://github.com/onnx/onnx/pull/3557
#
# Updated operators:
# Identity
# If
# LeakyRelu
# Loop
# PRelu
# RoiAlign
# Scan
# ScatterElemenets
# ScatterND
# Where
# GreaterOrEqual
# LessOrEqual
# SequenceMap
from torch.onnx.symbolic_helper import parse_args
from torch.nn.functional import GRID_SAMPLE_INTERPOLATION_MODES, GRID_SAMPLE_PADDING_MODES
# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
@parse_args("v", "v", "i", "i", "b")
def grid_sampler(g, input, grid, mode_enum, padding_mode_enum, align_corners):
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg]
return g.op(
"GridSample",
input,
grid,
align_corners_i=int(align_corners),
mode_s=mode_s,
padding_mode_s=padding_mode_s,
)