mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix to support no-batch-dim inputs in ConvTransposeNd._output_padding
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76151 Approved by: https://github.com/albanD
This commit is contained in:
parent
ea8a0184b7
commit
041e6e750a
|
|
@ -13833,6 +13833,17 @@ class TestNNDeviceType(NNTestCase):
|
|||
gradcheck(lambda x, y: F.conv3d(x, y, padding='valid'), (x, y), check_forward_ad=check_forward_ad)
|
||||
gradgradcheck(lambda x, y: F.conv3d(x, y, padding='valid'), (x, y), check_fwd_over_rev=check_forward_ad)
|
||||
|
||||
@parametrize_test("N", range(2, 4), name_fn=lambda N: 'ConvTranspose{}d'.format(N))
|
||||
def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N):
|
||||
# For inputs with no batch dim, verify output is the correct shape when output_size is set.
|
||||
# See https://github.com/pytorch/pytorch/issues/75889
|
||||
inp = torch.randn((1, 15, 13) if N == 2 else (1, 15, 13, 13), device=device)
|
||||
output_size = (1, 240, 200) if N == 2 else (1, 240, 200, 200)
|
||||
ConvTransposeNd = getattr(nn, 'ConvTranspose{}d'.format(N))
|
||||
m = ConvTransposeNd(1, 1, kernel_size=16, stride=16, padding=7, bias=False, device=device)
|
||||
output = m(inp, output_size=output_size)
|
||||
self.assertEqual(output.shape, output_size)
|
||||
|
||||
@skipMeta
|
||||
@parametrize_test("input_shape,transposed,dilated,groups,layout,backend_expected", [
|
||||
# === slow ===
|
||||
|
|
|
|||
|
|
@ -610,22 +610,24 @@ class _ConvTransposeNd(_ConvNd):
|
|||
# compatibility
|
||||
def _output_padding(self, input: Tensor, output_size: Optional[List[int]],
|
||||
stride: List[int], padding: List[int], kernel_size: List[int],
|
||||
dilation: Optional[List[int]] = None) -> List[int]:
|
||||
num_spatial_dims: int, dilation: Optional[List[int]] = None) -> List[int]:
|
||||
if output_size is None:
|
||||
ret = _single(self.output_padding) # converting to list if was not already
|
||||
else:
|
||||
k = input.dim() - 2
|
||||
if len(output_size) == k + 2:
|
||||
output_size = output_size[2:]
|
||||
if len(output_size) != k:
|
||||
has_batch_dim = input.dim() == num_spatial_dims + 2
|
||||
num_non_spatial_dims = 2 if has_batch_dim else 1
|
||||
if len(output_size) == num_non_spatial_dims + num_spatial_dims:
|
||||
output_size = output_size[num_non_spatial_dims:]
|
||||
if len(output_size) != num_spatial_dims:
|
||||
raise ValueError(
|
||||
"output_size must have {} or {} elements (got {})"
|
||||
.format(k, k + 2, len(output_size)))
|
||||
"ConvTranspose{}D: for {}D input, output_size must have {} or {} elements (got {})"
|
||||
.format(num_spatial_dims, input.dim(), num_spatial_dims,
|
||||
num_non_spatial_dims + num_spatial_dims, len(output_size)))
|
||||
|
||||
min_sizes = torch.jit.annotate(List[int], [])
|
||||
max_sizes = torch.jit.annotate(List[int], [])
|
||||
for d in range(k):
|
||||
dim_size = ((input.size(d + 2) - 1) * stride[d] -
|
||||
for d in range(num_spatial_dims):
|
||||
dim_size = ((input.size(d + num_non_spatial_dims) - 1) * stride[d] -
|
||||
2 * padding[d] +
|
||||
(dilation[d] if dilation is not None else 1) * (kernel_size[d] - 1) + 1)
|
||||
min_sizes.append(dim_size)
|
||||
|
|
@ -642,7 +644,7 @@ class _ConvTransposeNd(_ConvNd):
|
|||
output_size, min_sizes, max_sizes, input.size()[2:]))
|
||||
|
||||
res = torch.jit.annotate(List[int], [])
|
||||
for d in range(k):
|
||||
for d in range(num_spatial_dims):
|
||||
res.append(output_size[d] - min_sizes[d])
|
||||
|
||||
ret = res
|
||||
|
|
@ -769,8 +771,10 @@ class ConvTranspose1d(_ConvTransposeNd):
|
|||
assert isinstance(self.padding, tuple)
|
||||
# One cannot replace List by Tuple or Sequence in "_output_padding" because
|
||||
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
|
||||
num_spatial_dims = 1
|
||||
output_padding = self._output_padding(
|
||||
input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore[arg-type]
|
||||
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
|
||||
num_spatial_dims, self.dilation) # type: ignore[arg-type]
|
||||
return F.conv_transpose1d(
|
||||
input, self.weight, self.bias, self.stride, self.padding,
|
||||
output_padding, self.groups, self.dilation)
|
||||
|
|
@ -919,8 +923,10 @@ class ConvTranspose2d(_ConvTransposeNd):
|
|||
assert isinstance(self.padding, tuple)
|
||||
# One cannot replace List by Tuple or Sequence in "_output_padding" because
|
||||
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
|
||||
num_spatial_dims = 2
|
||||
output_padding = self._output_padding(
|
||||
input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore[arg-type]
|
||||
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
|
||||
num_spatial_dims, self.dilation) # type: ignore[arg-type]
|
||||
|
||||
return F.conv_transpose2d(
|
||||
input, self.weight, self.bias, self.stride, self.padding,
|
||||
|
|
@ -1067,8 +1073,10 @@ class ConvTranspose3d(_ConvTransposeNd):
|
|||
assert isinstance(self.padding, tuple)
|
||||
# One cannot replace List by Tuple or Sequence in "_output_padding" because
|
||||
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
|
||||
num_spatial_dims = 3
|
||||
output_padding = self._output_padding(
|
||||
input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore[arg-type]
|
||||
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
|
||||
num_spatial_dims, self.dilation) # type: ignore[arg-type]
|
||||
|
||||
return F.conv_transpose3d(
|
||||
input, self.weight, self.bias, self.stride, self.padding,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user