mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Update NestedTensor docs (#80963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80963 Approved by: https://github.com/george-qi
This commit is contained in:
parent
712b3e76ef
commit
c97ff3d51e
|
|
@ -10,16 +10,11 @@ Introduction
|
|||
|
||||
The PyTorch API of nested tensors is in prototype stage and will change in the near future.
|
||||
|
||||
.. warning::
|
||||
|
||||
torch.NestedTensor currently does not support autograd. It needs to be used in the context
|
||||
of torch.inference_mode().
|
||||
|
||||
NestedTensor allows the user to pack a list of Tensors into a single, efficient datastructure.
|
||||
|
||||
The only constraint on the input Tensors is that their dimension must match.
|
||||
|
||||
This enables more efficient metadata representations and operator coverage.
|
||||
This enables more efficient metadata representations and access to purpose built kernels.
|
||||
|
||||
Construction is straightforward and involves passing a list of Tensors to the constructor.
|
||||
|
||||
|
|
@ -35,7 +30,7 @@ nested_tensor([
|
|||
tensor([3, 4, 5, 6, 7])
|
||||
])
|
||||
|
||||
Data type and device can be chosen via the usual keyword arguments
|
||||
Data type and device can be chosen via the usual keyword arguments.
|
||||
|
||||
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32, device="cuda")
|
||||
>>> nt
|
||||
|
|
@ -44,22 +39,108 @@ nested_tensor([
|
|||
tensor([3., 4., 5., 6., 7.], device='cuda:0')
|
||||
])
|
||||
|
||||
In order to form a valid NestedTensor the passed Tensors also all need to match in dimension, but none of the other attributes need to.
|
||||
|
||||
Operator coverage
|
||||
+++++++++++++++++
|
||||
>>> a = torch.randn(3, 50, 70) # image 1
|
||||
>>> b = torch.randn(3, 128, 64) # image 2
|
||||
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32)
|
||||
>>> nt.dim()
|
||||
4
|
||||
|
||||
We are currently on our path to wholesale extend operator coverage guided by specific ML use cases.
|
||||
If one of the dimensions don't match, the constructor throws an error.
|
||||
|
||||
Operator coverage thus is currently very limited and only unbind is supported.
|
||||
>>> a = torch.randn(50, 128) # text 1
|
||||
>>> b = torch.randn(3, 128, 64) # image 2
|
||||
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32)
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 1, in <module>
|
||||
RuntimeError: All Tensors given to nested_tensor must have the same dimension. Found dimension 3 for Tensor at index 1 and dimension 2 for Tensor at index 0.
|
||||
|
||||
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32, device="cuda")
|
||||
Note that the passed Tensors are being copied into a contiguous piece of memory. The resulting
|
||||
NestedTensor allocates new memory to store them and does not keep a reference.
|
||||
|
||||
|
||||
At this moment we only support one level of nesting, i.e. a simple, flat list of Tensors. In the future
|
||||
we can add support for multiple levels of nesting, such as a list that consists entirely of lists of Tensors.
|
||||
Note that for this extension it is important to maintain an even level of nesting across entries so that the resulting NestedTensor
|
||||
has a well defined dimension. If you have a need for this feature, please feel encourage to open a feature request so that
|
||||
we can track it and plan accordingly.
|
||||
|
||||
size
|
||||
+++++++++++++++++++++++++
|
||||
|
||||
Even though a NestedTensor does not support .size() (or .shape), it supports .size(i) if dimension i is regular.
|
||||
|
||||
>>> a = torch.randn(50, 128) # text 1
|
||||
>>> b = torch.randn(32, 128) # text 2
|
||||
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32)
|
||||
>>> nt.size(0)
|
||||
2
|
||||
>>> nt.size(1)
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 1, in <module>
|
||||
RuntimeError: Given dimension 1 is irregular and does not have a size.
|
||||
>>> nt.size(2)
|
||||
128
|
||||
|
||||
If all dimensions are regular, the NestedTensor is intended to be semantically indistinguishable from a regular torch.Tensor.
|
||||
|
||||
>>> a = torch.randn(20, 128) # text 1
|
||||
>>> nt = torch.nested_tensor([a, a], dtype=torch.float32)
|
||||
>>> nt.size(0)
|
||||
2
|
||||
>>> nt.size(1)
|
||||
20
|
||||
>>> nt.size(2)
|
||||
128
|
||||
>>> torch.stack(nt.unbind()).size()
|
||||
torch.Size([2, 20, 128])
|
||||
>>> torch.stack([a, a]).size()
|
||||
torch.Size([2, 20, 128])
|
||||
>>> torch.equal(torch.stack(nt.unbind()), torch.stack([a, a]))
|
||||
True
|
||||
|
||||
In the future we might make it easier to detect this condition and convert seamlessly.
|
||||
|
||||
Please open a feature request if you have a need for this (or any other related feature for that manner).
|
||||
|
||||
unbind
|
||||
+++++++++++++++++++++++++
|
||||
|
||||
unbind allows you to retrieve a view of the constituents.
|
||||
|
||||
>>> import torch
|
||||
>>> a = torch.randn(2, 3)
|
||||
>>> b = torch.randn(3, 4)
|
||||
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32)
|
||||
>>> nt
|
||||
nested_tensor([
|
||||
tensor([0., 1., 2.], device='cuda:0'),
|
||||
tensor([3., 4., 5., 6., 7.], device='cuda:0')
|
||||
tensor([[ 1.2286, -1.2343, -1.4842],
|
||||
[-0.7827, 0.6745, 0.0658]]),
|
||||
tensor([[-1.1247, -0.4078, -1.0633, 0.8083],
|
||||
[-0.2871, -0.2980, 0.5559, 1.9885],
|
||||
[ 0.4074, 2.4855, 0.0733, 0.8285]])
|
||||
])
|
||||
>>> nt.unbind()
|
||||
[tensor([0., 1., 2.], device='cuda:0'), tensor([3., 4., 5., 6., 7.], device='cuda:0')]
|
||||
(tensor([[ 1.2286, -1.2343, -1.4842],
|
||||
[-0.7827, 0.6745, 0.0658]]), tensor([[-1.1247, -0.4078, -1.0633, 0.8083],
|
||||
[-0.2871, -0.2980, 0.5559, 1.9885],
|
||||
[ 0.4074, 2.4855, 0.0733, 0.8285]]))
|
||||
>>> nt.unbind()[0] is not a
|
||||
True
|
||||
>>> nt.unbind()[0].mul_(3)
|
||||
tensor([[ 3.6858, -3.7030, -4.4525],
|
||||
[-2.3481, 2.0236, 0.1975]])
|
||||
>>> nt
|
||||
nested_tensor([
|
||||
tensor([[ 3.6858, -3.7030, -4.4525],
|
||||
[-2.3481, 2.0236, 0.1975]]),
|
||||
tensor([[-1.1247, -0.4078, -1.0633, 0.8083],
|
||||
[-0.2871, -0.2980, 0.5559, 1.9885],
|
||||
[ 0.4074, 2.4855, 0.0733, 0.8285]])
|
||||
])
|
||||
|
||||
Note that nt.unbind()[0] is not a, but rather a slice of the underlying memory, which represents the first entry or constituent of the NestedTensor.
|
||||
|
||||
Nested tensor methods
|
||||
+++++++++++++++++++++++++
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user