[jit] support pad_sequence/pack_sequence (#39844)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39844

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D22026720

Pulled By: wanchaol

fbshipit-source-id: cc51ea77eff3689e319ec7e89a54c788646b5940
This commit is contained in:
Wanchao Liang 2020-06-19 19:01:26 -07:00 committed by Facebook GitHub Bot
parent 4f761f325c
commit 4b028a8e07
3 changed files with 54 additions and 28 deletions

View File

@ -62,10 +62,7 @@ C10_DEFINE_bool(
"Whether to print performance stats for AI-PEP.");
C10_DEFINE_int(pytext_len, 0, "Length of input sequence.");
C10_DEFINE_bool(
vulkan,
false,
"Whether to use Vulkan backend (GPU).");
C10_DEFINE_bool(vulkan, false, "Whether to use Vulkan backend (GPU).");
std::vector<std::string>
split(char separator, const std::string& string, bool ignore_empty = true) {

View File

@ -2713,27 +2713,6 @@ class TestScript(JitTestCase):
with self.assertRaises(RuntimeError):
m.foo = 6
def test_script_packedsequence(self):
class ExperimentalLSTM(torch.nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
def forward(self, input):
# type: (Tensor)
packed = torch.nn.utils.rnn.pack_padded_sequence(
input=input, lengths=torch.tensor([1, 2]), enforce_sorted=False
)
output, lengths = torch.nn.utils.rnn.pad_packed_sequence(
sequence=packed, total_length=2
)
# lengths is flipped, so is output
return output[0]
lstm = ExperimentalLSTM(input_dim=2, hidden_dim=2)
with torch.jit._disable_emit_hooks():
self.checkModule(lstm, [torch.ones(2, 2)])
def test_class_attribute(self):
class M(torch.jit.ScriptModule):
FOO = 0
@ -9088,6 +9067,55 @@ a")
self.assertEqual(eager_seq, script_seq)
self.assertEqual(eager_lengths, script_lengths)
class ExperimentalLSTM(torch.nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
def forward(self, input):
# type: (Tensor)
packed = pack_padded_sequence(
input=input, lengths=torch.tensor([1, 2]), enforce_sorted=False
)
output, lengths = pad_packed_sequence(
sequence=packed, total_length=2
)
# lengths is flipped, so is output
return output[0]
lstm = ExperimentalLSTM(input_dim=2, hidden_dim=2)
with torch.jit._disable_emit_hooks():
self.checkModule(lstm, [torch.ones(2, 2)])
def test_script_pad_sequence_pack_sequence(self):
from torch.nn.utils.rnn import pad_sequence, pack_sequence, pad_packed_sequence
def pad_sequence_func(tensor_list, batch_first=False, padding_value=0.0):
# type: (List[Tensor], bool, float) -> Tensor
return pad_sequence(tensor_list, batch_first, padding_value)
def pack_sequence_func(tensor_list, enforce_sorted=True):
# type: (List[Tensor], bool) -> Tensor
return pad_packed_sequence(pack_sequence(tensor_list, enforce_sorted))[0]
ones3 = torch.ones(3, 5)
ones4 = torch.ones(4, 5)
ones5 = torch.ones(5, 5)
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5])
tensor3 = torch.tensor([6])
with torch.jit._disable_emit_hooks():
self.checkScript(pad_sequence_func,
([ones3, ones4, ones5],))
self.checkScript(pad_sequence_func,
([ones3, ones4, ones5], True))
self.checkScript(pad_sequence_func,
([ones3, ones4, ones5], True, 2.5))
self.checkScript(pack_sequence_func,
([tensor1, tensor2, tensor3],))
self.checkScript(pack_sequence_func,
([tensor1, tensor2, tensor3], False))
def test_script_get_tracing_state(self):
def test_if_tracing(x):
if torch._C._get_tracing_state():

View File

@ -315,7 +315,8 @@ def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_le
return padded_output, lengths
def pad_sequence(sequences, batch_first=False, padding_value=0):
def pad_sequence(sequences, batch_first=False, padding_value=0.0):
# type: (List[Tensor], bool, float) -> Tensor
r"""Pad a list of variable length Tensors with ``padding_value``
``pad_sequence`` stacks a list of Tensors along a new dimension,
@ -362,7 +363,7 @@ def pad_sequence(sequences, batch_first=False, padding_value=0):
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
out_tensor = sequences[0].new_full(out_dims, padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# use index notation to prevent duplicate references to the tensor
@ -405,5 +406,5 @@ def pack_sequence(sequences, enforce_sorted=True):
Returns:
a :class:`PackedSequence` object
"""
lengths = [v.size(0) for v in sequences]
lengths = torch.as_tensor([v.size(0) for v in sequences])
return pack_padded_sequence(pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted)