pytorch/torch/namedtensor.py
Richard Zou f996f8d61d Update tensor.view_names / tensor.names_ API (#23973)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23973

Without loss of generality, I describe the API for `tensor.view_names`.
`tensor.names_` has an analogous API.

`tensor.view_names(*names)` returns a view on tensor with named dims `names`.
`names` must be of length `tensor.dim()`; otherwise, if '*' is in `names`,
then it (known as the "glob") is expanded greedily to be equal to the
corresponding names from `tensor.names`.

For example,
```
>>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
>>> x.view_names('*', 'height', 'width').names
('N', 'C', 'height', 'width')

>>> x.view_names('batch', '*', 'width').names
('batch', 'C', 'H', 'width')
```

tensor.view_names(**rename_map) returns a view on tensor that has
renamed dims as specified in the mapping `rename_map`.

For example,
```
>>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
>>> x.view_names(W='width', H='height').names
('N', 'C', 'height', 'width')
```

These are different(!!!) from the C++ API, which only allows the
following:
- tensor.view_names(optional<DimnameList>)

C++ API parity for named tensors is not important right now; I am
punting that to the future.

Test Plan: - [namedtensor ci]

Differential Revision: D16710916

Pulled By: zou3519

fbshipit-source-id: 7cb8056c0fb4c97b04c3a2d1dd0f737e0a67ce34
2019-08-14 09:40:35 -07:00

115 lines
4.3 KiB
Python

import torch
from collections import OrderedDict
"""
This file contains helper functions that implement experimental functionality
for named tensors in python. All of these are experimental, unstable, and
subject to change or deletion.
"""
def _assert_namedtensor_build(api_name):
if not torch._C._BUILD_NAMEDTENSOR:
raise RuntimeError('NYI: {} is experimental and a part '
'of our named tensors project.'.format(api_name))
def _build_dim_map(tensor):
"""Returns a map of { dim: dim_name } where dim is a name if the dim is named
and the dim index otherwise."""
return OrderedDict([(idx if name is None else name, name)
for idx, name in enumerate(tensor.names)])
def _namer_api_name(inplace):
if inplace:
return 'names_'
else:
return 'view_names'
def _expand_single_glob(numel_pre_glob, numel_post_glob, names):
return names[numel_pre_glob:len(names) - numel_post_glob]
def _update_names_with_list(tensor, names, inplace):
# Special case for tensor.view_names(None)
if len(names) == 1 and names[0] is None:
return tensor._update_names(None, inplace)
glob_indices = [i for i, x in enumerate(names) if x == '*']
if len(glob_indices) >= 2:
raise RuntimeError('{}: More than one \'*\' found in names ('
'{}). This function supports up to one \'*\'.'
.format(_namer_api_name(inplace), names))
elif len(glob_indices) == 1:
glob_idx = glob_indices[0]
globbed_names = _expand_single_glob(glob_idx, len(names) - glob_idx - 1, tensor.names)
return tensor._update_names(
names[:glob_idx] + globbed_names + names[glob_idx + 1:], inplace)
else:
return tensor._update_names(names, inplace)
def _update_names_with_mapping(tensor, rename_map, inplace):
dim_map = _build_dim_map(tensor)
for old_dim in rename_map.keys():
new_dim = rename_map[old_dim]
if old_dim in dim_map.keys():
dim_map[old_dim] = new_dim
else:
raise RuntimeError(('{api_name}: Tried to rename dim \'{old_dim}\' to dim '
'{new_dim} in Tensor[{dims}] but dim \'{old_dim}\' does not exist')
.format(old_dim=old_dim, new_dim=new_dim, dims=tensor.names,
api_name=_namer_api_name(inplace)))
return tensor._update_names(tuple(dim_map.values()), inplace)
def _update_names(tensor, names, rename_map, inplace):
"""There are two usages:
tensor.view_names(*names) returns a view on tensor with named dims `names`.
`names` must be of length `tensor.dim()`; otherwise, if '*' is in `names`,
then it is expanded greedily to be equal to the corresponding names from
`tensor.names`.
For example,
```
>>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
>>> x.view_names('*', 'height', 'width').names
('N', 'C', 'height', 'width')
>>> x.view_names('batch', '*', 'width').names
('batch', 'C', 'H', 'width')
```
tensor.view_names(**rename_map) returns a view on tensor that has renamed dims
as specified in the mapping `rename_map`.
For example,
```
>>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
>>> x.view_names(W='width', H='height').names
('N', 'C', 'height', 'width')
```
Finally, tensor.view_names has an in-place version called tensor.names_.
"""
_assert_namedtensor_build(_namer_api_name(inplace))
has_names = len(names) > 0
has_rename_pairs = bool(rename_map)
if has_names and has_rename_pairs:
raise RuntimeError('{api_name}: This function takes either positional '
'args or keyword args, but not both. Use tensor.{api_name}(*names) '
'to name dims and tensor.{api_name}(**rename_map) to rename '
'dims.'.format(api_name=_namer_api_name(inplace)))
# Special case for tensor.view_names(*[]), which is valid for a 0 dim tensor.
if not has_names and not has_rename_pairs:
return _update_names_with_list(tensor, names, inplace)
if has_names:
return _update_names_with_list(tensor, names, inplace)
return _update_names_with_mapping(tensor, rename_map, inplace)