mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add torch.unflatten and improve its docs (#81399)
unflatten now has a free function version in torch.flatten in addition to
the method in torch.Tensor.flatten.
Updated docs to reflect this and polished them a little.
For consistency, changed the signature of the int version of unflatten in
native_functions.yaml.
Some override tests were failing because unflatten has unusual
characteristics in terms of the .int and .Dimname versions having
different number of arguments so this required some changes
to test/test_override.py
Removed support for using mix of integer and string arguments
when specifying dimensions in unflatten.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81399
Approved by: https://github.com/Lezcano, https://github.com/ngimel
This commit is contained in:
parent
5257d1d64b
commit
fd84c458f4
|
|
@ -2899,7 +2899,7 @@ static inline void handle_unflatten_exception(const std::runtime_error &e,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::optional<DimnameList> names) {
|
Tensor unflatten_impl(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::optional<DimnameList> names) {
|
||||||
dim = maybe_wrap_dim(dim, self.dim());
|
dim = maybe_wrap_dim(dim, self.dim());
|
||||||
|
|
||||||
TORCH_CHECK(sizes.size() > 0, "unflatten: sizes must be non-empty");
|
TORCH_CHECK(sizes.size() > 0, "unflatten: sizes must be non-empty");
|
||||||
|
|
@ -2938,8 +2938,12 @@ Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::option
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes) {
|
||||||
|
return native::unflatten_impl(self, dim, sizes, c10::nullopt);
|
||||||
|
}
|
||||||
|
|
||||||
Tensor unflatten(const Tensor& self, Dimname dim, IntArrayRef sizes, DimnameList names) {
|
Tensor unflatten(const Tensor& self, Dimname dim, IntArrayRef sizes, DimnameList names) {
|
||||||
return native::unflatten(self, dimname_to_position(self, dim), sizes, names);
|
return native::unflatten_impl(self, dimname_to_position(self, dim), sizes, names);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor view_as(const Tensor& self, const Tensor& other) {
|
Tensor view_as(const Tensor& self, const Tensor& other) {
|
||||||
|
|
|
||||||
|
|
@ -2247,11 +2247,11 @@
|
||||||
- func: flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a)
|
- func: flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a)
|
||||||
variants: function, method
|
variants: function, method
|
||||||
|
|
||||||
- func: unflatten.int(Tensor(a) self, int dim, int[] sizes, Dimname[]? names=None) -> Tensor(a)
|
- func: unflatten.int(Tensor(a) self, int dim, int[] sizes) -> Tensor(a)
|
||||||
variants: method
|
variants: function, method
|
||||||
|
|
||||||
- func: unflatten.Dimname(Tensor(a) self, Dimname dim, int[] sizes, Dimname[] names) -> Tensor(a)
|
- func: unflatten.Dimname(Tensor(a) self, Dimname dim, int[] sizes, Dimname[] names) -> Tensor(a)
|
||||||
variants: method
|
variants: function, method
|
||||||
|
|
||||||
- func: fill.Scalar(Tensor self, Scalar value) -> Tensor
|
- func: fill.Scalar(Tensor self, Scalar value) -> Tensor
|
||||||
variants: function
|
variants: function
|
||||||
|
|
|
||||||
|
|
@ -300,7 +300,6 @@ operators, see :ref:`name_inference_reference-doc`.
|
||||||
.. automethod:: align_as
|
.. automethod:: align_as
|
||||||
.. automethod:: align_to
|
.. automethod:: align_to
|
||||||
|
|
||||||
.. automethod:: unflatten
|
|
||||||
.. py:method:: flatten(dims, out_dim) -> Tensor
|
.. py:method:: flatten(dims, out_dim) -> Tensor
|
||||||
:noindex:
|
:noindex:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -685,6 +685,7 @@ Tensor class reference
|
||||||
Tensor.type
|
Tensor.type
|
||||||
Tensor.type_as
|
Tensor.type_as
|
||||||
Tensor.unbind
|
Tensor.unbind
|
||||||
|
Tensor.unflatten
|
||||||
Tensor.unfold
|
Tensor.unfold
|
||||||
Tensor.uniform_
|
Tensor.uniform_
|
||||||
Tensor.unique
|
Tensor.unique
|
||||||
|
|
|
||||||
|
|
@ -533,6 +533,7 @@ Other Operations
|
||||||
tril_indices
|
tril_indices
|
||||||
triu
|
triu
|
||||||
triu_indices
|
triu_indices
|
||||||
|
unflatten
|
||||||
vander
|
vander
|
||||||
view_as_real
|
view_as_real
|
||||||
view_as_complex
|
view_as_complex
|
||||||
|
|
|
||||||
|
|
@ -114,6 +114,7 @@ ALLOW_LIST = [
|
||||||
("c10d::broadcast", datetime.date(2022, 6, 25)),
|
("c10d::broadcast", datetime.date(2022, 6, 25)),
|
||||||
("aten::.*functional", datetime.date(2022, 8, 1)),
|
("aten::.*functional", datetime.date(2022, 8, 1)),
|
||||||
("aten::_foreach.*", datetime.date(2022, 8, 1)),
|
("aten::_foreach.*", datetime.date(2022, 8, 1)),
|
||||||
|
("aten::unflatten", datetime.date(2022, 8, 10)),
|
||||||
# TODO: FIXME: prims shouldn't be checked
|
# TODO: FIXME: prims shouldn't be checked
|
||||||
("prims::.*", datetime.date(9999, 1, 1)),
|
("prims::.*", datetime.date(9999, 1, 1)),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1083,22 +1083,13 @@ class TestNamedTensor(TestCase):
|
||||||
def test_unflatten(self):
|
def test_unflatten(self):
|
||||||
# test args: tensor, int, namedshape
|
# test args: tensor, int, namedshape
|
||||||
self.assertTrue(torch.equal(
|
self.assertTrue(torch.equal(
|
||||||
torch.ones(4).unflatten(0, (('A', 2), ('B', 2))),
|
torch.ones(4, names=('A',)).unflatten('A', (('A', 2), ('B', 2))),
|
||||||
torch.ones(2, 2, names=('A', 'B'))))
|
torch.ones(2, 2, names=('A', 'B'))))
|
||||||
self.assertTrue(torch.equal(
|
self.assertTrue(torch.equal(
|
||||||
torch.ones(4).unflatten(0, [('A', 2), ('B', 2)]),
|
torch.ones(4, names=('A',)).unflatten('A', [('A', 2), ('B', 2)]),
|
||||||
torch.ones(2, 2, names=('A', 'B'))))
|
torch.ones(2, 2, names=('A', 'B'))))
|
||||||
self.assertTrue(torch.equal(
|
self.assertTrue(torch.equal(
|
||||||
torch.ones(4).unflatten(0, (['A', 2], ['B', 2])),
|
torch.ones(4, names=('A',)).unflatten('A', (['A', 2], ['B', 2])),
|
||||||
torch.ones(2, 2, names=('A', 'B'))))
|
|
||||||
self.assertTrue(torch.equal(
|
|
||||||
torch.ones(4).unflatten(-1, (['A', 2], ['B', 2])),
|
|
||||||
torch.ones(2, 2, names=('A', 'B'))))
|
|
||||||
self.assertTrue(torch.equal(
|
|
||||||
torch.ones(4).unflatten(-1, (['A', -1], ['B', 2])),
|
|
||||||
torch.ones(2, 2, names=('A', 'B'))))
|
|
||||||
self.assertTrue(torch.equal(
|
|
||||||
torch.ones(4).unflatten(-1, (['A', 2], ['B', -1])),
|
|
||||||
torch.ones(2, 2, names=('A', 'B'))))
|
torch.ones(2, 2, names=('A', 'B'))))
|
||||||
self.assertTrue(torch.equal(
|
self.assertTrue(torch.equal(
|
||||||
torch.ones(2, 10, names=('A', 'B')).unflatten('B', (['B1', -1],)),
|
torch.ones(2, 10, names=('A', 'B')).unflatten('B', (['B1', -1],)),
|
||||||
|
|
@ -1112,18 +1103,13 @@ class TestNamedTensor(TestCase):
|
||||||
.unflatten('B', (['B1', 3], ['B2', -1], ['B3', 4])),
|
.unflatten('B', (['B1', 3], ['B2', -1], ['B3', 4])),
|
||||||
torch.ones(2, 3, 0, 4, names=('A', 'B1', 'B2', 'B3'))))
|
torch.ones(2, 3, 0, 4, names=('A', 'B1', 'B2', 'B3'))))
|
||||||
|
|
||||||
# test args: namedtensor, int, namedshape
|
|
||||||
self.assertTrue(torch.equal(
|
|
||||||
torch.ones(2, 4, names=('A', 'B')).unflatten(1, (('B1', 2), ('B2', 2))),
|
|
||||||
torch.ones(2, 2, 2, names=('A', 'B1', 'B2'))))
|
|
||||||
|
|
||||||
# test args: namedtensor, str, namedshape
|
# test args: namedtensor, str, namedshape
|
||||||
self.assertTrue(torch.equal(
|
self.assertTrue(torch.equal(
|
||||||
torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2))),
|
torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2))),
|
||||||
torch.ones(2, 2, 2, names=('A', 'B1', 'B2'))))
|
torch.ones(2, 2, 2, names=('A', 'B1', 'B2'))))
|
||||||
|
|
||||||
# test invalid args: namedtensor, str, sizes
|
# test invalid args: namedtensor, str, sizes
|
||||||
with self.assertRaisesRegex(TypeError, r"received an invalid combination of arguments"):
|
with self.assertRaisesRegex(TypeError, r"unflatten\(\): argument 'dim' \(position 1\) must be int, not str"):
|
||||||
torch.tensor([1], names=('A',)).unflatten('A', (1, 1))
|
torch.tensor([1], names=('A',)).unflatten('A', (1, 1))
|
||||||
|
|
||||||
# test invalid args: namedtensor, int, sizes
|
# test invalid args: namedtensor, int, sizes
|
||||||
|
|
|
||||||
|
|
@ -337,7 +337,7 @@ def generate_tensor_like_torch_implementations():
|
||||||
msg = (
|
msg = (
|
||||||
"The following functions are not tested for __torch_function__ "
|
"The following functions are not tested for __torch_function__ "
|
||||||
"support, please ensure there is an entry in the dict returned by "
|
"support, please ensure there is an entry in the dict returned by "
|
||||||
"torch._overrides.get_testing_overrides for this function or if a "
|
"torch.overrides.get_testing_overrides for this function or if a "
|
||||||
"__torch_function__ override does not make sense, add an entry to "
|
"__torch_function__ override does not make sense, add an entry to "
|
||||||
"the tuple returned by torch._overrides.get_ignored_functions.\n\n{}"
|
"the tuple returned by torch._overrides.get_ignored_functions.\n\n{}"
|
||||||
)
|
)
|
||||||
|
|
@ -648,7 +648,11 @@ def generate_tensor_like_override_tests(cls):
|
||||||
func_args.append(3.5)
|
func_args.append(3.5)
|
||||||
elif t == 'bool':
|
elif t == 'bool':
|
||||||
func_args.append(False)
|
func_args.append(False)
|
||||||
elif t.startswith('int') or t in {'Dimname', 'DimnameList'}:
|
elif t == 'Dimname':
|
||||||
|
func_args.append("")
|
||||||
|
elif t == 'DimnameList':
|
||||||
|
func_args.append([""])
|
||||||
|
elif t.startswith('int'):
|
||||||
func_args.append(0)
|
func_args.append(0)
|
||||||
elif t in {'Stream'}:
|
elif t in {'Stream'}:
|
||||||
func_args.append(torch.Stream())
|
func_args.append(torch.Stream())
|
||||||
|
|
|
||||||
|
|
@ -5857,7 +5857,7 @@ class TestTorch(TestCase):
|
||||||
torch.ones(2, 3, 0, 4, 5, 2))
|
torch.ones(2, 3, 0, 4, 5, 2))
|
||||||
|
|
||||||
# test invalid args: tensor, str, sizes
|
# test invalid args: tensor, str, sizes
|
||||||
with self.assertRaisesRegex(TypeError, r"received an invalid combination of arguments"):
|
with self.assertRaisesRegex(TypeError, r"unflatten\(\): argument 'dim' \(position 1\) must be int, not str"):
|
||||||
torch.tensor([1]).unflatten('A', (1, 1))
|
torch.tensor([1]).unflatten('A', (1, 1))
|
||||||
|
|
||||||
# test invalid args: tensor, str, namedshape
|
# test invalid args: tensor, str, namedshape
|
||||||
|
|
|
||||||
|
|
@ -1135,34 +1135,10 @@ class Tensor(torch._C._TensorBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def unflatten(self, dim, sizes):
|
def unflatten(self, dim, sizes):
|
||||||
r"""Expands the dimension :attr:`dim` of the :attr:`self` tensor over multiple dimensions
|
r"""
|
||||||
of sizes given by :attr:`sizes`.
|
unflatten(dim, sizes) -> Tensor
|
||||||
|
|
||||||
* :attr:`sizes` is the new shape of the unflattened dimension and it can be a `Tuple[int]` as well
|
See :func:`torch.unflatten`.
|
||||||
as `torch.Size` if :attr:`self` is a `Tensor`, or `namedshape` (Tuple[(name: str, size: int)])
|
|
||||||
if :attr:`self` is a `NamedTensor`. The total number of elements in sizes must match the number
|
|
||||||
of elements in the original dim being unflattened.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim (Union[int, str]): Dimension to unflatten
|
|
||||||
sizes (Union[Tuple[int] or torch.Size, Tuple[Tuple[str, int]]]): New shape of the unflattened dimension
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> torch.randn(3, 4, 1).unflatten(1, (2, 2)).shape
|
|
||||||
torch.Size([3, 2, 2, 1])
|
|
||||||
>>> torch.randn(3, 4, 1).unflatten(1, (-1, 2)).shape # the size -1 is inferred from the size of dimension 1
|
|
||||||
torch.Size([3, 2, 2, 1])
|
|
||||||
>>> torch.randn(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2)))
|
|
||||||
tensor([[[-1.1772, 0.0180],
|
|
||||||
[ 0.2412, 0.1431]],
|
|
||||||
[[-1.1819, -0.8899],
|
|
||||||
[ 1.5813, 0.2274]]], names=('A', 'B1', 'B2'))
|
|
||||||
>>> torch.randn(2, names=('A',)).unflatten('A', (('B1', -1), ('B2', 1)))
|
|
||||||
tensor([[-0.8591],
|
|
||||||
[ 0.3100]], names=('B1', 'B2'))
|
|
||||||
|
|
||||||
.. warning::
|
|
||||||
The named tensor API is experimental and subject to change.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if has_torch_function_unary(self):
|
if has_torch_function_unary(self):
|
||||||
|
|
@ -1177,6 +1153,8 @@ class Tensor(torch._C._TensorBase):
|
||||||
):
|
):
|
||||||
names, sizes = unzip_namedshape(sizes)
|
names, sizes = unzip_namedshape(sizes)
|
||||||
return super(Tensor, self).unflatten(dim, sizes, names)
|
return super(Tensor, self).unflatten(dim, sizes, names)
|
||||||
|
else:
|
||||||
|
return super(Tensor, self).unflatten(dim, sizes)
|
||||||
|
|
||||||
def rename_(self, *names, **rename_map):
|
def rename_(self, *names, **rename_map):
|
||||||
"""In-place version of :meth:`~Tensor.rename`."""
|
"""In-place version of :meth:`~Tensor.rename`."""
|
||||||
|
|
|
||||||
|
|
@ -4552,6 +4552,41 @@ Example::
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_docstr(
|
||||||
|
torch.unflatten,
|
||||||
|
r"""
|
||||||
|
unflatten(input, dim, sizes) -> Tensor
|
||||||
|
|
||||||
|
Expands a dimension of the input tensor over multiple dimensions.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
:func:`torch.flatten` the inverse of this function. It coalesces several dimensions into one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
{input}
|
||||||
|
dim (int): Dimension to be unflattened, specified as an index into
|
||||||
|
``input.shape``.
|
||||||
|
sizes (Tuple[int]): New shape of the unflattened dimension.
|
||||||
|
One of its elements can be `-1` in which case the corresponding output
|
||||||
|
dimension is inferred. Otherwise, the product of ``sizes`` *must*
|
||||||
|
equal ``input.shape[dim]``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A View of input with the specified dimension unflattened.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
>>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape
|
||||||
|
torch.Size([3, 2, 2, 1])
|
||||||
|
>>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape
|
||||||
|
torch.Size([3, 2, 2, 1])
|
||||||
|
>>> torch.unflatten(torch.randn(5, 12, 3), -1, (2, 2, 3, 1, 1)).shape
|
||||||
|
torch.Size([5, 2, 2, 3, 1, 1, 3])
|
||||||
|
""".format(
|
||||||
|
**common_args
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
add_docstr(
|
add_docstr(
|
||||||
torch.gather,
|
torch.gather,
|
||||||
r"""
|
r"""
|
||||||
|
|
|
||||||
|
|
@ -126,7 +126,7 @@ Tensor UnflattenImpl::forward(const Tensor& input) {
|
||||||
}
|
}
|
||||||
return input.unflatten(dimname, sizes, names);
|
return input.unflatten(dimname, sizes, names);
|
||||||
}
|
}
|
||||||
return input.unflatten(options.dim(), options.sizes(), torch::nullopt);
|
return input.unflatten(options.dim(), options.sizes());
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
|
||||||
|
|
@ -1073,6 +1073,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.true_divide: lambda input, other: -1,
|
torch.true_divide: lambda input, other: -1,
|
||||||
torch.trunc: lambda input, out=None: -1,
|
torch.trunc: lambda input, out=None: -1,
|
||||||
torch.unbind: lambda input, dim=0: -1,
|
torch.unbind: lambda input, dim=0: -1,
|
||||||
|
torch.unflatten: lambda input, dim, sizes, names: -1,
|
||||||
torch.unique: lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1,
|
torch.unique: lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1,
|
||||||
torch.unique_consecutive: lambda input, return_inverse=False, return_counts=False, dim=None: -1,
|
torch.unique_consecutive: lambda input, return_inverse=False, return_counts=False, dim=None: -1,
|
||||||
torch.unsafe_chunk: lambda input, chunks, dim=0: -1,
|
torch.unsafe_chunk: lambda input, chunks, dim=0: -1,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user