[ONNX] Add Squeeze/Unsqueeze dynamic dimensions support when opset >= 13 (#71158)

* Add Squeeze/Unsqueeze dynamic axes support when opset >= 13

Co-authored-by: hwangdeyu <dejack953outlook.com>
Co-authored-by: Gary Miguel <garymmgarymm.org>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73104
This commit is contained in:
BowenBao 2022-02-22 14:55:19 -08:00 committed by PyTorch MergeBot
parent 80291dff43
commit bd4902d81f
3 changed files with 55 additions and 9 deletions

View File

@ -1525,6 +1525,16 @@ class TestONNXRuntime(unittest.TestCase):
x = torch.randn(2, 1, 4)
self.run_test(Squeeze(), x)
@skipIfUnsupportedMinOpsetVersion(13)
def test_squeeze_dynamic_dim(self):
class Squeeze(torch.nn.Module):
def forward(self, x, dim: int):
return torch.squeeze(x, dim)
x = torch.randn(2, 1, 4)
dim = 1
self.run_test(Squeeze(), (x, dim))
def test_unsqueeze(self):
class Unsqueeze(torch.nn.Module):
def forward(self, x):
@ -1533,6 +1543,16 @@ class TestONNXRuntime(unittest.TestCase):
x = torch.randn(2, 3, 4)
self.run_test(Unsqueeze(), x)
@skipIfUnsupportedMinOpsetVersion(13)
def test_unsqueeze_dynamic_dim(self):
class Unsqueeze(torch.nn.Module):
def forward(self, x, dim: int):
return torch.unsqueeze(x, dim)
x = torch.randn(2, 1, 4)
dim = -1
self.run_test(Unsqueeze(), (x, dim))
def test_maxpool_default_stride(self):
class MaxPoolModel(torch.nn.Module):
def forward(self, x):

View File

@ -476,19 +476,38 @@ def _interpolate_warning(interpolate_mode):
"to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).\n"
"We recommend using opset 11 and above for models using this operator.")
def _unsqueeze_helper(g, input, axes_i):
if _is_constant(axes_i[0]):
if _export_onnx_opset_version >= 13:
axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long))
return g.op("Unsqueeze", input, axes)
else:
return g.op("Unsqueeze", input, axes_i=axes_i)
# Tensor type
if _export_onnx_opset_version < 13:
raise ValueError(f"Opset version must be >= 13 for Unsqueeze with dynamic axes. {input.node().sourceRange()}")
return g.op("Unsqueeze", input, axes_i[0])
def _squeeze_helper(g, input, axes_i):
if _is_constant(axes_i[0]):
if _export_onnx_opset_version >= 13:
axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long))
return g.op("Squeeze", input, axes)
else:
return g.op("Squeeze", input, axes_i=axes_i)
# Tensor type
if _export_onnx_opset_version < 13:
raise ValueError(f"Opset version must be >= 13 for Squeeze with dynamic axes. {input.node().sourceRange()}")
axes_t = axes_i[0]
axes_rank = _get_tensor_rank(axes_t)
if axes_rank > 1:
raise ValueError("For Squeeze axses as input, the axes rank must be one in ONNX spec.")
elif axes_rank == 0:
# The axes is a scalar. Unsqueeze it to a rank 1 tensor.
axes_t = _unsqueeze_helper(g, axes_t, [0])
return g.op("Squeeze", input, axes_t)
return g.op("Squeeze", input, axes_t)
def _reducesum_helper(g, input, axes_i=None, keepdims_i=1, noop_with_empty_axes_i=0):
keepdims_i = _maybe_get_const(keepdims_i, "i")
@ -501,6 +520,7 @@ def _reducesum_helper(g, input, axes_i=None, keepdims_i=1, noop_with_empty_axes_
else:
return g.op("ReduceSum", input, axes_i=axes_i, keepdims_i=keepdims_i)
def _interpolate_size_to_scales(g, input, output_size, dim):
output_size = _maybe_get_const(output_size, "is")
if _is_value(output_size):

View File

@ -569,6 +569,10 @@ def squeeze(g, self, dim=None):
if dim is None:
return g.op("Squeeze", self)
# dim as a tensor
if not sym_help._is_constant(dim):
return sym_help._squeeze_helper(g, self, [dim])
dim = sym_help._get_const(dim, "i", "dim")
input_rank = sym_help._get_tensor_rank(self)
@ -606,8 +610,10 @@ def squeeze(g, self, dim=None):
return sym_help._squeeze_helper(g, self, [dim])
@parse_args("v", "i")
def unsqueeze(g, self, dim):
if sym_help._is_constant(dim):
dim = sym_help._get_const(dim, "i", "dim")
return sym_help._unsqueeze_helper(g, self, [dim])
def mm(g, self, other):