mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Allow any single non-batch dim to be ragged for NJT (#137125)
Fixes #137512 Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly. It's possible before this PR to get an NJT of e.g. shape `(B, D, j0)` by constructing an NJT of shape `(B, j0, D)` and transposing it. This PR allows a user to go straight there without the transpose. The standard `torch.nested.nested_tensor(list)` constructor has been updated to support this. At the very least, this is useful for testing on transposed NJTs. I'm willing to make this functionality private if needed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/137125 Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
This commit is contained in:
parent
d1e2e81ede
commit
3abbde976d
|
|
@ -5709,6 +5709,104 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
|
||||||
(nt * 2).backward(torch.ones_like(nt))
|
(nt * 2).backward(torch.ones_like(nt))
|
||||||
self.assertEqual(orig_t.grad, torch.ones_like(orig_t) * 2)
|
self.assertEqual(orig_t.grad, torch.ones_like(orig_t) * 2)
|
||||||
|
|
||||||
|
@dtypes(torch.float32)
|
||||||
|
def test_construction_from_list(self, device, dtype):
|
||||||
|
from torch.fx.experimental.symbolic_shapes import is_nested_int
|
||||||
|
|
||||||
|
# success case: single ragged dim anywhere but the batch dim
|
||||||
|
for nt_dim in [2, 3, 4]:
|
||||||
|
for ragged_dim in range(1, nt_dim):
|
||||||
|
B = 6
|
||||||
|
shapes = [list(range(3, 3 + nt_dim - 1)) for _ in range(B)]
|
||||||
|
for b in range(B):
|
||||||
|
# subtract 1 to convert to component dim space
|
||||||
|
shapes[b][ragged_dim - 1] = torch.randint(
|
||||||
|
2, 9, (1,), device=device, dtype=torch.int64
|
||||||
|
).item()
|
||||||
|
|
||||||
|
components = [
|
||||||
|
torch.randn(shape, device=device, dtype=dtype) for shape in shapes
|
||||||
|
]
|
||||||
|
nt = torch.nested.nested_tensor(components, layout=torch.jagged)
|
||||||
|
|
||||||
|
self.assertEqual(nt.dim(), nt_dim)
|
||||||
|
self.assertEqual(nt._ragged_idx, ragged_dim)
|
||||||
|
for d in range(nt_dim):
|
||||||
|
self.assertEqual(d == ragged_dim, is_nested_int(nt.shape[d]))
|
||||||
|
|
||||||
|
# error case: empty list
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, "Cannot construct a nested tensor from an empty tensor list"
|
||||||
|
):
|
||||||
|
torch.nested.nested_tensor([], layout=torch.jagged)
|
||||||
|
|
||||||
|
# error case: list of zero-dim tensors
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"Cannot construct a nested tensor from a list of zero-dim tensors",
|
||||||
|
):
|
||||||
|
torch.nested.nested_tensor(
|
||||||
|
[
|
||||||
|
torch.tensor(3.0, device=device, dtype=dtype),
|
||||||
|
torch.tensor(4.0, device=device, dtype=dtype),
|
||||||
|
torch.tensor(5.0, device=device, dtype=dtype),
|
||||||
|
],
|
||||||
|
layout=torch.jagged,
|
||||||
|
)
|
||||||
|
|
||||||
|
# error case: multiple ragged dims
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"Cannot represent given tensor list as a nested tensor with the jagged layout",
|
||||||
|
):
|
||||||
|
torch.nested.nested_tensor(
|
||||||
|
[
|
||||||
|
torch.randn(2, 3, device=device, dtype=dtype),
|
||||||
|
torch.randn(4, 5, device=device, dtype=dtype),
|
||||||
|
],
|
||||||
|
layout=torch.jagged,
|
||||||
|
)
|
||||||
|
|
||||||
|
# error case: components on multiple devices
|
||||||
|
if "cuda" in device:
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"When constructing a nested tensor, all tensors in list must be on the same device",
|
||||||
|
):
|
||||||
|
torch.nested.nested_tensor(
|
||||||
|
[
|
||||||
|
torch.randn(2, 3, device=device, dtype=dtype),
|
||||||
|
torch.randn(2, 4, device="cpu", dtype=dtype),
|
||||||
|
],
|
||||||
|
layout=torch.jagged,
|
||||||
|
)
|
||||||
|
|
||||||
|
# error case: components with multiple dtypes
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"When constructing a nested tensor, all tensors in list must have the same dtype",
|
||||||
|
):
|
||||||
|
torch.nested.nested_tensor(
|
||||||
|
[
|
||||||
|
torch.randn(2, 3, device=device, dtype=dtype),
|
||||||
|
torch.randn(2, 4, device=device, dtype=torch.float64),
|
||||||
|
],
|
||||||
|
layout=torch.jagged,
|
||||||
|
)
|
||||||
|
|
||||||
|
# error case: components with multiple dims
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"When constructing a nested tensor, all tensors in list must have the same dim",
|
||||||
|
):
|
||||||
|
torch.nested.nested_tensor(
|
||||||
|
[
|
||||||
|
torch.randn(2, 3, device=device, dtype=dtype),
|
||||||
|
torch.randn(2, 3, 4, device=device, dtype=dtype),
|
||||||
|
],
|
||||||
|
layout=torch.jagged,
|
||||||
|
)
|
||||||
|
|
||||||
@dtypes(torch.double, torch.half)
|
@dtypes(torch.double, torch.half)
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
def test_device_dtype_transfer_updates_offsets(self, device, dtype):
|
def test_device_dtype_transfer_updates_offsets(self, device, dtype):
|
||||||
|
|
|
||||||
|
|
@ -419,6 +419,8 @@ def jagged_from_list(
|
||||||
) -> Tuple[NestedTensor, torch.Tensor]:
|
) -> Tuple[NestedTensor, torch.Tensor]:
|
||||||
"""Constructs a NestedTensor backed by jagged layout from a list of tensors"""
|
"""Constructs a NestedTensor backed by jagged layout from a list of tensors"""
|
||||||
|
|
||||||
|
if len(tensors) == 0:
|
||||||
|
raise RuntimeError("Cannot construct a nested tensor from an empty tensor list")
|
||||||
if not len(set(t.dtype for t in tensors)) == 1: # noqa: C401
|
if not len(set(t.dtype for t in tensors)) == 1: # noqa: C401
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"When constructing a nested tensor, all tensors in list must have the same dtype"
|
"When constructing a nested tensor, all tensors in list must have the same dtype"
|
||||||
|
|
@ -427,22 +429,40 @@ def jagged_from_list(
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"When constructing a nested tensor, all tensors in list must be on the same device"
|
"When constructing a nested tensor, all tensors in list must be on the same device"
|
||||||
)
|
)
|
||||||
|
if not len(set(t.dim() for t in tensors)) == 1: # noqa: C401
|
||||||
# Check that the NT is representable by the jagged layout.
|
|
||||||
# Jagged layout represents (B, *, D_0, D_1, ..., D_N), where the only
|
|
||||||
# raggedness allowed is for the single dim immediately adjacent to the batch dim.
|
|
||||||
sizes = [t.shape for t in tensors]
|
|
||||||
non_first_sizes = [s[1:] for s in sizes]
|
|
||||||
at_most_first_ragged = all(s == non_first_sizes[0] for s in non_first_sizes)
|
|
||||||
if not at_most_first_ragged:
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Cannot represent given tensor list as a nested tensor with the jagged layout. "
|
"When constructing a nested tensor, all tensors in list must have the same dim"
|
||||||
"Note that the jagged layout only represents shapes of the form "
|
)
|
||||||
"(B, *, D_0, D_1, ..., D_N), with only * allowed to be ragged."
|
component_dim = tensors[0].dim()
|
||||||
|
if component_dim == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot construct a nested tensor from a list of zero-dim tensors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check that the NT is representable by the jagged layout, which
|
||||||
|
# allows for a single ragged dimension after the batch dim.
|
||||||
|
# e.g. (B, *, D_0, ..., D_N), (B, D_0, *, ..., D_N), etc.
|
||||||
|
sizes = [t.shape for t in tensors]
|
||||||
|
ragged_idx = None
|
||||||
|
for d in range(component_dim):
|
||||||
|
dim_is_ragged = any(size[d] != sizes[0][d] for size in sizes)
|
||||||
|
if dim_is_ragged:
|
||||||
|
if ragged_idx is None:
|
||||||
|
# add 1 to convert to outer NJT dim space
|
||||||
|
ragged_idx = d + 1
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot represent given tensor list as a nested tensor with the jagged layout. "
|
||||||
|
"Note that the jagged layout only allows for a single ragged dimension. "
|
||||||
|
"For example: (B, *, D_0, D_1, ..., D_N), with ragged * dim."
|
||||||
|
)
|
||||||
|
|
||||||
|
# allow for a rectangular NJT and default the ragged dim next to the batch dim
|
||||||
|
if ragged_idx is None:
|
||||||
|
ragged_idx = 1
|
||||||
|
|
||||||
# Set properties appropriately.
|
# Set properties appropriately.
|
||||||
values = torch.cat(tensors, dim=0)
|
values = torch.cat(tensors, dim=(ragged_idx - 1))
|
||||||
to_kwargs = {}
|
to_kwargs = {}
|
||||||
if device is not None:
|
if device is not None:
|
||||||
to_kwargs["device"] = device
|
to_kwargs["device"] = device
|
||||||
|
|
@ -458,15 +478,21 @@ def jagged_from_list(
|
||||||
offsets = torch.cat(
|
offsets = torch.cat(
|
||||||
[
|
[
|
||||||
torch.zeros(1, dtype=torch.int64, device=values.device),
|
torch.zeros(1, dtype=torch.int64, device=values.device),
|
||||||
torch.tensor([s[0] for s in sizes], device=values.device).cumsum(dim=0),
|
torch.tensor(
|
||||||
|
[s[ragged_idx - 1] for s in sizes], device=values.device
|
||||||
|
).cumsum(dim=0),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute this now since it's easy
|
# compute this now since it's easy
|
||||||
min_seqlen = min(t.shape[0] for t in tensors)
|
min_seqlen = min(t.shape[ragged_idx - 1] for t in tensors)
|
||||||
max_seqlen = max(t.shape[0] for t in tensors)
|
max_seqlen = max(t.shape[ragged_idx - 1] for t in tensors)
|
||||||
ret_nt = nested_view_from_values_offsets(
|
ret_nt = nested_view_from_values_offsets(
|
||||||
values, offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen
|
values,
|
||||||
|
offsets,
|
||||||
|
min_seqlen=min_seqlen,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
ragged_idx=ragged_idx,
|
||||||
)
|
)
|
||||||
return (ret_nt, offsets) # type: ignore[return-value]
|
return (ret_nt, offsets) # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user