Improve torch.flatten docs and add tests to test_view_ops (#49501)

Summary:
Addresses https://github.com/pytorch/pytorch/issues/39474

Pull Request resolved: https://github.com/pytorch/pytorch/pull/49501

Reviewed By: mrshenli

Differential Revision: D25740586

Pulled By: soulitzer

fbshipit-source-id: 3d7bdbab91eb208ac9e6832bb766d9d95a00c103
This commit is contained in:
Jeffrey Wan 2021-01-04 11:10:08 -08:00 committed by Facebook GitHub Bot
parent b76822eb49
commit fdb81c538a
2 changed files with 75 additions and 1 deletions

View File

@ -100,6 +100,12 @@ class TestViewOps(TestCase):
return True
# Returns true if v1 and v2 are views of the same base
def is_view_of_same_base(self, v1, v2):
if (not v1._is_view() or v1 is v2):
return False
return self.is_view_of(v1._base, v2)
# Performs transpose if contiguous=True, else returns the input tensor as is
def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1):
if contiguous:
@ -457,6 +463,64 @@ class TestViewOps(TestCase):
nv[6] = 0
self.assertNotEqual(t[1, 1], nv[6])
def test_flatten_view(self, device):
def test_writes_propagate(t, v):
idx_t = (0,) * t.ndim
idx_v = (0,) * v.ndim
v[idx_v] = 0
self.assertEqual(t[idx_t], v[idx_v])
t = torch.ones(1, 2, 3, 4, device=device)
v = t.flatten()
self.assertTrue(self.is_view_of(t, v))
test_writes_propagate(t, v)
# zero-dimensional tensor
t = torch.tensor(1, device=device)
v = t.flatten()
test_writes_propagate(t, v)
self.assertTrue(self.is_view_of(t, v))
t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3)
v = t.flatten(0, 1)
test_writes_propagate(t, v)
self.assertTrue(self.is_view_of_same_base(t, v))
# stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups:
t = torch.ones(720, device=device) \
.as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0))
# [--1--|---2---|-3-] [--1--|----2---|-3-]
v1 = t.flatten(0, 1)
v2 = v1.flatten(1, 3)
v3 = v2.flatten(2, 2)
test_writes_propagate(t, v1)
self.assertTrue(self.is_view_of_same_base(t, v1))
test_writes_propagate(t, v2)
self.assertTrue(self.is_view_of_same_base(t, v2))
test_writes_propagate(t, v3)
self.assertTrue(self.is_view_of_same_base(t, v3))
@onlyOnCPUAndCUDA
def test_flatten_nonview(self, device):
def assert_is_nonview(t, nv):
idx_t = (0,) * t.ndim
idx_nv = (0,) * nv.ndim
self.assertTrue(not nv._is_view())
nv[idx_nv] = 0
self.assertNotEqual(t[idx_t], nv[idx_nv])
t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3)
nv = t.flatten(1, 3)
assert_is_nonview(t, nv)
t = torch.ones(2, 2, device=device).T
nv = t.flatten()
assert_is_nonview(t, nv)
# flatten returns the original object if start_dim=end_dim
t = t = torch.ones(2, 2, device=device)
nv = t.flatten(1, 1)
self.assertTrue(t is nv)
def test_basic_indexing_slice_view(self, device):
t = torch.ones(5, 5, device=device)
v = t[:2, :3]

View File

@ -3095,7 +3095,17 @@ add_docstr(torch.flatten,
r"""
flatten(input, start_dim=0, end_dim=-1) -> Tensor
Flattens a contiguous range of dims in a tensor.
Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim`
are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened.
The order of elements in :attr:`input` is unchanged.
Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view,
or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can
be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the
flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned.
.. note::
Flattening a zero-dimensional tensor will return a one-dimensional view.
Args:
{input}