Added a flatten module (#22245)

Summary:
https://github.com/pytorch/pytorch/issues/2118

I'm not sure I'm doing it correctly, so I'll add tests if we decide that it's roughly correct.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22245

Differential Revision: D16508957

Pulled By: Chillee

fbshipit-source-id: a8dc7af999ba698c921006889f71cb1bc5a59d50
This commit is contained in:
Horace He 2019-07-25 22:45:24 -07:00 committed by Facebook Github Bot
parent 5b0484d977
commit 1c00e0fc3f
4 changed files with 44 additions and 2 deletions

View File

@ -869,3 +869,9 @@ Utilities
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torch.nn.utils.rnn.pack_sequence
:hidden:`Flatten`
~~~~~~~~~~~~~~~~~
.. autoclass:: Flatten
:members:

View File

@ -104,6 +104,11 @@ module_tests = [
module_name='Tanh',
input_size=(2, 3, 4, 5)
),
dict(
module_name='Flatten',
input_size=(2, 3, 4, 5),
reference_fn=lambda i, *_: torch.flatten(i, 1)
),
dict(
module_name='Softmax',
constructor_args=(1,),

View File

@ -28,7 +28,8 @@ from .distance import PairwiseDistance, CosineSimilarity
from .fold import Fold, Unfold
from .adaptive import AdaptiveLogSoftmaxWithLoss
from .transformer import TransformerEncoder, TransformerDecoder, \
TransformerEncoderLayer, TransformerDecoderLayer, Transformer
TransformerEncoderLayer, TransformerDecoderLayer, Transformer
from .flatten import Flatten
__all__ = [
'Module', 'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d',
@ -50,6 +51,7 @@ __all__ = [
'PairwiseDistance', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d',
'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d',
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
'Flatten'
]

View File

@ -0,0 +1,29 @@
from .module import Module
class Flatten(Module):
r"""
Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
Args:
start_dim: first dim to flatten (default = 1).
end_dim: last dim to flatten (default = -1).
Shape:
- Input: :math:`(N, *dims)`
- Output: :math:`(N, \prod *dims)` (for the default case).
Examples::
>>> m = nn.Sequential(
>>> nn.Conv2d(1, 32, 5, 1, 1),
>>> nn.Flatten()
>>> )
"""
__constants__ = ['start_dim', 'end_dim']
def __init__(self, start_dim=1, end_dim=-1):
super(Flatten, self).__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, input):
return input.flatten(self.start_dim, self.end_dim)