diff --git a/test/test_jit.py b/test/test_jit.py index 160dc291fc2..75fe4d58e60 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4353,26 +4353,6 @@ class TestScript(JitTestCase): with self.assertRaises(RuntimeError): m.foo = 6 - def test_script_packedsequence(self): - class ExperimentalLSTM(torch.nn.Module): - def __init__(self, input_dim, hidden_dim): - super().__init__() - - def forward(self, input): - # type: (Tensor) - packed = torch.nn.utils.rnn.pack_padded_sequence( - input=input, lengths=torch.tensor([1, 2]), enforce_sorted=False - ) - output, lengths = torch.nn.utils.rnn.pad_packed_sequence( - sequence=packed, total_length=2 - ) - # lengths is flipped, so is output - return output[0] - - lstm = ExperimentalLSTM(input_dim=2, hidden_dim=2) - - with torch.jit._disable_emit_hooks(): - self.checkModule(lstm, [torch.ones(2, 2)]) def test_class_attribute(self): class M(torch.jit.ScriptModule): @@ -11062,8 +11042,7 @@ a") x[seq_lens[b]:, b, :] = 0 eager_seq, eager_lengths = pack_padded_pad_packed_script(x, seq_lens) - with torch.jit._disable_emit_hooks(): - scripted_pack_padded_seq = torch.jit.script(pack_padded_pad_packed_script) + scripted_pack_padded_seq = torch.jit.script(pack_padded_pad_packed_script) script_seq, script_lengths = scripted_pack_padded_seq(x, seq_lens) self.assertEqual(eager_seq, script_seq) self.assertEqual(eager_lengths, script_lengths) diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py index 18503610259..4934a460fd6 100644 --- a/torch/nn/utils/rnn.py +++ b/torch/nn/utils/rnn.py @@ -56,16 +56,42 @@ class PackedSequence(PackedSequence_): (i.e., they only pass in tensors conforming to this constraint). """ - def __new__(cls, data, batch_sizes=None, sorted_indices=None, unsorted_indices=None): - return super(PackedSequence, cls).__new__( - cls, - *_packed_sequence_init_args(data, batch_sizes, sorted_indices, - unsorted_indices)) # NOTE [ device and dtype of a PackedSequence ] # # See the note above in doc string (starting with ":attr:`data` can be on # arbitrary device..."). + + def __new__(cls, data, batch_sizes=None, sorted_indices=None, unsorted_indices=None): + # PackedSequence used to only have __init__(self, data, batch_sizes) + # without a __new__ like this. So to preserve BC for calling in keyword + # arg style (e.g., `PackedSequence(data=..., batch_sizes=...)`), we have + # to provide two arguments with exact names `data` and `batch_sizes`. + + # NB: if unsorted_indices is provided, it should be the inverse permutation + # to sorted_indices. Don't assert it here because the PackedSequence ctor + # should only be used internally. + if unsorted_indices is None: + unsorted_indices = invert_permutation(sorted_indices) + + # support being called as `PackedSequence(data, batch_sizes, sorted_indices)` + if batch_sizes is not None: + if batch_sizes.device.type != 'cpu': + raise ValueError( + "batch_sizes should always be on CPU. " + "Instances of PackedSequence should never be created manually. " + "They should be instantiated by functions like pack_sequence " + "and pack_padded_sequences in nn.utils.rnn. " + "https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence") + return super(PackedSequence, cls).__new__( + cls, data, batch_sizes, sorted_indices, unsorted_indices) + + # support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)` + else: + assert isinstance(data, (list, tuple)) and len(data) == 2 + return super(PackedSequence, cls).__new__( + cls, data[0], data[1], sorted_indices) + def pin_memory(self): # Why not convert `batch_sizes`? # See NOTE [ device and dtype of a PackedSequence ] @@ -147,44 +173,7 @@ class PackedSequence(PackedSequence_): return self.data.is_pinned() -# TorchScript doesn't support constructors on named tuples, so we use this helper -# method to construct PackedSequence -def _packed_sequence_init_args(data, batch_sizes=None, sorted_indices=None, unsorted_indices=None): - # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]] - # NB: if unsorted_indices is provided, it should be the inverse permutation - # to sorted_indices. Don't assert it here because the PackedSequence ctor - # should only be used internally. - - if unsorted_indices is None: - unsorted_indices = invert_permutation(sorted_indices) - - # support being called as `PackedSequence(data, batch_sizes, sorted_indices)` - if batch_sizes is not None: - # TODO: Re-enable this check (.type isn't supported in TorchScript) - if batch_sizes.device.type != 'cpu': - raise ValueError( - "batch_sizes should always be on CPU. " - "Instances of PackedSequence should never be created manually. " - "They should be instantiated by functions like pack_sequence " - "and pack_padded_sequences in nn.utils.rnn. " - "https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence") - return data, batch_sizes, sorted_indices, unsorted_indices - - # support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)` - else: - assert isinstance(data, (list, tuple)) and len(data) == 2 - return data[0], data[1], sorted_indices, unsorted_indices - - -def _packed_sequence_init(data, batch_sizes=None, sorted_indices=None, unsorted_indices=None): - # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> PackedSequence - data, batch_sizes, sorted_indices, unsorted_indices = _packed_sequence_init_args( - data, batch_sizes, sorted_indices, unsorted_indices) - return PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices) - - def invert_permutation(permutation): - # type: (Optional[Tensor]) -> Optional[Tensor] if permutation is None: return None output = torch.empty_like(permutation, memory_format=torch.legacy_contiguous_format) @@ -242,7 +231,7 @@ def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True) data, batch_sizes = \ _VF._pack_padded_sequence(input, lengths, batch_first) - return _packed_sequence_init(data, batch_sizes, sorted_indices, None) + return PackedSequence(data, batch_sizes, sorted_indices, None) def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None):