Revert D20103905: [jit] Fix flipped PackedSequence outputs in script

Test Plan: revert-hammer

Differential Revision:
D20103905

Original commit changeset: 84081213ed21

fbshipit-source-id: 2b260654fac87e52fbaf8035018e4ea484928af1
This commit is contained in:
Brian Vaughan 2020-02-27 13:26:35 -08:00 committed by Facebook Github Bot
parent a7cf5c859f
commit 243af17d65
2 changed files with 33 additions and 65 deletions

View File

@ -4353,26 +4353,6 @@ class TestScript(JitTestCase):
with self.assertRaises(RuntimeError):
m.foo = 6
def test_script_packedsequence(self):
class ExperimentalLSTM(torch.nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
def forward(self, input):
# type: (Tensor)
packed = torch.nn.utils.rnn.pack_padded_sequence(
input=input, lengths=torch.tensor([1, 2]), enforce_sorted=False
)
output, lengths = torch.nn.utils.rnn.pad_packed_sequence(
sequence=packed, total_length=2
)
# lengths is flipped, so is output
return output[0]
lstm = ExperimentalLSTM(input_dim=2, hidden_dim=2)
with torch.jit._disable_emit_hooks():
self.checkModule(lstm, [torch.ones(2, 2)])
def test_class_attribute(self):
class M(torch.jit.ScriptModule):
@ -11062,8 +11042,7 @@ a")
x[seq_lens[b]:, b, :] = 0
eager_seq, eager_lengths = pack_padded_pad_packed_script(x, seq_lens)
with torch.jit._disable_emit_hooks():
scripted_pack_padded_seq = torch.jit.script(pack_padded_pad_packed_script)
scripted_pack_padded_seq = torch.jit.script(pack_padded_pad_packed_script)
script_seq, script_lengths = scripted_pack_padded_seq(x, seq_lens)
self.assertEqual(eager_seq, script_seq)
self.assertEqual(eager_lengths, script_lengths)

View File

@ -56,16 +56,42 @@ class PackedSequence(PackedSequence_):
(i.e., they only pass in tensors conforming to this constraint).
"""
def __new__(cls, data, batch_sizes=None, sorted_indices=None, unsorted_indices=None):
return super(PackedSequence, cls).__new__(
cls,
*_packed_sequence_init_args(data, batch_sizes, sorted_indices,
unsorted_indices))
# NOTE [ device and dtype of a PackedSequence ]
#
# See the note above in doc string (starting with ":attr:`data` can be on
# arbitrary device...").
def __new__(cls, data, batch_sizes=None, sorted_indices=None, unsorted_indices=None):
# PackedSequence used to only have __init__(self, data, batch_sizes)
# without a __new__ like this. So to preserve BC for calling in keyword
# arg style (e.g., `PackedSequence(data=..., batch_sizes=...)`), we have
# to provide two arguments with exact names `data` and `batch_sizes`.
# NB: if unsorted_indices is provided, it should be the inverse permutation
# to sorted_indices. Don't assert it here because the PackedSequence ctor
# should only be used internally.
if unsorted_indices is None:
unsorted_indices = invert_permutation(sorted_indices)
# support being called as `PackedSequence(data, batch_sizes, sorted_indices)`
if batch_sizes is not None:
if batch_sizes.device.type != 'cpu':
raise ValueError(
"batch_sizes should always be on CPU. "
"Instances of PackedSequence should never be created manually. "
"They should be instantiated by functions like pack_sequence "
"and pack_padded_sequences in nn.utils.rnn. "
"https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence")
return super(PackedSequence, cls).__new__(
cls, data, batch_sizes, sorted_indices, unsorted_indices)
# support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)`
else:
assert isinstance(data, (list, tuple)) and len(data) == 2
return super(PackedSequence, cls).__new__(
cls, data[0], data[1], sorted_indices)
def pin_memory(self):
# Why not convert `batch_sizes`?
# See NOTE [ device and dtype of a PackedSequence ]
@ -147,44 +173,7 @@ class PackedSequence(PackedSequence_):
return self.data.is_pinned()
# TorchScript doesn't support constructors on named tuples, so we use this helper
# method to construct PackedSequence
def _packed_sequence_init_args(data, batch_sizes=None, sorted_indices=None, unsorted_indices=None):
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]
# NB: if unsorted_indices is provided, it should be the inverse permutation
# to sorted_indices. Don't assert it here because the PackedSequence ctor
# should only be used internally.
if unsorted_indices is None:
unsorted_indices = invert_permutation(sorted_indices)
# support being called as `PackedSequence(data, batch_sizes, sorted_indices)`
if batch_sizes is not None:
# TODO: Re-enable this check (.type isn't supported in TorchScript)
if batch_sizes.device.type != 'cpu':
raise ValueError(
"batch_sizes should always be on CPU. "
"Instances of PackedSequence should never be created manually. "
"They should be instantiated by functions like pack_sequence "
"and pack_padded_sequences in nn.utils.rnn. "
"https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence")
return data, batch_sizes, sorted_indices, unsorted_indices
# support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)`
else:
assert isinstance(data, (list, tuple)) and len(data) == 2
return data[0], data[1], sorted_indices, unsorted_indices
def _packed_sequence_init(data, batch_sizes=None, sorted_indices=None, unsorted_indices=None):
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> PackedSequence
data, batch_sizes, sorted_indices, unsorted_indices = _packed_sequence_init_args(
data, batch_sizes, sorted_indices, unsorted_indices)
return PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices)
def invert_permutation(permutation):
# type: (Optional[Tensor]) -> Optional[Tensor]
if permutation is None:
return None
output = torch.empty_like(permutation, memory_format=torch.legacy_contiguous_format)
@ -242,7 +231,7 @@ def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)
data, batch_sizes = \
_VF._pack_padded_sequence(input, lengths, batch_first)
return _packed_sequence_init(data, batch_sizes, sorted_indices, None)
return PackedSequence(data, batch_sizes, sorted_indices, None)
def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None):