mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Added a flatten module (#22245)
Summary: https://github.com/pytorch/pytorch/issues/2118 I'm not sure I'm doing it correctly, so I'll add tests if we decide that it's roughly correct. Pull Request resolved: https://github.com/pytorch/pytorch/pull/22245 Differential Revision: D16508957 Pulled By: Chillee fbshipit-source-id: a8dc7af999ba698c921006889f71cb1bc5a59d50
This commit is contained in:
parent
5b0484d977
commit
1c00e0fc3f
|
|
@ -869,3 +869,9 @@ Utilities
|
|||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torch.nn.utils.rnn.pack_sequence
|
||||
|
||||
:hidden:`Flatten`
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: Flatten
|
||||
:members:
|
||||
|
|
|
|||
|
|
@ -104,6 +104,11 @@ module_tests = [
|
|||
module_name='Tanh',
|
||||
input_size=(2, 3, 4, 5)
|
||||
),
|
||||
dict(
|
||||
module_name='Flatten',
|
||||
input_size=(2, 3, 4, 5),
|
||||
reference_fn=lambda i, *_: torch.flatten(i, 1)
|
||||
),
|
||||
dict(
|
||||
module_name='Softmax',
|
||||
constructor_args=(1,),
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ from .distance import PairwiseDistance, CosineSimilarity
|
|||
from .fold import Fold, Unfold
|
||||
from .adaptive import AdaptiveLogSoftmaxWithLoss
|
||||
from .transformer import TransformerEncoder, TransformerDecoder, \
|
||||
TransformerEncoderLayer, TransformerDecoderLayer, Transformer
|
||||
TransformerEncoderLayer, TransformerDecoderLayer, Transformer
|
||||
from .flatten import Flatten
|
||||
|
||||
__all__ = [
|
||||
'Module', 'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d',
|
||||
|
|
@ -50,6 +51,7 @@ __all__ = [
|
|||
'PairwiseDistance', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d',
|
||||
'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d',
|
||||
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
|
||||
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
|
||||
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
|
||||
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
|
||||
'Flatten'
|
||||
]
|
||||
|
|
|
|||
29
torch/nn/modules/flatten.py
Normal file
29
torch/nn/modules/flatten.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
from .module import Module
|
||||
|
||||
class Flatten(Module):
|
||||
r"""
|
||||
Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
|
||||
Args:
|
||||
start_dim: first dim to flatten (default = 1).
|
||||
end_dim: last dim to flatten (default = -1).
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, *dims)`
|
||||
- Output: :math:`(N, \prod *dims)` (for the default case).
|
||||
|
||||
|
||||
Examples::
|
||||
>>> m = nn.Sequential(
|
||||
>>> nn.Conv2d(1, 32, 5, 1, 1),
|
||||
>>> nn.Flatten()
|
||||
>>> )
|
||||
"""
|
||||
__constants__ = ['start_dim', 'end_dim']
|
||||
|
||||
def __init__(self, start_dim=1, end_dim=-1):
|
||||
super(Flatten, self).__init__()
|
||||
self.start_dim = start_dim
|
||||
self.end_dim = end_dim
|
||||
|
||||
def forward(self, input):
|
||||
return input.flatten(self.start_dim, self.end_dim)
|
||||
Loading…
Reference in New Issue
Block a user