[ONNX] Refactor MaxPool to support dynamic inputs (#113318)

In https://github.com/pytorch/pytorch/pull/106270, the solution managed to solve the [`ceil_model` corner issue](https://github.com/onnx/onnx/issues/5711) with the usage of `get_pool_ceil_padding`. However, padding the ceil in converter side only works when we already know the input shapes, therefore, a regression happens when users want to do dynamic inputs.

This PR provides (1) refactor codes with torchlib implementation, (2) add dynamic shapes test, and (3) disable the corner tests with comments saying re-enable it when the [real fix from ONNX](https://github.com/onnx/onnx/pull/5741) is merged.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113318
Approved by: https://github.com/thiagocrepaldi
This commit is contained in:
AllenTiTaiWang 2023-11-09 18:57:33 +00:00 committed by PyTorch MergeBot
parent a4dc3716c0
commit e8e3afb784
5 changed files with 234 additions and 73 deletions

View File

@ -12,6 +12,11 @@ graph {
i: 0
type: INT
}
attribute {
name: "dilations"
ints: 1
type: INTS
}
attribute {
name: "kernel_shape"
ints: 3

View File

@ -13,6 +13,11 @@ graph {
i: 0
type: INT
}
attribute {
name: "dilations"
ints: 1
type: INTS
}
attribute {
name: "kernel_shape"
ints: 3
@ -36,6 +41,11 @@ graph {
output: "onnx::Slice_4"
name: "MaxPool_4"
op_type: "MaxPool"
attribute {
name: "dilations"
ints: 1
type: INTS
}
attribute {
name: "kernel_shape"
ints: 1
@ -56,7 +66,7 @@ graph {
t {
dims: 1
data_type: 7
raw_data: "\002\000\000\000\000\000\000\000"
raw_data: "\001\000\000\000\000\000\000\000"
}
type: TENSOR
}
@ -84,7 +94,7 @@ graph {
t {
dims: 1
data_type: 7
raw_data: "\001\000\000\000\000\000\000\000"
raw_data: "\002\000\000\000\000\000\000\000"
}
type: TENSOR
}
@ -92,8 +102,8 @@ graph {
node {
input: "onnx::Slice_4"
input: "onnx::Slice_6"
input: "onnx::Slice_7"
input: "onnx::Slice_5"
input: "onnx::Slice_7"
output: "onnx::Sub_8"
name: "Slice_8"
op_type: "Slice"

View File

@ -139,6 +139,7 @@ class TestONNXOpset(pytorch_test_common.ExportTestCase):
"op_name": "MaxPool",
"attributes": [
{"name": "ceil_mode", "i": 0, "type": 2},
{"name": "dilations", "ints": [1], "type": 7},
{"name": "kernel_shape", "ints": [2], "type": 7},
{"name": "pads", "ints": [0, 0], "type": 7},
{"name": "strides", "ints": [1], "type": 7},

View File

@ -1454,6 +1454,33 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
x = torch.randn(20, 16, 50, 44, 31)
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(10)
def test_maxpool_dynamic(self):
class test(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
norm_layer = functools.partial(torch.nn.BatchNorm2d, eps=0.0009)
self.avgpool = torch.nn.MaxPool2d((2, 2), stride=2, ceil_mode=True)
self.conv = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, bias=False
)
self.norm = norm_layer(out_channels)
def forward(self, x):
return self.norm(self.conv(self.avgpool(x)))
model = test(8, 16)
inputs = torch.randn(2, 8, 64, 64)
self.run_test(
model,
inputs,
input_names=["input_0"],
dynamic_axes={"input_0": {3: "x", 2: "y"}, "output_0": {3: "x", 2: "y"}},
output_names=["output_0"],
)
# TODO: Enable after https://github.com/onnx/onnx/pull/5741 or after ONNX 1.15.1+ is released
@skipIfUnsupportedMaxOpsetVersion(9)
def test_maxpool_1d_ceil_corner(self):
model = torch.nn.MaxPool1d(
kernel_size=1, dilation=1, stride=2, ceil_mode=True, return_indices=False
@ -1461,6 +1488,8 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
x = torch.randn(1, 3, 32)
self.run_test(model, x)
# TODO: Enable after https://github.com/onnx/onnx/pull/5741 or after ONNX 1.15.1+ is released
@skipIfUnsupportedMaxOpsetVersion(9)
def test_maxpool_2d_ceil_corner(self):
model = torch.nn.MaxPool2d(
kernel_size=[1, 1],
@ -1472,6 +1501,8 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
x = torch.randn(1, 3, 32, 32)
self.run_test(model, x)
# TODO: Enable after https://github.com/onnx/onnx/pull/5741 or after ONNX 1.15.1+ is released
@skipIfUnsupportedMaxOpsetVersion(9)
def test_maxpool_3d_ceil_corner(self):
model = torch.nn.MaxPool3d(
kernel_size=[7, 8, 4],
@ -1484,6 +1515,8 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
x = torch.randn(1, 3, 51, 52, 45)
self.run_test(model, x)
# TODO: Enable after https://github.com/onnx/onnx/pull/5741 or after ONNX 1.15.1+ is released
@skipIfUnsupportedMaxOpsetVersion(9)
@skipIfUnsupportedMinOpsetVersion(8)
def test_maxpool_1d_ceil_corner_with_indices(self):
model = torch.nn.MaxPool1d(
@ -1492,6 +1525,8 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
x = torch.randn(1, 3, 32)
self.run_test(model, x)
# TODO: Enable after https://github.com/onnx/onnx/pull/5741 or after ONNX 1.15.1+ is released
@skipIfUnsupportedMaxOpsetVersion(9)
@skipIfUnsupportedMinOpsetVersion(8)
def test_maxpool_2d_ceil_corner_with_indices(self):
model = torch.nn.MaxPool2d(
@ -1504,6 +1539,8 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
x = torch.randn(1, 3, 32, 32)
self.run_test(model, x)
# TODO: Enable after https://github.com/onnx/onnx/pull/5741 or after ONNX 1.15.1+ is released
@skipIfUnsupportedMaxOpsetVersion(9)
@skipIfUnsupportedMinOpsetVersion(8)
def test_maxpool_3d_ceil_corner_with_indices(self):
model = torch.nn.MaxPool3d(

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import functools
import sys
import warnings
from typing import Callable, List, Optional, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union
import torch
import torch._C._onnx as _C_onnx
@ -135,36 +135,162 @@ def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
)
def _aten_max_pool_onnx(
g: jit_utils.GraphContext,
self: _C.Value,
kernel_shape: Sequence[int],
strides: Sequence[int],
pads: Sequence[int],
dilations: Sequence[int],
ceil_mode: bool,
unbatched_rank: int,
) -> _C.Value:
self_rank = g.op("Size", g.op("Shape", self))
if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1
self = g.op(
"Unsqueeze",
self,
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
)
pool_result, _ = g.op(
"MaxPool",
self,
outputs=2,
ceil_mode_i=ceil_mode,
dilations_i=dilations,
kernel_shape_i=kernel_shape,
pads_i=pads,
strides_i=strides,
)
if self_rank == unbatched_rank:
pool_result = g.op(
"Squeeze",
pool_result,
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
)
return pool_result
# For MaxPool
def _adjust_attributes_of_max_pool(
expand_size: int,
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int],
padding: Union[Sequence[int], int],
dilation: Union[Sequence[int], int],
) -> Tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]:
"""Adjust attributes of avg_pool to match ONNX specification."""
if isinstance(dilation, int):
dilation = [dilation] * expand_size
if isinstance(kernel_size, int):
kernel_shape = [kernel_size] * expand_size
else:
kernel_shape = kernel_size # type: ignore[assignment]
if isinstance(padding, int):
pads = [padding] * expand_size * 2 # type: ignore[operator, assignment]
elif len(padding) == 1:
pads = padding * expand_size * 2 # type: ignore[operator, assignment]
elif len(padding) == 2:
# 2D padding
pads = padding * 2 # type: ignore[operator, assignment]
elif len(padding) == 3:
# 3D padding
pads = padding * 2 # type: ignore[operator, assignment]
else:
# When padding is already done for all dimensions,
# we don't need to double it
# eg: (1, 1, 1, 1, 1, 1)
pads = padding # type: ignore[assignment]
if isinstance(stride, int):
strides = [stride] * expand_size
elif not stride:
strides = kernel_shape
else:
strides = stride # type: ignore[assignment]
return (kernel_shape, strides, pads, dilation)
def _aten_max_pool_with_indices_onnx(
g: jit_utils.GraphContext,
self: _C.Value,
kernel_shape: Sequence[int],
strides: Sequence[int],
pads: Sequence[int],
dilations: Sequence[int],
ceil_mode: bool,
unbatched_rank: int,
n_dims_one: Sequence[int],
n_dims_zero: Sequence[int],
n_dims_axes: Sequence[int],
) -> Tuple[_C.Value, Sequence[int]]:
self_rank = g.op("Size", g.op("Shape", self))
if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1
self = g.op(
"Unsqueeze",
self,
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
)
pool_result, indices = g.op(
"MaxPool",
self,
outputs=2,
ceil_mode_i=ceil_mode,
dilations_i=dilations,
kernel_shape_i=kernel_shape,
pads_i=pads,
strides_i=strides,
)
_, flatten_indices = g.op(
"MaxPool",
self,
outputs=2,
dilations_i=dilations,
kernel_shape_i=n_dims_one,
strides_i=n_dims_one,
)
ends = g.op("Constant", value_t=torch.tensor(n_dims_one))
starts = g.op("Constant", value_t=torch.tensor(n_dims_zero))
axes = g.op("Constant", value_t=torch.tensor(n_dims_axes))
delta = g.op("Slice", flatten_indices, starts, ends, axes)
indices = g.op("Sub", indices, delta)
if self_rank == unbatched_rank:
pool_result = g.op(
"Squeeze", pool_result, value_t=torch.tensor([0], dtype=torch.int64)
)
indices = g.op("Squeeze", indices, value_t=torch.tensor([0], dtype=torch.int64))
return (pool_result, indices)
@_onnx_symbolic(
"aten::max_pool1d",
decorate=[
_apply_params(
"max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False
)
],
decorate=[_apply_params("max_pool1d", 1, return_indices=False)],
)
@_onnx_symbolic(
"aten::max_pool2d",
decorate=[
_apply_params(
"max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False
)
],
decorate=[_apply_params("max_pool2d", 2, return_indices=False)],
)
@_onnx_symbolic(
"aten::max_pool3d",
decorate=[
_apply_params(
"max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False
)
],
decorate=[_apply_params("max_pool3d", 3, return_indices=False)],
)
@_onnx_symbolic(
"aten::max_pool1d_with_indices",
decorate=[
_apply_params(
"max_pool1d_with_indices",
torch.nn.modules.utils._single,
1,
return_indices=True,
)
@ -175,7 +301,6 @@ def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
decorate=[
_apply_params(
"max_pool2d_with_indices",
torch.nn.modules.utils._pair,
2,
return_indices=True,
)
@ -186,70 +311,53 @@ def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
decorate=[
_apply_params(
"max_pool3d_with_indices",
torch.nn.modules.utils._triple,
3,
return_indices=True,
)
],
)
@_beartype.beartype
def _max_pool(name: str, tuple_fn: Callable, ndims: int, return_indices: bool):
def _max_pool(name: str, expand_size: int, return_indices: bool):
@symbolic_helper.quantized_args(True, False, False, False, False, False)
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
if not stride:
stride = kernel_size
padding = tuple(tuple_fn(padding))
if ceil_mode:
padding_ceil = opset9.get_pool_ceil_padding(
input, kernel_size, stride, padding
)
padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding))
else:
padding = padding * 2
kwargs = {
"kernel_shape_i": tuple_fn(kernel_size),
"pads_i": padding,
"strides_i": tuple_fn(stride),
"ceil_mode_i": 0,
}
if set(tuple_fn(dilation)) != {1}:
kwargs["dilations_i"] = tuple_fn(dilation)
# easy but hacky way to get flattened indices values
# to be used to convert the indices values to non-flattened.
# In ONNX the indices are computed as a flatten 1-D tensor,
# so the values in indices are in [0, N x C x D1 x ... x Dn).
# To convert the indices to the same format used by Pytorch,
# we first execute a maxpool with a kernel and stride of 1 on the same input.
# This will result in a tensor of indices in which each index will have it's own value.
# Using this tensor as a reference, we extract the first index of each axis and subtract
# it from each index of this axis in the indices to convert.
# This step will result in a tensor were each dimension has values of indices within
# the dimension it is in.
# For more information :
# https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
def symbolic_fn(
g: jit_utils.GraphContext,
input: _C.Value,
kernel_size: Sequence[int],
stride: Sequence[int],
padding: Union[int, Sequence[int]],
dilation: Sequence[int],
ceil_mode: bool,
):
kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
expand_size, kernel_size, stride, padding, dilation
)
if return_indices:
r, indices = g.op("MaxPool", input, outputs=2, **kwargs)
_, flattened_indices = g.op(
"MaxPool",
input,
outputs=2,
kernel_shape_i=[1 for _ in range(ndims)],
strides_i=[1 for _ in range(ndims)],
)
# convert indices to have non-flattened indices values
s = symbolic_helper._slice_helper(
return _aten_max_pool_with_indices_onnx(
g,
flattened_indices,
axes=[2 + i for i in range(ndims)],
starts=list(tuple_fn(0)),
ends=list(tuple_fn(1)),
input,
kernel_shape,
strides,
pads,
dilations,
ceil_mode,
expand_size + 1,
([1] * expand_size),
([0] * expand_size),
([2 + i for i in range(expand_size)]),
)
indices = opset9.sub(g, indices, s)
return r, indices
else:
r = g.op("MaxPool", input, outputs=1, **kwargs)
return r
return _aten_max_pool_onnx(
g,
input,
kernel_shape,
strides,
pads,
dilations,
ceil_mode,
expand_size + 1,
)
return symbolic_fn