mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Fix device handling in nn.utils.rnn.unpad_sequence (#98042)
Without this change I get the following error.
```
line 444, in unpad_sequence
mask = idx < length
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98042
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
parent
1c21cd2213
commit
fa1a8b9f96
|
|
@ -440,7 +440,7 @@ def unpad_sequence(
|
|||
padded_sequences.transpose_(0, 1)
|
||||
|
||||
max_length = padded_sequences.shape[1]
|
||||
idx = torch.arange(max_length)
|
||||
idx = torch.arange(max_length, device=lengths.device)
|
||||
|
||||
for seq, length in zip(padded_sequences, lengths):
|
||||
mask = idx < length
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user