pytorch/torch/nn/modules/flatten.py
Horace He 1c00e0fc3f Added a flatten module (#22245)
Summary:
https://github.com/pytorch/pytorch/issues/2118

I'm not sure I'm doing it correctly, so I'll add tests if we decide that it's roughly correct.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22245

Differential Revision: D16508957

Pulled By: Chillee

fbshipit-source-id: a8dc7af999ba698c921006889f71cb1bc5a59d50
2019-07-25 22:48:52 -07:00

30 lines
819 B
Python

from .module import Module
class Flatten(Module):
r"""
Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
Args:
start_dim: first dim to flatten (default = 1).
end_dim: last dim to flatten (default = -1).
Shape:
- Input: :math:`(N, *dims)`
- Output: :math:`(N, \prod *dims)` (for the default case).
Examples::
>>> m = nn.Sequential(
>>> nn.Conv2d(1, 32, 5, 1, 1),
>>> nn.Flatten()
>>> )
"""
__constants__ = ['start_dim', 'end_dim']
def __init__(self, start_dim=1, end_dim=-1):
super(Flatten, self).__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, input):
return input.flatten(self.start_dim, self.end_dim)