[ONNX] Update repeat_interleave for dynamic repeats (#59979) (#62764)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62764

Fixes #58733

- Support dynamic interleave for cases with dynamic repeat values
- Moved repeat_interleave symbolic from opset 11 to opset 13, as sequence as output types for loop outputs is needed for this change

Test Plan: Imported from OSS

Reviewed By: SplitInfinity

Differential Revision: D30375179

Pulled By: msaroufim

fbshipit-source-id: 787f96bf91d124fd0483761088c5f4ae930d96a9

Co-authored-by: Shubham Bhokare <shubhambhokare@gmail.com>
This commit is contained in:
BowenBao 2021-08-20 12:44:29 -07:00 committed by Facebook GitHub Bot
parent 8760254911
commit db0771b05d
4 changed files with 171 additions and 119 deletions

View File

@ -4323,7 +4323,7 @@ class TestONNXRuntime(unittest.TestCase):
x = torch.tensor([[1, 2], [3, 4]])
self.run_test(RepeatsDimsModel2(), (x,))
@skipIfUnsupportedMinOpsetVersion(11)
@skipIfUnsupportedMinOpsetVersion(13)
def test_dynamic_repeat_interleave(self):
class SingleDynamicModel(torch.nn.Module):
def forward(self, x):
@ -4345,25 +4345,62 @@ class TestONNXRuntime(unittest.TestCase):
self.run_test(NegDynamicModel(), x, test_with_inputs=[another_x],
input_names=["input_1"], dynamic_axes={"input_1" : {1 : "w"}})
class SingleDynamicModel2(torch.nn.Module):
class SingleDynamicModelFloat(torch.nn.Module):
def forward(self, x):
repeats = torch.tensor([4])
return torch.repeat_interleave(x, repeats, dim=0)
x = torch.tensor([[1, 2], [3, 4]])
another_x = torch.tensor([[7, 8], [5, 6]])
self.run_test(SingleDynamicModel2(), x, test_with_inputs=[another_x],
x = torch.tensor([[1.1, 2.1], [3.1, 4.1]])
another_x = torch.tensor([[7.1, 8.1], [5.1, 6.1]])
self.run_test(SingleDynamicModelFloat(), x, test_with_inputs=[another_x],
input_names=["input_1"], dynamic_axes={"input_1" : {0 : "h"}})
class AllDynamicModel(torch.nn.Module):
def forward(self, x):
repeats = torch.tensor([4])
class DynamicRepeatsModel(torch.nn.Module):
def forward(self, x, repeats):
return torch.repeat_interleave(x, repeats, dim=1)
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
another_x = torch.tensor([[7, 8], [5, 6]])
repeats = torch.tensor([2])
another_repeats = torch.tensor([4])
self.run_test(DynamicRepeatsModel(), (x, repeats), test_with_inputs=[(another_x, another_repeats)],
input_names=["input_1", "repeats_1"],
dynamic_axes={"input_1" : {1 : "w"}, "repeats_1" : {0 : "r"}})
class DynamicRepeatsModel2(torch.nn.Module):
def forward(self, x, repeats):
return torch.repeat_interleave(x, repeats, dim=1)
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
repeats = torch.tensor([2])
another_repeats = torch.tensor([4])
self.run_test(DynamicRepeatsModel2(), (x, repeats), test_with_inputs=[(x, another_repeats)],
input_names=["input_1", "repeats_1"],
dynamic_axes={"repeats_1" : {0 : "r"}})
@skipIfUnsupportedMinOpsetVersion(13)
def test_multiple_dynamic_repeat_interleave(self):
class DynamicRepeatsModel(torch.nn.Module):
def forward(self, x, repeats):
return torch.repeat_interleave(x, repeats, dim=1)
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
repeats = torch.tensor([2, 3, 4])
another_repeats = torch.tensor([4, 3, 2])
self.run_test(DynamicRepeatsModel(), (x, repeats), test_with_inputs=[(x, another_repeats)],
input_names=["input_1", "repeats_1"],
dynamic_axes={"repeats_1" : {0 : "r"}})
class DynamicRepeatsModel2(torch.nn.Module):
def forward(self, x, repeats):
return torch.repeat_interleave(x, repeats, dim=0)
x = torch.tensor([[1, 2, 4, 16], [3, 9, 27, 81], [2, 3, 5, 7]])
another_x = torch.tensor([[7, 8], [5, 6]])
self.run_test(AllDynamicModel(), x, test_with_inputs=[another_x],
input_names=["input_1"], dynamic_axes={"input_1" : {0 : "h", 1 : "w"}})
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
repeats = torch.tensor([2, 3])
another_repeats = torch.tensor([4, 3])
self.run_test(DynamicRepeatsModel2(), (x, repeats), test_with_inputs=[(x, another_repeats)],
input_names=["input_1", "repeats_1"],
dynamic_axes={"repeats_1" : {0 : "r"}})
def test_view(self):
class ViewModel(torch.nn.Module):

View File

@ -894,110 +894,6 @@ def chunk(g, self, chunks, dim):
chunk_vec = g.op("Concat", *chunk_vec, axis_i=0)
return split(g, self, chunk_vec, dim)
def repeat_interleave(g, self, repeats, dim=None, output_size=None):
input = self
final_dim = dim
# if dim is None flatten
# By default, use the flattened input array, and return a flat output array
if sym_help._is_none(dim):
input = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1])))
dim = 0
else:
dim = sym_help._maybe_get_scalar(dim)
repeats_dim = sym_help._get_tensor_rank(repeats)
repeats_sizes = sym_help._get_tensor_sizes(repeats)
input_sizes = sym_help._get_tensor_sizes(input)
if repeats_dim is None:
raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
"repeats rank.")
if repeats_sizes is None:
raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
"repeats size.")
if input_sizes is None:
raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
"input size.")
# Handle cases where dim is negative
if dim < 0:
dim += len(input_sizes)
output_sizes = input_sizes.copy()
perm_i = [0]
for idx, input_size in enumerate(input_sizes):
perm_i.append(idx + 1)
if input_size is None:
output_sizes[idx], input_sizes[idx] = 0, -1
perm_i[0], perm_i[dim] = perm_i[dim], perm_i[0]
# Cases when repeats is a single value tensor and dim has unknown input size
if (repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1)) and output_sizes[dim] == 0:
if not sym_help._is_tensor(repeats):
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
reps = sym_help._size_helper(g, input, dim)
reps = unsqueeze(g, reps, 0)
repeats = g.op("Expand", repeats, reps)
# There are cases when the repeats are 1-d tensor with multiple repeats, but dim
# provided along one of the dynamic axes provided. A simple example would be
# input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
# Now, repeat interleaving can be performed in pytorch when the value of * matches
# with the number of elements in repeat, for example if * -> 2, number of repeats
# should be 2 as well.
else:
return torch.onnx.symbolic_opset9.repeat_interleave(g, self, repeats, final_dim)
reps_like = g.op("ConstantOfShape", g.op("Shape", repeats),
value_t=torch.tensor([1], dtype=torch.long))
r_splits = split(g, repeats, reps_like, 0)
i_splits = split(g, input, reps_like, dim)
output_sizes[dim], input_sizes[dim] = -1, 1
# Create a loop to iterate over each value along the dimension
# and perform individual interleaving using the repeats tensor
# Loop is of the following pattern
# input (trip_count, cond)
# int trip_count = ...;
# bool cond = ...;
# for (int i=0; i < trip_count && cond; ++i) {
# cond = ...;
# }
# Loop conditions
loop_condition = g.op("Constant", value_t=torch.tensor(1))
loop_condition = g.op("Cast", loop_condition, to_i=9)
loop_len = reps
loop = g.op("Loop", loop_len, loop_condition)
# Loop inputs
loop_block = _add_block(loop.node())
block_input_iter = _add_input_to_block(loop_block)
cond = _add_input_to_block(loop_block)
r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)
i_split = unsqueeze(loop_block, i_split, dim + 1)
r_concat = [loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[:dim + 1])),
r_split,
loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1:]))]
r_concat = loop_block.op("Concat", *r_concat, axis_i=0)
i_split = expand(loop_block, i_split, r_concat, None)
i_split = sym_help._reshape_helper(loop_block, i_split,
g.op("Constant", value_t=torch.LongTensor(output_sizes)))
# Loop outputs
cond_out = loop_block.op("Cast", loop_condition, to_i=9)
_add_output_to_block(loop_block, cond_out)
_add_output_to_block(loop_block, i_split)
loop_out = loop.node().output()
# In this loop, the outputs are scan outputs and are concatenated along
# the zero'th dimension (by default). In order to avoid this and concatenate
# along the dimension provided, some post-processing is required
loop_out = g.op("Transpose", loop_out, perm_i=perm_i)
return sym_help._reshape_helper(g, loop_out,
g.op("Constant", value_t=torch.LongTensor(output_sizes)))
def normal(g, loc, scale, seed):
# If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a

View File

@ -5,7 +5,9 @@
import torch
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _unimplemented
from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input, nonzero
from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input, nonzero, expand
from torch.onnx.symbolic_opset11 import unsqueeze
from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block
# EDITING THIS FILE? READ THIS FIRST!
@ -196,3 +198,117 @@ def unsafe_chunk(g, self, chunks, dim, _outputs=None):
# user's modules.
splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long))
return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
def repeat_interleave(g, self, repeats, dim=None, output_size=None):
input = self
final_dim = dim
# if dim is None flatten
# By default, use the flattened input array, and return a flat output array
if sym_help._is_none(dim):
input = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1])))
dim = 0
else:
dim = sym_help._maybe_get_scalar(dim)
repeats_dim = sym_help._get_tensor_rank(repeats)
repeats_sizes = sym_help._get_tensor_sizes(repeats)
input_sizes = sym_help._get_tensor_sizes(input)
if repeats_dim is None:
raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
"repeats rank.")
if repeats_sizes is None:
raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
"repeats size.")
if input_sizes is None:
raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
"input size.")
# Handle cases where dim is negative
if dim < 0:
dim += len(input_sizes)
output_sizes = input_sizes.copy()
for idx, input_size in enumerate(input_sizes):
if input_size is None:
output_sizes[idx], input_sizes[idx] = 0, -1
print(output_sizes, input_sizes)
cond_dynamic_repeats = (repeats_dim == 1 and repeats_sizes[0] is None)
# If input size is dynamic or repeats vector is dynamic
if output_sizes[dim] == 0 or cond_dynamic_repeats:
reps = sym_help._size_helper(g, input, dim)
reps = unsqueeze(g, reps, 0)
# Check if repeats vector is a single integer value
# or a single dimension tensor with non-dynamic values
if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1):
if not sym_help._is_tensor(repeats):
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
repeats = g.op("Expand", repeats, reps)
# Check if repeats is dynamic
# As repeats is dynamic, we use a where node as a substitute for the if statement
# If repests_dim = 1, expand repeats otherwise use original tensor
elif cond_dynamic_repeats:
repeat_dim = sym_help._size_helper(g, repeats, g.op("Constant", value_t=torch.LongTensor([0])))
repeat_cond = g.op("Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1])))
repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats)
# There are cases when the repeats are 1-d tensor with multiple repeats, but dim
# provided along one of the dynamic axes provided. A simple example would be
# input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
# Now, repeat interleaving can be performed in pytorch when the value of * matches
# with the number of elements in repeat, for example if * -> 2, number of repeats
# should be 2 as well.
else:
return torch.onnx.symbolic_opset9.repeat_interleave(g, self, repeats, final_dim)
reps_like = g.op("ConstantOfShape", g.op("Shape", repeats),
value_t=torch.tensor([1], dtype=torch.long))
r_splits = split(g, repeats, reps_like, 0)
i_splits = split(g, input, reps_like, dim)
output_sizes[dim], input_sizes[dim] = -1, 1
# Create a loop to iterate over each value along the dimension
# and perform individual interleaving using the repeats tensor
# Loop is of the following pattern
# input (trip_count, cond)
# int trip_count = ...;
# bool cond = ...;
# for (int i=0; i < trip_count && cond; ++i) {
# cond = ...;
# }
# Loop conditions
loop_condition = g.op("Constant", value_t=torch.tensor(1))
loop_condition = g.op("Cast", loop_condition, to_i=9)
loop_len = reps
# Create an empty sequence to store final expansions
final_splits = g.op("SequenceEmpty")
loop = g.op("Loop", loop_len, loop_condition, final_splits)
# Loop inputs
loop_block = _add_block(loop.node())
block_input_iter = _add_input_to_block(loop_block)
cond = _add_input_to_block(loop_block)
final_splits = _add_input_to_block(loop_block)
r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)
i_split = unsqueeze(loop_block, i_split, dim + 1)
r_concat = [loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[:dim + 1])),
r_split,
loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1:]))]
r_concat = loop_block.op("Concat", *r_concat, axis_i=0)
i_split = expand(loop_block, i_split, r_concat, None)
i_split = sym_help._reshape_helper(loop_block, i_split,
g.op("Constant", value_t=torch.LongTensor(output_sizes)))
final_splits = loop_block.op("SequenceInsert", final_splits, i_split)
# Loop outputs
cond_out = loop_block.op("Cast", loop_condition, to_i=9)
_add_output_to_block(loop_block, cond_out)
_add_output_to_block(loop_block, final_splits)
loop_out = loop.node().output()
loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim)
return loop_out

View File

@ -2058,7 +2058,7 @@ def repeat_interleave(g, self, repeats, dim=None, output_size=None):
if not sym_help._is_tensor(repeats):
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
if input_sizes[dim] == 0:
return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 11,
return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 13,
"Unsupported along dimension with unknown input size")
else:
reps = input_sizes[dim]
@ -2067,8 +2067,11 @@ def repeat_interleave(g, self, repeats, dim=None, output_size=None):
# Cases where repeats is a 1 dim Tensor
elif repeats_dim == 1:
if input_sizes[dim] == 0:
return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 11,
return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 13,
"Unsupported along dimension with unknown input size")
if repeats_sizes[0] is None:
return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 13,
"Unsupported for cases with dynamic repeats")
assert repeats_sizes[0] == input_sizes[dim], "repeats must have the same size as input along dim"
reps = repeats_sizes[0]
else: