Fix to support no-batch-dim inputs in ConvTransposeNd._output_padding

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76151

Approved by: https://github.com/albanD
This commit is contained in:
Joel Benjamin Schlosser 2022-04-22 10:23:18 -04:00 committed by PyTorch MergeBot
parent ea8a0184b7
commit 041e6e750a
2 changed files with 32 additions and 13 deletions

View File

@ -13833,6 +13833,17 @@ class TestNNDeviceType(NNTestCase):
gradcheck(lambda x, y: F.conv3d(x, y, padding='valid'), (x, y), check_forward_ad=check_forward_ad)
gradgradcheck(lambda x, y: F.conv3d(x, y, padding='valid'), (x, y), check_fwd_over_rev=check_forward_ad)
@parametrize_test("N", range(2, 4), name_fn=lambda N: 'ConvTranspose{}d'.format(N))
def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N):
# For inputs with no batch dim, verify output is the correct shape when output_size is set.
# See https://github.com/pytorch/pytorch/issues/75889
inp = torch.randn((1, 15, 13) if N == 2 else (1, 15, 13, 13), device=device)
output_size = (1, 240, 200) if N == 2 else (1, 240, 200, 200)
ConvTransposeNd = getattr(nn, 'ConvTranspose{}d'.format(N))
m = ConvTransposeNd(1, 1, kernel_size=16, stride=16, padding=7, bias=False, device=device)
output = m(inp, output_size=output_size)
self.assertEqual(output.shape, output_size)
@skipMeta
@parametrize_test("input_shape,transposed,dilated,groups,layout,backend_expected", [
# === slow ===

View File

@ -610,22 +610,24 @@ class _ConvTransposeNd(_ConvNd):
# compatibility
def _output_padding(self, input: Tensor, output_size: Optional[List[int]],
stride: List[int], padding: List[int], kernel_size: List[int],
dilation: Optional[List[int]] = None) -> List[int]:
num_spatial_dims: int, dilation: Optional[List[int]] = None) -> List[int]:
if output_size is None:
ret = _single(self.output_padding) # converting to list if was not already
else:
k = input.dim() - 2
if len(output_size) == k + 2:
output_size = output_size[2:]
if len(output_size) != k:
has_batch_dim = input.dim() == num_spatial_dims + 2
num_non_spatial_dims = 2 if has_batch_dim else 1
if len(output_size) == num_non_spatial_dims + num_spatial_dims:
output_size = output_size[num_non_spatial_dims:]
if len(output_size) != num_spatial_dims:
raise ValueError(
"output_size must have {} or {} elements (got {})"
.format(k, k + 2, len(output_size)))
"ConvTranspose{}D: for {}D input, output_size must have {} or {} elements (got {})"
.format(num_spatial_dims, input.dim(), num_spatial_dims,
num_non_spatial_dims + num_spatial_dims, len(output_size)))
min_sizes = torch.jit.annotate(List[int], [])
max_sizes = torch.jit.annotate(List[int], [])
for d in range(k):
dim_size = ((input.size(d + 2) - 1) * stride[d] -
for d in range(num_spatial_dims):
dim_size = ((input.size(d + num_non_spatial_dims) - 1) * stride[d] -
2 * padding[d] +
(dilation[d] if dilation is not None else 1) * (kernel_size[d] - 1) + 1)
min_sizes.append(dim_size)
@ -642,7 +644,7 @@ class _ConvTransposeNd(_ConvNd):
output_size, min_sizes, max_sizes, input.size()[2:]))
res = torch.jit.annotate(List[int], [])
for d in range(k):
for d in range(num_spatial_dims):
res.append(output_size[d] - min_sizes[d])
ret = res
@ -769,8 +771,10 @@ class ConvTranspose1d(_ConvTransposeNd):
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 1
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore[arg-type]
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]
return F.conv_transpose1d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
@ -919,8 +923,10 @@ class ConvTranspose2d(_ConvTransposeNd):
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 2
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore[arg-type]
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]
return F.conv_transpose2d(
input, self.weight, self.bias, self.stride, self.padding,
@ -1067,8 +1073,10 @@ class ConvTranspose3d(_ConvTransposeNd):
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 3
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore[arg-type]
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]
return F.conv_transpose3d(
input, self.weight, self.bias, self.stride, self.padding,