mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[jit] support pad_sequence/pack_sequence (#39844)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39844 Test Plan: Imported from OSS Reviewed By: jamesr66a Differential Revision: D22026720 Pulled By: wanchaol fbshipit-source-id: cc51ea77eff3689e319ec7e89a54c788646b5940
This commit is contained in:
parent
4f761f325c
commit
4b028a8e07
|
|
@ -62,10 +62,7 @@ C10_DEFINE_bool(
|
|||
"Whether to print performance stats for AI-PEP.");
|
||||
|
||||
C10_DEFINE_int(pytext_len, 0, "Length of input sequence.");
|
||||
C10_DEFINE_bool(
|
||||
vulkan,
|
||||
false,
|
||||
"Whether to use Vulkan backend (GPU).");
|
||||
C10_DEFINE_bool(vulkan, false, "Whether to use Vulkan backend (GPU).");
|
||||
|
||||
std::vector<std::string>
|
||||
split(char separator, const std::string& string, bool ignore_empty = true) {
|
||||
|
|
|
|||
|
|
@ -2713,27 +2713,6 @@ class TestScript(JitTestCase):
|
|||
with self.assertRaises(RuntimeError):
|
||||
m.foo = 6
|
||||
|
||||
def test_script_packedsequence(self):
|
||||
class ExperimentalLSTM(torch.nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input):
|
||||
# type: (Tensor)
|
||||
packed = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=input, lengths=torch.tensor([1, 2]), enforce_sorted=False
|
||||
)
|
||||
output, lengths = torch.nn.utils.rnn.pad_packed_sequence(
|
||||
sequence=packed, total_length=2
|
||||
)
|
||||
# lengths is flipped, so is output
|
||||
return output[0]
|
||||
|
||||
lstm = ExperimentalLSTM(input_dim=2, hidden_dim=2)
|
||||
|
||||
with torch.jit._disable_emit_hooks():
|
||||
self.checkModule(lstm, [torch.ones(2, 2)])
|
||||
|
||||
def test_class_attribute(self):
|
||||
class M(torch.jit.ScriptModule):
|
||||
FOO = 0
|
||||
|
|
@ -9088,6 +9067,55 @@ a")
|
|||
self.assertEqual(eager_seq, script_seq)
|
||||
self.assertEqual(eager_lengths, script_lengths)
|
||||
|
||||
class ExperimentalLSTM(torch.nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input):
|
||||
# type: (Tensor)
|
||||
packed = pack_padded_sequence(
|
||||
input=input, lengths=torch.tensor([1, 2]), enforce_sorted=False
|
||||
)
|
||||
output, lengths = pad_packed_sequence(
|
||||
sequence=packed, total_length=2
|
||||
)
|
||||
# lengths is flipped, so is output
|
||||
return output[0]
|
||||
|
||||
lstm = ExperimentalLSTM(input_dim=2, hidden_dim=2)
|
||||
|
||||
with torch.jit._disable_emit_hooks():
|
||||
self.checkModule(lstm, [torch.ones(2, 2)])
|
||||
|
||||
def test_script_pad_sequence_pack_sequence(self):
|
||||
from torch.nn.utils.rnn import pad_sequence, pack_sequence, pad_packed_sequence
|
||||
|
||||
def pad_sequence_func(tensor_list, batch_first=False, padding_value=0.0):
|
||||
# type: (List[Tensor], bool, float) -> Tensor
|
||||
return pad_sequence(tensor_list, batch_first, padding_value)
|
||||
|
||||
def pack_sequence_func(tensor_list, enforce_sorted=True):
|
||||
# type: (List[Tensor], bool) -> Tensor
|
||||
return pad_packed_sequence(pack_sequence(tensor_list, enforce_sorted))[0]
|
||||
|
||||
ones3 = torch.ones(3, 5)
|
||||
ones4 = torch.ones(4, 5)
|
||||
ones5 = torch.ones(5, 5)
|
||||
tensor1 = torch.tensor([1, 2, 3])
|
||||
tensor2 = torch.tensor([4, 5])
|
||||
tensor3 = torch.tensor([6])
|
||||
with torch.jit._disable_emit_hooks():
|
||||
self.checkScript(pad_sequence_func,
|
||||
([ones3, ones4, ones5],))
|
||||
self.checkScript(pad_sequence_func,
|
||||
([ones3, ones4, ones5], True))
|
||||
self.checkScript(pad_sequence_func,
|
||||
([ones3, ones4, ones5], True, 2.5))
|
||||
self.checkScript(pack_sequence_func,
|
||||
([tensor1, tensor2, tensor3],))
|
||||
self.checkScript(pack_sequence_func,
|
||||
([tensor1, tensor2, tensor3], False))
|
||||
|
||||
def test_script_get_tracing_state(self):
|
||||
def test_if_tracing(x):
|
||||
if torch._C._get_tracing_state():
|
||||
|
|
|
|||
|
|
@ -315,7 +315,8 @@ def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_le
|
|||
return padded_output, lengths
|
||||
|
||||
|
||||
def pad_sequence(sequences, batch_first=False, padding_value=0):
|
||||
def pad_sequence(sequences, batch_first=False, padding_value=0.0):
|
||||
# type: (List[Tensor], bool, float) -> Tensor
|
||||
r"""Pad a list of variable length Tensors with ``padding_value``
|
||||
|
||||
``pad_sequence`` stacks a list of Tensors along a new dimension,
|
||||
|
|
@ -362,7 +363,7 @@ def pad_sequence(sequences, batch_first=False, padding_value=0):
|
|||
else:
|
||||
out_dims = (max_len, len(sequences)) + trailing_dims
|
||||
|
||||
out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
|
||||
out_tensor = sequences[0].new_full(out_dims, padding_value)
|
||||
for i, tensor in enumerate(sequences):
|
||||
length = tensor.size(0)
|
||||
# use index notation to prevent duplicate references to the tensor
|
||||
|
|
@ -405,5 +406,5 @@ def pack_sequence(sequences, enforce_sorted=True):
|
|||
Returns:
|
||||
a :class:`PackedSequence` object
|
||||
"""
|
||||
lengths = [v.size(0) for v in sequences]
|
||||
lengths = torch.as_tensor([v.size(0) for v in sequences])
|
||||
return pack_padded_sequence(pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user