Implement narrow from a regular tensor to jagged tensor (#112770)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112770
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Antoni Viros 2023-11-10 20:52:56 +00:00 committed by PyTorch MergeBot
parent 3700894099
commit 1aece432ba
6 changed files with 271 additions and 20 deletions

View File

@ -14,7 +14,11 @@ from torch._dynamo.testing import normalize_gm
from torch._higher_order_ops.wrap import wrap
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
from torch.nested._internal.nested_tensor import jagged_from_list, ViewBufferFromNested
from torch.nested._internal.nested_tensor import (
jagged_from_list,
jagged_from_tensor_and_lengths,
ViewBufferFromNested,
)
from torch.testing._internal.inductor_utils import HAS_CUDA
@ -799,6 +803,19 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
)
return jagged_from_list(out, offsets)
def _get_nc_jagged_tensor(self, inner_dim, starts, lengths, requires_grad=True):
# Makes a jagged tensor with N constituent tensors with size
# as specified ((S0, S1, S2), D)
max_dim = (starts + lengths).max()
values_tensor = torch.randn(
starts.shape[0],
max_dim.item(),
inner_dim,
requires_grad=requires_grad,
dtype=torch.float64,
)
return jagged_from_tensor_and_lengths(values_tensor, starts, lengths)
def _check_recompiles(self, fn, inputs1, inputs2, recompiles):
compile_count = [0]

View File

@ -3325,6 +3325,66 @@ class TestNestedTensorSubclass(NestedTestCase):
nt, _ = jagged_from_list(test_tensor_list, None)
_ = torch.nn.functional.layer_norm(nt, (nt.shape[-2], nt.shape[-1]))
def test_narrow(self, device):
starts = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
lengths = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)
nt = torch.nested.narrow(
torch.arange(0, 10, device=device, dtype=torch.int64).unsqueeze(0).expand(5, -1).clone().detach(),
1,
starts,
lengths,
layout=torch.jagged
)
# TODO: Use this approach when unbind is functional
# unbinded_nt = nt.unbind()
# for i in range(starts.shape[0]):
# self.assertEqual(torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64), unbinded_nt[i])
for i in range(starts.shape[0]):
self.assertEqual(
torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64),
nt.values()[nt.offsets()[i]:(nt.offsets()[i] + nt.lengths()[i])]
)
def test_is_contiguous(self, device):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
nt_contiguous, _ = jagged_from_list([a, b, c], None)
starts_nc = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
lengths_nc = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)
narrow_base = torch.arange(0, 10, device=device, dtype=torch.int64).unsqueeze(0).expand(5, -1).clone()
nt_noncontiguous = torch.nested.narrow(
narrow_base,
1,
starts_nc,
lengths_nc,
layout=torch.jagged
)
starts_c = torch.tensor([1, 0, 0, 0, 0], device=device, dtype=torch.int64)
lengths_c = torch.tensor([9, 10, 10, 10, 8], device=device, dtype=torch.int64)
nt_contiguous_narrow = torch.nested.narrow(
narrow_base,
1,
starts_c,
lengths_c,
layout=torch.jagged
)
# Test contiguous case
assert nt_contiguous.is_contiguous()
# Test narrow case
assert not nt_noncontiguous.is_contiguous()
assert nt_contiguous_narrow.is_contiguous()
# Test querying by memory_format
self.assertTrue(nt_contiguous.is_contiguous(memory_format=torch.contiguous_format))
self.assertTrue(not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format))
self.assertTrue(nt_contiguous_narrow.is_contiguous(memory_format=torch.contiguous_format))
instantiate_parametrized_tests(TestNestedTensor)
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())

View File

@ -309,7 +309,7 @@ class MetaConverter:
from torch._dynamo.source import AttrSource
from torch.fx.experimental.symbolic_shapes import DimDynamic
if shape_env and not t.is_nested:
if shape_env and not t.is_nested and not t._base.is_nested:
base_dynamic_dims = [DimDynamic.STATIC] * t._base.dim()
else:
base_dynamic_dims = None

View File

@ -1,7 +1,7 @@
from typing import List, Optional
from typing import List, Optional, Union
import torch
from torch import Tensor
from torch import SymInt, Tensor
from torch._C import _add_docstr, _nested # type: ignore[attr-defined]
from torch.types import _device as Device, _dtype as DType
@ -10,6 +10,7 @@ __all__ = [
"to_padded_tensor",
"as_nested_tensor",
"nested_tensor",
"narrow",
]
# Nested Tensor constructor functions
@ -187,3 +188,69 @@ Example::
return nt
else:
raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}")
def narrow(tensor: Tensor, dim: int, start: Union[int, Tensor], length: Union[int, Tensor], layout=torch.strided) -> Tensor:
r"""
Constructs a nested tensor (which might be a view) from :attr:`tensor`, a strided tensor. This follows
similar semantics to torch.Tensor.narrow, where in the :attr:`dim`-th dimension the new nested tensor
shows only the elements in the interval `[start, start+length)`. As nested representations
allow for a different `start` and `length` at each 'row' of that dimension, :attr:`start` and :attr:`length`
can also be tensors of shape `tensor.shape[0]`.
There's some differences depending on the layout you use for the nested tensor. If using strided layout,
torch.narrow will do a copy of the narrowed data into a contiguous NT with strided layout, while
jagged layout narrow() will create a non-contiguous view of your original strided tensor. This particular
representation is really useful for representing kv-caches in Transformer models, as specialized
SDPA kernels can deal with format easily, resulting in performance improvements.
Args:
tensor (:class:`torch.Tensor`): a strided tensor, which will be used as the underlying data
for the nested tensor if using the jagged layout or will be copied for the strided layout.
dim (int): the dimension where narrow will be applied. Only `dim=1` is supported for the
jagged layout, while strided supports all dim
start (Union[int, :class:`torch.Tensor`]): starting element for the narrow operation
length (Union[int, :class:`torch.Tensor`]): number of elements taken during the narrow op
Keyword arguments:
layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor.
Only strided and jagged layouts are supported. Default: if None, the strided layout.
Example::
>>> starts = torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64)
>>> lengths = torch.tensor([3, 2, 2, 1, 5], dtype=torch.int64)
>>> narrow_base = torch.randn(5, 10, 20)
>>> nt_narrowed = torch.nested.narrow(narrow_base, 1, starts, lengths, layout=torch.jagged)
>>> nt_narrowed.is_contiguous()
False
"""
if not isinstance(start, (int, SymInt, Tensor)):
raise RuntimeError("start must be an integer or a tensor")
if not isinstance(length, (int, SymInt, Tensor)):
raise RuntimeError("length must be an integer or a tensor")
if layout == torch.strided:
if isinstance(start, Tensor) or isinstance(length, Tensor):
raise RuntimeError("start and length must be integers for the strided layout NT impl")
# TODO: switch to as_nested_tensor(tensor) when it is available
nt = as_nested_tensor(torch.unbind(tensor), layout=torch.strided).narrow(dim, start, length)
elif layout == torch.jagged:
if dim != 1:
raise RuntimeError("jagged layout only supports dim=1")
from torch.nested._internal.nested_tensor import jagged_from_tensor_and_lengths
if isinstance(start, (int, SymInt)):
start = torch.tensor([start], device=tensor.device, dtype=torch.int64)
if isinstance(length, (int, SymInt)):
length = torch.tensor([length], device=tensor.device, dtype=torch.int64)
nt, _, _ = jagged_from_tensor_and_lengths(tensor, start, length)
else:
raise RuntimeError(f"Specified layout is unsupported for nested narrow: {layout}")
return nt

View File

@ -2,6 +2,7 @@ from typing import Tuple
import torch
from torch._C import DispatchKey, DispatchKeySet
from torch._prims_common import is_expandable_to
from torch.fx.experimental.symbolic_shapes import has_free_symbols
from torch.utils.weak import WeakTensorKeyDictionary
from typing import * # noqa: F403
@ -21,6 +22,7 @@ def get_tensor_id(tensor, *, coeff=1):
class NestedTensor(torch.Tensor):
_values: torch.Tensor # type: ignore[assignment]
_offsets: torch.Tensor
_lengths: Optional[torch.Tensor]
# NOTE [ Singleton ints for ragged sizes and strides ]
#
# Jagged layout tensors are tensors that represent a n-dim tensor with a
@ -46,6 +48,7 @@ class NestedTensor(torch.Tensor):
values,
offsets,
*,
lengths=None,
ragged_size=None,
**kwargs,
):
@ -69,7 +72,7 @@ class NestedTensor(torch.Tensor):
)
return r
def __init__(self, values, offsets, *, ragged_size=None, **kwargs):
def __init__(self, values, offsets, *, lengths=None, ragged_size=None, **kwargs):
super().__init__()
# Only support jagged for now.
assert offsets is not None
@ -82,7 +85,10 @@ class NestedTensor(torch.Tensor):
# we perform operations on fake nested tensors.
# Calling get_tensor_id won't work in those cases because we want
# the existing symbolic ragged_size to be propagated.
ragged_size = get_tensor_id(offsets, coeff=1)
if lengths is None:
ragged_size = get_tensor_id(offsets, coeff=1)
else:
ragged_size = get_tensor_id(lengths, coeff=1)
B = offsets.shape[0] - 1
Ds = values.shape[1:]
self._size = (B, ragged_size, *Ds)
@ -97,6 +103,7 @@ class NestedTensor(torch.Tensor):
)
self._values = values
self._offsets = offsets
self._lengths = lengths
def values(self):
return self._values
@ -104,6 +111,9 @@ class NestedTensor(torch.Tensor):
def offsets(self):
return self._offsets
def lengths(self):
return self._lengths
def __repr__(self):
# We should implement this in torch/_tensor_str.py instead
grad_fn_str = (
@ -111,7 +121,7 @@ class NestedTensor(torch.Tensor):
)
if self.grad_fn:
grad_fn_str = f", grad_fn={self.grad_fn}"
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str})"
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._lengths is None})"
def __reduce_ex__(self, proto):
state = torch._utils._get_obj_state(self)
@ -131,13 +141,20 @@ class NestedTensor(torch.Tensor):
"requires_grad": self.requires_grad,
"ragged_size": self._size[self._ragged_idx],
}
return ["_values", "_offsets"], ctx
inner_tensors = ["_values", "_offsets"]
if self._lengths is not None:
inner_tensors.append("_lengths")
return inner_tensors, ctx
@staticmethod
def __tensor_unflatten__(inner_tensors: Dict, meta):
assert len(inner_tensors) == 2
assert len(inner_tensors) >= 2 and len(inner_tensors) <= 3
values = inner_tensors["_values"]
offsets = inner_tensors["_offsets"]
if "_lengths" in inner_tensors and inner_tensors["_lengths"] is not None:
lengths = inner_tensors["_lengths"]
else:
lengths = None
# NOTE [ Storing symbolic values as plain attributes on subclasses ]
#
@ -173,6 +190,7 @@ class NestedTensor(torch.Tensor):
return NestedTensor(
values,
offsets=offsets,
lengths=lengths,
ragged_size=meta["ragged_size"],
requires_grad=meta["requires_grad"],
)
@ -232,6 +250,18 @@ class ViewNestedFromBuffer(torch.autograd.Function):
return gO.values(), None, None
# Not actually a view!
# NOTE: @jbschlosser is working on making it a view
class ViewNonContiguousNestedFromBuffer(torch.autograd.Function):
@staticmethod
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor): # type: ignore[override]
return NestedTensor(values.detach(), offsets=offsets, lengths=lengths)
@staticmethod
def backward(ctx, gO: NestedTensor): # type: ignore[override]
return gO.values(), None, None
# Need to make it obvious that users should be passing in offsets
def jagged_from_list(
tensors: List[torch.Tensor],
@ -285,5 +315,65 @@ def jagged_from_list(
return ViewNestedFromBuffer.apply(values, offsets), offsets # type: ignore[call-overload]
def jagged_from_tensor_and_lengths(
tensor: torch.Tensor, starts: torch.Tensor, lengths: torch.Tensor
) -> Tuple[NestedTensor, torch.Tensor, Optional[torch.Tensor]]:
"""Constructs a NestedTensor backed by jagged layout from a tensor, starts of sequences, and sequence lengths"""
batch_size = tensor.shape[0]
if is_expandable_to(starts.shape, (batch_size,)) and is_expandable_to(
lengths.shape, (batch_size,)
):
start_list = starts.expand(batch_size)
length_list = lengths.expand(batch_size)
else:
raise RuntimeError(
"When constructing a jagged nested tensor using narrow(), "
"your start and length must be Tensors that broadcast to input.shape[0]"
)
# Calculate jagged offsets
assert (
len(tensor.shape) >= 2
), "tensor must at least be 2D for the nested narrow op to work"
max_seq_len = tensor.shape[1]
offset_lengths = max_seq_len * torch.arange(
0, batch_size, dtype=torch.int64, device=tensor.device
)
# Jagged layout specifies that offsets are stored as int64 on the same device as values.
offsets = torch.cat(
[
start_list + offset_lengths,
(start_list[-1] + offset_lengths[-1] + length_list[-1]).unsqueeze(0),
]
)
# Reshape buffer to flatten the 1st and 2nd dimension (view used to enforce non-copy)
if len(tensor.shape) > 2:
values = tensor.view(-1, *tensor.shape[2:])
else:
values = tensor.view(-1)
# Check if offsets and lengths make it possibly contiguous and return a regular NT
is_contiguous = True
orig_dim = tensor.shape[1]
if torch.any(length_list[1:-1].ne(orig_dim)):
is_contiguous = False
if torch.any(offsets[1:-2].diff().ne(orig_dim)):
is_contiguous = False
if offsets[0] + length_list[0] != orig_dim:
is_contiguous = False
if is_contiguous:
return (
ViewNestedFromBuffer.apply(
values[offsets[0] : offsets[-1]], offsets - offsets[0]
),
offsets,
None,
)
return ViewNonContiguousNestedFromBuffer.apply(values, offsets, length_list), offsets, length_list # type: ignore[call-overload]
def buffer_from_jagged(jagged):
return ViewBufferFromNested.apply(jagged)

View File

@ -1,4 +1,5 @@
import functools
import math
import torch
@ -65,7 +66,11 @@ def check_schema(schema_str: str, func, *args, **kwargs) -> None:
arg_type_check_fns = {
"t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
"jt": lambda x: isinstance(x, NestedTensor),
"jt": lambda x: isinstance(x, NestedTensor)
and x._lengths is None, # ops with "jt" require contiguous JT only
"jt_all": lambda x: isinstance(
x, NestedTensor
), # ops with "jt_all" can accept all kinds of JT
"any": lambda x: True,
}
for i, named_arg_type in enumerate(named_arg_types):
@ -297,7 +302,7 @@ def jagged_torch_function(func, *args, **kwargs):
torch.ops.aten.sym_stride.default,
torch.ops.aten.sym_storage_offset.default,
],
"self: jt",
"self: jt_all",
)
def tensor_attr_supported_getter(func, *args, **kwargs):
if func == torch.ops.aten.is_non_overlapping_and_dense.default:
@ -310,23 +315,25 @@ def tensor_attr_supported_getter(func, *args, **kwargs):
return len(args[0]._size)
if func == torch.ops.aten.sym_numel.default:
if args[0]._lengths is not None:
return int(sum(args[0]._lengths) * math.prod(args[0]._size[2:]))
return args[0]._values.numel()
if func == torch.ops.aten.sym_stride.default:
return args[0]._strides
if func == torch.ops.aten.sym_storage_offset.default:
return 0
return args[0]._values.storage_offset()
@register_jagged_func(torch.ops.prim.layout.default, "self: jt")
@register_jagged_func(torch.ops.prim.layout.default, "self: jt_all")
def prim_layout_default(func, *args, **kwargs):
return torch.jagged
@register_jagged_func(
[torch.ops.aten.size.default],
"self: jt",
"self: jt_all",
)
def tensor_attr_unsupported_getter(func, *args, **kwargs):
if func == torch.ops.aten.size.default:
@ -336,7 +343,7 @@ def tensor_attr_unsupported_getter(func, *args, **kwargs):
)
@register_jagged_func(torch.ops.aten.is_contiguous.default, "self: jt")
@register_jagged_func(torch.ops.aten.is_contiguous.default, "self: jt_all")
def is_contiguous_general(func, *args, **kwargs):
from torch._prims_common import is_contiguous_for_memory_format
@ -344,16 +351,21 @@ def is_contiguous_general(func, *args, **kwargs):
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
# If created from narrow() check for lengths
if inp.lengths() is not None:
return False
new_kwargs["memory_format"] = new_kwargs.get(
"memory_format", torch.contiguous_format
)
if new_kwargs["memory_format"] == torch.preserve_format:
return True
return is_contiguous_for_memory_format(inp, **new_kwargs)
return is_contiguous_for_memory_format(inp.values(), **new_kwargs)
register_jagged_func(
torch.ops.aten.is_contiguous.memory_format, "self: jt, memory_format: any?"
torch.ops.aten.is_contiguous.memory_format, "self: jt_all, memory_format: any?"
)(is_contiguous_general)
@ -508,7 +520,7 @@ def split_with_sizes_default(func, *args, **kwargs):
]
@register_jagged_func(torch.ops.aten.unbind.int, "self: jt, dim: any?")
@register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?")
def unbind_int(func, *args, **kwargs):
# Note that this specializes on the length of the offsets
_, new_kwargs = normalize_function(
@ -520,10 +532,15 @@ def unbind_int(func, *args, **kwargs):
raise RuntimeError("unbind(): only supported for NestedTensor on dim=0")
inp = new_kwargs.pop("input")
values = inp._values
values = inp.values()
offsets = inp.offsets()
lengths = inp.lengths()
return torch.split(values, offsets.diff().tolist())
if lengths is None:
return torch.split(values, offsets.diff().tolist())
return [
values[offsets[i] : (offsets[i] + lengths[i])] for i in range(lengths.shape[0])
]
@register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt, dim: any")