Add check nested_tensor_from_jagged param jagged_dim >= 1 (#157770)

Fixes #157404

## Test Result

```bash
pytest test/test_nestedtensor.py

...............................................s..........ssssss.................................................................................................s.s..sssss..s...ss............................................................. [ 44%]
...........................................................sssss....sss...s.........ss....s....sss.........s.sss...s..s......s............s.sss.ss...............s.....................s....s......................s.s.....s....s..s..ssssssssss [ 59%]
sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss..ssssss.ssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.ssssssss...............................s........................................... [ 74%]
.......sss...................................................................................................................................................................................................................................... [ 89%]
....sss..........................................................................................................................................................                                                                                [100%]

==================================================================================================== 1317 passed, 258 skipped in 2504.27s (0:41:44) ====================================================================================================
```

![image](https://github.com/user-attachments/assets/dcc8e46d-b88f-4580-b4ad-0999bad33ec9)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157770
Approved by: https://github.com/soulitzer

Co-authored-by: Jeffrey Wan <soulitzer@gmail.com>
This commit is contained in:
zeshengzong 2025-07-10 00:34:35 +00:00 committed by PyTorch MergeBot
parent 752f202ef3
commit ed03492238
2 changed files with 9 additions and 2 deletions

View File

@ -5640,6 +5640,11 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
):
torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None)
with self.assertRaisesRegex(ValueError, "Expected jagged_dim >=1, but got 0."):
torch.nested.nested_tensor_from_jagged(
values, lengths=lengths, jagged_dim=0
)
@onlyCPU
def test_nested_tensor_from_jagged_fx_trace(self, device):
def fn(x, y):

View File

@ -392,8 +392,8 @@ def nested_tensor_from_jagged(
offsets (optional :class:`torch.Tensor`): Offsets into the jagged dimension of shape B + 1.
lengths (optional :class:`torch.Tensor`): Lengths of the batch elements of shape B.
jagged_dim (optional int): Indicates which dimension in values is the packed jagged
dimension. If None, this is set to dim=1 (i.e. the dimension immediately following
the batch dimension). Default: None
dimension. Must be >= 1 as the batch dimension (dim=0) cannot be ragged.
If None, this is set to dim=1 (i.e. the dimension immediately following the batch dimension). Default: None
min_seqlen (optional int): If set, uses the specified value as the cached minimum sequence
length for the returned nested tensor. This can be a useful alternative to computing
this value on-demand, possibly avoiding a GPU -> CPU sync. Default: None
@ -450,6 +450,8 @@ def nested_tensor_from_jagged(
if jagged_dim is None:
jagged_dim = 1
elif jagged_dim < 1:
raise ValueError(f"Expected jagged_dim >=1, but got {jagged_dim}.")
from torch.nested._internal.nested_tensor import (
nested_view_from_values_offsets_lengths,