[ONNX] Add pixel_unshuffle support in opset 9

Current we are unable to utilize ONNX's SpaceToDepth operator due to the lack of the mode_s attribute, hence we add an alternative symbolic in opset 9 to support pixel_unshuffle

- Adds support for pixel_unshuffle in opset9
- Adds support for dynamic input shapes for pixel_shuffle and pixel_unshuffle
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72449
This commit is contained in:
shubhambhokare1 2022-02-19 00:15:16 +00:00 committed by PyTorch MergeBot
parent 0d66748948
commit 671c8a459a
2 changed files with 80 additions and 14 deletions

View File

@ -5903,7 +5903,24 @@ class TestONNXRuntime(unittest.TestCase):
return torch.pixel_shuffle(x, upscale_factor=2)
x = torch.randn(2, 16, 4, 3, requires_grad=True)
y = torch.randn(4, 32, 8, 4, requires_grad=True)
self.run_test(PixelShuffle(), x)
self.run_test(PixelShuffle(), x, input_names=["x"],
dynamic_axes={"x": [0, 1, 2, 3]},
test_with_inputs=[y])
@skipIfUnsupportedMinOpsetVersion(9)
def test_pixel_unshuffle(self):
class PixelUnshuffle(torch.nn.Module):
def forward(self, x):
return torch.pixel_unshuffle(x, downscale_factor=2)
x = torch.randn(2, 16, 4, 6, requires_grad=True)
y = torch.randn(4, 32, 8, 4, requires_grad=True)
self.run_test(PixelUnshuffle(), x)
self.run_test(PixelUnshuffle(), x, input_names=["x"],
dynamic_axes={"x": [0, 1, 2, 3]},
test_with_inputs=[y])
@skipIfUnsupportedMinOpsetVersion(9)
def test_reciprocal(self):

View File

@ -2157,20 +2157,69 @@ def pixel_shuffle(g, self, upscale_factor):
dims = sym_help._get_tensor_sizes(self)
if len(dims) != 4:
return _unimplemented("pixel_shuffle", "only support 4d input")
if any([i is None for i in dims[1:]]):
return _unimplemented("pixel_shuffle", "only support static input shape, except for batch size")
output_channel = dims[1] // upscale_factor // upscale_factor
after_view = sym_help._reshape_helper(g, self,
g.op("Constant", value_t=torch.tensor([-1, output_channel,
upscale_factor, upscale_factor,
dims[2], dims[3]])),
allowzero=0)
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
return sym_help._reshape_helper(g, after_transpose,
g.op("Constant", value_t=torch.tensor([-1, output_channel,
dims[2] * upscale_factor,
dims[3] * upscale_factor])),
allowzero=0)
if any(i is None for i in dims[1:]):
after_view = sym_help._reshape_helper(g, sym_help._unsqueeze_helper(g, self, [2, 3]),
g.op("Constant", value_t=torch.tensor([0, -1,
upscale_factor, upscale_factor,
0, 0])),
allowzero=0)
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
# For dynamic input shapes, two reshapes are performed
reshape_h = sym_help._reshape_helper(g, after_transpose,
g.op("Constant", value_t=torch.tensor([0, 0, -1, 1, 0, 0])),
allowzero=0)
reshape_w = sym_help._reshape_helper(g, reshape_h,
g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, 1])),
allowzero=0)
return sym_help._squeeze_helper(g, reshape_w, [3, 5])
else:
output_channel = dims[1] // upscale_factor // upscale_factor
after_view = sym_help._reshape_helper(g, self,
g.op("Constant", value_t=torch.tensor([-1, output_channel,
upscale_factor, upscale_factor,
dims[2], dims[3]])),
allowzero=0)
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
return sym_help._reshape_helper(g, after_transpose,
g.op("Constant", value_t=torch.tensor([-1, output_channel,
dims[2] * upscale_factor,
dims[3] * upscale_factor])),
allowzero=0)
@parse_args("v", "i")
def pixel_unshuffle(g, self, downscale_factor):
dims = sym_help._get_tensor_sizes(self)
if len(dims) != 4:
return _unimplemented("pixel_shuffle", "only support 4d input")
if any(i is None for i in dims[1:]):
# For dynamic input shapes, two reshapes are performed
reshape_h = sym_help._reshape_helper(g, sym_help._unsqueeze_helper(g, self, [3]),
g.op("Constant", value_t=torch.tensor([0, 0, -1, downscale_factor, 0])),
allowzero=0)
reshape_w = sym_help._reshape_helper(g, reshape_h,
g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, downscale_factor])),
allowzero=0)
after_transpose = g.op("Transpose", reshape_w, perm_i=[0, 1, 3, 5, 2, 4])
final_reshape = sym_help._reshape_helper(g, after_transpose,
g.op("Constant", value_t=torch.tensor([0, -1, 1, 1, 0, 0])),
allowzero=0)
return sym_help._squeeze_helper(g, final_reshape, [2, 3])
else:
output_channel = dims[1] * downscale_factor * downscale_factor
after_view = sym_help._reshape_helper(g, self,
g.op("Constant", value_t=torch.tensor([-1, dims[1],
dims[2] // downscale_factor,
downscale_factor,
dims[3] // downscale_factor,
downscale_factor])),
allowzero=0)
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 3, 5, 2, 4])
return sym_help._reshape_helper(g, after_transpose,
g.op("Constant", value_t=torch.tensor([-1, output_channel,
dims[2] // downscale_factor,
dims[3] // downscale_factor])),
allowzero=0)
def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases,