diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 55a86a11f8a..0e0234b0894 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -5640,6 +5640,11 @@ class TestNestedTensorSubclass(NestedTensorTestCase): ): torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None) + with self.assertRaisesRegex(ValueError, "Expected jagged_dim >=1, but got 0."): + torch.nested.nested_tensor_from_jagged( + values, lengths=lengths, jagged_dim=0 + ) + @onlyCPU def test_nested_tensor_from_jagged_fx_trace(self, device): def fn(x, y): diff --git a/torch/nested/__init__.py b/torch/nested/__init__.py index 433c22489f0..5aa739efd2e 100644 --- a/torch/nested/__init__.py +++ b/torch/nested/__init__.py @@ -392,8 +392,8 @@ def nested_tensor_from_jagged( offsets (optional :class:`torch.Tensor`): Offsets into the jagged dimension of shape B + 1. lengths (optional :class:`torch.Tensor`): Lengths of the batch elements of shape B. jagged_dim (optional int): Indicates which dimension in values is the packed jagged - dimension. If None, this is set to dim=1 (i.e. the dimension immediately following - the batch dimension). Default: None + dimension. Must be >= 1 as the batch dimension (dim=0) cannot be ragged. + If None, this is set to dim=1 (i.e. the dimension immediately following the batch dimension). Default: None min_seqlen (optional int): If set, uses the specified value as the cached minimum sequence length for the returned nested tensor. This can be a useful alternative to computing this value on-demand, possibly avoiding a GPU -> CPU sync. Default: None @@ -450,6 +450,8 @@ def nested_tensor_from_jagged( if jagged_dim is None: jagged_dim = 1 + elif jagged_dim < 1: + raise ValueError(f"Expected jagged_dim >=1, but got {jagged_dim}.") from torch.nested._internal.nested_tensor import ( nested_view_from_values_offsets_lengths,