mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: https://github.com/pytorch/pytorch/issues/2118 I'm not sure I'm doing it correctly, so I'll add tests if we decide that it's roughly correct. Pull Request resolved: https://github.com/pytorch/pytorch/pull/22245 Differential Revision: D16508957 Pulled By: Chillee fbshipit-source-id: a8dc7af999ba698c921006889f71cb1bc5a59d50
30 lines
819 B
Python
30 lines
819 B
Python
from .module import Module
|
|
|
|
class Flatten(Module):
|
|
r"""
|
|
Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
|
|
Args:
|
|
start_dim: first dim to flatten (default = 1).
|
|
end_dim: last dim to flatten (default = -1).
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *dims)`
|
|
- Output: :math:`(N, \prod *dims)` (for the default case).
|
|
|
|
|
|
Examples::
|
|
>>> m = nn.Sequential(
|
|
>>> nn.Conv2d(1, 32, 5, 1, 1),
|
|
>>> nn.Flatten()
|
|
>>> )
|
|
"""
|
|
__constants__ = ['start_dim', 'end_dim']
|
|
|
|
def __init__(self, start_dim=1, end_dim=-1):
|
|
super(Flatten, self).__init__()
|
|
self.start_dim = start_dim
|
|
self.end_dim = end_dim
|
|
|
|
def forward(self, input):
|
|
return input.flatten(self.start_dim, self.end_dim)
|