mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add check nested_tensor_from_jagged param jagged_dim >= 1 (#157770)
Fixes #157404 ## Test Result ```bash pytest test/test_nestedtensor.py ...............................................s..........ssssss.................................................................................................s.s..sssss..s...ss............................................................. [ 44%] ...........................................................sssss....sss...s.........ss....s....sss.........s.sss...s..s......s............s.sss.ss...............s.....................s....s......................s.s.....s....s..s..ssssssssss [ 59%] sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss..ssssss.ssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.ssssssss...............................s........................................... [ 74%] .......sss...................................................................................................................................................................................................................................... [ 89%] ....sss.......................................................................................................................................................... [100%] ==================================================================================================== 1317 passed, 258 skipped in 2504.27s (0:41:44) ==================================================================================================== ```  Pull Request resolved: https://github.com/pytorch/pytorch/pull/157770 Approved by: https://github.com/soulitzer Co-authored-by: Jeffrey Wan <soulitzer@gmail.com>
This commit is contained in:
parent
752f202ef3
commit
ed03492238
|
|
@ -5640,6 +5640,11 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
|
|||
):
|
||||
torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Expected jagged_dim >=1, but got 0."):
|
||||
torch.nested.nested_tensor_from_jagged(
|
||||
values, lengths=lengths, jagged_dim=0
|
||||
)
|
||||
|
||||
@onlyCPU
|
||||
def test_nested_tensor_from_jagged_fx_trace(self, device):
|
||||
def fn(x, y):
|
||||
|
|
|
|||
|
|
@ -392,8 +392,8 @@ def nested_tensor_from_jagged(
|
|||
offsets (optional :class:`torch.Tensor`): Offsets into the jagged dimension of shape B + 1.
|
||||
lengths (optional :class:`torch.Tensor`): Lengths of the batch elements of shape B.
|
||||
jagged_dim (optional int): Indicates which dimension in values is the packed jagged
|
||||
dimension. If None, this is set to dim=1 (i.e. the dimension immediately following
|
||||
the batch dimension). Default: None
|
||||
dimension. Must be >= 1 as the batch dimension (dim=0) cannot be ragged.
|
||||
If None, this is set to dim=1 (i.e. the dimension immediately following the batch dimension). Default: None
|
||||
min_seqlen (optional int): If set, uses the specified value as the cached minimum sequence
|
||||
length for the returned nested tensor. This can be a useful alternative to computing
|
||||
this value on-demand, possibly avoiding a GPU -> CPU sync. Default: None
|
||||
|
|
@ -450,6 +450,8 @@ def nested_tensor_from_jagged(
|
|||
|
||||
if jagged_dim is None:
|
||||
jagged_dim = 1
|
||||
elif jagged_dim < 1:
|
||||
raise ValueError(f"Expected jagged_dim >=1, but got {jagged_dim}.")
|
||||
|
||||
from torch.nested._internal.nested_tensor import (
|
||||
nested_view_from_values_offsets_lengths,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user