mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[ONNX] Add Squeeze/Unsqueeze dynamic dimensions support when opset >= 13 (#71158)
* Add Squeeze/Unsqueeze dynamic axes support when opset >= 13 Co-authored-by: hwangdeyu <dejack953outlook.com> Co-authored-by: Gary Miguel <garymmgarymm.org> Pull Request resolved: https://github.com/pytorch/pytorch/pull/73104
This commit is contained in:
parent
80291dff43
commit
bd4902d81f
|
|
@ -1525,6 +1525,16 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
x = torch.randn(2, 1, 4)
|
||||
self.run_test(Squeeze(), x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(13)
|
||||
def test_squeeze_dynamic_dim(self):
|
||||
class Squeeze(torch.nn.Module):
|
||||
def forward(self, x, dim: int):
|
||||
return torch.squeeze(x, dim)
|
||||
|
||||
x = torch.randn(2, 1, 4)
|
||||
dim = 1
|
||||
self.run_test(Squeeze(), (x, dim))
|
||||
|
||||
def test_unsqueeze(self):
|
||||
class Unsqueeze(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
@ -1533,6 +1543,16 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
x = torch.randn(2, 3, 4)
|
||||
self.run_test(Unsqueeze(), x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(13)
|
||||
def test_unsqueeze_dynamic_dim(self):
|
||||
class Unsqueeze(torch.nn.Module):
|
||||
def forward(self, x, dim: int):
|
||||
return torch.unsqueeze(x, dim)
|
||||
|
||||
x = torch.randn(2, 1, 4)
|
||||
dim = -1
|
||||
self.run_test(Unsqueeze(), (x, dim))
|
||||
|
||||
def test_maxpool_default_stride(self):
|
||||
class MaxPoolModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
|
|||
|
|
@ -476,19 +476,38 @@ def _interpolate_warning(interpolate_mode):
|
|||
"to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).\n"
|
||||
"We recommend using opset 11 and above for models using this operator.")
|
||||
|
||||
|
||||
def _unsqueeze_helper(g, input, axes_i):
|
||||
if _export_onnx_opset_version >= 13:
|
||||
axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long))
|
||||
return g.op("Unsqueeze", input, axes)
|
||||
else:
|
||||
if _is_constant(axes_i[0]):
|
||||
if _export_onnx_opset_version >= 13:
|
||||
axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long))
|
||||
return g.op("Unsqueeze", input, axes)
|
||||
return g.op("Unsqueeze", input, axes_i=axes_i)
|
||||
# Tensor type
|
||||
if _export_onnx_opset_version < 13:
|
||||
raise ValueError(f"Opset version must be >= 13 for Unsqueeze with dynamic axes. {input.node().sourceRange()}")
|
||||
return g.op("Unsqueeze", input, axes_i[0])
|
||||
|
||||
|
||||
def _squeeze_helper(g, input, axes_i):
|
||||
if _export_onnx_opset_version >= 13:
|
||||
axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long))
|
||||
return g.op("Squeeze", input, axes)
|
||||
else:
|
||||
if _is_constant(axes_i[0]):
|
||||
if _export_onnx_opset_version >= 13:
|
||||
axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long))
|
||||
return g.op("Squeeze", input, axes)
|
||||
return g.op("Squeeze", input, axes_i=axes_i)
|
||||
# Tensor type
|
||||
if _export_onnx_opset_version < 13:
|
||||
raise ValueError(f"Opset version must be >= 13 for Squeeze with dynamic axes. {input.node().sourceRange()}")
|
||||
axes_t = axes_i[0]
|
||||
axes_rank = _get_tensor_rank(axes_t)
|
||||
if axes_rank > 1:
|
||||
raise ValueError("For Squeeze axses as input, the axes rank must be one in ONNX spec.")
|
||||
elif axes_rank == 0:
|
||||
# The axes is a scalar. Unsqueeze it to a rank 1 tensor.
|
||||
axes_t = _unsqueeze_helper(g, axes_t, [0])
|
||||
return g.op("Squeeze", input, axes_t)
|
||||
return g.op("Squeeze", input, axes_t)
|
||||
|
||||
|
||||
def _reducesum_helper(g, input, axes_i=None, keepdims_i=1, noop_with_empty_axes_i=0):
|
||||
keepdims_i = _maybe_get_const(keepdims_i, "i")
|
||||
|
|
@ -501,6 +520,7 @@ def _reducesum_helper(g, input, axes_i=None, keepdims_i=1, noop_with_empty_axes_
|
|||
else:
|
||||
return g.op("ReduceSum", input, axes_i=axes_i, keepdims_i=keepdims_i)
|
||||
|
||||
|
||||
def _interpolate_size_to_scales(g, input, output_size, dim):
|
||||
output_size = _maybe_get_const(output_size, "is")
|
||||
if _is_value(output_size):
|
||||
|
|
|
|||
|
|
@ -569,6 +569,10 @@ def squeeze(g, self, dim=None):
|
|||
if dim is None:
|
||||
return g.op("Squeeze", self)
|
||||
|
||||
# dim as a tensor
|
||||
if not sym_help._is_constant(dim):
|
||||
return sym_help._squeeze_helper(g, self, [dim])
|
||||
|
||||
dim = sym_help._get_const(dim, "i", "dim")
|
||||
|
||||
input_rank = sym_help._get_tensor_rank(self)
|
||||
|
|
@ -606,8 +610,10 @@ def squeeze(g, self, dim=None):
|
|||
return sym_help._squeeze_helper(g, self, [dim])
|
||||
|
||||
|
||||
@parse_args("v", "i")
|
||||
def unsqueeze(g, self, dim):
|
||||
if sym_help._is_constant(dim):
|
||||
dim = sym_help._get_const(dim, "i", "dim")
|
||||
|
||||
return sym_help._unsqueeze_helper(g, self, [dim])
|
||||
|
||||
def mm(g, self, other):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user