mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
DataLoader now supports the constructor argument 'pin_memory'. When set to true, tensors in the sample are copied to pinned memory. This happens in a background thread when num_workers > 1.
150 lines
5.4 KiB
Python
150 lines
5.4 KiB
Python
import math
|
|
import sys
|
|
import torch
|
|
import traceback
|
|
import unittest
|
|
from torch.utils.data import Dataset, TensorDataset, DataLoader
|
|
from common import TestCase
|
|
|
|
|
|
class TestTensorDataset(TestCase):
|
|
|
|
def test_len(self):
|
|
source = TensorDataset(torch.randn(15, 10, 2, 3, 4, 5), torch.randperm(15))
|
|
self.assertEqual(len(source), 15)
|
|
|
|
def test_getitem(self):
|
|
t = torch.randn(15, 10, 2, 3, 4, 5)
|
|
l = torch.randn(15, 10)
|
|
source = TensorDataset(t, l)
|
|
for i in range(15):
|
|
self.assertEqual(t[i], source[i][0])
|
|
self.assertEqual(l[i], source[i][1])
|
|
|
|
def test_getitem_1d(self):
|
|
t = torch.randn(15)
|
|
l = torch.randn(15)
|
|
source = TensorDataset(t, l)
|
|
for i in range(15):
|
|
self.assertEqual(t[i:i+1], source[i][0])
|
|
self.assertEqual(l[i:i+1], source[i][1])
|
|
|
|
|
|
class ErrorDataset(Dataset):
|
|
def __init__(self, size):
|
|
self.size = size
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
|
|
class TestDataLoader(TestCase):
|
|
|
|
def setUp(self):
|
|
self.data = torch.randn(100, 2, 3, 5)
|
|
self.labels = torch.randperm(50).repeat(2)
|
|
self.dataset = TensorDataset(self.data, self.labels)
|
|
|
|
def _test_sequential(self, loader):
|
|
batch_size = loader.batch_size
|
|
for i, (sample, target) in enumerate(loader):
|
|
idx = i * batch_size
|
|
self.assertEqual(sample, self.data[idx:idx+batch_size])
|
|
self.assertEqual(target, self.labels[idx:idx+batch_size].view(-1, 1))
|
|
self.assertEqual(i, math.floor((len(self.dataset)-1) / batch_size))
|
|
|
|
def _test_shuffle(self, loader):
|
|
found_data = {i: 0 for i in range(self.data.size(0))}
|
|
found_labels = {i: 0 for i in range(self.labels.size(0))}
|
|
batch_size = loader.batch_size
|
|
for i, (batch_samples, batch_targets) in enumerate(loader):
|
|
for sample, target in zip(batch_samples, batch_targets):
|
|
for data_point_idx, data_point in enumerate(self.data):
|
|
if data_point.eq(sample).all():
|
|
self.assertFalse(found_data[data_point_idx])
|
|
found_data[data_point_idx] += 1
|
|
break
|
|
self.assertEqual(target, self.labels.narrow(0, data_point_idx, 1))
|
|
found_labels[data_point_idx] += 1
|
|
self.assertEqual(sum(found_data.values()), (i+1) * batch_size)
|
|
self.assertEqual(sum(found_labels.values()), (i+1) * batch_size)
|
|
self.assertEqual(i, math.floor((len(self.dataset)-1) / batch_size))
|
|
|
|
def _test_error(self, loader):
|
|
it = iter(loader)
|
|
errors = 0
|
|
while True:
|
|
try:
|
|
it.next()
|
|
except NotImplementedError:
|
|
msg = "".join(traceback.format_exception(*sys.exc_info()))
|
|
self.assertTrue("collate_fn" in msg)
|
|
errors += 1
|
|
except StopIteration:
|
|
self.assertEqual(errors,
|
|
math.ceil(float(len(loader.dataset))/loader.batch_size))
|
|
return
|
|
|
|
|
|
def test_sequential(self):
|
|
self._test_sequential(DataLoader(self.dataset))
|
|
|
|
def test_sequential_batch(self):
|
|
self._test_sequential(DataLoader(self.dataset, batch_size=2))
|
|
|
|
def test_sequential_pin_memory(self):
|
|
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
|
|
for input, target in loader:
|
|
self.assertTrue(input.is_pinned())
|
|
self.assertTrue(target.is_pinned())
|
|
|
|
def test_shuffle(self):
|
|
self._test_shuffle(DataLoader(self.dataset, shuffle=True))
|
|
|
|
def test_shuffle_batch(self):
|
|
self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True))
|
|
|
|
def test_sequential_workers(self):
|
|
self._test_sequential(DataLoader(self.dataset, num_workers=4))
|
|
|
|
def test_seqential_batch_workers(self):
|
|
self._test_sequential(DataLoader(self.dataset, batch_size=2, num_workers=4))
|
|
|
|
def test_shuffle_workers(self):
|
|
self._test_shuffle(DataLoader(self.dataset, shuffle=True, num_workers=4))
|
|
|
|
def test_shuffle_batch_workers(self):
|
|
self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4))
|
|
|
|
def test_shuffle_pin_memory(self):
|
|
loader = DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
|
|
for input, target in loader:
|
|
self.assertTrue(input.is_pinned())
|
|
self.assertTrue(target.is_pinned())
|
|
|
|
def test_error(self):
|
|
self._test_error(DataLoader(ErrorDataset(100), batch_size=2, shuffle=True))
|
|
|
|
def test_error_workers(self):
|
|
self._test_error(DataLoader(ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4))
|
|
|
|
def test_partial_workers(self):
|
|
"check that workers exit even if the iterator is not exhausted"
|
|
loader = iter(DataLoader(self.dataset, batch_size=2, num_workers=4, pin_memory=True))
|
|
workers = loader.workers
|
|
pin_thread = loader.pin_thread
|
|
for i, sample in enumerate(loader):
|
|
if i == 3:
|
|
break
|
|
del loader
|
|
for w in workers:
|
|
w.join(1.0) # timeout of one second
|
|
self.assertFalse(w.is_alive(), 'subprocess not terminated')
|
|
self.assertEqual(w.exitcode, 0)
|
|
pin_thread.join(1.0)
|
|
self.assertFalse(pin_thread.is_alive())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|