mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62764 Fixes #58733 - Support dynamic interleave for cases with dynamic repeat values - Moved repeat_interleave symbolic from opset 11 to opset 13, as sequence as output types for loop outputs is needed for this change Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D30375179 Pulled By: msaroufim fbshipit-source-id: 787f96bf91d124fd0483761088c5f4ae930d96a9 Co-authored-by: Shubham Bhokare <shubhambhokare@gmail.com>
This commit is contained in:
parent
8760254911
commit
db0771b05d
|
|
@ -4323,7 +4323,7 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
x = torch.tensor([[1, 2], [3, 4]])
|
||||
self.run_test(RepeatsDimsModel2(), (x,))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
@skipIfUnsupportedMinOpsetVersion(13)
|
||||
def test_dynamic_repeat_interleave(self):
|
||||
class SingleDynamicModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
@ -4345,25 +4345,62 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
self.run_test(NegDynamicModel(), x, test_with_inputs=[another_x],
|
||||
input_names=["input_1"], dynamic_axes={"input_1" : {1 : "w"}})
|
||||
|
||||
class SingleDynamicModel2(torch.nn.Module):
|
||||
class SingleDynamicModelFloat(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
repeats = torch.tensor([4])
|
||||
return torch.repeat_interleave(x, repeats, dim=0)
|
||||
|
||||
x = torch.tensor([[1, 2], [3, 4]])
|
||||
another_x = torch.tensor([[7, 8], [5, 6]])
|
||||
self.run_test(SingleDynamicModel2(), x, test_with_inputs=[another_x],
|
||||
x = torch.tensor([[1.1, 2.1], [3.1, 4.1]])
|
||||
another_x = torch.tensor([[7.1, 8.1], [5.1, 6.1]])
|
||||
self.run_test(SingleDynamicModelFloat(), x, test_with_inputs=[another_x],
|
||||
input_names=["input_1"], dynamic_axes={"input_1" : {0 : "h"}})
|
||||
|
||||
class AllDynamicModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
repeats = torch.tensor([4])
|
||||
class DynamicRepeatsModel(torch.nn.Module):
|
||||
def forward(self, x, repeats):
|
||||
return torch.repeat_interleave(x, repeats, dim=1)
|
||||
|
||||
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
|
||||
another_x = torch.tensor([[7, 8], [5, 6]])
|
||||
repeats = torch.tensor([2])
|
||||
another_repeats = torch.tensor([4])
|
||||
self.run_test(DynamicRepeatsModel(), (x, repeats), test_with_inputs=[(another_x, another_repeats)],
|
||||
input_names=["input_1", "repeats_1"],
|
||||
dynamic_axes={"input_1" : {1 : "w"}, "repeats_1" : {0 : "r"}})
|
||||
|
||||
class DynamicRepeatsModel2(torch.nn.Module):
|
||||
def forward(self, x, repeats):
|
||||
return torch.repeat_interleave(x, repeats, dim=1)
|
||||
|
||||
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
|
||||
repeats = torch.tensor([2])
|
||||
another_repeats = torch.tensor([4])
|
||||
self.run_test(DynamicRepeatsModel2(), (x, repeats), test_with_inputs=[(x, another_repeats)],
|
||||
input_names=["input_1", "repeats_1"],
|
||||
dynamic_axes={"repeats_1" : {0 : "r"}})
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(13)
|
||||
def test_multiple_dynamic_repeat_interleave(self):
|
||||
class DynamicRepeatsModel(torch.nn.Module):
|
||||
def forward(self, x, repeats):
|
||||
return torch.repeat_interleave(x, repeats, dim=1)
|
||||
|
||||
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
|
||||
repeats = torch.tensor([2, 3, 4])
|
||||
another_repeats = torch.tensor([4, 3, 2])
|
||||
self.run_test(DynamicRepeatsModel(), (x, repeats), test_with_inputs=[(x, another_repeats)],
|
||||
input_names=["input_1", "repeats_1"],
|
||||
dynamic_axes={"repeats_1" : {0 : "r"}})
|
||||
|
||||
class DynamicRepeatsModel2(torch.nn.Module):
|
||||
def forward(self, x, repeats):
|
||||
return torch.repeat_interleave(x, repeats, dim=0)
|
||||
|
||||
x = torch.tensor([[1, 2, 4, 16], [3, 9, 27, 81], [2, 3, 5, 7]])
|
||||
another_x = torch.tensor([[7, 8], [5, 6]])
|
||||
self.run_test(AllDynamicModel(), x, test_with_inputs=[another_x],
|
||||
input_names=["input_1"], dynamic_axes={"input_1" : {0 : "h", 1 : "w"}})
|
||||
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
|
||||
repeats = torch.tensor([2, 3])
|
||||
another_repeats = torch.tensor([4, 3])
|
||||
self.run_test(DynamicRepeatsModel2(), (x, repeats), test_with_inputs=[(x, another_repeats)],
|
||||
input_names=["input_1", "repeats_1"],
|
||||
dynamic_axes={"repeats_1" : {0 : "r"}})
|
||||
|
||||
def test_view(self):
|
||||
class ViewModel(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -894,110 +894,6 @@ def chunk(g, self, chunks, dim):
|
|||
chunk_vec = g.op("Concat", *chunk_vec, axis_i=0)
|
||||
return split(g, self, chunk_vec, dim)
|
||||
|
||||
def repeat_interleave(g, self, repeats, dim=None, output_size=None):
|
||||
input = self
|
||||
final_dim = dim
|
||||
# if dim is None flatten
|
||||
# By default, use the flattened input array, and return a flat output array
|
||||
if sym_help._is_none(dim):
|
||||
input = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1])))
|
||||
dim = 0
|
||||
else:
|
||||
dim = sym_help._maybe_get_scalar(dim)
|
||||
|
||||
repeats_dim = sym_help._get_tensor_rank(repeats)
|
||||
repeats_sizes = sym_help._get_tensor_sizes(repeats)
|
||||
input_sizes = sym_help._get_tensor_sizes(input)
|
||||
if repeats_dim is None:
|
||||
raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
|
||||
"repeats rank.")
|
||||
if repeats_sizes is None:
|
||||
raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
|
||||
"repeats size.")
|
||||
if input_sizes is None:
|
||||
raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
|
||||
"input size.")
|
||||
# Handle cases where dim is negative
|
||||
if dim < 0:
|
||||
dim += len(input_sizes)
|
||||
|
||||
output_sizes = input_sizes.copy()
|
||||
perm_i = [0]
|
||||
for idx, input_size in enumerate(input_sizes):
|
||||
perm_i.append(idx + 1)
|
||||
if input_size is None:
|
||||
output_sizes[idx], input_sizes[idx] = 0, -1
|
||||
perm_i[0], perm_i[dim] = perm_i[dim], perm_i[0]
|
||||
|
||||
# Cases when repeats is a single value tensor and dim has unknown input size
|
||||
if (repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1)) and output_sizes[dim] == 0:
|
||||
if not sym_help._is_tensor(repeats):
|
||||
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
|
||||
reps = sym_help._size_helper(g, input, dim)
|
||||
reps = unsqueeze(g, reps, 0)
|
||||
repeats = g.op("Expand", repeats, reps)
|
||||
# There are cases when the repeats are 1-d tensor with multiple repeats, but dim
|
||||
# provided along one of the dynamic axes provided. A simple example would be
|
||||
# input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
|
||||
# Now, repeat interleaving can be performed in pytorch when the value of * matches
|
||||
# with the number of elements in repeat, for example if * -> 2, number of repeats
|
||||
# should be 2 as well.
|
||||
else:
|
||||
return torch.onnx.symbolic_opset9.repeat_interleave(g, self, repeats, final_dim)
|
||||
|
||||
reps_like = g.op("ConstantOfShape", g.op("Shape", repeats),
|
||||
value_t=torch.tensor([1], dtype=torch.long))
|
||||
r_splits = split(g, repeats, reps_like, 0)
|
||||
i_splits = split(g, input, reps_like, dim)
|
||||
|
||||
output_sizes[dim], input_sizes[dim] = -1, 1
|
||||
|
||||
# Create a loop to iterate over each value along the dimension
|
||||
# and perform individual interleaving using the repeats tensor
|
||||
# Loop is of the following pattern
|
||||
# input (trip_count, cond)
|
||||
# int trip_count = ...;
|
||||
# bool cond = ...;
|
||||
# for (int i=0; i < trip_count && cond; ++i) {
|
||||
# cond = ...;
|
||||
# }
|
||||
|
||||
# Loop conditions
|
||||
loop_condition = g.op("Constant", value_t=torch.tensor(1))
|
||||
loop_condition = g.op("Cast", loop_condition, to_i=9)
|
||||
loop_len = reps
|
||||
loop = g.op("Loop", loop_len, loop_condition)
|
||||
|
||||
# Loop inputs
|
||||
loop_block = _add_block(loop.node())
|
||||
block_input_iter = _add_input_to_block(loop_block)
|
||||
cond = _add_input_to_block(loop_block)
|
||||
|
||||
r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
|
||||
i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)
|
||||
|
||||
i_split = unsqueeze(loop_block, i_split, dim + 1)
|
||||
r_concat = [loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[:dim + 1])),
|
||||
r_split,
|
||||
loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1:]))]
|
||||
r_concat = loop_block.op("Concat", *r_concat, axis_i=0)
|
||||
i_split = expand(loop_block, i_split, r_concat, None)
|
||||
i_split = sym_help._reshape_helper(loop_block, i_split,
|
||||
g.op("Constant", value_t=torch.LongTensor(output_sizes)))
|
||||
|
||||
# Loop outputs
|
||||
cond_out = loop_block.op("Cast", loop_condition, to_i=9)
|
||||
_add_output_to_block(loop_block, cond_out)
|
||||
_add_output_to_block(loop_block, i_split)
|
||||
loop_out = loop.node().output()
|
||||
|
||||
# In this loop, the outputs are scan outputs and are concatenated along
|
||||
# the zero'th dimension (by default). In order to avoid this and concatenate
|
||||
# along the dimension provided, some post-processing is required
|
||||
loop_out = g.op("Transpose", loop_out, perm_i=perm_i)
|
||||
return sym_help._reshape_helper(g, loop_out,
|
||||
g.op("Constant", value_t=torch.LongTensor(output_sizes)))
|
||||
|
||||
|
||||
def normal(g, loc, scale, seed):
|
||||
# If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@
|
|||
import torch
|
||||
import torch.onnx.symbolic_helper as sym_help
|
||||
from torch.onnx.symbolic_helper import parse_args, _unimplemented
|
||||
from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input, nonzero
|
||||
from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input, nonzero, expand
|
||||
from torch.onnx.symbolic_opset11 import unsqueeze
|
||||
from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
|
|
@ -196,3 +198,117 @@ def unsafe_chunk(g, self, chunks, dim, _outputs=None):
|
|||
# user's modules.
|
||||
splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long))
|
||||
return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
|
||||
|
||||
def repeat_interleave(g, self, repeats, dim=None, output_size=None):
|
||||
input = self
|
||||
final_dim = dim
|
||||
# if dim is None flatten
|
||||
# By default, use the flattened input array, and return a flat output array
|
||||
if sym_help._is_none(dim):
|
||||
input = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1])))
|
||||
dim = 0
|
||||
else:
|
||||
dim = sym_help._maybe_get_scalar(dim)
|
||||
|
||||
repeats_dim = sym_help._get_tensor_rank(repeats)
|
||||
repeats_sizes = sym_help._get_tensor_sizes(repeats)
|
||||
input_sizes = sym_help._get_tensor_sizes(input)
|
||||
if repeats_dim is None:
|
||||
raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
|
||||
"repeats rank.")
|
||||
if repeats_sizes is None:
|
||||
raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
|
||||
"repeats size.")
|
||||
if input_sizes is None:
|
||||
raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
|
||||
"input size.")
|
||||
# Handle cases where dim is negative
|
||||
if dim < 0:
|
||||
dim += len(input_sizes)
|
||||
|
||||
output_sizes = input_sizes.copy()
|
||||
for idx, input_size in enumerate(input_sizes):
|
||||
if input_size is None:
|
||||
output_sizes[idx], input_sizes[idx] = 0, -1
|
||||
print(output_sizes, input_sizes)
|
||||
|
||||
cond_dynamic_repeats = (repeats_dim == 1 and repeats_sizes[0] is None)
|
||||
# If input size is dynamic or repeats vector is dynamic
|
||||
if output_sizes[dim] == 0 or cond_dynamic_repeats:
|
||||
reps = sym_help._size_helper(g, input, dim)
|
||||
reps = unsqueeze(g, reps, 0)
|
||||
# Check if repeats vector is a single integer value
|
||||
# or a single dimension tensor with non-dynamic values
|
||||
if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1):
|
||||
if not sym_help._is_tensor(repeats):
|
||||
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
|
||||
repeats = g.op("Expand", repeats, reps)
|
||||
# Check if repeats is dynamic
|
||||
# As repeats is dynamic, we use a where node as a substitute for the if statement
|
||||
# If repests_dim = 1, expand repeats otherwise use original tensor
|
||||
elif cond_dynamic_repeats:
|
||||
repeat_dim = sym_help._size_helper(g, repeats, g.op("Constant", value_t=torch.LongTensor([0])))
|
||||
repeat_cond = g.op("Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1])))
|
||||
repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats)
|
||||
# There are cases when the repeats are 1-d tensor with multiple repeats, but dim
|
||||
# provided along one of the dynamic axes provided. A simple example would be
|
||||
# input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
|
||||
# Now, repeat interleaving can be performed in pytorch when the value of * matches
|
||||
# with the number of elements in repeat, for example if * -> 2, number of repeats
|
||||
# should be 2 as well.
|
||||
else:
|
||||
return torch.onnx.symbolic_opset9.repeat_interleave(g, self, repeats, final_dim)
|
||||
|
||||
reps_like = g.op("ConstantOfShape", g.op("Shape", repeats),
|
||||
value_t=torch.tensor([1], dtype=torch.long))
|
||||
r_splits = split(g, repeats, reps_like, 0)
|
||||
i_splits = split(g, input, reps_like, dim)
|
||||
|
||||
output_sizes[dim], input_sizes[dim] = -1, 1
|
||||
|
||||
# Create a loop to iterate over each value along the dimension
|
||||
# and perform individual interleaving using the repeats tensor
|
||||
# Loop is of the following pattern
|
||||
# input (trip_count, cond)
|
||||
# int trip_count = ...;
|
||||
# bool cond = ...;
|
||||
# for (int i=0; i < trip_count && cond; ++i) {
|
||||
# cond = ...;
|
||||
# }
|
||||
|
||||
# Loop conditions
|
||||
loop_condition = g.op("Constant", value_t=torch.tensor(1))
|
||||
loop_condition = g.op("Cast", loop_condition, to_i=9)
|
||||
loop_len = reps
|
||||
|
||||
# Create an empty sequence to store final expansions
|
||||
final_splits = g.op("SequenceEmpty")
|
||||
loop = g.op("Loop", loop_len, loop_condition, final_splits)
|
||||
|
||||
# Loop inputs
|
||||
loop_block = _add_block(loop.node())
|
||||
block_input_iter = _add_input_to_block(loop_block)
|
||||
cond = _add_input_to_block(loop_block)
|
||||
final_splits = _add_input_to_block(loop_block)
|
||||
|
||||
r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
|
||||
i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)
|
||||
|
||||
i_split = unsqueeze(loop_block, i_split, dim + 1)
|
||||
r_concat = [loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[:dim + 1])),
|
||||
r_split,
|
||||
loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1:]))]
|
||||
r_concat = loop_block.op("Concat", *r_concat, axis_i=0)
|
||||
i_split = expand(loop_block, i_split, r_concat, None)
|
||||
i_split = sym_help._reshape_helper(loop_block, i_split,
|
||||
g.op("Constant", value_t=torch.LongTensor(output_sizes)))
|
||||
final_splits = loop_block.op("SequenceInsert", final_splits, i_split)
|
||||
|
||||
# Loop outputs
|
||||
cond_out = loop_block.op("Cast", loop_condition, to_i=9)
|
||||
_add_output_to_block(loop_block, cond_out)
|
||||
_add_output_to_block(loop_block, final_splits)
|
||||
|
||||
loop_out = loop.node().output()
|
||||
loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim)
|
||||
return loop_out
|
||||
|
|
|
|||
|
|
@ -2058,7 +2058,7 @@ def repeat_interleave(g, self, repeats, dim=None, output_size=None):
|
|||
if not sym_help._is_tensor(repeats):
|
||||
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
|
||||
if input_sizes[dim] == 0:
|
||||
return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 11,
|
||||
return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 13,
|
||||
"Unsupported along dimension with unknown input size")
|
||||
else:
|
||||
reps = input_sizes[dim]
|
||||
|
|
@ -2067,8 +2067,11 @@ def repeat_interleave(g, self, repeats, dim=None, output_size=None):
|
|||
# Cases where repeats is a 1 dim Tensor
|
||||
elif repeats_dim == 1:
|
||||
if input_sizes[dim] == 0:
|
||||
return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 11,
|
||||
return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 13,
|
||||
"Unsupported along dimension with unknown input size")
|
||||
if repeats_sizes[0] is None:
|
||||
return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 13,
|
||||
"Unsupported for cases with dynamic repeats")
|
||||
assert repeats_sizes[0] == input_sizes[dim], "repeats must have the same size as input along dim"
|
||||
reps = repeats_sizes[0]
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user