mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Implement narrow from a regular tensor to jagged tensor (#112770)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112770 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
3700894099
commit
1aece432ba
|
|
@ -14,7 +14,11 @@ from torch._dynamo.testing import normalize_gm
|
||||||
from torch._higher_order_ops.wrap import wrap
|
from torch._higher_order_ops.wrap import wrap
|
||||||
|
|
||||||
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
|
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
|
||||||
from torch.nested._internal.nested_tensor import jagged_from_list, ViewBufferFromNested
|
from torch.nested._internal.nested_tensor import (
|
||||||
|
jagged_from_list,
|
||||||
|
jagged_from_tensor_and_lengths,
|
||||||
|
ViewBufferFromNested,
|
||||||
|
)
|
||||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -799,6 +803,19 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
|
||||||
)
|
)
|
||||||
return jagged_from_list(out, offsets)
|
return jagged_from_list(out, offsets)
|
||||||
|
|
||||||
|
def _get_nc_jagged_tensor(self, inner_dim, starts, lengths, requires_grad=True):
|
||||||
|
# Makes a jagged tensor with N constituent tensors with size
|
||||||
|
# as specified ((S0, S1, S2), D)
|
||||||
|
max_dim = (starts + lengths).max()
|
||||||
|
values_tensor = torch.randn(
|
||||||
|
starts.shape[0],
|
||||||
|
max_dim.item(),
|
||||||
|
inner_dim,
|
||||||
|
requires_grad=requires_grad,
|
||||||
|
dtype=torch.float64,
|
||||||
|
)
|
||||||
|
return jagged_from_tensor_and_lengths(values_tensor, starts, lengths)
|
||||||
|
|
||||||
def _check_recompiles(self, fn, inputs1, inputs2, recompiles):
|
def _check_recompiles(self, fn, inputs1, inputs2, recompiles):
|
||||||
compile_count = [0]
|
compile_count = [0]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3325,6 +3325,66 @@ class TestNestedTensorSubclass(NestedTestCase):
|
||||||
nt, _ = jagged_from_list(test_tensor_list, None)
|
nt, _ = jagged_from_list(test_tensor_list, None)
|
||||||
_ = torch.nn.functional.layer_norm(nt, (nt.shape[-2], nt.shape[-1]))
|
_ = torch.nn.functional.layer_norm(nt, (nt.shape[-2], nt.shape[-1]))
|
||||||
|
|
||||||
|
def test_narrow(self, device):
|
||||||
|
starts = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
|
||||||
|
lengths = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)
|
||||||
|
nt = torch.nested.narrow(
|
||||||
|
torch.arange(0, 10, device=device, dtype=torch.int64).unsqueeze(0).expand(5, -1).clone().detach(),
|
||||||
|
1,
|
||||||
|
starts,
|
||||||
|
lengths,
|
||||||
|
layout=torch.jagged
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Use this approach when unbind is functional
|
||||||
|
# unbinded_nt = nt.unbind()
|
||||||
|
# for i in range(starts.shape[0]):
|
||||||
|
# self.assertEqual(torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64), unbinded_nt[i])
|
||||||
|
for i in range(starts.shape[0]):
|
||||||
|
self.assertEqual(
|
||||||
|
torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64),
|
||||||
|
nt.values()[nt.offsets()[i]:(nt.offsets()[i] + nt.lengths()[i])]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_is_contiguous(self, device):
|
||||||
|
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
||||||
|
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
||||||
|
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
||||||
|
nt_contiguous, _ = jagged_from_list([a, b, c], None)
|
||||||
|
|
||||||
|
starts_nc = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
|
||||||
|
lengths_nc = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)
|
||||||
|
narrow_base = torch.arange(0, 10, device=device, dtype=torch.int64).unsqueeze(0).expand(5, -1).clone()
|
||||||
|
nt_noncontiguous = torch.nested.narrow(
|
||||||
|
narrow_base,
|
||||||
|
1,
|
||||||
|
starts_nc,
|
||||||
|
lengths_nc,
|
||||||
|
layout=torch.jagged
|
||||||
|
)
|
||||||
|
|
||||||
|
starts_c = torch.tensor([1, 0, 0, 0, 0], device=device, dtype=torch.int64)
|
||||||
|
lengths_c = torch.tensor([9, 10, 10, 10, 8], device=device, dtype=torch.int64)
|
||||||
|
nt_contiguous_narrow = torch.nested.narrow(
|
||||||
|
narrow_base,
|
||||||
|
1,
|
||||||
|
starts_c,
|
||||||
|
lengths_c,
|
||||||
|
layout=torch.jagged
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test contiguous case
|
||||||
|
assert nt_contiguous.is_contiguous()
|
||||||
|
|
||||||
|
# Test narrow case
|
||||||
|
assert not nt_noncontiguous.is_contiguous()
|
||||||
|
assert nt_contiguous_narrow.is_contiguous()
|
||||||
|
|
||||||
|
# Test querying by memory_format
|
||||||
|
self.assertTrue(nt_contiguous.is_contiguous(memory_format=torch.contiguous_format))
|
||||||
|
self.assertTrue(not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format))
|
||||||
|
self.assertTrue(nt_contiguous_narrow.is_contiguous(memory_format=torch.contiguous_format))
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(TestNestedTensor)
|
instantiate_parametrized_tests(TestNestedTensor)
|
||||||
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
|
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
|
||||||
|
|
|
||||||
|
|
@ -309,7 +309,7 @@ class MetaConverter:
|
||||||
from torch._dynamo.source import AttrSource
|
from torch._dynamo.source import AttrSource
|
||||||
from torch.fx.experimental.symbolic_shapes import DimDynamic
|
from torch.fx.experimental.symbolic_shapes import DimDynamic
|
||||||
|
|
||||||
if shape_env and not t.is_nested:
|
if shape_env and not t.is_nested and not t._base.is_nested:
|
||||||
base_dynamic_dims = [DimDynamic.STATIC] * t._base.dim()
|
base_dynamic_dims = [DimDynamic.STATIC] * t._base.dim()
|
||||||
else:
|
else:
|
||||||
base_dynamic_dims = None
|
base_dynamic_dims = None
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import SymInt, Tensor
|
||||||
from torch._C import _add_docstr, _nested # type: ignore[attr-defined]
|
from torch._C import _add_docstr, _nested # type: ignore[attr-defined]
|
||||||
|
|
||||||
from torch.types import _device as Device, _dtype as DType
|
from torch.types import _device as Device, _dtype as DType
|
||||||
|
|
@ -10,6 +10,7 @@ __all__ = [
|
||||||
"to_padded_tensor",
|
"to_padded_tensor",
|
||||||
"as_nested_tensor",
|
"as_nested_tensor",
|
||||||
"nested_tensor",
|
"nested_tensor",
|
||||||
|
"narrow",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Nested Tensor constructor functions
|
# Nested Tensor constructor functions
|
||||||
|
|
@ -187,3 +188,69 @@ Example::
|
||||||
return nt
|
return nt
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}")
|
raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}")
|
||||||
|
|
||||||
|
|
||||||
|
def narrow(tensor: Tensor, dim: int, start: Union[int, Tensor], length: Union[int, Tensor], layout=torch.strided) -> Tensor:
|
||||||
|
r"""
|
||||||
|
Constructs a nested tensor (which might be a view) from :attr:`tensor`, a strided tensor. This follows
|
||||||
|
similar semantics to torch.Tensor.narrow, where in the :attr:`dim`-th dimension the new nested tensor
|
||||||
|
shows only the elements in the interval `[start, start+length)`. As nested representations
|
||||||
|
allow for a different `start` and `length` at each 'row' of that dimension, :attr:`start` and :attr:`length`
|
||||||
|
can also be tensors of shape `tensor.shape[0]`.
|
||||||
|
|
||||||
|
There's some differences depending on the layout you use for the nested tensor. If using strided layout,
|
||||||
|
torch.narrow will do a copy of the narrowed data into a contiguous NT with strided layout, while
|
||||||
|
jagged layout narrow() will create a non-contiguous view of your original strided tensor. This particular
|
||||||
|
representation is really useful for representing kv-caches in Transformer models, as specialized
|
||||||
|
SDPA kernels can deal with format easily, resulting in performance improvements.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (:class:`torch.Tensor`): a strided tensor, which will be used as the underlying data
|
||||||
|
for the nested tensor if using the jagged layout or will be copied for the strided layout.
|
||||||
|
dim (int): the dimension where narrow will be applied. Only `dim=1` is supported for the
|
||||||
|
jagged layout, while strided supports all dim
|
||||||
|
start (Union[int, :class:`torch.Tensor`]): starting element for the narrow operation
|
||||||
|
length (Union[int, :class:`torch.Tensor`]): number of elements taken during the narrow op
|
||||||
|
|
||||||
|
Keyword arguments:
|
||||||
|
layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor.
|
||||||
|
Only strided and jagged layouts are supported. Default: if None, the strided layout.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> starts = torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64)
|
||||||
|
>>> lengths = torch.tensor([3, 2, 2, 1, 5], dtype=torch.int64)
|
||||||
|
>>> narrow_base = torch.randn(5, 10, 20)
|
||||||
|
>>> nt_narrowed = torch.nested.narrow(narrow_base, 1, starts, lengths, layout=torch.jagged)
|
||||||
|
>>> nt_narrowed.is_contiguous()
|
||||||
|
False
|
||||||
|
"""
|
||||||
|
if not isinstance(start, (int, SymInt, Tensor)):
|
||||||
|
raise RuntimeError("start must be an integer or a tensor")
|
||||||
|
|
||||||
|
if not isinstance(length, (int, SymInt, Tensor)):
|
||||||
|
raise RuntimeError("length must be an integer or a tensor")
|
||||||
|
|
||||||
|
if layout == torch.strided:
|
||||||
|
if isinstance(start, Tensor) or isinstance(length, Tensor):
|
||||||
|
raise RuntimeError("start and length must be integers for the strided layout NT impl")
|
||||||
|
# TODO: switch to as_nested_tensor(tensor) when it is available
|
||||||
|
nt = as_nested_tensor(torch.unbind(tensor), layout=torch.strided).narrow(dim, start, length)
|
||||||
|
elif layout == torch.jagged:
|
||||||
|
if dim != 1:
|
||||||
|
raise RuntimeError("jagged layout only supports dim=1")
|
||||||
|
|
||||||
|
from torch.nested._internal.nested_tensor import jagged_from_tensor_and_lengths
|
||||||
|
|
||||||
|
if isinstance(start, (int, SymInt)):
|
||||||
|
start = torch.tensor([start], device=tensor.device, dtype=torch.int64)
|
||||||
|
|
||||||
|
if isinstance(length, (int, SymInt)):
|
||||||
|
length = torch.tensor([length], device=tensor.device, dtype=torch.int64)
|
||||||
|
|
||||||
|
nt, _, _ = jagged_from_tensor_and_lengths(tensor, start, length)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Specified layout is unsupported for nested narrow: {layout}")
|
||||||
|
|
||||||
|
return nt
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._C import DispatchKey, DispatchKeySet
|
from torch._C import DispatchKey, DispatchKeySet
|
||||||
|
from torch._prims_common import is_expandable_to
|
||||||
from torch.fx.experimental.symbolic_shapes import has_free_symbols
|
from torch.fx.experimental.symbolic_shapes import has_free_symbols
|
||||||
from torch.utils.weak import WeakTensorKeyDictionary
|
from torch.utils.weak import WeakTensorKeyDictionary
|
||||||
from typing import * # noqa: F403
|
from typing import * # noqa: F403
|
||||||
|
|
@ -21,6 +22,7 @@ def get_tensor_id(tensor, *, coeff=1):
|
||||||
class NestedTensor(torch.Tensor):
|
class NestedTensor(torch.Tensor):
|
||||||
_values: torch.Tensor # type: ignore[assignment]
|
_values: torch.Tensor # type: ignore[assignment]
|
||||||
_offsets: torch.Tensor
|
_offsets: torch.Tensor
|
||||||
|
_lengths: Optional[torch.Tensor]
|
||||||
# NOTE [ Singleton ints for ragged sizes and strides ]
|
# NOTE [ Singleton ints for ragged sizes and strides ]
|
||||||
#
|
#
|
||||||
# Jagged layout tensors are tensors that represent a n-dim tensor with a
|
# Jagged layout tensors are tensors that represent a n-dim tensor with a
|
||||||
|
|
@ -46,6 +48,7 @@ class NestedTensor(torch.Tensor):
|
||||||
values,
|
values,
|
||||||
offsets,
|
offsets,
|
||||||
*,
|
*,
|
||||||
|
lengths=None,
|
||||||
ragged_size=None,
|
ragged_size=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
@ -69,7 +72,7 @@ class NestedTensor(torch.Tensor):
|
||||||
)
|
)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def __init__(self, values, offsets, *, ragged_size=None, **kwargs):
|
def __init__(self, values, offsets, *, lengths=None, ragged_size=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Only support jagged for now.
|
# Only support jagged for now.
|
||||||
assert offsets is not None
|
assert offsets is not None
|
||||||
|
|
@ -82,7 +85,10 @@ class NestedTensor(torch.Tensor):
|
||||||
# we perform operations on fake nested tensors.
|
# we perform operations on fake nested tensors.
|
||||||
# Calling get_tensor_id won't work in those cases because we want
|
# Calling get_tensor_id won't work in those cases because we want
|
||||||
# the existing symbolic ragged_size to be propagated.
|
# the existing symbolic ragged_size to be propagated.
|
||||||
|
if lengths is None:
|
||||||
ragged_size = get_tensor_id(offsets, coeff=1)
|
ragged_size = get_tensor_id(offsets, coeff=1)
|
||||||
|
else:
|
||||||
|
ragged_size = get_tensor_id(lengths, coeff=1)
|
||||||
B = offsets.shape[0] - 1
|
B = offsets.shape[0] - 1
|
||||||
Ds = values.shape[1:]
|
Ds = values.shape[1:]
|
||||||
self._size = (B, ragged_size, *Ds)
|
self._size = (B, ragged_size, *Ds)
|
||||||
|
|
@ -97,6 +103,7 @@ class NestedTensor(torch.Tensor):
|
||||||
)
|
)
|
||||||
self._values = values
|
self._values = values
|
||||||
self._offsets = offsets
|
self._offsets = offsets
|
||||||
|
self._lengths = lengths
|
||||||
|
|
||||||
def values(self):
|
def values(self):
|
||||||
return self._values
|
return self._values
|
||||||
|
|
@ -104,6 +111,9 @@ class NestedTensor(torch.Tensor):
|
||||||
def offsets(self):
|
def offsets(self):
|
||||||
return self._offsets
|
return self._offsets
|
||||||
|
|
||||||
|
def lengths(self):
|
||||||
|
return self._lengths
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
# We should implement this in torch/_tensor_str.py instead
|
# We should implement this in torch/_tensor_str.py instead
|
||||||
grad_fn_str = (
|
grad_fn_str = (
|
||||||
|
|
@ -111,7 +121,7 @@ class NestedTensor(torch.Tensor):
|
||||||
)
|
)
|
||||||
if self.grad_fn:
|
if self.grad_fn:
|
||||||
grad_fn_str = f", grad_fn={self.grad_fn}"
|
grad_fn_str = f", grad_fn={self.grad_fn}"
|
||||||
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str})"
|
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._lengths is None})"
|
||||||
|
|
||||||
def __reduce_ex__(self, proto):
|
def __reduce_ex__(self, proto):
|
||||||
state = torch._utils._get_obj_state(self)
|
state = torch._utils._get_obj_state(self)
|
||||||
|
|
@ -131,13 +141,20 @@ class NestedTensor(torch.Tensor):
|
||||||
"requires_grad": self.requires_grad,
|
"requires_grad": self.requires_grad,
|
||||||
"ragged_size": self._size[self._ragged_idx],
|
"ragged_size": self._size[self._ragged_idx],
|
||||||
}
|
}
|
||||||
return ["_values", "_offsets"], ctx
|
inner_tensors = ["_values", "_offsets"]
|
||||||
|
if self._lengths is not None:
|
||||||
|
inner_tensors.append("_lengths")
|
||||||
|
return inner_tensors, ctx
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __tensor_unflatten__(inner_tensors: Dict, meta):
|
def __tensor_unflatten__(inner_tensors: Dict, meta):
|
||||||
assert len(inner_tensors) == 2
|
assert len(inner_tensors) >= 2 and len(inner_tensors) <= 3
|
||||||
values = inner_tensors["_values"]
|
values = inner_tensors["_values"]
|
||||||
offsets = inner_tensors["_offsets"]
|
offsets = inner_tensors["_offsets"]
|
||||||
|
if "_lengths" in inner_tensors and inner_tensors["_lengths"] is not None:
|
||||||
|
lengths = inner_tensors["_lengths"]
|
||||||
|
else:
|
||||||
|
lengths = None
|
||||||
|
|
||||||
# NOTE [ Storing symbolic values as plain attributes on subclasses ]
|
# NOTE [ Storing symbolic values as plain attributes on subclasses ]
|
||||||
#
|
#
|
||||||
|
|
@ -173,6 +190,7 @@ class NestedTensor(torch.Tensor):
|
||||||
return NestedTensor(
|
return NestedTensor(
|
||||||
values,
|
values,
|
||||||
offsets=offsets,
|
offsets=offsets,
|
||||||
|
lengths=lengths,
|
||||||
ragged_size=meta["ragged_size"],
|
ragged_size=meta["ragged_size"],
|
||||||
requires_grad=meta["requires_grad"],
|
requires_grad=meta["requires_grad"],
|
||||||
)
|
)
|
||||||
|
|
@ -232,6 +250,18 @@ class ViewNestedFromBuffer(torch.autograd.Function):
|
||||||
return gO.values(), None, None
|
return gO.values(), None, None
|
||||||
|
|
||||||
|
|
||||||
|
# Not actually a view!
|
||||||
|
# NOTE: @jbschlosser is working on making it a view
|
||||||
|
class ViewNonContiguousNestedFromBuffer(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor): # type: ignore[override]
|
||||||
|
return NestedTensor(values.detach(), offsets=offsets, lengths=lengths)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, gO: NestedTensor): # type: ignore[override]
|
||||||
|
return gO.values(), None, None
|
||||||
|
|
||||||
|
|
||||||
# Need to make it obvious that users should be passing in offsets
|
# Need to make it obvious that users should be passing in offsets
|
||||||
def jagged_from_list(
|
def jagged_from_list(
|
||||||
tensors: List[torch.Tensor],
|
tensors: List[torch.Tensor],
|
||||||
|
|
@ -285,5 +315,65 @@ def jagged_from_list(
|
||||||
return ViewNestedFromBuffer.apply(values, offsets), offsets # type: ignore[call-overload]
|
return ViewNestedFromBuffer.apply(values, offsets), offsets # type: ignore[call-overload]
|
||||||
|
|
||||||
|
|
||||||
|
def jagged_from_tensor_and_lengths(
|
||||||
|
tensor: torch.Tensor, starts: torch.Tensor, lengths: torch.Tensor
|
||||||
|
) -> Tuple[NestedTensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""Constructs a NestedTensor backed by jagged layout from a tensor, starts of sequences, and sequence lengths"""
|
||||||
|
batch_size = tensor.shape[0]
|
||||||
|
if is_expandable_to(starts.shape, (batch_size,)) and is_expandable_to(
|
||||||
|
lengths.shape, (batch_size,)
|
||||||
|
):
|
||||||
|
start_list = starts.expand(batch_size)
|
||||||
|
length_list = lengths.expand(batch_size)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"When constructing a jagged nested tensor using narrow(), "
|
||||||
|
"your start and length must be Tensors that broadcast to input.shape[0]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate jagged offsets
|
||||||
|
assert (
|
||||||
|
len(tensor.shape) >= 2
|
||||||
|
), "tensor must at least be 2D for the nested narrow op to work"
|
||||||
|
max_seq_len = tensor.shape[1]
|
||||||
|
offset_lengths = max_seq_len * torch.arange(
|
||||||
|
0, batch_size, dtype=torch.int64, device=tensor.device
|
||||||
|
)
|
||||||
|
# Jagged layout specifies that offsets are stored as int64 on the same device as values.
|
||||||
|
offsets = torch.cat(
|
||||||
|
[
|
||||||
|
start_list + offset_lengths,
|
||||||
|
(start_list[-1] + offset_lengths[-1] + length_list[-1]).unsqueeze(0),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape buffer to flatten the 1st and 2nd dimension (view used to enforce non-copy)
|
||||||
|
if len(tensor.shape) > 2:
|
||||||
|
values = tensor.view(-1, *tensor.shape[2:])
|
||||||
|
else:
|
||||||
|
values = tensor.view(-1)
|
||||||
|
|
||||||
|
# Check if offsets and lengths make it possibly contiguous and return a regular NT
|
||||||
|
is_contiguous = True
|
||||||
|
orig_dim = tensor.shape[1]
|
||||||
|
if torch.any(length_list[1:-1].ne(orig_dim)):
|
||||||
|
is_contiguous = False
|
||||||
|
if torch.any(offsets[1:-2].diff().ne(orig_dim)):
|
||||||
|
is_contiguous = False
|
||||||
|
if offsets[0] + length_list[0] != orig_dim:
|
||||||
|
is_contiguous = False
|
||||||
|
|
||||||
|
if is_contiguous:
|
||||||
|
return (
|
||||||
|
ViewNestedFromBuffer.apply(
|
||||||
|
values[offsets[0] : offsets[-1]], offsets - offsets[0]
|
||||||
|
),
|
||||||
|
offsets,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ViewNonContiguousNestedFromBuffer.apply(values, offsets, length_list), offsets, length_list # type: ignore[call-overload]
|
||||||
|
|
||||||
|
|
||||||
def buffer_from_jagged(jagged):
|
def buffer_from_jagged(jagged):
|
||||||
return ViewBufferFromNested.apply(jagged)
|
return ViewBufferFromNested.apply(jagged)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import functools
|
import functools
|
||||||
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
@ -65,7 +66,11 @@ def check_schema(schema_str: str, func, *args, **kwargs) -> None:
|
||||||
|
|
||||||
arg_type_check_fns = {
|
arg_type_check_fns = {
|
||||||
"t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
|
"t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
|
||||||
"jt": lambda x: isinstance(x, NestedTensor),
|
"jt": lambda x: isinstance(x, NestedTensor)
|
||||||
|
and x._lengths is None, # ops with "jt" require contiguous JT only
|
||||||
|
"jt_all": lambda x: isinstance(
|
||||||
|
x, NestedTensor
|
||||||
|
), # ops with "jt_all" can accept all kinds of JT
|
||||||
"any": lambda x: True,
|
"any": lambda x: True,
|
||||||
}
|
}
|
||||||
for i, named_arg_type in enumerate(named_arg_types):
|
for i, named_arg_type in enumerate(named_arg_types):
|
||||||
|
|
@ -297,7 +302,7 @@ def jagged_torch_function(func, *args, **kwargs):
|
||||||
torch.ops.aten.sym_stride.default,
|
torch.ops.aten.sym_stride.default,
|
||||||
torch.ops.aten.sym_storage_offset.default,
|
torch.ops.aten.sym_storage_offset.default,
|
||||||
],
|
],
|
||||||
"self: jt",
|
"self: jt_all",
|
||||||
)
|
)
|
||||||
def tensor_attr_supported_getter(func, *args, **kwargs):
|
def tensor_attr_supported_getter(func, *args, **kwargs):
|
||||||
if func == torch.ops.aten.is_non_overlapping_and_dense.default:
|
if func == torch.ops.aten.is_non_overlapping_and_dense.default:
|
||||||
|
|
@ -310,23 +315,25 @@ def tensor_attr_supported_getter(func, *args, **kwargs):
|
||||||
return len(args[0]._size)
|
return len(args[0]._size)
|
||||||
|
|
||||||
if func == torch.ops.aten.sym_numel.default:
|
if func == torch.ops.aten.sym_numel.default:
|
||||||
|
if args[0]._lengths is not None:
|
||||||
|
return int(sum(args[0]._lengths) * math.prod(args[0]._size[2:]))
|
||||||
return args[0]._values.numel()
|
return args[0]._values.numel()
|
||||||
|
|
||||||
if func == torch.ops.aten.sym_stride.default:
|
if func == torch.ops.aten.sym_stride.default:
|
||||||
return args[0]._strides
|
return args[0]._strides
|
||||||
|
|
||||||
if func == torch.ops.aten.sym_storage_offset.default:
|
if func == torch.ops.aten.sym_storage_offset.default:
|
||||||
return 0
|
return args[0]._values.storage_offset()
|
||||||
|
|
||||||
|
|
||||||
@register_jagged_func(torch.ops.prim.layout.default, "self: jt")
|
@register_jagged_func(torch.ops.prim.layout.default, "self: jt_all")
|
||||||
def prim_layout_default(func, *args, **kwargs):
|
def prim_layout_default(func, *args, **kwargs):
|
||||||
return torch.jagged
|
return torch.jagged
|
||||||
|
|
||||||
|
|
||||||
@register_jagged_func(
|
@register_jagged_func(
|
||||||
[torch.ops.aten.size.default],
|
[torch.ops.aten.size.default],
|
||||||
"self: jt",
|
"self: jt_all",
|
||||||
)
|
)
|
||||||
def tensor_attr_unsupported_getter(func, *args, **kwargs):
|
def tensor_attr_unsupported_getter(func, *args, **kwargs):
|
||||||
if func == torch.ops.aten.size.default:
|
if func == torch.ops.aten.size.default:
|
||||||
|
|
@ -336,7 +343,7 @@ def tensor_attr_unsupported_getter(func, *args, **kwargs):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.is_contiguous.default, "self: jt")
|
@register_jagged_func(torch.ops.aten.is_contiguous.default, "self: jt_all")
|
||||||
def is_contiguous_general(func, *args, **kwargs):
|
def is_contiguous_general(func, *args, **kwargs):
|
||||||
from torch._prims_common import is_contiguous_for_memory_format
|
from torch._prims_common import is_contiguous_for_memory_format
|
||||||
|
|
||||||
|
|
@ -344,16 +351,21 @@ def is_contiguous_general(func, *args, **kwargs):
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
inp = new_kwargs.pop("input")
|
inp = new_kwargs.pop("input")
|
||||||
|
|
||||||
|
# If created from narrow() check for lengths
|
||||||
|
if inp.lengths() is not None:
|
||||||
|
return False
|
||||||
|
|
||||||
new_kwargs["memory_format"] = new_kwargs.get(
|
new_kwargs["memory_format"] = new_kwargs.get(
|
||||||
"memory_format", torch.contiguous_format
|
"memory_format", torch.contiguous_format
|
||||||
)
|
)
|
||||||
if new_kwargs["memory_format"] == torch.preserve_format:
|
if new_kwargs["memory_format"] == torch.preserve_format:
|
||||||
return True
|
return True
|
||||||
return is_contiguous_for_memory_format(inp, **new_kwargs)
|
return is_contiguous_for_memory_format(inp.values(), **new_kwargs)
|
||||||
|
|
||||||
|
|
||||||
register_jagged_func(
|
register_jagged_func(
|
||||||
torch.ops.aten.is_contiguous.memory_format, "self: jt, memory_format: any?"
|
torch.ops.aten.is_contiguous.memory_format, "self: jt_all, memory_format: any?"
|
||||||
)(is_contiguous_general)
|
)(is_contiguous_general)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -508,7 +520,7 @@ def split_with_sizes_default(func, *args, **kwargs):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.unbind.int, "self: jt, dim: any?")
|
@register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?")
|
||||||
def unbind_int(func, *args, **kwargs):
|
def unbind_int(func, *args, **kwargs):
|
||||||
# Note that this specializes on the length of the offsets
|
# Note that this specializes on the length of the offsets
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function(
|
||||||
|
|
@ -520,10 +532,15 @@ def unbind_int(func, *args, **kwargs):
|
||||||
raise RuntimeError("unbind(): only supported for NestedTensor on dim=0")
|
raise RuntimeError("unbind(): only supported for NestedTensor on dim=0")
|
||||||
|
|
||||||
inp = new_kwargs.pop("input")
|
inp = new_kwargs.pop("input")
|
||||||
values = inp._values
|
values = inp.values()
|
||||||
offsets = inp.offsets()
|
offsets = inp.offsets()
|
||||||
|
lengths = inp.lengths()
|
||||||
|
|
||||||
|
if lengths is None:
|
||||||
return torch.split(values, offsets.diff().tolist())
|
return torch.split(values, offsets.diff().tolist())
|
||||||
|
return [
|
||||||
|
values[offsets[i] : (offsets[i] + lengths[i])] for i in range(lengths.shape[0])
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt, dim: any")
|
@register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt, dim: any")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user