mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Implement narrow from a regular tensor to jagged tensor (#112770)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112770 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
3700894099
commit
1aece432ba
|
|
@ -14,7 +14,11 @@ from torch._dynamo.testing import normalize_gm
|
|||
from torch._higher_order_ops.wrap import wrap
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
|
||||
from torch.nested._internal.nested_tensor import jagged_from_list, ViewBufferFromNested
|
||||
from torch.nested._internal.nested_tensor import (
|
||||
jagged_from_list,
|
||||
jagged_from_tensor_and_lengths,
|
||||
ViewBufferFromNested,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
|
||||
|
||||
|
|
@ -799,6 +803,19 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
|
|||
)
|
||||
return jagged_from_list(out, offsets)
|
||||
|
||||
def _get_nc_jagged_tensor(self, inner_dim, starts, lengths, requires_grad=True):
|
||||
# Makes a jagged tensor with N constituent tensors with size
|
||||
# as specified ((S0, S1, S2), D)
|
||||
max_dim = (starts + lengths).max()
|
||||
values_tensor = torch.randn(
|
||||
starts.shape[0],
|
||||
max_dim.item(),
|
||||
inner_dim,
|
||||
requires_grad=requires_grad,
|
||||
dtype=torch.float64,
|
||||
)
|
||||
return jagged_from_tensor_and_lengths(values_tensor, starts, lengths)
|
||||
|
||||
def _check_recompiles(self, fn, inputs1, inputs2, recompiles):
|
||||
compile_count = [0]
|
||||
|
||||
|
|
|
|||
|
|
@ -3325,6 +3325,66 @@ class TestNestedTensorSubclass(NestedTestCase):
|
|||
nt, _ = jagged_from_list(test_tensor_list, None)
|
||||
_ = torch.nn.functional.layer_norm(nt, (nt.shape[-2], nt.shape[-1]))
|
||||
|
||||
def test_narrow(self, device):
|
||||
starts = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
|
||||
lengths = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)
|
||||
nt = torch.nested.narrow(
|
||||
torch.arange(0, 10, device=device, dtype=torch.int64).unsqueeze(0).expand(5, -1).clone().detach(),
|
||||
1,
|
||||
starts,
|
||||
lengths,
|
||||
layout=torch.jagged
|
||||
)
|
||||
|
||||
# TODO: Use this approach when unbind is functional
|
||||
# unbinded_nt = nt.unbind()
|
||||
# for i in range(starts.shape[0]):
|
||||
# self.assertEqual(torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64), unbinded_nt[i])
|
||||
for i in range(starts.shape[0]):
|
||||
self.assertEqual(
|
||||
torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64),
|
||||
nt.values()[nt.offsets()[i]:(nt.offsets()[i] + nt.lengths()[i])]
|
||||
)
|
||||
|
||||
def test_is_contiguous(self, device):
|
||||
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
||||
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
||||
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
||||
nt_contiguous, _ = jagged_from_list([a, b, c], None)
|
||||
|
||||
starts_nc = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
|
||||
lengths_nc = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)
|
||||
narrow_base = torch.arange(0, 10, device=device, dtype=torch.int64).unsqueeze(0).expand(5, -1).clone()
|
||||
nt_noncontiguous = torch.nested.narrow(
|
||||
narrow_base,
|
||||
1,
|
||||
starts_nc,
|
||||
lengths_nc,
|
||||
layout=torch.jagged
|
||||
)
|
||||
|
||||
starts_c = torch.tensor([1, 0, 0, 0, 0], device=device, dtype=torch.int64)
|
||||
lengths_c = torch.tensor([9, 10, 10, 10, 8], device=device, dtype=torch.int64)
|
||||
nt_contiguous_narrow = torch.nested.narrow(
|
||||
narrow_base,
|
||||
1,
|
||||
starts_c,
|
||||
lengths_c,
|
||||
layout=torch.jagged
|
||||
)
|
||||
|
||||
# Test contiguous case
|
||||
assert nt_contiguous.is_contiguous()
|
||||
|
||||
# Test narrow case
|
||||
assert not nt_noncontiguous.is_contiguous()
|
||||
assert nt_contiguous_narrow.is_contiguous()
|
||||
|
||||
# Test querying by memory_format
|
||||
self.assertTrue(nt_contiguous.is_contiguous(memory_format=torch.contiguous_format))
|
||||
self.assertTrue(not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format))
|
||||
self.assertTrue(nt_contiguous_narrow.is_contiguous(memory_format=torch.contiguous_format))
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestNestedTensor)
|
||||
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
|
||||
|
|
|
|||
|
|
@ -309,7 +309,7 @@ class MetaConverter:
|
|||
from torch._dynamo.source import AttrSource
|
||||
from torch.fx.experimental.symbolic_shapes import DimDynamic
|
||||
|
||||
if shape_env and not t.is_nested:
|
||||
if shape_env and not t.is_nested and not t._base.is_nested:
|
||||
base_dynamic_dims = [DimDynamic.STATIC] * t._base.dim()
|
||||
else:
|
||||
base_dynamic_dims = None
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import SymInt, Tensor
|
||||
from torch._C import _add_docstr, _nested # type: ignore[attr-defined]
|
||||
|
||||
from torch.types import _device as Device, _dtype as DType
|
||||
|
|
@ -10,6 +10,7 @@ __all__ = [
|
|||
"to_padded_tensor",
|
||||
"as_nested_tensor",
|
||||
"nested_tensor",
|
||||
"narrow",
|
||||
]
|
||||
|
||||
# Nested Tensor constructor functions
|
||||
|
|
@ -187,3 +188,69 @@ Example::
|
|||
return nt
|
||||
else:
|
||||
raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}")
|
||||
|
||||
|
||||
def narrow(tensor: Tensor, dim: int, start: Union[int, Tensor], length: Union[int, Tensor], layout=torch.strided) -> Tensor:
|
||||
r"""
|
||||
Constructs a nested tensor (which might be a view) from :attr:`tensor`, a strided tensor. This follows
|
||||
similar semantics to torch.Tensor.narrow, where in the :attr:`dim`-th dimension the new nested tensor
|
||||
shows only the elements in the interval `[start, start+length)`. As nested representations
|
||||
allow for a different `start` and `length` at each 'row' of that dimension, :attr:`start` and :attr:`length`
|
||||
can also be tensors of shape `tensor.shape[0]`.
|
||||
|
||||
There's some differences depending on the layout you use for the nested tensor. If using strided layout,
|
||||
torch.narrow will do a copy of the narrowed data into a contiguous NT with strided layout, while
|
||||
jagged layout narrow() will create a non-contiguous view of your original strided tensor. This particular
|
||||
representation is really useful for representing kv-caches in Transformer models, as specialized
|
||||
SDPA kernels can deal with format easily, resulting in performance improvements.
|
||||
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): a strided tensor, which will be used as the underlying data
|
||||
for the nested tensor if using the jagged layout or will be copied for the strided layout.
|
||||
dim (int): the dimension where narrow will be applied. Only `dim=1` is supported for the
|
||||
jagged layout, while strided supports all dim
|
||||
start (Union[int, :class:`torch.Tensor`]): starting element for the narrow operation
|
||||
length (Union[int, :class:`torch.Tensor`]): number of elements taken during the narrow op
|
||||
|
||||
Keyword arguments:
|
||||
layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor.
|
||||
Only strided and jagged layouts are supported. Default: if None, the strided layout.
|
||||
|
||||
Example::
|
||||
|
||||
>>> starts = torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64)
|
||||
>>> lengths = torch.tensor([3, 2, 2, 1, 5], dtype=torch.int64)
|
||||
>>> narrow_base = torch.randn(5, 10, 20)
|
||||
>>> nt_narrowed = torch.nested.narrow(narrow_base, 1, starts, lengths, layout=torch.jagged)
|
||||
>>> nt_narrowed.is_contiguous()
|
||||
False
|
||||
"""
|
||||
if not isinstance(start, (int, SymInt, Tensor)):
|
||||
raise RuntimeError("start must be an integer or a tensor")
|
||||
|
||||
if not isinstance(length, (int, SymInt, Tensor)):
|
||||
raise RuntimeError("length must be an integer or a tensor")
|
||||
|
||||
if layout == torch.strided:
|
||||
if isinstance(start, Tensor) or isinstance(length, Tensor):
|
||||
raise RuntimeError("start and length must be integers for the strided layout NT impl")
|
||||
# TODO: switch to as_nested_tensor(tensor) when it is available
|
||||
nt = as_nested_tensor(torch.unbind(tensor), layout=torch.strided).narrow(dim, start, length)
|
||||
elif layout == torch.jagged:
|
||||
if dim != 1:
|
||||
raise RuntimeError("jagged layout only supports dim=1")
|
||||
|
||||
from torch.nested._internal.nested_tensor import jagged_from_tensor_and_lengths
|
||||
|
||||
if isinstance(start, (int, SymInt)):
|
||||
start = torch.tensor([start], device=tensor.device, dtype=torch.int64)
|
||||
|
||||
if isinstance(length, (int, SymInt)):
|
||||
length = torch.tensor([length], device=tensor.device, dtype=torch.int64)
|
||||
|
||||
nt, _, _ = jagged_from_tensor_and_lengths(tensor, start, length)
|
||||
else:
|
||||
raise RuntimeError(f"Specified layout is unsupported for nested narrow: {layout}")
|
||||
|
||||
return nt
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import Tuple
|
|||
|
||||
import torch
|
||||
from torch._C import DispatchKey, DispatchKeySet
|
||||
from torch._prims_common import is_expandable_to
|
||||
from torch.fx.experimental.symbolic_shapes import has_free_symbols
|
||||
from torch.utils.weak import WeakTensorKeyDictionary
|
||||
from typing import * # noqa: F403
|
||||
|
|
@ -21,6 +22,7 @@ def get_tensor_id(tensor, *, coeff=1):
|
|||
class NestedTensor(torch.Tensor):
|
||||
_values: torch.Tensor # type: ignore[assignment]
|
||||
_offsets: torch.Tensor
|
||||
_lengths: Optional[torch.Tensor]
|
||||
# NOTE [ Singleton ints for ragged sizes and strides ]
|
||||
#
|
||||
# Jagged layout tensors are tensors that represent a n-dim tensor with a
|
||||
|
|
@ -46,6 +48,7 @@ class NestedTensor(torch.Tensor):
|
|||
values,
|
||||
offsets,
|
||||
*,
|
||||
lengths=None,
|
||||
ragged_size=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
|
@ -69,7 +72,7 @@ class NestedTensor(torch.Tensor):
|
|||
)
|
||||
return r
|
||||
|
||||
def __init__(self, values, offsets, *, ragged_size=None, **kwargs):
|
||||
def __init__(self, values, offsets, *, lengths=None, ragged_size=None, **kwargs):
|
||||
super().__init__()
|
||||
# Only support jagged for now.
|
||||
assert offsets is not None
|
||||
|
|
@ -82,7 +85,10 @@ class NestedTensor(torch.Tensor):
|
|||
# we perform operations on fake nested tensors.
|
||||
# Calling get_tensor_id won't work in those cases because we want
|
||||
# the existing symbolic ragged_size to be propagated.
|
||||
if lengths is None:
|
||||
ragged_size = get_tensor_id(offsets, coeff=1)
|
||||
else:
|
||||
ragged_size = get_tensor_id(lengths, coeff=1)
|
||||
B = offsets.shape[0] - 1
|
||||
Ds = values.shape[1:]
|
||||
self._size = (B, ragged_size, *Ds)
|
||||
|
|
@ -97,6 +103,7 @@ class NestedTensor(torch.Tensor):
|
|||
)
|
||||
self._values = values
|
||||
self._offsets = offsets
|
||||
self._lengths = lengths
|
||||
|
||||
def values(self):
|
||||
return self._values
|
||||
|
|
@ -104,6 +111,9 @@ class NestedTensor(torch.Tensor):
|
|||
def offsets(self):
|
||||
return self._offsets
|
||||
|
||||
def lengths(self):
|
||||
return self._lengths
|
||||
|
||||
def __repr__(self):
|
||||
# We should implement this in torch/_tensor_str.py instead
|
||||
grad_fn_str = (
|
||||
|
|
@ -111,7 +121,7 @@ class NestedTensor(torch.Tensor):
|
|||
)
|
||||
if self.grad_fn:
|
||||
grad_fn_str = f", grad_fn={self.grad_fn}"
|
||||
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str})"
|
||||
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._lengths is None})"
|
||||
|
||||
def __reduce_ex__(self, proto):
|
||||
state = torch._utils._get_obj_state(self)
|
||||
|
|
@ -131,13 +141,20 @@ class NestedTensor(torch.Tensor):
|
|||
"requires_grad": self.requires_grad,
|
||||
"ragged_size": self._size[self._ragged_idx],
|
||||
}
|
||||
return ["_values", "_offsets"], ctx
|
||||
inner_tensors = ["_values", "_offsets"]
|
||||
if self._lengths is not None:
|
||||
inner_tensors.append("_lengths")
|
||||
return inner_tensors, ctx
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors: Dict, meta):
|
||||
assert len(inner_tensors) == 2
|
||||
assert len(inner_tensors) >= 2 and len(inner_tensors) <= 3
|
||||
values = inner_tensors["_values"]
|
||||
offsets = inner_tensors["_offsets"]
|
||||
if "_lengths" in inner_tensors and inner_tensors["_lengths"] is not None:
|
||||
lengths = inner_tensors["_lengths"]
|
||||
else:
|
||||
lengths = None
|
||||
|
||||
# NOTE [ Storing symbolic values as plain attributes on subclasses ]
|
||||
#
|
||||
|
|
@ -173,6 +190,7 @@ class NestedTensor(torch.Tensor):
|
|||
return NestedTensor(
|
||||
values,
|
||||
offsets=offsets,
|
||||
lengths=lengths,
|
||||
ragged_size=meta["ragged_size"],
|
||||
requires_grad=meta["requires_grad"],
|
||||
)
|
||||
|
|
@ -232,6 +250,18 @@ class ViewNestedFromBuffer(torch.autograd.Function):
|
|||
return gO.values(), None, None
|
||||
|
||||
|
||||
# Not actually a view!
|
||||
# NOTE: @jbschlosser is working on making it a view
|
||||
class ViewNonContiguousNestedFromBuffer(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor): # type: ignore[override]
|
||||
return NestedTensor(values.detach(), offsets=offsets, lengths=lengths)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO: NestedTensor): # type: ignore[override]
|
||||
return gO.values(), None, None
|
||||
|
||||
|
||||
# Need to make it obvious that users should be passing in offsets
|
||||
def jagged_from_list(
|
||||
tensors: List[torch.Tensor],
|
||||
|
|
@ -285,5 +315,65 @@ def jagged_from_list(
|
|||
return ViewNestedFromBuffer.apply(values, offsets), offsets # type: ignore[call-overload]
|
||||
|
||||
|
||||
def jagged_from_tensor_and_lengths(
|
||||
tensor: torch.Tensor, starts: torch.Tensor, lengths: torch.Tensor
|
||||
) -> Tuple[NestedTensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Constructs a NestedTensor backed by jagged layout from a tensor, starts of sequences, and sequence lengths"""
|
||||
batch_size = tensor.shape[0]
|
||||
if is_expandable_to(starts.shape, (batch_size,)) and is_expandable_to(
|
||||
lengths.shape, (batch_size,)
|
||||
):
|
||||
start_list = starts.expand(batch_size)
|
||||
length_list = lengths.expand(batch_size)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"When constructing a jagged nested tensor using narrow(), "
|
||||
"your start and length must be Tensors that broadcast to input.shape[0]"
|
||||
)
|
||||
|
||||
# Calculate jagged offsets
|
||||
assert (
|
||||
len(tensor.shape) >= 2
|
||||
), "tensor must at least be 2D for the nested narrow op to work"
|
||||
max_seq_len = tensor.shape[1]
|
||||
offset_lengths = max_seq_len * torch.arange(
|
||||
0, batch_size, dtype=torch.int64, device=tensor.device
|
||||
)
|
||||
# Jagged layout specifies that offsets are stored as int64 on the same device as values.
|
||||
offsets = torch.cat(
|
||||
[
|
||||
start_list + offset_lengths,
|
||||
(start_list[-1] + offset_lengths[-1] + length_list[-1]).unsqueeze(0),
|
||||
]
|
||||
)
|
||||
|
||||
# Reshape buffer to flatten the 1st and 2nd dimension (view used to enforce non-copy)
|
||||
if len(tensor.shape) > 2:
|
||||
values = tensor.view(-1, *tensor.shape[2:])
|
||||
else:
|
||||
values = tensor.view(-1)
|
||||
|
||||
# Check if offsets and lengths make it possibly contiguous and return a regular NT
|
||||
is_contiguous = True
|
||||
orig_dim = tensor.shape[1]
|
||||
if torch.any(length_list[1:-1].ne(orig_dim)):
|
||||
is_contiguous = False
|
||||
if torch.any(offsets[1:-2].diff().ne(orig_dim)):
|
||||
is_contiguous = False
|
||||
if offsets[0] + length_list[0] != orig_dim:
|
||||
is_contiguous = False
|
||||
|
||||
if is_contiguous:
|
||||
return (
|
||||
ViewNestedFromBuffer.apply(
|
||||
values[offsets[0] : offsets[-1]], offsets - offsets[0]
|
||||
),
|
||||
offsets,
|
||||
None,
|
||||
)
|
||||
|
||||
return ViewNonContiguousNestedFromBuffer.apply(values, offsets, length_list), offsets, length_list # type: ignore[call-overload]
|
||||
|
||||
|
||||
def buffer_from_jagged(jagged):
|
||||
return ViewBufferFromNested.apply(jagged)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import functools
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -65,7 +66,11 @@ def check_schema(schema_str: str, func, *args, **kwargs) -> None:
|
|||
|
||||
arg_type_check_fns = {
|
||||
"t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
|
||||
"jt": lambda x: isinstance(x, NestedTensor),
|
||||
"jt": lambda x: isinstance(x, NestedTensor)
|
||||
and x._lengths is None, # ops with "jt" require contiguous JT only
|
||||
"jt_all": lambda x: isinstance(
|
||||
x, NestedTensor
|
||||
), # ops with "jt_all" can accept all kinds of JT
|
||||
"any": lambda x: True,
|
||||
}
|
||||
for i, named_arg_type in enumerate(named_arg_types):
|
||||
|
|
@ -297,7 +302,7 @@ def jagged_torch_function(func, *args, **kwargs):
|
|||
torch.ops.aten.sym_stride.default,
|
||||
torch.ops.aten.sym_storage_offset.default,
|
||||
],
|
||||
"self: jt",
|
||||
"self: jt_all",
|
||||
)
|
||||
def tensor_attr_supported_getter(func, *args, **kwargs):
|
||||
if func == torch.ops.aten.is_non_overlapping_and_dense.default:
|
||||
|
|
@ -310,23 +315,25 @@ def tensor_attr_supported_getter(func, *args, **kwargs):
|
|||
return len(args[0]._size)
|
||||
|
||||
if func == torch.ops.aten.sym_numel.default:
|
||||
if args[0]._lengths is not None:
|
||||
return int(sum(args[0]._lengths) * math.prod(args[0]._size[2:]))
|
||||
return args[0]._values.numel()
|
||||
|
||||
if func == torch.ops.aten.sym_stride.default:
|
||||
return args[0]._strides
|
||||
|
||||
if func == torch.ops.aten.sym_storage_offset.default:
|
||||
return 0
|
||||
return args[0]._values.storage_offset()
|
||||
|
||||
|
||||
@register_jagged_func(torch.ops.prim.layout.default, "self: jt")
|
||||
@register_jagged_func(torch.ops.prim.layout.default, "self: jt_all")
|
||||
def prim_layout_default(func, *args, **kwargs):
|
||||
return torch.jagged
|
||||
|
||||
|
||||
@register_jagged_func(
|
||||
[torch.ops.aten.size.default],
|
||||
"self: jt",
|
||||
"self: jt_all",
|
||||
)
|
||||
def tensor_attr_unsupported_getter(func, *args, **kwargs):
|
||||
if func == torch.ops.aten.size.default:
|
||||
|
|
@ -336,7 +343,7 @@ def tensor_attr_unsupported_getter(func, *args, **kwargs):
|
|||
)
|
||||
|
||||
|
||||
@register_jagged_func(torch.ops.aten.is_contiguous.default, "self: jt")
|
||||
@register_jagged_func(torch.ops.aten.is_contiguous.default, "self: jt_all")
|
||||
def is_contiguous_general(func, *args, **kwargs):
|
||||
from torch._prims_common import is_contiguous_for_memory_format
|
||||
|
||||
|
|
@ -344,16 +351,21 @@ def is_contiguous_general(func, *args, **kwargs):
|
|||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
inp = new_kwargs.pop("input")
|
||||
|
||||
# If created from narrow() check for lengths
|
||||
if inp.lengths() is not None:
|
||||
return False
|
||||
|
||||
new_kwargs["memory_format"] = new_kwargs.get(
|
||||
"memory_format", torch.contiguous_format
|
||||
)
|
||||
if new_kwargs["memory_format"] == torch.preserve_format:
|
||||
return True
|
||||
return is_contiguous_for_memory_format(inp, **new_kwargs)
|
||||
return is_contiguous_for_memory_format(inp.values(), **new_kwargs)
|
||||
|
||||
|
||||
register_jagged_func(
|
||||
torch.ops.aten.is_contiguous.memory_format, "self: jt, memory_format: any?"
|
||||
torch.ops.aten.is_contiguous.memory_format, "self: jt_all, memory_format: any?"
|
||||
)(is_contiguous_general)
|
||||
|
||||
|
||||
|
|
@ -508,7 +520,7 @@ def split_with_sizes_default(func, *args, **kwargs):
|
|||
]
|
||||
|
||||
|
||||
@register_jagged_func(torch.ops.aten.unbind.int, "self: jt, dim: any?")
|
||||
@register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?")
|
||||
def unbind_int(func, *args, **kwargs):
|
||||
# Note that this specializes on the length of the offsets
|
||||
_, new_kwargs = normalize_function(
|
||||
|
|
@ -520,10 +532,15 @@ def unbind_int(func, *args, **kwargs):
|
|||
raise RuntimeError("unbind(): only supported for NestedTensor on dim=0")
|
||||
|
||||
inp = new_kwargs.pop("input")
|
||||
values = inp._values
|
||||
values = inp.values()
|
||||
offsets = inp.offsets()
|
||||
lengths = inp.lengths()
|
||||
|
||||
if lengths is None:
|
||||
return torch.split(values, offsets.diff().tolist())
|
||||
return [
|
||||
values[offsets[i] : (offsets[i] + lengths[i])] for i in range(lengths.shape[0])
|
||||
]
|
||||
|
||||
|
||||
@register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt, dim: any")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user