mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add ConvTranspose1d module (#449)
This commit is contained in:
parent
24a2f2e3a0
commit
3a07228509
|
|
@ -21,10 +21,13 @@ Convolution Layers
|
|||
.. autoclass:: Conv2d
|
||||
:members:
|
||||
|
||||
.. autoclass:: ConvTranspose2d
|
||||
.. autoclass:: Conv3d
|
||||
:members:
|
||||
|
||||
.. autoclass:: Conv3d
|
||||
.. autoclass:: ConvTranspose1d
|
||||
:members:
|
||||
|
||||
.. autoclass:: ConvTranspose2d
|
||||
:members:
|
||||
|
||||
.. autoclass:: ConvTranspose3d
|
||||
|
|
|
|||
|
|
@ -962,7 +962,7 @@ class TestNN(NNTestCase):
|
|||
for bidirectional in (False, True):
|
||||
for dropout in (0, 1): # Because of dropout randomness, can only compare 0 and 1
|
||||
for batch_first in (False, True):
|
||||
num_directions = 2 if bidirectional else 1
|
||||
num_directions = 2 if bidirectional else 1
|
||||
if batch_first:
|
||||
input_val = torch.randn(batch, seq_length, input_size)
|
||||
else:
|
||||
|
|
@ -1211,6 +1211,19 @@ new_module_tests = [
|
|||
input_size=(2, 4, 6),
|
||||
cudnn=True,
|
||||
),
|
||||
dict(
|
||||
module_name='ConvTranspose1d',
|
||||
constructor_args=(3, 4, 3, (3,), 1, (1,)),
|
||||
cudnn=True,
|
||||
input_size=(1, 3, 7)
|
||||
),
|
||||
dict(
|
||||
module_name='ConvTranspose1d',
|
||||
constructor_args=(3, 4, 3, 2, 1, 1, 1, False),
|
||||
input_size=(1, 3, 6),
|
||||
cudnn=True,
|
||||
desc='no_bias'
|
||||
),
|
||||
dict(
|
||||
module_name='MaxPool1d',
|
||||
constructor_args=(4,),
|
||||
|
|
|
|||
|
|
@ -124,6 +124,13 @@ def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1,
|
|||
return f(input, weight, bias) if bias is not None else f(input, weight)
|
||||
|
||||
|
||||
def conv_transpose1d(input, weight, bias=None, stride=1, padding=0,
|
||||
output_padding=0, groups=1):
|
||||
f = ConvNd(_single(stride), _single(padding), _single(1), True,
|
||||
_single(output_padding), groups)
|
||||
return f(input, weight, bias) if bias is not None else f(input, weight)
|
||||
|
||||
|
||||
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0,
|
||||
output_padding=0, groups=1):
|
||||
"""Applies a 2D transposed convolution operator over an input image
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from .module import Module
|
||||
from .linear import Linear
|
||||
from .conv import Conv1d, Conv2d, ConvTranspose2d, Conv3d, ConvTranspose3d
|
||||
from .conv import Conv1d, Conv2d, Conv3d, \
|
||||
ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
|
||||
from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \
|
||||
Softmax, Softmax2d, LogSoftmax, ELU, Hardshrink, LeakyReLU, LogSigmoid, \
|
||||
Softplus, Softshrink, PReLU, Softsign, Softmin, Tanhshrink, RReLU
|
||||
|
|
|
|||
|
|
@ -69,45 +69,45 @@ class Conv1d(_ConvNd):
|
|||
and output :math:`(N, C_{out}, L_{out})` can be precisely described as:
|
||||
|
||||
.. math::
|
||||
|
||||
|
||||
\begin{array}{ll}
|
||||
out(N_i, C_{out_j}) = bias(C_{out_j})
|
||||
out(N_i, C_{out_j}) = bias(C_{out_j})
|
||||
+ \sum_{{k}=0}^{C_{in}-1} weight(C_{out_j}, k) \star input(N_i, k)
|
||||
\end{array}
|
||||
|
||||
where :math:`\star` is the valid `cross-correlation`_ operator
|
||||
|
||||
where :math:`\star` is the valid `cross-correlation`_ operator
|
||||
|
||||
| :attr:`stride` controls the stride for the cross-correlation.
|
||||
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
||||
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
||||
for :attr:`padding` number of points
|
||||
| :attr:`dilation` controls the spacing between the kernel points. It is harder to describe,
|
||||
| :attr:`dilation` controls the spacing between the kernel points. It is harder to describe,
|
||||
but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
||||
| :attr:`groups` controls the connections between inputs and outputs.
|
||||
| At groups=1, all inputs are convolved to all outputs.
|
||||
| At groups=2, the operation becomes equivalent to having two conv layers
|
||||
side by side, each seeing half the input channels,
|
||||
| :attr:`groups` controls the connections between inputs and outputs.
|
||||
| At groups=1, all inputs are convolved to all outputs.
|
||||
| At groups=2, the operation becomes equivalent to having two conv layers
|
||||
side by side, each seeing half the input channels,
|
||||
and producing half the output channels, and both subsequently concatenated.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
|
||||
Depending of the size of your kernel, several (of the last)
|
||||
columns of the input might be lost, because it is a valid `cross-correlation`_,
|
||||
columns of the input might be lost, because it is a valid `cross-correlation`_,
|
||||
and not a full `cross-correlation`_.
|
||||
It is up to the user to add proper padding.
|
||||
|
||||
|
||||
Args:
|
||||
in_channels (long): The number of expected input channels in the image given as input
|
||||
out_channels (long): The number of output channels the convolution layer will produce
|
||||
kernel_size (long): the size of the convolving kernel.
|
||||
stride (long, optional): the stride of the convolving kernel.
|
||||
padding (long, optional): zero-padding to be added to the input on both sides
|
||||
dilation (long, optional): controls the kernel striding
|
||||
groups (long, optional): controls the number of blocked connections from input to output
|
||||
bias (bool, optional): If True, adds a learnable bias to the output before convolving
|
||||
|
||||
in_channels (int): Number of channels in the input image
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int or tuple): Size of the convolving kernel
|
||||
stride (int or tuple, optional): Stride of the convolution
|
||||
padding (int or tuple, optional): Zero-padding added to both sides of the input
|
||||
dilation (int or tuple, optional): Spacing between kernel elements
|
||||
groups (int, optional): Number of blocked connections from input channels to output channels
|
||||
bias (bool, optional): If True, adds a learnable bias to the output
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C_{in}, L_{in})`
|
||||
- Output: :math:`(N, C_{out}, L_{out})` where
|
||||
- Output: :math:`(N, C_{out}, L_{out})` where
|
||||
:math:`L_{out} = floor((L_{in} + 2 * padding - dilation * (kernel\_size - 1) - 1) / stride + 1)`
|
||||
|
||||
Attributes:
|
||||
|
|
@ -115,14 +115,14 @@ class Conv1d(_ConvNd):
|
|||
bias (Tensor): the learnable bias of the module of shape (out_channels)
|
||||
|
||||
Examples::
|
||||
|
||||
|
||||
>>> m = nn.Conv1d(16, 33, 3, stride=2)
|
||||
>>> input = autograd.Variable(torch.randn(20, 16, 50))
|
||||
>>> output = m(input)
|
||||
|
||||
|
||||
.. _cross-correlation:
|
||||
https://en.wikipedia.org/wiki/Cross-correlation
|
||||
|
||||
|
||||
.. _link:
|
||||
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
||||
"""
|
||||
|
|
@ -150,51 +150,51 @@ class Conv2d(_ConvNd):
|
|||
and output :math:`(N, C_{out}, H_{out}, W_{out})` can be precisely described as:
|
||||
|
||||
.. math::
|
||||
|
||||
|
||||
\begin{array}{ll}
|
||||
out(N_i, C_{out_j}) = bias(C_{out_j})
|
||||
out(N_i, C_{out_j}) = bias(C_{out_j})
|
||||
+ \sum_{{k}=0}^{C_{in}-1} weight(C_{out_j}, k) \star input(N_i, k)
|
||||
\end{array}
|
||||
|
||||
where :math:`\star` is the valid 2D `cross-correlation`_ operator
|
||||
|
||||
where :math:`\star` is the valid 2D `cross-correlation`_ operator
|
||||
|
||||
| :attr:`stride` controls the stride for the cross-correlation.
|
||||
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
||||
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
||||
for :attr:`padding` number of points
|
||||
| :attr:`dilation` controls the spacing between the kernel points. It is harder to describe,
|
||||
| :attr:`dilation` controls the spacing between the kernel points. It is harder to describe,
|
||||
but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
||||
| :attr:`groups` controls the connections between inputs and outputs.
|
||||
| At groups=1, all inputs are convolved to all outputs.
|
||||
| At groups=2, the operation becomes equivalent to having two conv layers
|
||||
side by side, each seeing half the input channels,
|
||||
| :attr:`groups` controls the connections between inputs and outputs.
|
||||
| At groups=1, all inputs are convolved to all outputs.
|
||||
| At groups=2, the operation becomes equivalent to having two conv layers
|
||||
side by side, each seeing half the input channels,
|
||||
and producing half the output channels, and both subsequently concatenated.
|
||||
|
||||
|
||||
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
|
||||
|
||||
- a single ``int`` -- in which case the same value is used for the height and width dimension
|
||||
- a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
|
||||
- a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
|
||||
and the second `int` for the width dimension
|
||||
|
||||
.. note::
|
||||
|
||||
|
||||
Depending of the size of your kernel, several (of the last)
|
||||
columns of the input might be lost, because it is a valid `cross-correlation`_,
|
||||
columns of the input might be lost, because it is a valid `cross-correlation`_,
|
||||
and not a full `cross-correlation`_.
|
||||
It is up to the user to add proper padding.
|
||||
|
||||
|
||||
Args:
|
||||
in_channels (long): The number of expected input channels in the image given as input
|
||||
out_channels (long): The number of output channels the convolution layer will produce
|
||||
kernel_size (long or tuple): the size of the convolving kernel.
|
||||
stride (long or tuple, optional): the stride of the convolving kernel.
|
||||
padding (long or tuple, optional): zero-padding to be added to the input on both sides
|
||||
dilation (long or tuple, optional): controls the kernel striding
|
||||
groups (long, optional): controls the number of blocked connections from input to output
|
||||
bias (bool, optional): If True, adds a learnable bias to the output before convolving
|
||||
|
||||
in_channels (int): Number of channels in the input image
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int or tuple): Size of the convolving kernel
|
||||
stride (int or tuple, optional): Stride of the convolution
|
||||
padding (int or tuple, optional): Zero-padding added to both sides of the input
|
||||
dilation (int or tuple, optional): Spacing between kernel elements
|
||||
groups (int, optional): Number of blocked connections from input channels to output channels
|
||||
bias (bool, optional): If True, adds a learnable bias to the output
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
|
||||
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
|
||||
:math:`H_{out} = floor((H_{in} + 2 * padding[0] - dilation[0] * (kernel\_size[0] - 1) - 1) / stride[0] + 1)`
|
||||
:math:`W_{out} = floor((W_{in} + 2 * padding[1] - dilation[1] * (kernel\_size[1] - 1) - 1) / stride[1] + 1)`
|
||||
|
||||
|
|
@ -203,7 +203,7 @@ class Conv2d(_ConvNd):
|
|||
bias (Tensor): the learnable bias of the module of shape (out_channels)
|
||||
|
||||
Examples::
|
||||
|
||||
|
||||
>>> # With square kernels and equal stride
|
||||
>>> m = nn.Conv2d(16, 33, 3, stride=2)
|
||||
>>> # non-square kernels and unequal stride and with padding
|
||||
|
|
@ -243,51 +243,51 @@ class Conv3d(_ConvNd):
|
|||
and output :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` can be precisely described as:
|
||||
|
||||
.. math::
|
||||
|
||||
|
||||
\begin{array}{ll}
|
||||
out(N_i, C_{out_j}) = bias(C_{out_j})
|
||||
out(N_i, C_{out_j}) = bias(C_{out_j})
|
||||
+ \sum_{{k}=0}^{C_{in}-1} weight(C_{out_j}, k) \star input(N_i, k)
|
||||
\end{array}
|
||||
|
||||
where :math:`\star` is the valid 3D `cross-correlation`_ operator
|
||||
|
||||
where :math:`\star` is the valid 3D `cross-correlation`_ operator
|
||||
|
||||
| :attr:`stride` controls the stride for the cross-correlation.
|
||||
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
||||
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
||||
for :attr:`padding` number of points
|
||||
| :attr:`dilation` controls the spacing between the kernel points. It is harder to describe,
|
||||
| :attr:`dilation` controls the spacing between the kernel points. It is harder to describe,
|
||||
but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
||||
| :attr:`groups` controls the connections between inputs and outputs.
|
||||
| At groups=1, all inputs are convolved to all outputs.
|
||||
| At groups=2, the operation becomes equivalent to having two conv layers
|
||||
side by side, each seeing half the input channels,
|
||||
| :attr:`groups` controls the connections between inputs and outputs.
|
||||
| At groups=1, all inputs are convolved to all outputs.
|
||||
| At groups=2, the operation becomes equivalent to having two conv layers
|
||||
side by side, each seeing half the input channels,
|
||||
and producing half the output channels, and both subsequently concatenated.
|
||||
|
||||
|
||||
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
|
||||
|
||||
- a single ``int`` -- in which case the same value is used for the height and width dimension
|
||||
- a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
|
||||
- a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
|
||||
the second `int` for the width dimension and the third `int` for the width dimension
|
||||
|
||||
.. note::
|
||||
|
||||
|
||||
Depending of the size of your kernel, several (of the last)
|
||||
columns of the input might be lost, because it is a valid `cross-correlation`_,
|
||||
columns of the input might be lost, because it is a valid `cross-correlation`_,
|
||||
and not a full `cross-correlation`_.
|
||||
It is up to the user to add proper padding.
|
||||
|
||||
|
||||
Args:
|
||||
in_channels (long): The number of expected input channels in the image given as input
|
||||
out_channels (long): The number of output channels the convolution layer will produce
|
||||
kernel_size (long or tuple): the size of the convolving kernel.
|
||||
stride (long or tuple, optional): the stride of the convolving kernel.
|
||||
padding (long or tuple, optional): zero-padding to be added to the input on both sides
|
||||
dilation (long or tuple, optional): controls the kernel striding
|
||||
groups (long, optional): controls the number of blocked connections from input to output
|
||||
bias (bool, optional): If True, adds a learnable bias to the output before convolving
|
||||
|
||||
in_channels (int): Number of channels in the input image
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int or tuple): Size of the convolving kernel
|
||||
stride (int or tuple, optional): Stride of the convolution
|
||||
padding (int or tuple, optional): Zero-padding added to both sides of the input
|
||||
dilation (int or tuple, optional): Spacing between kernel elements
|
||||
groups (int, optional): Number of blocked connections from input channels to output channels
|
||||
bias (bool, optional): If True, adds a learnable bias to the output
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where
|
||||
- Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where
|
||||
:math:`D_{out} = floor((D_{in} + 2 * padding[0] - dilation[0] * (kernel\_size[0] - 1) - 1) / stride[0] + 1)`
|
||||
:math:`H_{out} = floor((H_{in} + 2 * padding[1] - dilation[1] * (kernel\_size[1] - 1) - 1) / stride[1] + 1)`
|
||||
:math:`W_{out} = floor((W_{in} + 2 * padding[2] - dilation[2] * (kernel\_size[2] - 1) - 1) / stride[2] + 1)`
|
||||
|
|
@ -367,54 +367,102 @@ class _ConvTransposeMixin(object):
|
|||
return tuple([output_size[d] - min_sizes[d] for d in range(k)])
|
||||
|
||||
|
||||
class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
|
||||
"""Applies a 1D transposed convolution operator over an input image
|
||||
composed of several input planes.
|
||||
|
||||
This module can be seen as the gradient of Conv1d with respect to its input.
|
||||
|
||||
.. note::
|
||||
|
||||
Depending of the size of your kernel, several (of the last)
|
||||
columns of the input might be lost, because it is a valid `cross-correlation`_,
|
||||
and not a full `cross-correlation`_.
|
||||
It is up to the user to add proper padding.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input image
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int or tuple): Size of the convolving kernel
|
||||
stride (int or tuple, optional): Stride of the convolution
|
||||
padding (int or tuple, optional): Zero-padding added to both sides of the input
|
||||
dilation (int or tuple, optional): Spacing between kernel elements
|
||||
groups (int, optional): Number of blocked connections from input channels to output channels
|
||||
bias (bool, optional): If True, adds a learnable bias to the output
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C_{in}, L_{in})`
|
||||
- Output: :math:`(N, C_{out}, L_{out})` where
|
||||
:math:`L_{out} = (L_{in} - 1) * stride - 2 * padding + kernel_size + output_padding`
|
||||
|
||||
Attributes:
|
||||
weight (Tensor): the learnable weights of the module of shape (in_channels, out_channels, kernel_size[0], kernel_size[1])
|
||||
bias (Tensor): the learnable bias of the module of shape (out_channels)
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, output_padding=0, groups=1, bias=True):
|
||||
kernel_size = _single(kernel_size)
|
||||
stride = _single(stride)
|
||||
padding = _single(padding)
|
||||
dilation = _single(1)
|
||||
output_padding = _single(output_padding)
|
||||
super(ConvTranspose1d, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
True, output_padding, groups, bias)
|
||||
|
||||
def forward(self, input, output_size=None):
|
||||
output_padding = self._output_padding(input, output_size)
|
||||
return F.conv_transpose1d(
|
||||
input, self.weight, self.bias, self.stride, self.padding,
|
||||
output_padding, self.groups)
|
||||
|
||||
|
||||
class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
|
||||
r"""Applies a 2D transposed convolution operator over an input image composed of several input
|
||||
planes.
|
||||
The transposed convolution operator multiplies each input value element-wise by a learnable kernel,
|
||||
and sums over the outputs from all input feature planes.
|
||||
|
||||
**This module can be seen as the exact reverse of Conv2d**
|
||||
r"""Applies a 2D transposed convolution operator over an input image
|
||||
composed of several input planes.
|
||||
|
||||
This module can be seen as the gradient of Conv2d with respect to its input.
|
||||
|
||||
| :attr:`stride` controls the stride for the cross-correlation.
|
||||
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
||||
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
||||
for :attr:`padding` number of points
|
||||
| If :attr:`padding` is non-zero, then the output is implicitly zero-padded on both sides
|
||||
| If :attr:`padding` is non-zero, then the output is implicitly zero-padded on both sides
|
||||
for :attr:`padding` number of points
|
||||
| :attr:`dilation` controls the spacing between the kernel points. It is harder to describe,
|
||||
| :attr:`dilation` controls the spacing between the kernel points. It is harder to describe,
|
||||
but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
||||
| :attr:`groups` controls the connections between inputs and outputs.
|
||||
| At groups=1, all inputs are convolved to all outputs.
|
||||
| At groups=2, the operation becomes equivalent to having two conv layers
|
||||
side by side, each seeing half the input channels,
|
||||
| :attr:`groups` controls the connections between inputs and outputs.
|
||||
| At groups=1, all inputs are convolved to all outputs.
|
||||
| At groups=2, the operation becomes equivalent to having two conv layers
|
||||
side by side, each seeing half the input channels,
|
||||
and producing half the output channels, and both subsequently concatenated.
|
||||
|
||||
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`,
|
||||
|
||||
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`,
|
||||
:attr:`dilation` can either be:
|
||||
|
||||
- a single ``int`` -- in which case the same value is used for the height and width dimension
|
||||
- a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
|
||||
- a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
|
||||
and the second `int` for the width dimension
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
|
||||
Depending of the size of your kernel, several (of the last)
|
||||
columns of the input might be lost, because it is a valid `cross-correlation`_,
|
||||
columns of the input might be lost, because it is a valid `cross-correlation`_,
|
||||
and not a full `cross-correlation`_.
|
||||
It is up to the user to add proper padding.
|
||||
|
||||
|
||||
Args:
|
||||
in_channels (long): The number of expected input channels in the image given as input
|
||||
out_channels (long): The number of output channels the convolution layer will produce
|
||||
kernel_size (long or tuple): the size of the convolving kernel.
|
||||
stride (long or tuple, optional): the stride of the convolving kernel.
|
||||
padding (long or tuple, optional): zero-padding to be added to the input on both sides
|
||||
dilation (long or tuple, optional): controls the kernel striding
|
||||
groups (long, optional): controls the number of blocked connections from input to output
|
||||
bias (bool, optional): If True, adds a learnable bias to the output before convolving
|
||||
|
||||
in_channels (int): Number of channels in the input image
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int or tuple): Size of the convolving kernel
|
||||
stride (int or tuple, optional): Stride of the convolution
|
||||
padding (int or tuple, optional): Zero-padding added to both sides of the input
|
||||
dilation (int or tuple, optional): Spacing between kernel elements
|
||||
groups (int, optional): Number of blocked connections from input channels to output channels
|
||||
bias (bool, optional): If True, adds a learnable bias to the output
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
|
||||
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
|
||||
:math:`H_{out} = (H_{in} - 1) * stride[0] - 2 * padding[0] + kernel_size[0] + output_padding[0]`
|
||||
:math:`W_{out} = (W_{in} - 1) * stride[1] - 2 * padding[1] + kernel_size[1] + output_padding[1]`
|
||||
|
||||
|
|
@ -435,7 +483,11 @@ class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
|
|||
>>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1)
|
||||
>>> upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
|
||||
>>> h = downsample(input)
|
||||
>>> h.size()
|
||||
torch.Size([1, 16, 6, 6])
|
||||
>>> output = upsample(h, output_size=input.size())
|
||||
>>> output.size()
|
||||
torch.Size([1, 16, 12, 12])
|
||||
|
||||
.. _cross-correlation:
|
||||
https://en.wikipedia.org/wiki/Cross-correlation
|
||||
|
|
@ -467,49 +519,49 @@ class ConvTranspose3d(_ConvTransposeMixin, _ConvNd):
|
|||
planes.
|
||||
The transposed convolution operator multiplies each input value element-wise by a learnable kernel,
|
||||
and sums over the outputs from all input feature planes.
|
||||
|
||||
**This module can be seen as the exact reverse of Conv3d**
|
||||
|
||||
**This module can be seen as the exact reverse of Conv3d**
|
||||
|
||||
| :attr:`stride` controls the stride for the cross-correlation.
|
||||
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
||||
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
||||
for :attr:`padding` number of points
|
||||
| If :attr:`padding` is non-zero, then the output is implicitly zero-padded on both sides
|
||||
| If :attr:`padding` is non-zero, then the output is implicitly zero-padded on both sides
|
||||
for :attr:`padding` number of points
|
||||
| :attr:`dilation` controls the spacing between the kernel points. It is harder to describe,
|
||||
| :attr:`dilation` controls the spacing between the kernel points. It is harder to describe,
|
||||
but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
||||
| :attr:`groups` controls the connections between inputs and outputs.
|
||||
| At groups=1, all inputs are convolved to all outputs.
|
||||
| At groups=2, the operation becomes equivalent to having two conv layers
|
||||
side by side, each seeing half the input channels,
|
||||
| :attr:`groups` controls the connections between inputs and outputs.
|
||||
| At groups=1, all inputs are convolved to all outputs.
|
||||
| At groups=2, the operation becomes equivalent to having two conv layers
|
||||
side by side, each seeing half the input channels,
|
||||
and producing half the output channels, and both subsequently concatenated.
|
||||
|
||||
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`,
|
||||
|
||||
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`,
|
||||
:attr:`dilation` can either be:
|
||||
|
||||
- a single ``int`` -- in which case the same value is used for the height and width dimension
|
||||
- a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
|
||||
- a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
|
||||
the second `int` for the width dimension and the third `int` for the width dimension
|
||||
|
||||
.. note::
|
||||
|
||||
|
||||
Depending of the size of your kernel, several (of the last)
|
||||
columns of the input might be lost, because it is a valid `cross-correlation`_,
|
||||
columns of the input might be lost, because it is a valid `cross-correlation`_,
|
||||
and not a full `cross-correlation`_.
|
||||
It is up to the user to add proper padding.
|
||||
|
||||
|
||||
Args:
|
||||
in_channels (long): The number of expected input channels in the image given as input
|
||||
out_channels (long): The number of output channels the convolution layer will produce
|
||||
kernel_size (long or tuple): the size of the convolving kernel.
|
||||
stride (long or tuple, optional): the stride of the convolving kernel.
|
||||
padding (long or tuple, optional): zero-padding to be added to the input on both sides
|
||||
dilation (long or tuple, optional): controls the kernel striding
|
||||
groups (long, optional): controls the number of blocked connections from input to output
|
||||
bias (bool, optional): If True, adds a learnable bias to the output before convolving
|
||||
|
||||
in_channels (int): Number of channels in the input image
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int or tuple): Size of the convolving kernel
|
||||
stride (int or tuple, optional): Stride of the convolution
|
||||
padding (int or tuple, optional): Zero-padding added to both sides of the input
|
||||
dilation (int or tuple, optional): Spacing between kernel elements
|
||||
groups (int, optional): Number of blocked connections from input channels to output channels
|
||||
bias (bool, optional): If True, adds a learnable bias to the output
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`
|
||||
- Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where
|
||||
- Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where
|
||||
:math:`D_{out} = (D_{in} - 1) * stride[0] - 2 * padding[0] + kernel_size[0] + output_padding[0]`
|
||||
:math:`H_{out} = (H_{in} - 1) * stride[1] - 2 * padding[1] + kernel_size[1] + output_padding[1]`
|
||||
:math:`W_{out} = (W_{in} - 1) * stride[2] - 2 * padding[2] + kernel_size[2] + output_padding[2]`
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user