mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Add pixel_unshuffle support in opset 9
Current we are unable to utilize ONNX's SpaceToDepth operator due to the lack of the mode_s attribute, hence we add an alternative symbolic in opset 9 to support pixel_unshuffle - Adds support for pixel_unshuffle in opset9 - Adds support for dynamic input shapes for pixel_shuffle and pixel_unshuffle Pull Request resolved: https://github.com/pytorch/pytorch/pull/72449
This commit is contained in:
parent
0d66748948
commit
671c8a459a
|
|
@ -5903,7 +5903,24 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
return torch.pixel_shuffle(x, upscale_factor=2)
|
||||
|
||||
x = torch.randn(2, 16, 4, 3, requires_grad=True)
|
||||
y = torch.randn(4, 32, 8, 4, requires_grad=True)
|
||||
self.run_test(PixelShuffle(), x)
|
||||
self.run_test(PixelShuffle(), x, input_names=["x"],
|
||||
dynamic_axes={"x": [0, 1, 2, 3]},
|
||||
test_with_inputs=[y])
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_pixel_unshuffle(self):
|
||||
class PixelUnshuffle(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.pixel_unshuffle(x, downscale_factor=2)
|
||||
|
||||
x = torch.randn(2, 16, 4, 6, requires_grad=True)
|
||||
y = torch.randn(4, 32, 8, 4, requires_grad=True)
|
||||
self.run_test(PixelUnshuffle(), x)
|
||||
self.run_test(PixelUnshuffle(), x, input_names=["x"],
|
||||
dynamic_axes={"x": [0, 1, 2, 3]},
|
||||
test_with_inputs=[y])
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_reciprocal(self):
|
||||
|
|
|
|||
|
|
@ -2157,20 +2157,69 @@ def pixel_shuffle(g, self, upscale_factor):
|
|||
dims = sym_help._get_tensor_sizes(self)
|
||||
if len(dims) != 4:
|
||||
return _unimplemented("pixel_shuffle", "only support 4d input")
|
||||
if any([i is None for i in dims[1:]]):
|
||||
return _unimplemented("pixel_shuffle", "only support static input shape, except for batch size")
|
||||
output_channel = dims[1] // upscale_factor // upscale_factor
|
||||
after_view = sym_help._reshape_helper(g, self,
|
||||
g.op("Constant", value_t=torch.tensor([-1, output_channel,
|
||||
upscale_factor, upscale_factor,
|
||||
dims[2], dims[3]])),
|
||||
allowzero=0)
|
||||
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
|
||||
return sym_help._reshape_helper(g, after_transpose,
|
||||
g.op("Constant", value_t=torch.tensor([-1, output_channel,
|
||||
dims[2] * upscale_factor,
|
||||
dims[3] * upscale_factor])),
|
||||
allowzero=0)
|
||||
if any(i is None for i in dims[1:]):
|
||||
after_view = sym_help._reshape_helper(g, sym_help._unsqueeze_helper(g, self, [2, 3]),
|
||||
g.op("Constant", value_t=torch.tensor([0, -1,
|
||||
upscale_factor, upscale_factor,
|
||||
0, 0])),
|
||||
allowzero=0)
|
||||
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
|
||||
# For dynamic input shapes, two reshapes are performed
|
||||
reshape_h = sym_help._reshape_helper(g, after_transpose,
|
||||
g.op("Constant", value_t=torch.tensor([0, 0, -1, 1, 0, 0])),
|
||||
allowzero=0)
|
||||
reshape_w = sym_help._reshape_helper(g, reshape_h,
|
||||
g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, 1])),
|
||||
allowzero=0)
|
||||
return sym_help._squeeze_helper(g, reshape_w, [3, 5])
|
||||
else:
|
||||
output_channel = dims[1] // upscale_factor // upscale_factor
|
||||
after_view = sym_help._reshape_helper(g, self,
|
||||
g.op("Constant", value_t=torch.tensor([-1, output_channel,
|
||||
upscale_factor, upscale_factor,
|
||||
dims[2], dims[3]])),
|
||||
allowzero=0)
|
||||
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
|
||||
return sym_help._reshape_helper(g, after_transpose,
|
||||
g.op("Constant", value_t=torch.tensor([-1, output_channel,
|
||||
dims[2] * upscale_factor,
|
||||
dims[3] * upscale_factor])),
|
||||
allowzero=0)
|
||||
|
||||
|
||||
@parse_args("v", "i")
|
||||
def pixel_unshuffle(g, self, downscale_factor):
|
||||
dims = sym_help._get_tensor_sizes(self)
|
||||
if len(dims) != 4:
|
||||
return _unimplemented("pixel_shuffle", "only support 4d input")
|
||||
if any(i is None for i in dims[1:]):
|
||||
# For dynamic input shapes, two reshapes are performed
|
||||
reshape_h = sym_help._reshape_helper(g, sym_help._unsqueeze_helper(g, self, [3]),
|
||||
g.op("Constant", value_t=torch.tensor([0, 0, -1, downscale_factor, 0])),
|
||||
allowzero=0)
|
||||
reshape_w = sym_help._reshape_helper(g, reshape_h,
|
||||
g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, downscale_factor])),
|
||||
allowzero=0)
|
||||
after_transpose = g.op("Transpose", reshape_w, perm_i=[0, 1, 3, 5, 2, 4])
|
||||
final_reshape = sym_help._reshape_helper(g, after_transpose,
|
||||
g.op("Constant", value_t=torch.tensor([0, -1, 1, 1, 0, 0])),
|
||||
allowzero=0)
|
||||
return sym_help._squeeze_helper(g, final_reshape, [2, 3])
|
||||
else:
|
||||
output_channel = dims[1] * downscale_factor * downscale_factor
|
||||
after_view = sym_help._reshape_helper(g, self,
|
||||
g.op("Constant", value_t=torch.tensor([-1, dims[1],
|
||||
dims[2] // downscale_factor,
|
||||
downscale_factor,
|
||||
dims[3] // downscale_factor,
|
||||
downscale_factor])),
|
||||
allowzero=0)
|
||||
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 3, 5, 2, 4])
|
||||
return sym_help._reshape_helper(g, after_transpose,
|
||||
g.op("Constant", value_t=torch.tensor([-1, output_channel,
|
||||
dims[2] // downscale_factor,
|
||||
dims[3] // downscale_factor])),
|
||||
allowzero=0)
|
||||
|
||||
|
||||
def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user