mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Improve torch.flatten docs and add tests to test_view_ops (#49501)
Summary: Addresses https://github.com/pytorch/pytorch/issues/39474 Pull Request resolved: https://github.com/pytorch/pytorch/pull/49501 Reviewed By: mrshenli Differential Revision: D25740586 Pulled By: soulitzer fbshipit-source-id: 3d7bdbab91eb208ac9e6832bb766d9d95a00c103
This commit is contained in:
parent
b76822eb49
commit
fdb81c538a
|
|
@ -100,6 +100,12 @@ class TestViewOps(TestCase):
|
|||
|
||||
return True
|
||||
|
||||
# Returns true if v1 and v2 are views of the same base
|
||||
def is_view_of_same_base(self, v1, v2):
|
||||
if (not v1._is_view() or v1 is v2):
|
||||
return False
|
||||
return self.is_view_of(v1._base, v2)
|
||||
|
||||
# Performs transpose if contiguous=True, else returns the input tensor as is
|
||||
def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1):
|
||||
if contiguous:
|
||||
|
|
@ -457,6 +463,64 @@ class TestViewOps(TestCase):
|
|||
nv[6] = 0
|
||||
self.assertNotEqual(t[1, 1], nv[6])
|
||||
|
||||
def test_flatten_view(self, device):
|
||||
def test_writes_propagate(t, v):
|
||||
idx_t = (0,) * t.ndim
|
||||
idx_v = (0,) * v.ndim
|
||||
v[idx_v] = 0
|
||||
self.assertEqual(t[idx_t], v[idx_v])
|
||||
|
||||
t = torch.ones(1, 2, 3, 4, device=device)
|
||||
v = t.flatten()
|
||||
self.assertTrue(self.is_view_of(t, v))
|
||||
test_writes_propagate(t, v)
|
||||
|
||||
# zero-dimensional tensor
|
||||
t = torch.tensor(1, device=device)
|
||||
v = t.flatten()
|
||||
test_writes_propagate(t, v)
|
||||
self.assertTrue(self.is_view_of(t, v))
|
||||
|
||||
t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3)
|
||||
v = t.flatten(0, 1)
|
||||
test_writes_propagate(t, v)
|
||||
self.assertTrue(self.is_view_of_same_base(t, v))
|
||||
|
||||
# stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups:
|
||||
t = torch.ones(720, device=device) \
|
||||
.as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0))
|
||||
# [--1--|---2---|-3-] [--1--|----2---|-3-]
|
||||
v1 = t.flatten(0, 1)
|
||||
v2 = v1.flatten(1, 3)
|
||||
v3 = v2.flatten(2, 2)
|
||||
test_writes_propagate(t, v1)
|
||||
self.assertTrue(self.is_view_of_same_base(t, v1))
|
||||
test_writes_propagate(t, v2)
|
||||
self.assertTrue(self.is_view_of_same_base(t, v2))
|
||||
test_writes_propagate(t, v3)
|
||||
self.assertTrue(self.is_view_of_same_base(t, v3))
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
def test_flatten_nonview(self, device):
|
||||
def assert_is_nonview(t, nv):
|
||||
idx_t = (0,) * t.ndim
|
||||
idx_nv = (0,) * nv.ndim
|
||||
self.assertTrue(not nv._is_view())
|
||||
nv[idx_nv] = 0
|
||||
self.assertNotEqual(t[idx_t], nv[idx_nv])
|
||||
t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3)
|
||||
nv = t.flatten(1, 3)
|
||||
assert_is_nonview(t, nv)
|
||||
|
||||
t = torch.ones(2, 2, device=device).T
|
||||
nv = t.flatten()
|
||||
assert_is_nonview(t, nv)
|
||||
|
||||
# flatten returns the original object if start_dim=end_dim
|
||||
t = t = torch.ones(2, 2, device=device)
|
||||
nv = t.flatten(1, 1)
|
||||
self.assertTrue(t is nv)
|
||||
|
||||
def test_basic_indexing_slice_view(self, device):
|
||||
t = torch.ones(5, 5, device=device)
|
||||
v = t[:2, :3]
|
||||
|
|
|
|||
|
|
@ -3095,7 +3095,17 @@ add_docstr(torch.flatten,
|
|||
r"""
|
||||
flatten(input, start_dim=0, end_dim=-1) -> Tensor
|
||||
|
||||
Flattens a contiguous range of dims in a tensor.
|
||||
Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim`
|
||||
are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened.
|
||||
The order of elements in :attr:`input` is unchanged.
|
||||
|
||||
Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view,
|
||||
or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can
|
||||
be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the
|
||||
flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned.
|
||||
|
||||
.. note::
|
||||
Flattening a zero-dimensional tensor will return a one-dimensional view.
|
||||
|
||||
Args:
|
||||
{input}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user