mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38211 Just because the annotations are inline doesn't mean the files type check; most of the newly annotated files have type errors and I added exclusions for them in mypy.ini. The payoff of moving all of these modules inline is I can delete the relevant code generation logic for the pyi files (which was added ignore annotations that weren't actually relevant anymore.) For the most part the translation was completely mechanical, but there were two hairy issues. First, I needed to work around a Python 3.6 and earlier bug where Generic has a nontrivial metaclass. This fix is in torch/jit/__init__.py. Second, module.py, we need to apply the same fix for avoiding contravariance checks that the pyi file used to have; this is done by declaring forward as a variable (rather than a function), which appears to be sufficient enough to get mypy to not contravariantly check input arguments. Because we aren't actually typechecking these modules in most cases, it is inevitable that some of these type annotations are wrong. I slavishly copied the old annotations from the pyi files unless there was an obvious correction I could make. These annotations will probably need fixing up later. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D21497397 Pulled By: ezyang fbshipit-source-id: 2b08bacc152c48f074e7edc4ee5dce1b77d83702
455 lines
16 KiB
Python
455 lines
16 KiB
Python
from .module import Module
|
|
from .utils import _pair, _quadruple, _ntuple
|
|
from .. import functional as F
|
|
|
|
from torch import Tensor
|
|
from ..common_types import _size_2_t, _size_4_t, _size_6_t
|
|
|
|
|
|
# TODO: grad_output size asserts in THNN
|
|
|
|
|
|
class _ConstantPadNd(Module):
|
|
__constants__ = ['padding', 'value']
|
|
value: float
|
|
|
|
def __init__(self, value: float) -> None:
|
|
super(_ConstantPadNd, self).__init__()
|
|
self.value = value
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.pad(input, self.padding, 'constant', self.value)
|
|
|
|
def extra_repr(self) -> str:
|
|
return 'padding={}, value={}'.format(self.padding, self.value)
|
|
|
|
|
|
class ConstantPad1d(_ConstantPadNd):
|
|
r"""Pads the input tensor boundaries with a constant value.
|
|
|
|
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
|
|
|
Args:
|
|
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
|
padding in both boundaries. If a 2-`tuple`, uses
|
|
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, W_{in})`
|
|
- Output: :math:`(N, C, W_{out})` where
|
|
|
|
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.ConstantPad1d(2, 3.5)
|
|
>>> input = torch.randn(1, 2, 4)
|
|
>>> input
|
|
tensor([[[-1.0491, -0.7152, -0.0749, 0.8530],
|
|
[-1.3287, 1.8966, 0.1466, -0.2771]]])
|
|
>>> m(input)
|
|
tensor([[[ 3.5000, 3.5000, -1.0491, -0.7152, -0.0749, 0.8530, 3.5000,
|
|
3.5000],
|
|
[ 3.5000, 3.5000, -1.3287, 1.8966, 0.1466, -0.2771, 3.5000,
|
|
3.5000]]])
|
|
>>> m = nn.ConstantPad1d(2, 3.5)
|
|
>>> input = torch.randn(1, 2, 3)
|
|
>>> input
|
|
tensor([[[ 1.6616, 1.4523, -1.1255],
|
|
[-3.6372, 0.1182, -1.8652]]])
|
|
>>> m(input)
|
|
tensor([[[ 3.5000, 3.5000, 1.6616, 1.4523, -1.1255, 3.5000, 3.5000],
|
|
[ 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000, 3.5000]]])
|
|
>>> # using different paddings for different sides
|
|
>>> m = nn.ConstantPad1d((3, 1), 3.5)
|
|
>>> m(input)
|
|
tensor([[[ 3.5000, 3.5000, 3.5000, 1.6616, 1.4523, -1.1255, 3.5000],
|
|
[ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]])
|
|
|
|
"""
|
|
padding: _size_2_t
|
|
|
|
def __init__(self, padding: _size_2_t, value: float):
|
|
super(ConstantPad1d, self).__init__(value)
|
|
self.padding = _pair(padding)
|
|
|
|
|
|
class ConstantPad2d(_ConstantPadNd):
|
|
r"""Pads the input tensor boundaries with a constant value.
|
|
|
|
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
|
|
|
Args:
|
|
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
|
padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
|
|
:math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, H_{in}, W_{in})`
|
|
- Output: :math:`(N, C, H_{out}, W_{out})` where
|
|
|
|
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
|
|
|
|
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.ConstantPad2d(2, 3.5)
|
|
>>> input = torch.randn(1, 2, 2)
|
|
>>> input
|
|
tensor([[[ 1.6585, 0.4320],
|
|
[-0.8701, -0.4649]]])
|
|
>>> m(input)
|
|
tensor([[[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
|
|
[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
|
|
[ 3.5000, 3.5000, 1.6585, 0.4320, 3.5000, 3.5000],
|
|
[ 3.5000, 3.5000, -0.8701, -0.4649, 3.5000, 3.5000],
|
|
[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
|
|
[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000]]])
|
|
>>> # using different paddings for different sides
|
|
>>> m = nn.ConstantPad2d((3, 0, 2, 1), 3.5)
|
|
>>> m(input)
|
|
tensor([[[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
|
|
[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
|
|
[ 3.5000, 3.5000, 3.5000, 1.6585, 0.4320],
|
|
[ 3.5000, 3.5000, 3.5000, -0.8701, -0.4649],
|
|
[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000]]])
|
|
|
|
"""
|
|
__constants__ = ['padding', 'value']
|
|
padding: _size_4_t
|
|
|
|
def __init__(self, padding: _size_4_t, value: float) -> None:
|
|
super(ConstantPad2d, self).__init__(value)
|
|
self.padding = _quadruple(padding)
|
|
|
|
|
|
class ConstantPad3d(_ConstantPadNd):
|
|
r"""Pads the input tensor boundaries with a constant value.
|
|
|
|
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
|
|
|
Args:
|
|
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
|
padding in all boundaries. If a 6-`tuple`, uses
|
|
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
|
|
:math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
|
|
:math:`\text{padding\_front}`, :math:`\text{padding\_back}`)
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
|
|
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` where
|
|
|
|
:math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`
|
|
|
|
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
|
|
|
|
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.ConstantPad3d(3, 3.5)
|
|
>>> input = torch.randn(16, 3, 10, 20, 30)
|
|
>>> output = m(input)
|
|
>>> # using different paddings for different sides
|
|
>>> m = nn.ConstantPad3d((3, 3, 6, 6, 0, 1), 3.5)
|
|
>>> output = m(input)
|
|
|
|
"""
|
|
padding: _size_6_t
|
|
|
|
def __init__(self, padding: _size_6_t, value: float) -> None:
|
|
super(ConstantPad3d, self).__init__(value)
|
|
self.padding = _ntuple(6)(padding)
|
|
|
|
|
|
class _ReflectionPadNd(Module):
|
|
__constants__ = ['padding']
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.pad(input, self.padding, 'reflect')
|
|
|
|
def extra_repr(self) -> str:
|
|
return '{}'.format(self.padding)
|
|
|
|
|
|
class ReflectionPad1d(_ReflectionPadNd):
|
|
r"""Pads the input tensor using the reflection of the input boundary.
|
|
|
|
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
|
|
|
Args:
|
|
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
|
padding in all boundaries. If a 2-`tuple`, uses
|
|
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, W_{in})`
|
|
- Output: :math:`(N, C, W_{out})` where
|
|
|
|
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.ReflectionPad1d(2)
|
|
>>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4)
|
|
>>> input
|
|
tensor([[[0., 1., 2., 3.],
|
|
[4., 5., 6., 7.]]])
|
|
>>> m(input)
|
|
tensor([[[2., 1., 0., 1., 2., 3., 2., 1.],
|
|
[6., 5., 4., 5., 6., 7., 6., 5.]]])
|
|
>>> # using different paddings for different sides
|
|
>>> m = nn.ReflectionPad1d((3, 1))
|
|
>>> m(input)
|
|
tensor([[[3., 2., 1., 0., 1., 2., 3., 2.],
|
|
[7., 6., 5., 4., 5., 6., 7., 6.]]])
|
|
|
|
"""
|
|
padding: _size_2_t
|
|
|
|
def __init__(self, padding: _size_2_t) -> None:
|
|
super(ReflectionPad1d, self).__init__()
|
|
self.padding = _pair(padding)
|
|
|
|
|
|
class ReflectionPad2d(_ReflectionPadNd):
|
|
r"""Pads the input tensor using the reflection of the input boundary.
|
|
|
|
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
|
|
|
Args:
|
|
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
|
padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
|
|
:math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, H_{in}, W_{in})`
|
|
- Output: :math:`(N, C, H_{out}, W_{out})` where
|
|
|
|
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
|
|
|
|
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.ReflectionPad2d(2)
|
|
>>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)
|
|
>>> input
|
|
tensor([[[[0., 1., 2.],
|
|
[3., 4., 5.],
|
|
[6., 7., 8.]]]])
|
|
>>> m(input)
|
|
tensor([[[[8., 7., 6., 7., 8., 7., 6.],
|
|
[5., 4., 3., 4., 5., 4., 3.],
|
|
[2., 1., 0., 1., 2., 1., 0.],
|
|
[5., 4., 3., 4., 5., 4., 3.],
|
|
[8., 7., 6., 7., 8., 7., 6.],
|
|
[5., 4., 3., 4., 5., 4., 3.],
|
|
[2., 1., 0., 1., 2., 1., 0.]]]])
|
|
>>> # using different paddings for different sides
|
|
>>> m = nn.ReflectionPad2d((1, 1, 2, 0))
|
|
>>> m(input)
|
|
tensor([[[[7., 6., 7., 8., 7.],
|
|
[4., 3., 4., 5., 4.],
|
|
[1., 0., 1., 2., 1.],
|
|
[4., 3., 4., 5., 4.],
|
|
[7., 6., 7., 8., 7.]]]])
|
|
|
|
"""
|
|
padding: _size_4_t
|
|
|
|
def __init__(self, padding: _size_4_t) -> None:
|
|
super(ReflectionPad2d, self).__init__()
|
|
self.padding = _quadruple(padding)
|
|
|
|
|
|
class _ReplicationPadNd(Module):
|
|
__constants__ = ['padding']
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.pad(input, self.padding, 'replicate')
|
|
|
|
def extra_repr(self) -> str:
|
|
return '{}'.format(self.padding)
|
|
|
|
|
|
class ReplicationPad1d(_ReplicationPadNd):
|
|
r"""Pads the input tensor using replication of the input boundary.
|
|
|
|
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
|
|
|
Args:
|
|
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
|
padding in all boundaries. If a 2-`tuple`, uses
|
|
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, W_{in})`
|
|
- Output: :math:`(N, C, W_{out})` where
|
|
|
|
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.ReplicationPad1d(2)
|
|
>>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4)
|
|
>>> input
|
|
tensor([[[0., 1., 2., 3.],
|
|
[4., 5., 6., 7.]]])
|
|
>>> m(input)
|
|
tensor([[[0., 0., 0., 1., 2., 3., 3., 3.],
|
|
[4., 4., 4., 5., 6., 7., 7., 7.]]])
|
|
>>> # using different paddings for different sides
|
|
>>> m = nn.ReplicationPad1d((3, 1))
|
|
>>> m(input)
|
|
tensor([[[0., 0., 0., 0., 1., 2., 3., 3.],
|
|
[4., 4., 4., 4., 5., 6., 7., 7.]]])
|
|
|
|
"""
|
|
padding: _size_2_t
|
|
|
|
def __init__(self, padding: _size_2_t) -> None:
|
|
super(ReplicationPad1d, self).__init__()
|
|
self.padding = _pair(padding)
|
|
|
|
|
|
class ReplicationPad2d(_ReplicationPadNd):
|
|
r"""Pads the input tensor using replication of the input boundary.
|
|
|
|
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
|
|
|
Args:
|
|
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
|
padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
|
|
:math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, H_{in}, W_{in})`
|
|
- Output: :math:`(N, C, H_{out}, W_{out})` where
|
|
|
|
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
|
|
|
|
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.ReplicationPad2d(2)
|
|
>>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)
|
|
>>> input
|
|
tensor([[[[0., 1., 2.],
|
|
[3., 4., 5.],
|
|
[6., 7., 8.]]]])
|
|
>>> m(input)
|
|
tensor([[[[0., 0., 0., 1., 2., 2., 2.],
|
|
[0., 0., 0., 1., 2., 2., 2.],
|
|
[0., 0., 0., 1., 2., 2., 2.],
|
|
[3., 3., 3., 4., 5., 5., 5.],
|
|
[6., 6., 6., 7., 8., 8., 8.],
|
|
[6., 6., 6., 7., 8., 8., 8.],
|
|
[6., 6., 6., 7., 8., 8., 8.]]]])
|
|
>>> # using different paddings for different sides
|
|
>>> m = nn.ReplicationPad2d((1, 1, 2, 0))
|
|
>>> m(input)
|
|
tensor([[[[0., 0., 1., 2., 2.],
|
|
[0., 0., 1., 2., 2.],
|
|
[0., 0., 1., 2., 2.],
|
|
[3., 3., 4., 5., 5.],
|
|
[6., 6., 7., 8., 8.]]]])
|
|
|
|
"""
|
|
padding: _size_4_t
|
|
|
|
def __init__(self, padding: _size_4_t) -> None:
|
|
super(ReplicationPad2d, self).__init__()
|
|
self.padding = _quadruple(padding)
|
|
|
|
|
|
class ReplicationPad3d(_ReplicationPadNd):
|
|
r"""Pads the input tensor using replication of the input boundary.
|
|
|
|
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
|
|
|
Args:
|
|
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
|
padding in all boundaries. If a 6-`tuple`, uses
|
|
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
|
|
:math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
|
|
:math:`\text{padding\_front}`, :math:`\text{padding\_back}`)
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
|
|
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` where
|
|
|
|
:math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`
|
|
|
|
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
|
|
|
|
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.ReplicationPad3d(3)
|
|
>>> input = torch.randn(16, 3, 8, 320, 480)
|
|
>>> output = m(input)
|
|
>>> # using different paddings for different sides
|
|
>>> m = nn.ReplicationPad3d((3, 3, 6, 6, 1, 1))
|
|
>>> output = m(input)
|
|
|
|
"""
|
|
padding: _size_6_t
|
|
|
|
def __init__(self, padding: _size_6_t) -> None:
|
|
super(ReplicationPad3d, self).__init__()
|
|
self.padding = _ntuple(6)(padding)
|
|
|
|
|
|
class ZeroPad2d(ConstantPad2d):
|
|
r"""Pads the input tensor boundaries with zero.
|
|
|
|
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
|
|
|
Args:
|
|
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
|
padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
|
|
:math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, H_{in}, W_{in})`
|
|
- Output: :math:`(N, C, H_{out}, W_{out})` where
|
|
|
|
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
|
|
|
|
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.ZeroPad2d(2)
|
|
>>> input = torch.randn(1, 1, 3, 3)
|
|
>>> input
|
|
tensor([[[[-0.1678, -0.4418, 1.9466],
|
|
[ 0.9604, -0.4219, -0.5241],
|
|
[-0.9162, -0.5436, -0.6446]]]])
|
|
>>> m(input)
|
|
tensor([[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
|
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
|
[ 0.0000, 0.0000, -0.1678, -0.4418, 1.9466, 0.0000, 0.0000],
|
|
[ 0.0000, 0.0000, 0.9604, -0.4219, -0.5241, 0.0000, 0.0000],
|
|
[ 0.0000, 0.0000, -0.9162, -0.5436, -0.6446, 0.0000, 0.0000],
|
|
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
|
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
|
|
>>> # using different paddings for different sides
|
|
>>> m = nn.ZeroPad2d((1, 1, 2, 0))
|
|
>>> m(input)
|
|
tensor([[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
|
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
|
[ 0.0000, -0.1678, -0.4418, 1.9466, 0.0000],
|
|
[ 0.0000, 0.9604, -0.4219, -0.5241, 0.0000],
|
|
[ 0.0000, -0.9162, -0.5436, -0.6446, 0.0000]]]])
|
|
|
|
"""
|
|
padding: _size_4_t
|
|
|
|
def __init__(self, padding: _size_4_t) -> None:
|
|
super(ZeroPad2d, self).__init__(padding, 0.)
|