Update NestedTensor docs (#80963)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80963
Approved by: https://github.com/george-qi
This commit is contained in:
Christian Puhrsch 2022-07-07 22:15:37 +00:00 committed by PyTorch MergeBot
parent 712b3e76ef
commit c97ff3d51e

View File

@ -10,16 +10,11 @@ Introduction
The PyTorch API of nested tensors is in prototype stage and will change in the near future.
.. warning::
torch.NestedTensor currently does not support autograd. It needs to be used in the context
of torch.inference_mode().
NestedTensor allows the user to pack a list of Tensors into a single, efficient datastructure.
The only constraint on the input Tensors is that their dimension must match.
This enables more efficient metadata representations and operator coverage.
This enables more efficient metadata representations and access to purpose built kernels.
Construction is straightforward and involves passing a list of Tensors to the constructor.
@ -35,7 +30,7 @@ nested_tensor([
tensor([3, 4, 5, 6, 7])
])
Data type and device can be chosen via the usual keyword arguments
Data type and device can be chosen via the usual keyword arguments.
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32, device="cuda")
>>> nt
@ -44,22 +39,108 @@ nested_tensor([
tensor([3., 4., 5., 6., 7.], device='cuda:0')
])
In order to form a valid NestedTensor the passed Tensors also all need to match in dimension, but none of the other attributes need to.
Operator coverage
+++++++++++++++++
>>> a = torch.randn(3, 50, 70) # image 1
>>> b = torch.randn(3, 128, 64) # image 2
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32)
>>> nt.dim()
4
We are currently on our path to wholesale extend operator coverage guided by specific ML use cases.
If one of the dimensions don't match, the constructor throws an error.
Operator coverage thus is currently very limited and only unbind is supported.
>>> a = torch.randn(50, 128) # text 1
>>> b = torch.randn(3, 128, 64) # image 2
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: All Tensors given to nested_tensor must have the same dimension. Found dimension 3 for Tensor at index 1 and dimension 2 for Tensor at index 0.
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32, device="cuda")
Note that the passed Tensors are being copied into a contiguous piece of memory. The resulting
NestedTensor allocates new memory to store them and does not keep a reference.
At this moment we only support one level of nesting, i.e. a simple, flat list of Tensors. In the future
we can add support for multiple levels of nesting, such as a list that consists entirely of lists of Tensors.
Note that for this extension it is important to maintain an even level of nesting across entries so that the resulting NestedTensor
has a well defined dimension. If you have a need for this feature, please feel encourage to open a feature request so that
we can track it and plan accordingly.
size
+++++++++++++++++++++++++
Even though a NestedTensor does not support .size() (or .shape), it supports .size(i) if dimension i is regular.
>>> a = torch.randn(50, 128) # text 1
>>> b = torch.randn(32, 128) # text 2
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32)
>>> nt.size(0)
2
>>> nt.size(1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Given dimension 1 is irregular and does not have a size.
>>> nt.size(2)
128
If all dimensions are regular, the NestedTensor is intended to be semantically indistinguishable from a regular torch.Tensor.
>>> a = torch.randn(20, 128) # text 1
>>> nt = torch.nested_tensor([a, a], dtype=torch.float32)
>>> nt.size(0)
2
>>> nt.size(1)
20
>>> nt.size(2)
128
>>> torch.stack(nt.unbind()).size()
torch.Size([2, 20, 128])
>>> torch.stack([a, a]).size()
torch.Size([2, 20, 128])
>>> torch.equal(torch.stack(nt.unbind()), torch.stack([a, a]))
True
In the future we might make it easier to detect this condition and convert seamlessly.
Please open a feature request if you have a need for this (or any other related feature for that manner).
unbind
+++++++++++++++++++++++++
unbind allows you to retrieve a view of the constituents.
>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 4)
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32)
>>> nt
nested_tensor([
tensor([0., 1., 2.], device='cuda:0'),
tensor([3., 4., 5., 6., 7.], device='cuda:0')
tensor([[ 1.2286, -1.2343, -1.4842],
[-0.7827, 0.6745, 0.0658]]),
tensor([[-1.1247, -0.4078, -1.0633, 0.8083],
[-0.2871, -0.2980, 0.5559, 1.9885],
[ 0.4074, 2.4855, 0.0733, 0.8285]])
])
>>> nt.unbind()
[tensor([0., 1., 2.], device='cuda:0'), tensor([3., 4., 5., 6., 7.], device='cuda:0')]
(tensor([[ 1.2286, -1.2343, -1.4842],
[-0.7827, 0.6745, 0.0658]]), tensor([[-1.1247, -0.4078, -1.0633, 0.8083],
[-0.2871, -0.2980, 0.5559, 1.9885],
[ 0.4074, 2.4855, 0.0733, 0.8285]]))
>>> nt.unbind()[0] is not a
True
>>> nt.unbind()[0].mul_(3)
tensor([[ 3.6858, -3.7030, -4.4525],
[-2.3481, 2.0236, 0.1975]])
>>> nt
nested_tensor([
tensor([[ 3.6858, -3.7030, -4.4525],
[-2.3481, 2.0236, 0.1975]]),
tensor([[-1.1247, -0.4078, -1.0633, 0.8083],
[-0.2871, -0.2980, 0.5559, 1.9885],
[ 0.4074, 2.4855, 0.0733, 0.8285]])
])
Note that nt.unbind()[0] is not a, but rather a slice of the underlying memory, which represents the first entry or constituent of the NestedTensor.
Nested tensor methods
+++++++++++++++++++++++++