pytorch/torch/nn/modules/conv.py
2017-01-13 15:22:57 -05:00

610 lines
28 KiB
Python

import math
import torch
from torch.nn.parameter import Parameter
from .. import functional as F
from .module import Module
from .utils import _single, _pair, _triple
class _ConvNd(Module):
def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding, groups, bias):
super(_ConvNd, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = transposed
self.output_padding = output_padding
self.groups = groups
if transposed:
self.weight = Parameter(torch.Tensor(
in_channels, out_channels // groups, *kernel_size))
else:
self.weight = Parameter(torch.Tensor(
out_channels, in_channels // groups, *kernel_size))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def __repr__(self):
s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.output_padding != (0,) * len(self.output_padding):
s += ', output_padding={output_padding}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)
class Conv1d(_ConvNd):
r"""Applies a 1D convolution over an input signal composed of several input
planes.
In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, L)`
and output :math:`(N, C_{out}, L_{out})` can be precisely described as:
.. math::
\begin{array}{ll}
out(N_i, C_{out_j}) = bias(C_{out_j})
+ \sum_{{k}=0}^{C_{in}-1} weight(C_{out_j}, k) \star input(N_i, k)
\end{array}
where :math:`\star` is the valid `cross-correlation`_ operator
| :attr:`stride` controls the stride for the cross-correlation.
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
for :attr:`padding` number of points
| :attr:`dilation` controls the spacing between the kernel points. It is harder to describe,
but this `link`_ has a nice visualization of what :attr:`dilation` does.
| :attr:`groups` controls the connections between inputs and outputs.
| At groups=1, all inputs are convolved to all outputs.
| At groups=2, the operation becomes equivalent to having two conv layers
side by side, each seeing half the input channels,
and producing half the output channels, and both subsequently concatenated.
.. note::
Depending of the size of your kernel, several (of the last)
columns of the input might be lost, because it is a valid `cross-correlation`_,
and not a full `cross-correlation`_.
It is up to the user to add proper padding.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution
padding (int or tuple, optional): Zero-padding added to both sides of the input
dilation (int or tuple, optional): Spacing between kernel elements
groups (int, optional): Number of blocked connections from input channels to output channels
bias (bool, optional): If True, adds a learnable bias to the output
Shape:
- Input: :math:`(N, C_{in}, L_{in})`
- Output: :math:`(N, C_{out}, L_{out})` where
:math:`L_{out} = floor((L_{in} + 2 * padding - dilation * (kernel\_size - 1) - 1) / stride + 1)`
Attributes:
weight (Tensor): the learnable weights of the module of shape (out_channels, in_channels, kernel_size)
bias (Tensor): the learnable bias of the module of shape (out_channels)
Examples::
>>> m = nn.Conv1d(16, 33, 3, stride=2)
>>> input = autograd.Variable(torch.randn(20, 16, 50))
>>> output = m(input)
.. _cross-correlation:
https://en.wikipedia.org/wiki/Cross-correlation
.. _link:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = _single(padding)
dilation = _single(dilation)
super(Conv1d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _single(0), groups, bias)
def forward(self, input):
return F.conv1d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
class Conv2d(_ConvNd):
r"""Applies a 2D convolution over an input signal composed of several input
planes.
In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, H, W)`
and output :math:`(N, C_{out}, H_{out}, W_{out})` can be precisely described as:
.. math::
\begin{array}{ll}
out(N_i, C_{out_j}) = bias(C_{out_j})
+ \sum_{{k}=0}^{C_{in}-1} weight(C_{out_j}, k) \star input(N_i, k)
\end{array}
where :math:`\star` is the valid 2D `cross-correlation`_ operator
| :attr:`stride` controls the stride for the cross-correlation.
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
for :attr:`padding` number of points
| :attr:`dilation` controls the spacing between the kernel points. It is harder to describe,
but this `link`_ has a nice visualization of what :attr:`dilation` does.
| :attr:`groups` controls the connections between inputs and outputs.
| At groups=1, all inputs are convolved to all outputs.
| At groups=2, the operation becomes equivalent to having two conv layers
side by side, each seeing half the input channels,
and producing half the output channels, and both subsequently concatenated.
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
- a single ``int`` -- in which case the same value is used for the height and width dimension
- a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
and the second `int` for the width dimension
.. note::
Depending of the size of your kernel, several (of the last)
columns of the input might be lost, because it is a valid `cross-correlation`_,
and not a full `cross-correlation`_.
It is up to the user to add proper padding.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution
padding (int or tuple, optional): Zero-padding added to both sides of the input
dilation (int or tuple, optional): Spacing between kernel elements
groups (int, optional): Number of blocked connections from input channels to output channels
bias (bool, optional): If True, adds a learnable bias to the output
Shape:
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
:math:`H_{out} = floor((H_{in} + 2 * padding[0] - dilation[0] * (kernel\_size[0] - 1) - 1) / stride[0] + 1)`
:math:`W_{out} = floor((W_{in} + 2 * padding[1] - dilation[1] * (kernel\_size[1] - 1) - 1) / stride[1] + 1)`
Attributes:
weight (Tensor): the learnable weights of the module of shape (out_channels, in_channels, kernel_size[0], kernel_size[1])
bias (Tensor): the learnable bias of the module of shape (out_channels)
Examples::
>>> # With square kernels and equal stride
>>> m = nn.Conv2d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> # non-square kernels and unequal stride and with padding and dilation
>>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
>>> input = autograd.Variable(torch.randn(20, 16, 50, 100))
>>> output = m(input)
.. _cross-correlation:
https://en.wikipedia.org/wiki/Cross-correlation
.. _link:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super(Conv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias)
def forward(self, input):
return F.conv2d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
class Conv3d(_ConvNd):
r"""Applies a 3D convolution over an input signal composed of several input
planes.
In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)`
and output :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` can be precisely described as:
.. math::
\begin{array}{ll}
out(N_i, C_{out_j}) = bias(C_{out_j})
+ \sum_{{k}=0}^{C_{in}-1} weight(C_{out_j}, k) \star input(N_i, k)
\end{array}
where :math:`\star` is the valid 3D `cross-correlation`_ operator
| :attr:`stride` controls the stride for the cross-correlation.
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
for :attr:`padding` number of points
| :attr:`dilation` controls the spacing between the kernel points. It is harder to describe,
but this `link`_ has a nice visualization of what :attr:`dilation` does.
| :attr:`groups` controls the connections between inputs and outputs.
| At groups=1, all inputs are convolved to all outputs.
| At groups=2, the operation becomes equivalent to having two conv layers
side by side, each seeing half the input channels,
and producing half the output channels, and both subsequently concatenated.
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
- a single ``int`` -- in which case the same value is used for the height and width dimension
- a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
the second `int` for the width dimension and the third `int` for the width dimension
.. note::
Depending of the size of your kernel, several (of the last)
columns of the input might be lost, because it is a valid `cross-correlation`_,
and not a full `cross-correlation`_.
It is up to the user to add proper padding.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution
padding (int or tuple, optional): Zero-padding added to both sides of the input
dilation (int or tuple, optional): Spacing between kernel elements
groups (int, optional): Number of blocked connections from input channels to output channels
bias (bool, optional): If True, adds a learnable bias to the output
Shape:
- Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where
:math:`D_{out} = floor((D_{in} + 2 * padding[0] - dilation[0] * (kernel\_size[0] - 1) - 1) / stride[0] + 1)`
:math:`H_{out} = floor((H_{in} + 2 * padding[1] - dilation[1] * (kernel\_size[1] - 1) - 1) / stride[1] + 1)`
:math:`W_{out} = floor((W_{in} + 2 * padding[2] - dilation[2] * (kernel\_size[2] - 1) - 1) / stride[2] + 1)`
Attributes:
weight (Tensor): the learnable weights of the module of shape (out_channels, in_channels, kernel_size[0], kernel_size[1], kernel_size[2])
bias (Tensor): the learnable bias of the module of shape (out_channels)
Examples::
>>> # With square kernels and equal stride
>>> m = nn.Conv3d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
>>> input = autograd.Variable(torch.randn(20, 16, 10, 50, 100))
>>> output = m(input)
.. _cross-correlation:
https://en.wikipedia.org/wiki/Cross-correlation
.. _link:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
dilation = _triple(dilation)
super(Conv3d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _triple(0), groups, bias)
def forward(self, input):
return F.conv3d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
class _ConvTransposeMixin(object):
def forward(self, input, output_size=None):
output_padding = self._output_padding(input, output_size)
func = self._backend.ConvNd(
self.stride, self.padding, self.dilation, self.transposed,
output_padding, self.groups)
if self.bias is None:
return func(input, self.weight)
else:
return func(input, self.weight, self.bias)
def _output_padding(self, input, output_size):
if output_size is None:
return self.output_padding
output_size = list(output_size)
k = input.dim() - 2
if len(output_size) == k + 2:
output_size = output_size[-2:]
if len(output_size) != k:
raise ValueError(
"output_size must have {} or {} elements (got {})"
.format(k, k + 2, len(output_size)))
def dim_size(d):
return ((input.size(d + 2) - 1) * self.stride[d] -
2 * self.padding[d] + self.kernel_size[d])
min_sizes = [dim_size(d) for d in range(k)]
max_sizes = [min_sizes[d] + self.stride[d] - 1 for d in range(k)]
for size, min_size, max_size in zip(output_size, min_sizes, max_sizes):
if size < min_size or size > max_size:
raise ValueError((
"requested an output size of {}, but valid sizes range "
"from {} to {} (for an input of {})").format(
output_size, min_sizes, max_sizes, input.size()[2:]))
return tuple([output_size[d] - min_sizes[d] for d in range(k)])
class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
"""Applies a 1D transposed convolution operator over an input image
composed of several input planes.
This module can be seen as the gradient of Conv1d with respect to its input.
.. note::
Depending of the size of your kernel, several (of the last)
columns of the input might be lost, because it is a valid `cross-correlation`_,
and not a full `cross-correlation`_.
It is up to the user to add proper padding.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution
padding (int or tuple, optional): Zero-padding added to both sides of the input
dilation (int or tuple, optional): Spacing between kernel elements
groups (int, optional): Number of blocked connections from input channels to output channels
bias (bool, optional): If True, adds a learnable bias to the output
Shape:
- Input: :math:`(N, C_{in}, L_{in})`
- Output: :math:`(N, C_{out}, L_{out})` where
:math:`L_{out} = (L_{in} - 1) * stride - 2 * padding + kernel_size + output_padding`
Attributes:
weight (Tensor): the learnable weights of the module of shape (in_channels, out_channels, kernel_size[0], kernel_size[1])
bias (Tensor): the learnable bias of the module of shape (out_channels)
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True):
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = _single(padding)
dilation = _single(1)
output_padding = _single(output_padding)
super(ConvTranspose1d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias)
def forward(self, input, output_size=None):
output_padding = self._output_padding(input, output_size)
return F.conv_transpose1d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups)
class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
r"""Applies a 2D transposed convolution operator over an input image
composed of several input planes.
This module can be seen as the gradient of Conv2d with respect to its input.
| :attr:`stride` controls the stride for the cross-correlation.
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
for :attr:`padding` number of points
| If :attr:`padding` is non-zero, then the output is implicitly zero-padded on both sides
for :attr:`padding` number of points
| :attr:`dilation` controls the spacing between the kernel points. It is harder to describe,
but this `link`_ has a nice visualization of what :attr:`dilation` does.
| :attr:`groups` controls the connections between inputs and outputs.
| At groups=1, all inputs are convolved to all outputs.
| At groups=2, the operation becomes equivalent to having two conv layers
side by side, each seeing half the input channels,
and producing half the output channels, and both subsequently concatenated.
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`,
:attr:`dilation` can either be:
- a single ``int`` -- in which case the same value is used for the height and width dimension
- a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
and the second `int` for the width dimension
.. note::
Depending of the size of your kernel, several (of the last)
columns of the input might be lost, because it is a valid `cross-correlation`_,
and not a full `cross-correlation`_.
It is up to the user to add proper padding.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution
padding (int or tuple, optional): Zero-padding added to both sides of the input
dilation (int or tuple, optional): Spacing between kernel elements
groups (int, optional): Number of blocked connections from input channels to output channels
bias (bool, optional): If True, adds a learnable bias to the output
Shape:
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
:math:`H_{out} = (H_{in} - 1) * stride[0] - 2 * padding[0] + kernel_size[0] + output_padding[0]`
:math:`W_{out} = (W_{in} - 1) * stride[1] - 2 * padding[1] + kernel_size[1] + output_padding[1]`
Attributes:
weight (Tensor): the learnable weights of the module of shape (in_channels, out_channels, kernel_size[0], kernel_size[1])
bias (Tensor): the learnable bias of the module of shape (out_channels)
Examples::
>>> # With square kernels and equal stride
>>> m = nn.ConvTranspose2d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> input = autograd.Variable(torch.randn(20, 16, 50, 100))
>>> output = m(input)
>>> # exact output size can be also specified as an argument
>>> input = autograd.Variable(torch.randn(1, 16, 12, 12))
>>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1)
>>> upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
>>> h = downsample(input)
>>> h.size()
torch.Size([1, 16, 6, 6])
>>> output = upsample(h, output_size=input.size())
>>> output.size()
torch.Size([1, 16, 12, 12])
.. _cross-correlation:
https://en.wikipedia.org/wiki/Cross-correlation
.. _link:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(1)
output_padding = _pair(output_padding)
super(ConvTranspose2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias)
def forward(self, input, output_size=None):
output_padding = self._output_padding(input, output_size)
return F.conv_transpose2d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups)
class ConvTranspose3d(_ConvTransposeMixin, _ConvNd):
r"""Applies a 3D transposed convolution operator over an input image composed of several input
planes.
The transposed convolution operator multiplies each input value element-wise by a learnable kernel,
and sums over the outputs from all input feature planes.
**This module can be seen as the exact reverse of Conv3d**
| :attr:`stride` controls the stride for the cross-correlation.
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
for :attr:`padding` number of points
| If :attr:`padding` is non-zero, then the output is implicitly zero-padded on both sides
for :attr:`padding` number of points
| :attr:`dilation` controls the spacing between the kernel points. It is harder to describe,
but this `link`_ has a nice visualization of what :attr:`dilation` does.
| :attr:`groups` controls the connections between inputs and outputs.
| At groups=1, all inputs are convolved to all outputs.
| At groups=2, the operation becomes equivalent to having two conv layers
side by side, each seeing half the input channels,
and producing half the output channels, and both subsequently concatenated.
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`,
:attr:`dilation` can either be:
- a single ``int`` -- in which case the same value is used for the height and width dimension
- a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
the second `int` for the width dimension and the third `int` for the width dimension
.. note::
Depending of the size of your kernel, several (of the last)
columns of the input might be lost, because it is a valid `cross-correlation`_,
and not a full `cross-correlation`_.
It is up to the user to add proper padding.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution
padding (int or tuple, optional): Zero-padding added to both sides of the input
dilation (int or tuple, optional): Spacing between kernel elements
groups (int, optional): Number of blocked connections from input channels to output channels
bias (bool, optional): If True, adds a learnable bias to the output
Shape:
- Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where
:math:`D_{out} = (D_{in} - 1) * stride[0] - 2 * padding[0] + kernel_size[0] + output_padding[0]`
:math:`H_{out} = (H_{in} - 1) * stride[1] - 2 * padding[1] + kernel_size[1] + output_padding[1]`
:math:`W_{out} = (W_{in} - 1) * stride[2] - 2 * padding[2] + kernel_size[2] + output_padding[2]`
Attributes:
weight (Tensor): the learnable weights of the module of shape (in_channels, out_channels, kernel_size[0], kernel_size[1], kernel_size[2])
bias (Tensor): the learnable bias of the module of shape (out_channels)
Examples::
>>> # With square kernels and equal stride
>>> m = nn.ConvTranspose3d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2))
>>> input = autograd.Variable(torch.randn(20, 16, 10, 50, 100))
>>> output = m(input)
.. _cross-correlation:
https://en.wikipedia.org/wiki/Cross-correlation
.. _link:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True):
kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
dilation = _triple(1)
output_padding = _triple(output_padding)
super(ConvTranspose3d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias)
def forward(self, input, output_size=None):
output_padding = self._output_padding(input, output_size)
return F.conv_transpose3d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups)
# TODO: Conv2dLocal
# TODO: Conv2dMap
# TODO: ConvTranspose2dMap