diff --git a/docs/source/nn.rst b/docs/source/nn.rst index e90664b782e..ff66e4ea398 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -21,10 +21,13 @@ Convolution Layers .. autoclass:: Conv2d :members: -.. autoclass:: ConvTranspose2d +.. autoclass:: Conv3d :members: -.. autoclass:: Conv3d +.. autoclass:: ConvTranspose1d + :members: + +.. autoclass:: ConvTranspose2d :members: .. autoclass:: ConvTranspose3d diff --git a/test/test_nn.py b/test/test_nn.py index 0c6125a19cf..108ff923f38 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -962,7 +962,7 @@ class TestNN(NNTestCase): for bidirectional in (False, True): for dropout in (0, 1): # Because of dropout randomness, can only compare 0 and 1 for batch_first in (False, True): - num_directions = 2 if bidirectional else 1 + num_directions = 2 if bidirectional else 1 if batch_first: input_val = torch.randn(batch, seq_length, input_size) else: @@ -1211,6 +1211,19 @@ new_module_tests = [ input_size=(2, 4, 6), cudnn=True, ), + dict( + module_name='ConvTranspose1d', + constructor_args=(3, 4, 3, (3,), 1, (1,)), + cudnn=True, + input_size=(1, 3, 7) + ), + dict( + module_name='ConvTranspose1d', + constructor_args=(3, 4, 3, 2, 1, 1, 1, False), + input_size=(1, 3, 6), + cudnn=True, + desc='no_bias' + ), dict( module_name='MaxPool1d', constructor_args=(4,), diff --git a/torch/nn/functional.py b/torch/nn/functional.py index f3e14ea6073..25bb4d7dd43 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -124,6 +124,13 @@ def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, return f(input, weight, bias) if bias is not None else f(input, weight) +def conv_transpose1d(input, weight, bias=None, stride=1, padding=0, + output_padding=0, groups=1): + f = ConvNd(_single(stride), _single(padding), _single(1), True, + _single(output_padding), groups) + return f(input, weight, bias) if bias is not None else f(input, weight) + + def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1): """Applies a 2D transposed convolution operator over an input image diff --git a/torch/nn/modules/__init__.py b/torch/nn/modules/__init__.py index 21da670ac6a..7533664696f 100644 --- a/torch/nn/modules/__init__.py +++ b/torch/nn/modules/__init__.py @@ -1,6 +1,7 @@ from .module import Module from .linear import Linear -from .conv import Conv1d, Conv2d, ConvTranspose2d, Conv3d, ConvTranspose3d +from .conv import Conv1d, Conv2d, Conv3d, \ + ConvTranspose1d, ConvTranspose2d, ConvTranspose3d from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \ Softmax, Softmax2d, LogSoftmax, ELU, Hardshrink, LeakyReLU, LogSigmoid, \ Softplus, Softshrink, PReLU, Softsign, Softmin, Tanhshrink, RReLU diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 6e0fdd58bce..62a8469b3af 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -69,45 +69,45 @@ class Conv1d(_ConvNd): and output :math:`(N, C_{out}, L_{out})` can be precisely described as: .. math:: - + \begin{array}{ll} - out(N_i, C_{out_j}) = bias(C_{out_j}) + out(N_i, C_{out_j}) = bias(C_{out_j}) + \sum_{{k}=0}^{C_{in}-1} weight(C_{out_j}, k) \star input(N_i, k) \end{array} - - where :math:`\star` is the valid `cross-correlation`_ operator + + where :math:`\star` is the valid `cross-correlation`_ operator | :attr:`stride` controls the stride for the cross-correlation. - | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides + | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides for :attr:`padding` number of points - | :attr:`dilation` controls the spacing between the kernel points. It is harder to describe, + | :attr:`dilation` controls the spacing between the kernel points. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - | :attr:`groups` controls the connections between inputs and outputs. - | At groups=1, all inputs are convolved to all outputs. - | At groups=2, the operation becomes equivalent to having two conv layers - side by side, each seeing half the input channels, + | :attr:`groups` controls the connections between inputs and outputs. + | At groups=1, all inputs are convolved to all outputs. + | At groups=2, the operation becomes equivalent to having two conv layers + side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated. - + .. note:: - + Depending of the size of your kernel, several (of the last) - columns of the input might be lost, because it is a valid `cross-correlation`_, + columns of the input might be lost, because it is a valid `cross-correlation`_, and not a full `cross-correlation`_. It is up to the user to add proper padding. - + Args: - in_channels (long): The number of expected input channels in the image given as input - out_channels (long): The number of output channels the convolution layer will produce - kernel_size (long): the size of the convolving kernel. - stride (long, optional): the stride of the convolving kernel. - padding (long, optional): zero-padding to be added to the input on both sides - dilation (long, optional): controls the kernel striding - groups (long, optional): controls the number of blocked connections from input to output - bias (bool, optional): If True, adds a learnable bias to the output before convolving - + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution + padding (int or tuple, optional): Zero-padding added to both sides of the input + dilation (int or tuple, optional): Spacing between kernel elements + groups (int, optional): Number of blocked connections from input channels to output channels + bias (bool, optional): If True, adds a learnable bias to the output + Shape: - Input: :math:`(N, C_{in}, L_{in})` - - Output: :math:`(N, C_{out}, L_{out})` where + - Output: :math:`(N, C_{out}, L_{out})` where :math:`L_{out} = floor((L_{in} + 2 * padding - dilation * (kernel\_size - 1) - 1) / stride + 1)` Attributes: @@ -115,14 +115,14 @@ class Conv1d(_ConvNd): bias (Tensor): the learnable bias of the module of shape (out_channels) Examples:: - + >>> m = nn.Conv1d(16, 33, 3, stride=2) >>> input = autograd.Variable(torch.randn(20, 16, 50)) >>> output = m(input) - + .. _cross-correlation: https://en.wikipedia.org/wiki/Cross-correlation - + .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ @@ -150,51 +150,51 @@ class Conv2d(_ConvNd): and output :math:`(N, C_{out}, H_{out}, W_{out})` can be precisely described as: .. math:: - + \begin{array}{ll} - out(N_i, C_{out_j}) = bias(C_{out_j}) + out(N_i, C_{out_j}) = bias(C_{out_j}) + \sum_{{k}=0}^{C_{in}-1} weight(C_{out_j}, k) \star input(N_i, k) \end{array} - - where :math:`\star` is the valid 2D `cross-correlation`_ operator + + where :math:`\star` is the valid 2D `cross-correlation`_ operator | :attr:`stride` controls the stride for the cross-correlation. - | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides + | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides for :attr:`padding` number of points - | :attr:`dilation` controls the spacing between the kernel points. It is harder to describe, + | :attr:`dilation` controls the spacing between the kernel points. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - | :attr:`groups` controls the connections between inputs and outputs. - | At groups=1, all inputs are convolved to all outputs. - | At groups=2, the operation becomes equivalent to having two conv layers - side by side, each seeing half the input channels, + | :attr:`groups` controls the connections between inputs and outputs. + | At groups=1, all inputs are convolved to all outputs. + | At groups=2, the operation becomes equivalent to having two conv layers + side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated. - + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: - a single ``int`` -- in which case the same value is used for the height and width dimension - - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, and the second `int` for the width dimension .. note:: - + Depending of the size of your kernel, several (of the last) - columns of the input might be lost, because it is a valid `cross-correlation`_, + columns of the input might be lost, because it is a valid `cross-correlation`_, and not a full `cross-correlation`_. It is up to the user to add proper padding. - + Args: - in_channels (long): The number of expected input channels in the image given as input - out_channels (long): The number of output channels the convolution layer will produce - kernel_size (long or tuple): the size of the convolving kernel. - stride (long or tuple, optional): the stride of the convolving kernel. - padding (long or tuple, optional): zero-padding to be added to the input on both sides - dilation (long or tuple, optional): controls the kernel striding - groups (long, optional): controls the number of blocked connections from input to output - bias (bool, optional): If True, adds a learnable bias to the output before convolving - + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution + padding (int or tuple, optional): Zero-padding added to both sides of the input + dilation (int or tuple, optional): Spacing between kernel elements + groups (int, optional): Number of blocked connections from input channels to output channels + bias (bool, optional): If True, adds a learnable bias to the output + Shape: - Input: :math:`(N, C_{in}, H_{in}, W_{in})` - - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where :math:`H_{out} = floor((H_{in} + 2 * padding[0] - dilation[0] * (kernel\_size[0] - 1) - 1) / stride[0] + 1)` :math:`W_{out} = floor((W_{in} + 2 * padding[1] - dilation[1] * (kernel\_size[1] - 1) - 1) / stride[1] + 1)` @@ -203,7 +203,7 @@ class Conv2d(_ConvNd): bias (Tensor): the learnable bias of the module of shape (out_channels) Examples:: - + >>> # With square kernels and equal stride >>> m = nn.Conv2d(16, 33, 3, stride=2) >>> # non-square kernels and unequal stride and with padding @@ -243,51 +243,51 @@ class Conv3d(_ConvNd): and output :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` can be precisely described as: .. math:: - + \begin{array}{ll} - out(N_i, C_{out_j}) = bias(C_{out_j}) + out(N_i, C_{out_j}) = bias(C_{out_j}) + \sum_{{k}=0}^{C_{in}-1} weight(C_{out_j}, k) \star input(N_i, k) \end{array} - - where :math:`\star` is the valid 3D `cross-correlation`_ operator + + where :math:`\star` is the valid 3D `cross-correlation`_ operator | :attr:`stride` controls the stride for the cross-correlation. - | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides + | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides for :attr:`padding` number of points - | :attr:`dilation` controls the spacing between the kernel points. It is harder to describe, + | :attr:`dilation` controls the spacing between the kernel points. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - | :attr:`groups` controls the connections between inputs and outputs. - | At groups=1, all inputs are convolved to all outputs. - | At groups=2, the operation becomes equivalent to having two conv layers - side by side, each seeing half the input channels, + | :attr:`groups` controls the connections between inputs and outputs. + | At groups=1, all inputs are convolved to all outputs. + | At groups=2, the operation becomes equivalent to having two conv layers + side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated. - + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: - a single ``int`` -- in which case the same value is used for the height and width dimension - - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, the second `int` for the width dimension and the third `int` for the width dimension .. note:: - + Depending of the size of your kernel, several (of the last) - columns of the input might be lost, because it is a valid `cross-correlation`_, + columns of the input might be lost, because it is a valid `cross-correlation`_, and not a full `cross-correlation`_. It is up to the user to add proper padding. - + Args: - in_channels (long): The number of expected input channels in the image given as input - out_channels (long): The number of output channels the convolution layer will produce - kernel_size (long or tuple): the size of the convolving kernel. - stride (long or tuple, optional): the stride of the convolving kernel. - padding (long or tuple, optional): zero-padding to be added to the input on both sides - dilation (long or tuple, optional): controls the kernel striding - groups (long, optional): controls the number of blocked connections from input to output - bias (bool, optional): If True, adds a learnable bias to the output before convolving - + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution + padding (int or tuple, optional): Zero-padding added to both sides of the input + dilation (int or tuple, optional): Spacing between kernel elements + groups (int, optional): Number of blocked connections from input channels to output channels + bias (bool, optional): If True, adds a learnable bias to the output + Shape: - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` - - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where + - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where :math:`D_{out} = floor((D_{in} + 2 * padding[0] - dilation[0] * (kernel\_size[0] - 1) - 1) / stride[0] + 1)` :math:`H_{out} = floor((H_{in} + 2 * padding[1] - dilation[1] * (kernel\_size[1] - 1) - 1) / stride[1] + 1)` :math:`W_{out} = floor((W_{in} + 2 * padding[2] - dilation[2] * (kernel\_size[2] - 1) - 1) / stride[2] + 1)` @@ -367,54 +367,102 @@ class _ConvTransposeMixin(object): return tuple([output_size[d] - min_sizes[d] for d in range(k)]) +class ConvTranspose1d(_ConvTransposeMixin, _ConvNd): + """Applies a 1D transposed convolution operator over an input image + composed of several input planes. + + This module can be seen as the gradient of Conv1d with respect to its input. + + .. note:: + + Depending of the size of your kernel, several (of the last) + columns of the input might be lost, because it is a valid `cross-correlation`_, + and not a full `cross-correlation`_. + It is up to the user to add proper padding. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution + padding (int or tuple, optional): Zero-padding added to both sides of the input + dilation (int or tuple, optional): Spacing between kernel elements + groups (int, optional): Number of blocked connections from input channels to output channels + bias (bool, optional): If True, adds a learnable bias to the output + + Shape: + - Input: :math:`(N, C_{in}, L_{in})` + - Output: :math:`(N, C_{out}, L_{out})` where + :math:`L_{out} = (L_{in} - 1) * stride - 2 * padding + kernel_size + output_padding` + + Attributes: + weight (Tensor): the learnable weights of the module of shape (in_channels, out_channels, kernel_size[0], kernel_size[1]) + bias (Tensor): the learnable bias of the module of shape (out_channels) + """ + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, output_padding=0, groups=1, bias=True): + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = _single(padding) + dilation = _single(1) + output_padding = _single(output_padding) + super(ConvTranspose1d, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + True, output_padding, groups, bias) + + def forward(self, input, output_size=None): + output_padding = self._output_padding(input, output_size) + return F.conv_transpose1d( + input, self.weight, self.bias, self.stride, self.padding, + output_padding, self.groups) + + class ConvTranspose2d(_ConvTransposeMixin, _ConvNd): - r"""Applies a 2D transposed convolution operator over an input image composed of several input - planes. - The transposed convolution operator multiplies each input value element-wise by a learnable kernel, - and sums over the outputs from all input feature planes. - - **This module can be seen as the exact reverse of Conv2d** + r"""Applies a 2D transposed convolution operator over an input image + composed of several input planes. + + This module can be seen as the gradient of Conv2d with respect to its input. | :attr:`stride` controls the stride for the cross-correlation. - | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides + | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides for :attr:`padding` number of points - | If :attr:`padding` is non-zero, then the output is implicitly zero-padded on both sides + | If :attr:`padding` is non-zero, then the output is implicitly zero-padded on both sides for :attr:`padding` number of points - | :attr:`dilation` controls the spacing between the kernel points. It is harder to describe, + | :attr:`dilation` controls the spacing between the kernel points. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - | :attr:`groups` controls the connections between inputs and outputs. - | At groups=1, all inputs are convolved to all outputs. - | At groups=2, the operation becomes equivalent to having two conv layers - side by side, each seeing half the input channels, + | :attr:`groups` controls the connections between inputs and outputs. + | At groups=1, all inputs are convolved to all outputs. + | At groups=2, the operation becomes equivalent to having two conv layers + side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated. - - The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`, + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`, :attr:`dilation` can either be: - a single ``int`` -- in which case the same value is used for the height and width dimension - - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, and the second `int` for the width dimension - + .. note:: - + Depending of the size of your kernel, several (of the last) - columns of the input might be lost, because it is a valid `cross-correlation`_, + columns of the input might be lost, because it is a valid `cross-correlation`_, and not a full `cross-correlation`_. It is up to the user to add proper padding. - + Args: - in_channels (long): The number of expected input channels in the image given as input - out_channels (long): The number of output channels the convolution layer will produce - kernel_size (long or tuple): the size of the convolving kernel. - stride (long or tuple, optional): the stride of the convolving kernel. - padding (long or tuple, optional): zero-padding to be added to the input on both sides - dilation (long or tuple, optional): controls the kernel striding - groups (long, optional): controls the number of blocked connections from input to output - bias (bool, optional): If True, adds a learnable bias to the output before convolving - + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution + padding (int or tuple, optional): Zero-padding added to both sides of the input + dilation (int or tuple, optional): Spacing between kernel elements + groups (int, optional): Number of blocked connections from input channels to output channels + bias (bool, optional): If True, adds a learnable bias to the output + Shape: - Input: :math:`(N, C_{in}, H_{in}, W_{in})` - - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where :math:`H_{out} = (H_{in} - 1) * stride[0] - 2 * padding[0] + kernel_size[0] + output_padding[0]` :math:`W_{out} = (W_{in} - 1) * stride[1] - 2 * padding[1] + kernel_size[1] + output_padding[1]` @@ -435,7 +483,11 @@ class ConvTranspose2d(_ConvTransposeMixin, _ConvNd): >>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1) >>> upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1) >>> h = downsample(input) + >>> h.size() + torch.Size([1, 16, 6, 6]) >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12]) .. _cross-correlation: https://en.wikipedia.org/wiki/Cross-correlation @@ -467,49 +519,49 @@ class ConvTranspose3d(_ConvTransposeMixin, _ConvNd): planes. The transposed convolution operator multiplies each input value element-wise by a learnable kernel, and sums over the outputs from all input feature planes. - - **This module can be seen as the exact reverse of Conv3d** + + **This module can be seen as the exact reverse of Conv3d** | :attr:`stride` controls the stride for the cross-correlation. - | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides + | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides for :attr:`padding` number of points - | If :attr:`padding` is non-zero, then the output is implicitly zero-padded on both sides + | If :attr:`padding` is non-zero, then the output is implicitly zero-padded on both sides for :attr:`padding` number of points - | :attr:`dilation` controls the spacing between the kernel points. It is harder to describe, + | :attr:`dilation` controls the spacing between the kernel points. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - | :attr:`groups` controls the connections between inputs and outputs. - | At groups=1, all inputs are convolved to all outputs. - | At groups=2, the operation becomes equivalent to having two conv layers - side by side, each seeing half the input channels, + | :attr:`groups` controls the connections between inputs and outputs. + | At groups=1, all inputs are convolved to all outputs. + | At groups=2, the operation becomes equivalent to having two conv layers + side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated. - - The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`, + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`, :attr:`dilation` can either be: - a single ``int`` -- in which case the same value is used for the height and width dimension - - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, the second `int` for the width dimension and the third `int` for the width dimension .. note:: - + Depending of the size of your kernel, several (of the last) - columns of the input might be lost, because it is a valid `cross-correlation`_, + columns of the input might be lost, because it is a valid `cross-correlation`_, and not a full `cross-correlation`_. It is up to the user to add proper padding. - + Args: - in_channels (long): The number of expected input channels in the image given as input - out_channels (long): The number of output channels the convolution layer will produce - kernel_size (long or tuple): the size of the convolving kernel. - stride (long or tuple, optional): the stride of the convolving kernel. - padding (long or tuple, optional): zero-padding to be added to the input on both sides - dilation (long or tuple, optional): controls the kernel striding - groups (long, optional): controls the number of blocked connections from input to output - bias (bool, optional): If True, adds a learnable bias to the output before convolving - + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution + padding (int or tuple, optional): Zero-padding added to both sides of the input + dilation (int or tuple, optional): Spacing between kernel elements + groups (int, optional): Number of blocked connections from input channels to output channels + bias (bool, optional): If True, adds a learnable bias to the output + Shape: - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` - - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where + - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where :math:`D_{out} = (D_{in} - 1) * stride[0] - 2 * padding[0] + kernel_size[0] + output_padding[0]` :math:`H_{out} = (H_{in} - 1) * stride[1] - 2 * padding[1] + kernel_size[1] + output_padding[1]` :math:`W_{out} = (W_{in} - 1) * stride[2] - 2 * padding[2] + kernel_size[2] + output_padding[2]`