Allow any single non-batch dim to be ragged for NJT (#137125)

Fixes #137512

Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly. It's possible before this PR to get an NJT of e.g. shape `(B, D, j0)` by constructing an NJT of shape `(B, j0, D)` and transposing it. This PR allows a user to go straight there without the transpose. The standard `torch.nested.nested_tensor(list)` constructor has been updated to support this.

At the very least, this is useful for testing on transposed NJTs. I'm willing to make this functionality private if needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137125
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
This commit is contained in:
Joel Schlosser 2024-11-06 11:03:17 -05:00 committed by PyTorch MergeBot
parent d1e2e81ede
commit 3abbde976d
2 changed files with 140 additions and 16 deletions

View File

@ -5709,6 +5709,104 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
(nt * 2).backward(torch.ones_like(nt))
self.assertEqual(orig_t.grad, torch.ones_like(orig_t) * 2)
@dtypes(torch.float32)
def test_construction_from_list(self, device, dtype):
from torch.fx.experimental.symbolic_shapes import is_nested_int
# success case: single ragged dim anywhere but the batch dim
for nt_dim in [2, 3, 4]:
for ragged_dim in range(1, nt_dim):
B = 6
shapes = [list(range(3, 3 + nt_dim - 1)) for _ in range(B)]
for b in range(B):
# subtract 1 to convert to component dim space
shapes[b][ragged_dim - 1] = torch.randint(
2, 9, (1,), device=device, dtype=torch.int64
).item()
components = [
torch.randn(shape, device=device, dtype=dtype) for shape in shapes
]
nt = torch.nested.nested_tensor(components, layout=torch.jagged)
self.assertEqual(nt.dim(), nt_dim)
self.assertEqual(nt._ragged_idx, ragged_dim)
for d in range(nt_dim):
self.assertEqual(d == ragged_dim, is_nested_int(nt.shape[d]))
# error case: empty list
with self.assertRaisesRegex(
RuntimeError, "Cannot construct a nested tensor from an empty tensor list"
):
torch.nested.nested_tensor([], layout=torch.jagged)
# error case: list of zero-dim tensors
with self.assertRaisesRegex(
RuntimeError,
"Cannot construct a nested tensor from a list of zero-dim tensors",
):
torch.nested.nested_tensor(
[
torch.tensor(3.0, device=device, dtype=dtype),
torch.tensor(4.0, device=device, dtype=dtype),
torch.tensor(5.0, device=device, dtype=dtype),
],
layout=torch.jagged,
)
# error case: multiple ragged dims
with self.assertRaisesRegex(
RuntimeError,
"Cannot represent given tensor list as a nested tensor with the jagged layout",
):
torch.nested.nested_tensor(
[
torch.randn(2, 3, device=device, dtype=dtype),
torch.randn(4, 5, device=device, dtype=dtype),
],
layout=torch.jagged,
)
# error case: components on multiple devices
if "cuda" in device:
with self.assertRaisesRegex(
RuntimeError,
"When constructing a nested tensor, all tensors in list must be on the same device",
):
torch.nested.nested_tensor(
[
torch.randn(2, 3, device=device, dtype=dtype),
torch.randn(2, 4, device="cpu", dtype=dtype),
],
layout=torch.jagged,
)
# error case: components with multiple dtypes
with self.assertRaisesRegex(
RuntimeError,
"When constructing a nested tensor, all tensors in list must have the same dtype",
):
torch.nested.nested_tensor(
[
torch.randn(2, 3, device=device, dtype=dtype),
torch.randn(2, 4, device=device, dtype=torch.float64),
],
layout=torch.jagged,
)
# error case: components with multiple dims
with self.assertRaisesRegex(
RuntimeError,
"When constructing a nested tensor, all tensors in list must have the same dim",
):
torch.nested.nested_tensor(
[
torch.randn(2, 3, device=device, dtype=dtype),
torch.randn(2, 3, 4, device=device, dtype=dtype),
],
layout=torch.jagged,
)
@dtypes(torch.double, torch.half)
@onlyCUDA
def test_device_dtype_transfer_updates_offsets(self, device, dtype):

View File

@ -419,6 +419,8 @@ def jagged_from_list(
) -> Tuple[NestedTensor, torch.Tensor]:
"""Constructs a NestedTensor backed by jagged layout from a list of tensors"""
if len(tensors) == 0:
raise RuntimeError("Cannot construct a nested tensor from an empty tensor list")
if not len(set(t.dtype for t in tensors)) == 1: # noqa: C401
raise RuntimeError(
"When constructing a nested tensor, all tensors in list must have the same dtype"
@ -427,22 +429,40 @@ def jagged_from_list(
raise RuntimeError(
"When constructing a nested tensor, all tensors in list must be on the same device"
)
# Check that the NT is representable by the jagged layout.
# Jagged layout represents (B, *, D_0, D_1, ..., D_N), where the only
# raggedness allowed is for the single dim immediately adjacent to the batch dim.
sizes = [t.shape for t in tensors]
non_first_sizes = [s[1:] for s in sizes]
at_most_first_ragged = all(s == non_first_sizes[0] for s in non_first_sizes)
if not at_most_first_ragged:
if not len(set(t.dim() for t in tensors)) == 1: # noqa: C401
raise RuntimeError(
"Cannot represent given tensor list as a nested tensor with the jagged layout. "
"Note that the jagged layout only represents shapes of the form "
"(B, *, D_0, D_1, ..., D_N), with only * allowed to be ragged."
"When constructing a nested tensor, all tensors in list must have the same dim"
)
component_dim = tensors[0].dim()
if component_dim == 0:
raise RuntimeError(
"Cannot construct a nested tensor from a list of zero-dim tensors"
)
# Check that the NT is representable by the jagged layout, which
# allows for a single ragged dimension after the batch dim.
# e.g. (B, *, D_0, ..., D_N), (B, D_0, *, ..., D_N), etc.
sizes = [t.shape for t in tensors]
ragged_idx = None
for d in range(component_dim):
dim_is_ragged = any(size[d] != sizes[0][d] for size in sizes)
if dim_is_ragged:
if ragged_idx is None:
# add 1 to convert to outer NJT dim space
ragged_idx = d + 1
else:
raise RuntimeError(
"Cannot represent given tensor list as a nested tensor with the jagged layout. "
"Note that the jagged layout only allows for a single ragged dimension. "
"For example: (B, *, D_0, D_1, ..., D_N), with ragged * dim."
)
# allow for a rectangular NJT and default the ragged dim next to the batch dim
if ragged_idx is None:
ragged_idx = 1
# Set properties appropriately.
values = torch.cat(tensors, dim=0)
values = torch.cat(tensors, dim=(ragged_idx - 1))
to_kwargs = {}
if device is not None:
to_kwargs["device"] = device
@ -458,15 +478,21 @@ def jagged_from_list(
offsets = torch.cat(
[
torch.zeros(1, dtype=torch.int64, device=values.device),
torch.tensor([s[0] for s in sizes], device=values.device).cumsum(dim=0),
torch.tensor(
[s[ragged_idx - 1] for s in sizes], device=values.device
).cumsum(dim=0),
]
)
# compute this now since it's easy
min_seqlen = min(t.shape[0] for t in tensors)
max_seqlen = max(t.shape[0] for t in tensors)
min_seqlen = min(t.shape[ragged_idx - 1] for t in tensors)
max_seqlen = max(t.shape[ragged_idx - 1] for t in tensors)
ret_nt = nested_view_from_values_offsets(
values, offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen
values,
offsets,
min_seqlen=min_seqlen,
max_seqlen=max_seqlen,
ragged_idx=ragged_idx,
)
return (ret_nt, offsets) # type: ignore[return-value]