mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Add pixel_unshuffle support in opset 9
Current we are unable to utilize ONNX's SpaceToDepth operator due to the lack of the mode_s attribute, hence we add an alternative symbolic in opset 9 to support pixel_unshuffle - Adds support for pixel_unshuffle in opset9 - Adds support for dynamic input shapes for pixel_shuffle and pixel_unshuffle Pull Request resolved: https://github.com/pytorch/pytorch/pull/72449
This commit is contained in:
parent
0d66748948
commit
671c8a459a
|
|
@ -5903,7 +5903,24 @@ class TestONNXRuntime(unittest.TestCase):
|
||||||
return torch.pixel_shuffle(x, upscale_factor=2)
|
return torch.pixel_shuffle(x, upscale_factor=2)
|
||||||
|
|
||||||
x = torch.randn(2, 16, 4, 3, requires_grad=True)
|
x = torch.randn(2, 16, 4, 3, requires_grad=True)
|
||||||
|
y = torch.randn(4, 32, 8, 4, requires_grad=True)
|
||||||
self.run_test(PixelShuffle(), x)
|
self.run_test(PixelShuffle(), x)
|
||||||
|
self.run_test(PixelShuffle(), x, input_names=["x"],
|
||||||
|
dynamic_axes={"x": [0, 1, 2, 3]},
|
||||||
|
test_with_inputs=[y])
|
||||||
|
|
||||||
|
@skipIfUnsupportedMinOpsetVersion(9)
|
||||||
|
def test_pixel_unshuffle(self):
|
||||||
|
class PixelUnshuffle(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.pixel_unshuffle(x, downscale_factor=2)
|
||||||
|
|
||||||
|
x = torch.randn(2, 16, 4, 6, requires_grad=True)
|
||||||
|
y = torch.randn(4, 32, 8, 4, requires_grad=True)
|
||||||
|
self.run_test(PixelUnshuffle(), x)
|
||||||
|
self.run_test(PixelUnshuffle(), x, input_names=["x"],
|
||||||
|
dynamic_axes={"x": [0, 1, 2, 3]},
|
||||||
|
test_with_inputs=[y])
|
||||||
|
|
||||||
@skipIfUnsupportedMinOpsetVersion(9)
|
@skipIfUnsupportedMinOpsetVersion(9)
|
||||||
def test_reciprocal(self):
|
def test_reciprocal(self):
|
||||||
|
|
|
||||||
|
|
@ -2157,20 +2157,69 @@ def pixel_shuffle(g, self, upscale_factor):
|
||||||
dims = sym_help._get_tensor_sizes(self)
|
dims = sym_help._get_tensor_sizes(self)
|
||||||
if len(dims) != 4:
|
if len(dims) != 4:
|
||||||
return _unimplemented("pixel_shuffle", "only support 4d input")
|
return _unimplemented("pixel_shuffle", "only support 4d input")
|
||||||
if any([i is None for i in dims[1:]]):
|
if any(i is None for i in dims[1:]):
|
||||||
return _unimplemented("pixel_shuffle", "only support static input shape, except for batch size")
|
after_view = sym_help._reshape_helper(g, sym_help._unsqueeze_helper(g, self, [2, 3]),
|
||||||
output_channel = dims[1] // upscale_factor // upscale_factor
|
g.op("Constant", value_t=torch.tensor([0, -1,
|
||||||
after_view = sym_help._reshape_helper(g, self,
|
upscale_factor, upscale_factor,
|
||||||
g.op("Constant", value_t=torch.tensor([-1, output_channel,
|
0, 0])),
|
||||||
upscale_factor, upscale_factor,
|
allowzero=0)
|
||||||
dims[2], dims[3]])),
|
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
|
||||||
allowzero=0)
|
# For dynamic input shapes, two reshapes are performed
|
||||||
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
|
reshape_h = sym_help._reshape_helper(g, after_transpose,
|
||||||
return sym_help._reshape_helper(g, after_transpose,
|
g.op("Constant", value_t=torch.tensor([0, 0, -1, 1, 0, 0])),
|
||||||
g.op("Constant", value_t=torch.tensor([-1, output_channel,
|
allowzero=0)
|
||||||
dims[2] * upscale_factor,
|
reshape_w = sym_help._reshape_helper(g, reshape_h,
|
||||||
dims[3] * upscale_factor])),
|
g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, 1])),
|
||||||
allowzero=0)
|
allowzero=0)
|
||||||
|
return sym_help._squeeze_helper(g, reshape_w, [3, 5])
|
||||||
|
else:
|
||||||
|
output_channel = dims[1] // upscale_factor // upscale_factor
|
||||||
|
after_view = sym_help._reshape_helper(g, self,
|
||||||
|
g.op("Constant", value_t=torch.tensor([-1, output_channel,
|
||||||
|
upscale_factor, upscale_factor,
|
||||||
|
dims[2], dims[3]])),
|
||||||
|
allowzero=0)
|
||||||
|
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
|
||||||
|
return sym_help._reshape_helper(g, after_transpose,
|
||||||
|
g.op("Constant", value_t=torch.tensor([-1, output_channel,
|
||||||
|
dims[2] * upscale_factor,
|
||||||
|
dims[3] * upscale_factor])),
|
||||||
|
allowzero=0)
|
||||||
|
|
||||||
|
|
||||||
|
@parse_args("v", "i")
|
||||||
|
def pixel_unshuffle(g, self, downscale_factor):
|
||||||
|
dims = sym_help._get_tensor_sizes(self)
|
||||||
|
if len(dims) != 4:
|
||||||
|
return _unimplemented("pixel_shuffle", "only support 4d input")
|
||||||
|
if any(i is None for i in dims[1:]):
|
||||||
|
# For dynamic input shapes, two reshapes are performed
|
||||||
|
reshape_h = sym_help._reshape_helper(g, sym_help._unsqueeze_helper(g, self, [3]),
|
||||||
|
g.op("Constant", value_t=torch.tensor([0, 0, -1, downscale_factor, 0])),
|
||||||
|
allowzero=0)
|
||||||
|
reshape_w = sym_help._reshape_helper(g, reshape_h,
|
||||||
|
g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, downscale_factor])),
|
||||||
|
allowzero=0)
|
||||||
|
after_transpose = g.op("Transpose", reshape_w, perm_i=[0, 1, 3, 5, 2, 4])
|
||||||
|
final_reshape = sym_help._reshape_helper(g, after_transpose,
|
||||||
|
g.op("Constant", value_t=torch.tensor([0, -1, 1, 1, 0, 0])),
|
||||||
|
allowzero=0)
|
||||||
|
return sym_help._squeeze_helper(g, final_reshape, [2, 3])
|
||||||
|
else:
|
||||||
|
output_channel = dims[1] * downscale_factor * downscale_factor
|
||||||
|
after_view = sym_help._reshape_helper(g, self,
|
||||||
|
g.op("Constant", value_t=torch.tensor([-1, dims[1],
|
||||||
|
dims[2] // downscale_factor,
|
||||||
|
downscale_factor,
|
||||||
|
dims[3] // downscale_factor,
|
||||||
|
downscale_factor])),
|
||||||
|
allowzero=0)
|
||||||
|
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 3, 5, 2, 4])
|
||||||
|
return sym_help._reshape_helper(g, after_transpose,
|
||||||
|
g.op("Constant", value_t=torch.tensor([-1, output_channel,
|
||||||
|
dims[2] // downscale_factor,
|
||||||
|
dims[3] // downscale_factor])),
|
||||||
|
allowzero=0)
|
||||||
|
|
||||||
|
|
||||||
def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases,
|
def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user