mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Refactor MaxPool to support dynamic inputs (#113318)
In https://github.com/pytorch/pytorch/pull/106270, the solution managed to solve the [`ceil_model` corner issue](https://github.com/onnx/onnx/issues/5711) with the usage of `get_pool_ceil_padding`. However, padding the ceil in converter side only works when we already know the input shapes, therefore, a regression happens when users want to do dynamic inputs. This PR provides (1) refactor codes with torchlib implementation, (2) add dynamic shapes test, and (3) disable the corner tests with comments saying re-enable it when the [real fix from ONNX](https://github.com/onnx/onnx/pull/5741) is merged. Pull Request resolved: https://github.com/pytorch/pytorch/pull/113318 Approved by: https://github.com/thiagocrepaldi
This commit is contained in:
parent
a4dc3716c0
commit
e8e3afb784
|
|
@ -12,6 +12,11 @@ graph {
|
|||
i: 0
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "dilations"
|
||||
ints: 1
|
||||
type: INTS
|
||||
}
|
||||
attribute {
|
||||
name: "kernel_shape"
|
||||
ints: 3
|
||||
|
|
|
|||
|
|
@ -13,6 +13,11 @@ graph {
|
|||
i: 0
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "dilations"
|
||||
ints: 1
|
||||
type: INTS
|
||||
}
|
||||
attribute {
|
||||
name: "kernel_shape"
|
||||
ints: 3
|
||||
|
|
@ -36,6 +41,11 @@ graph {
|
|||
output: "onnx::Slice_4"
|
||||
name: "MaxPool_4"
|
||||
op_type: "MaxPool"
|
||||
attribute {
|
||||
name: "dilations"
|
||||
ints: 1
|
||||
type: INTS
|
||||
}
|
||||
attribute {
|
||||
name: "kernel_shape"
|
||||
ints: 1
|
||||
|
|
@ -56,7 +66,7 @@ graph {
|
|||
t {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
raw_data: "\002\000\000\000\000\000\000\000"
|
||||
raw_data: "\001\000\000\000\000\000\000\000"
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
|
|
@ -84,7 +94,7 @@ graph {
|
|||
t {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
raw_data: "\001\000\000\000\000\000\000\000"
|
||||
raw_data: "\002\000\000\000\000\000\000\000"
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
|
|
@ -92,8 +102,8 @@ graph {
|
|||
node {
|
||||
input: "onnx::Slice_4"
|
||||
input: "onnx::Slice_6"
|
||||
input: "onnx::Slice_7"
|
||||
input: "onnx::Slice_5"
|
||||
input: "onnx::Slice_7"
|
||||
output: "onnx::Sub_8"
|
||||
name: "Slice_8"
|
||||
op_type: "Slice"
|
||||
|
|
|
|||
|
|
@ -139,6 +139,7 @@ class TestONNXOpset(pytorch_test_common.ExportTestCase):
|
|||
"op_name": "MaxPool",
|
||||
"attributes": [
|
||||
{"name": "ceil_mode", "i": 0, "type": 2},
|
||||
{"name": "dilations", "ints": [1], "type": 7},
|
||||
{"name": "kernel_shape", "ints": [2], "type": 7},
|
||||
{"name": "pads", "ints": [0, 0], "type": 7},
|
||||
{"name": "strides", "ints": [1], "type": 7},
|
||||
|
|
|
|||
|
|
@ -1454,6 +1454,33 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
x = torch.randn(20, 16, 50, 44, 31)
|
||||
self.run_test(model, x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_maxpool_dynamic(self):
|
||||
class test(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
norm_layer = functools.partial(torch.nn.BatchNorm2d, eps=0.0009)
|
||||
self.avgpool = torch.nn.MaxPool2d((2, 2), stride=2, ceil_mode=True)
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, bias=False
|
||||
)
|
||||
self.norm = norm_layer(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(self.conv(self.avgpool(x)))
|
||||
|
||||
model = test(8, 16)
|
||||
inputs = torch.randn(2, 8, 64, 64)
|
||||
self.run_test(
|
||||
model,
|
||||
inputs,
|
||||
input_names=["input_0"],
|
||||
dynamic_axes={"input_0": {3: "x", 2: "y"}, "output_0": {3: "x", 2: "y"}},
|
||||
output_names=["output_0"],
|
||||
)
|
||||
|
||||
# TODO: Enable after https://github.com/onnx/onnx/pull/5741 or after ONNX 1.15.1+ is released
|
||||
@skipIfUnsupportedMaxOpsetVersion(9)
|
||||
def test_maxpool_1d_ceil_corner(self):
|
||||
model = torch.nn.MaxPool1d(
|
||||
kernel_size=1, dilation=1, stride=2, ceil_mode=True, return_indices=False
|
||||
|
|
@ -1461,6 +1488,8 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
x = torch.randn(1, 3, 32)
|
||||
self.run_test(model, x)
|
||||
|
||||
# TODO: Enable after https://github.com/onnx/onnx/pull/5741 or after ONNX 1.15.1+ is released
|
||||
@skipIfUnsupportedMaxOpsetVersion(9)
|
||||
def test_maxpool_2d_ceil_corner(self):
|
||||
model = torch.nn.MaxPool2d(
|
||||
kernel_size=[1, 1],
|
||||
|
|
@ -1472,6 +1501,8 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
x = torch.randn(1, 3, 32, 32)
|
||||
self.run_test(model, x)
|
||||
|
||||
# TODO: Enable after https://github.com/onnx/onnx/pull/5741 or after ONNX 1.15.1+ is released
|
||||
@skipIfUnsupportedMaxOpsetVersion(9)
|
||||
def test_maxpool_3d_ceil_corner(self):
|
||||
model = torch.nn.MaxPool3d(
|
||||
kernel_size=[7, 8, 4],
|
||||
|
|
@ -1484,6 +1515,8 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
x = torch.randn(1, 3, 51, 52, 45)
|
||||
self.run_test(model, x)
|
||||
|
||||
# TODO: Enable after https://github.com/onnx/onnx/pull/5741 or after ONNX 1.15.1+ is released
|
||||
@skipIfUnsupportedMaxOpsetVersion(9)
|
||||
@skipIfUnsupportedMinOpsetVersion(8)
|
||||
def test_maxpool_1d_ceil_corner_with_indices(self):
|
||||
model = torch.nn.MaxPool1d(
|
||||
|
|
@ -1492,6 +1525,8 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
x = torch.randn(1, 3, 32)
|
||||
self.run_test(model, x)
|
||||
|
||||
# TODO: Enable after https://github.com/onnx/onnx/pull/5741 or after ONNX 1.15.1+ is released
|
||||
@skipIfUnsupportedMaxOpsetVersion(9)
|
||||
@skipIfUnsupportedMinOpsetVersion(8)
|
||||
def test_maxpool_2d_ceil_corner_with_indices(self):
|
||||
model = torch.nn.MaxPool2d(
|
||||
|
|
@ -1504,6 +1539,8 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
x = torch.randn(1, 3, 32, 32)
|
||||
self.run_test(model, x)
|
||||
|
||||
# TODO: Enable after https://github.com/onnx/onnx/pull/5741 or after ONNX 1.15.1+ is released
|
||||
@skipIfUnsupportedMaxOpsetVersion(9)
|
||||
@skipIfUnsupportedMinOpsetVersion(8)
|
||||
def test_maxpool_3d_ceil_corner_with_indices(self):
|
||||
model = torch.nn.MaxPool3d(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import functools
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Callable, List, Optional, Sequence, Tuple, Union
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch._C._onnx as _C_onnx
|
||||
|
|
@ -135,36 +135,162 @@ def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
|
|||
)
|
||||
|
||||
|
||||
def _aten_max_pool_onnx(
|
||||
g: jit_utils.GraphContext,
|
||||
self: _C.Value,
|
||||
kernel_shape: Sequence[int],
|
||||
strides: Sequence[int],
|
||||
pads: Sequence[int],
|
||||
dilations: Sequence[int],
|
||||
ceil_mode: bool,
|
||||
unbatched_rank: int,
|
||||
) -> _C.Value:
|
||||
self_rank = g.op("Size", g.op("Shape", self))
|
||||
if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1
|
||||
self = g.op(
|
||||
"Unsqueeze",
|
||||
self,
|
||||
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
|
||||
)
|
||||
|
||||
pool_result, _ = g.op(
|
||||
"MaxPool",
|
||||
self,
|
||||
outputs=2,
|
||||
ceil_mode_i=ceil_mode,
|
||||
dilations_i=dilations,
|
||||
kernel_shape_i=kernel_shape,
|
||||
pads_i=pads,
|
||||
strides_i=strides,
|
||||
)
|
||||
|
||||
if self_rank == unbatched_rank:
|
||||
pool_result = g.op(
|
||||
"Squeeze",
|
||||
pool_result,
|
||||
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
|
||||
)
|
||||
|
||||
return pool_result
|
||||
|
||||
|
||||
# For MaxPool
|
||||
def _adjust_attributes_of_max_pool(
|
||||
expand_size: int,
|
||||
kernel_size: Union[Sequence[int], int],
|
||||
stride: Union[Sequence[int], int],
|
||||
padding: Union[Sequence[int], int],
|
||||
dilation: Union[Sequence[int], int],
|
||||
) -> Tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]:
|
||||
"""Adjust attributes of avg_pool to match ONNX specification."""
|
||||
|
||||
if isinstance(dilation, int):
|
||||
dilation = [dilation] * expand_size
|
||||
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_shape = [kernel_size] * expand_size
|
||||
else:
|
||||
kernel_shape = kernel_size # type: ignore[assignment]
|
||||
|
||||
if isinstance(padding, int):
|
||||
pads = [padding] * expand_size * 2 # type: ignore[operator, assignment]
|
||||
elif len(padding) == 1:
|
||||
pads = padding * expand_size * 2 # type: ignore[operator, assignment]
|
||||
elif len(padding) == 2:
|
||||
# 2D padding
|
||||
pads = padding * 2 # type: ignore[operator, assignment]
|
||||
elif len(padding) == 3:
|
||||
# 3D padding
|
||||
pads = padding * 2 # type: ignore[operator, assignment]
|
||||
else:
|
||||
# When padding is already done for all dimensions,
|
||||
# we don't need to double it
|
||||
# eg: (1, 1, 1, 1, 1, 1)
|
||||
pads = padding # type: ignore[assignment]
|
||||
|
||||
if isinstance(stride, int):
|
||||
strides = [stride] * expand_size
|
||||
elif not stride:
|
||||
strides = kernel_shape
|
||||
else:
|
||||
strides = stride # type: ignore[assignment]
|
||||
|
||||
return (kernel_shape, strides, pads, dilation)
|
||||
|
||||
|
||||
def _aten_max_pool_with_indices_onnx(
|
||||
g: jit_utils.GraphContext,
|
||||
self: _C.Value,
|
||||
kernel_shape: Sequence[int],
|
||||
strides: Sequence[int],
|
||||
pads: Sequence[int],
|
||||
dilations: Sequence[int],
|
||||
ceil_mode: bool,
|
||||
unbatched_rank: int,
|
||||
n_dims_one: Sequence[int],
|
||||
n_dims_zero: Sequence[int],
|
||||
n_dims_axes: Sequence[int],
|
||||
) -> Tuple[_C.Value, Sequence[int]]:
|
||||
self_rank = g.op("Size", g.op("Shape", self))
|
||||
if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1
|
||||
self = g.op(
|
||||
"Unsqueeze",
|
||||
self,
|
||||
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
|
||||
)
|
||||
|
||||
pool_result, indices = g.op(
|
||||
"MaxPool",
|
||||
self,
|
||||
outputs=2,
|
||||
ceil_mode_i=ceil_mode,
|
||||
dilations_i=dilations,
|
||||
kernel_shape_i=kernel_shape,
|
||||
pads_i=pads,
|
||||
strides_i=strides,
|
||||
)
|
||||
_, flatten_indices = g.op(
|
||||
"MaxPool",
|
||||
self,
|
||||
outputs=2,
|
||||
dilations_i=dilations,
|
||||
kernel_shape_i=n_dims_one,
|
||||
strides_i=n_dims_one,
|
||||
)
|
||||
|
||||
ends = g.op("Constant", value_t=torch.tensor(n_dims_one))
|
||||
starts = g.op("Constant", value_t=torch.tensor(n_dims_zero))
|
||||
axes = g.op("Constant", value_t=torch.tensor(n_dims_axes))
|
||||
|
||||
delta = g.op("Slice", flatten_indices, starts, ends, axes)
|
||||
indices = g.op("Sub", indices, delta)
|
||||
|
||||
if self_rank == unbatched_rank:
|
||||
pool_result = g.op(
|
||||
"Squeeze", pool_result, value_t=torch.tensor([0], dtype=torch.int64)
|
||||
)
|
||||
indices = g.op("Squeeze", indices, value_t=torch.tensor([0], dtype=torch.int64))
|
||||
|
||||
return (pool_result, indices)
|
||||
|
||||
|
||||
@_onnx_symbolic(
|
||||
"aten::max_pool1d",
|
||||
decorate=[
|
||||
_apply_params(
|
||||
"max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False
|
||||
)
|
||||
],
|
||||
decorate=[_apply_params("max_pool1d", 1, return_indices=False)],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::max_pool2d",
|
||||
decorate=[
|
||||
_apply_params(
|
||||
"max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False
|
||||
)
|
||||
],
|
||||
decorate=[_apply_params("max_pool2d", 2, return_indices=False)],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::max_pool3d",
|
||||
decorate=[
|
||||
_apply_params(
|
||||
"max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False
|
||||
)
|
||||
],
|
||||
decorate=[_apply_params("max_pool3d", 3, return_indices=False)],
|
||||
)
|
||||
@_onnx_symbolic(
|
||||
"aten::max_pool1d_with_indices",
|
||||
decorate=[
|
||||
_apply_params(
|
||||
"max_pool1d_with_indices",
|
||||
torch.nn.modules.utils._single,
|
||||
1,
|
||||
return_indices=True,
|
||||
)
|
||||
|
|
@ -175,7 +301,6 @@ def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
|
|||
decorate=[
|
||||
_apply_params(
|
||||
"max_pool2d_with_indices",
|
||||
torch.nn.modules.utils._pair,
|
||||
2,
|
||||
return_indices=True,
|
||||
)
|
||||
|
|
@ -186,70 +311,53 @@ def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
|
|||
decorate=[
|
||||
_apply_params(
|
||||
"max_pool3d_with_indices",
|
||||
torch.nn.modules.utils._triple,
|
||||
3,
|
||||
return_indices=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
@_beartype.beartype
|
||||
def _max_pool(name: str, tuple_fn: Callable, ndims: int, return_indices: bool):
|
||||
def _max_pool(name: str, expand_size: int, return_indices: bool):
|
||||
@symbolic_helper.quantized_args(True, False, False, False, False, False)
|
||||
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
|
||||
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
|
||||
if not stride:
|
||||
stride = kernel_size
|
||||
padding = tuple(tuple_fn(padding))
|
||||
if ceil_mode:
|
||||
padding_ceil = opset9.get_pool_ceil_padding(
|
||||
input, kernel_size, stride, padding
|
||||
def symbolic_fn(
|
||||
g: jit_utils.GraphContext,
|
||||
input: _C.Value,
|
||||
kernel_size: Sequence[int],
|
||||
stride: Sequence[int],
|
||||
padding: Union[int, Sequence[int]],
|
||||
dilation: Sequence[int],
|
||||
ceil_mode: bool,
|
||||
):
|
||||
kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
|
||||
expand_size, kernel_size, stride, padding, dilation
|
||||
)
|
||||
padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding))
|
||||
else:
|
||||
padding = padding * 2
|
||||
kwargs = {
|
||||
"kernel_shape_i": tuple_fn(kernel_size),
|
||||
"pads_i": padding,
|
||||
"strides_i": tuple_fn(stride),
|
||||
"ceil_mode_i": 0,
|
||||
}
|
||||
if set(tuple_fn(dilation)) != {1}:
|
||||
kwargs["dilations_i"] = tuple_fn(dilation)
|
||||
# easy but hacky way to get flattened indices values
|
||||
# to be used to convert the indices values to non-flattened.
|
||||
# In ONNX the indices are computed as a flatten 1-D tensor,
|
||||
# so the values in indices are in [0, N x C x D1 x ... x Dn).
|
||||
# To convert the indices to the same format used by Pytorch,
|
||||
# we first execute a maxpool with a kernel and stride of 1 on the same input.
|
||||
# This will result in a tensor of indices in which each index will have it's own value.
|
||||
# Using this tensor as a reference, we extract the first index of each axis and subtract
|
||||
# it from each index of this axis in the indices to convert.
|
||||
# This step will result in a tensor were each dimension has values of indices within
|
||||
# the dimension it is in.
|
||||
# For more information :
|
||||
# https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
|
||||
|
||||
if return_indices:
|
||||
r, indices = g.op("MaxPool", input, outputs=2, **kwargs)
|
||||
_, flattened_indices = g.op(
|
||||
"MaxPool",
|
||||
input,
|
||||
outputs=2,
|
||||
kernel_shape_i=[1 for _ in range(ndims)],
|
||||
strides_i=[1 for _ in range(ndims)],
|
||||
)
|
||||
# convert indices to have non-flattened indices values
|
||||
s = symbolic_helper._slice_helper(
|
||||
return _aten_max_pool_with_indices_onnx(
|
||||
g,
|
||||
flattened_indices,
|
||||
axes=[2 + i for i in range(ndims)],
|
||||
starts=list(tuple_fn(0)),
|
||||
ends=list(tuple_fn(1)),
|
||||
input,
|
||||
kernel_shape,
|
||||
strides,
|
||||
pads,
|
||||
dilations,
|
||||
ceil_mode,
|
||||
expand_size + 1,
|
||||
([1] * expand_size),
|
||||
([0] * expand_size),
|
||||
([2 + i for i in range(expand_size)]),
|
||||
)
|
||||
indices = opset9.sub(g, indices, s)
|
||||
return r, indices
|
||||
else:
|
||||
r = g.op("MaxPool", input, outputs=1, **kwargs)
|
||||
return r
|
||||
return _aten_max_pool_onnx(
|
||||
g,
|
||||
input,
|
||||
kernel_shape,
|
||||
strides,
|
||||
pads,
|
||||
dilations,
|
||||
ceil_mode,
|
||||
expand_size + 1,
|
||||
)
|
||||
|
||||
return symbolic_fn
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user