mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
587 lines
20 KiB
Python
587 lines
20 KiB
Python
import math
|
|
import sys
|
|
import errno
|
|
import os
|
|
import ctypes
|
|
import signal
|
|
import torch
|
|
import time
|
|
import traceback
|
|
import unittest
|
|
from torch import multiprocessing
|
|
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
|
|
from torch.utils.data.dataset import random_split
|
|
from torch.utils.data.dataloader import default_collate, ExceptionWrapper
|
|
from common import TestCase, run_tests, TEST_NUMPY, IS_WINDOWS
|
|
from common_nn import TEST_CUDA
|
|
|
|
|
|
JOIN_TIMEOUT = 17.0 if IS_WINDOWS else 4.5
|
|
|
|
|
|
class TestDatasetRandomSplit(TestCase):
|
|
def test_lengths_must_equal_datset_size(self):
|
|
with self.assertRaises(ValueError):
|
|
random_split([1, 2, 3, 4], [1, 2])
|
|
|
|
def test_splits_have_correct_size(self):
|
|
splits = random_split([1, 2, 3, 4, 5, 6], [2, 4])
|
|
self.assertEqual(len(splits), 2)
|
|
self.assertEqual(len(splits[0]), 2)
|
|
self.assertEqual(len(splits[1]), 4)
|
|
|
|
def test_splits_are_mutually_exclusive(self):
|
|
data = [5, 2, 3, 4, 1, 6]
|
|
splits = random_split(data, [2, 4])
|
|
all_values = []
|
|
all_values.extend(list(splits[0]))
|
|
all_values.extend(list(splits[1]))
|
|
data.sort()
|
|
all_values.sort()
|
|
self.assertListEqual(data, all_values)
|
|
|
|
|
|
class TestTensorDataset(TestCase):
|
|
|
|
def test_len(self):
|
|
source = TensorDataset(torch.randn(15, 10, 2, 3, 4, 5), torch.randperm(15))
|
|
self.assertEqual(len(source), 15)
|
|
|
|
def test_getitem(self):
|
|
t = torch.randn(15, 10, 2, 3, 4, 5)
|
|
l = torch.randn(15, 10)
|
|
source = TensorDataset(t, l)
|
|
for i in range(15):
|
|
self.assertEqual(t[i], source[i][0])
|
|
self.assertEqual(l[i], source[i][1])
|
|
|
|
def test_getitem_1d(self):
|
|
t = torch.randn(15)
|
|
l = torch.randn(15)
|
|
source = TensorDataset(t, l)
|
|
for i in range(15):
|
|
self.assertEqual(t[i], source[i][0])
|
|
self.assertEqual(l[i], source[i][1])
|
|
|
|
|
|
class TestConcatDataset(TestCase):
|
|
|
|
def test_concat_two_singletons(self):
|
|
result = ConcatDataset([[0], [1]])
|
|
self.assertEqual(2, len(result))
|
|
self.assertEqual(0, result[0])
|
|
self.assertEqual(1, result[1])
|
|
|
|
def test_concat_two_non_singletons(self):
|
|
result = ConcatDataset([[0, 1, 2, 3, 4],
|
|
[5, 6, 7, 8, 9]])
|
|
self.assertEqual(10, len(result))
|
|
self.assertEqual(0, result[0])
|
|
self.assertEqual(5, result[5])
|
|
|
|
def test_concat_two_non_singletons_with_empty(self):
|
|
# Adding an empty dataset somewhere is correctly handled
|
|
result = ConcatDataset([[0, 1, 2, 3, 4],
|
|
[],
|
|
[5, 6, 7, 8, 9]])
|
|
self.assertEqual(10, len(result))
|
|
self.assertEqual(0, result[0])
|
|
self.assertEqual(5, result[5])
|
|
|
|
def test_concat_raises_index_error(self):
|
|
result = ConcatDataset([[0, 1, 2, 3, 4],
|
|
[5, 6, 7, 8, 9]])
|
|
with self.assertRaises(IndexError):
|
|
# this one goes to 11
|
|
result[11]
|
|
|
|
def test_add_dataset(self):
|
|
d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
|
|
d2 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
|
|
d3 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
|
|
result = d1 + d2 + d3
|
|
self.assertEqual(21, len(result))
|
|
self.assertEqual(0, (d1[0][0] - result[0][0]).abs().sum())
|
|
self.assertEqual(0, (d2[0][0] - result[7][0]).abs().sum())
|
|
self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum())
|
|
|
|
|
|
# Stores the first encountered exception in .exception.
|
|
# Inspired by https://stackoverflow.com/a/33599967
|
|
class ErrorTrackingProcess(multiprocessing.Process):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(ErrorTrackingProcess, self).__init__(*args, **kwargs)
|
|
self._pconn, self._cconn = multiprocessing.Pipe()
|
|
self._exception = None
|
|
|
|
def run(self):
|
|
# Disable stderr printing from os level, and make workers not printing
|
|
# to stderr.
|
|
# Can't use sys.stderr.close, otherwise Python `raise` will error with
|
|
# ValueError: I/O operation on closed file.
|
|
os.close(sys.stderr.fileno())
|
|
try:
|
|
super(ErrorTrackingProcess, self).run()
|
|
self._cconn.send(None)
|
|
except Exception as e:
|
|
self._cconn.send(ExceptionWrapper(sys.exc_info()))
|
|
raise
|
|
|
|
@property
|
|
def exception(self):
|
|
if self._pconn.poll():
|
|
self._exception = self._pconn.recv()
|
|
if self._exception is None:
|
|
return None
|
|
else:
|
|
return self._exception.exc_type(self._exception.exc_msg)
|
|
|
|
# ESRCH means that os.kill can't finds alive proc
|
|
def send_signal(self, signum, ignore_ESRCH=False):
|
|
try:
|
|
os.kill(self.pid, signum)
|
|
except OSError as e:
|
|
if not ignore_ESRCH or e.errno != errno.ESRCH:
|
|
raise
|
|
|
|
|
|
class ErrorDataset(Dataset):
|
|
|
|
def __init__(self, size):
|
|
self.size = size
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
|
|
class SegfaultDataset(Dataset):
|
|
|
|
def __init__(self, size):
|
|
self.size = size
|
|
|
|
def __getitem__(self, idx):
|
|
return ctypes.string_at(0)
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
|
|
class SleepDataset(Dataset):
|
|
|
|
def __init__(self, size, sleep_sec):
|
|
self.size = size
|
|
self.sleep_sec = sleep_sec
|
|
|
|
def __getitem__(self, idx):
|
|
time.sleep(self.sleep_sec)
|
|
return idx
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
|
|
class SeedDataset(Dataset):
|
|
|
|
def __init__(self, size):
|
|
self.size = size
|
|
|
|
def __getitem__(self, idx):
|
|
return torch.initial_seed()
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
|
|
# Inspired by https://stackoverflow.com/a/26703365
|
|
# This will ensure that each worker at least processes one data
|
|
class SynchronizedSeedDataset(Dataset):
|
|
|
|
def __init__(self, size, num_workers):
|
|
assert size >= num_workers
|
|
self.count = multiprocessing.Value('i', 0, lock=True)
|
|
self.barrier = multiprocessing.Semaphore(0)
|
|
self.num_workers = num_workers
|
|
self.size = size
|
|
|
|
def __getitem__(self, idx):
|
|
with self.count.get_lock():
|
|
self.count.value += 1
|
|
if self.count.value == self.num_workers:
|
|
self.barrier.release()
|
|
self.barrier.acquire()
|
|
self.barrier.release()
|
|
return torch.initial_seed()
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
|
|
def _test_timeout():
|
|
dataset = SleepDataset(10, 10)
|
|
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1)
|
|
_ = next(iter(dataloader))
|
|
|
|
|
|
def _test_segfault():
|
|
dataset = SegfaultDataset(10)
|
|
dataloader = DataLoader(dataset, batch_size=2, num_workers=2)
|
|
_ = next(iter(dataloader))
|
|
|
|
|
|
# test custom init function
|
|
def init_fn(worker_id):
|
|
torch.manual_seed(12345)
|
|
|
|
|
|
class TestDataLoader(TestCase):
|
|
|
|
def setUp(self):
|
|
self.data = torch.randn(100, 2, 3, 5)
|
|
self.labels = torch.randperm(50).repeat(2)
|
|
self.dataset = TensorDataset(self.data, self.labels)
|
|
|
|
def _test_sequential(self, loader):
|
|
batch_size = loader.batch_size
|
|
for i, (sample, target) in enumerate(loader):
|
|
idx = i * batch_size
|
|
self.assertEqual(sample, self.data[idx:idx + batch_size])
|
|
self.assertEqual(target, self.labels[idx:idx + batch_size])
|
|
self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
|
|
|
|
def _test_shuffle(self, loader):
|
|
found_data = {i: 0 for i in range(self.data.size(0))}
|
|
found_labels = {i: 0 for i in range(self.labels.size(0))}
|
|
batch_size = loader.batch_size
|
|
for i, (batch_samples, batch_targets) in enumerate(loader):
|
|
for sample, target in zip(batch_samples, batch_targets):
|
|
for data_point_idx, data_point in enumerate(self.data):
|
|
if data_point.eq(sample).all():
|
|
self.assertFalse(found_data[data_point_idx])
|
|
found_data[data_point_idx] += 1
|
|
break
|
|
self.assertEqual(target, self.labels[data_point_idx])
|
|
found_labels[data_point_idx] += 1
|
|
self.assertEqual(sum(found_data.values()), (i + 1) * batch_size)
|
|
self.assertEqual(sum(found_labels.values()), (i + 1) * batch_size)
|
|
self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
|
|
|
|
def _test_error(self, loader):
|
|
it = iter(loader)
|
|
errors = 0
|
|
while True:
|
|
try:
|
|
next(it)
|
|
except NotImplementedError:
|
|
errors += 1
|
|
except StopIteration:
|
|
self.assertEqual(errors,
|
|
math.ceil(float(len(loader.dataset)) / loader.batch_size))
|
|
return
|
|
|
|
def test_sequential(self):
|
|
self._test_sequential(DataLoader(self.dataset))
|
|
|
|
def test_sequential_batch(self):
|
|
self._test_sequential(DataLoader(self.dataset, batch_size=2))
|
|
|
|
def test_growing_dataset(self):
|
|
dataset = [torch.ones(4) for _ in range(4)]
|
|
dataloader_seq = DataLoader(dataset, shuffle=False)
|
|
dataloader_shuffle = DataLoader(dataset, shuffle=True)
|
|
dataset.append(torch.ones(4))
|
|
self.assertEqual(len(dataloader_seq), 5)
|
|
self.assertEqual(len(dataloader_shuffle), 5)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_sequential_pin_memory(self):
|
|
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
|
|
for input, target in loader:
|
|
self.assertTrue(input.is_pinned())
|
|
self.assertTrue(target.is_pinned())
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
|
|
def test_multiple_dataloaders(self):
|
|
loader1_it = iter(DataLoader(self.dataset, num_workers=1))
|
|
loader2_it = iter(DataLoader(self.dataset, num_workers=2))
|
|
next(loader1_it)
|
|
next(loader1_it)
|
|
next(loader2_it)
|
|
next(loader2_it)
|
|
next(loader1_it)
|
|
next(loader2_it)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
|
|
@unittest.skip("temporarily disable until flaky failures are fixed")
|
|
def test_segfault(self):
|
|
p = ErrorTrackingProcess(target=_test_segfault)
|
|
p.start()
|
|
p.join(JOIN_TIMEOUT)
|
|
try:
|
|
self.assertFalse(p.is_alive())
|
|
self.assertNotEqual(p.exitcode, 0)
|
|
if IS_WINDOWS:
|
|
self.assertIsInstance(p.exception, OSError)
|
|
self.assertRegex(str(p.exception), r'access violation reading ')
|
|
else:
|
|
self.assertIsInstance(p.exception, RuntimeError)
|
|
self.assertRegex(str(p.exception), r'DataLoader worker \(pid \d+\) is killed by signal: ')
|
|
finally:
|
|
p.terminate()
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
|
|
def test_timeout(self):
|
|
p = ErrorTrackingProcess(target=_test_timeout)
|
|
p.start()
|
|
p.join(JOIN_TIMEOUT)
|
|
try:
|
|
self.assertFalse(p.is_alive())
|
|
self.assertNotEqual(p.exitcode, 0)
|
|
self.assertIsInstance(p.exception, RuntimeError)
|
|
self.assertRegex(str(p.exception), r'DataLoader timed out after \d+ seconds')
|
|
finally:
|
|
p.terminate()
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
|
|
def test_worker_seed(self):
|
|
num_workers = 6
|
|
dataset = SynchronizedSeedDataset(num_workers, num_workers)
|
|
dataloader = DataLoader(dataset, batch_size=1, num_workers=num_workers)
|
|
seeds = set()
|
|
for batch in dataloader:
|
|
seeds.add(batch[0])
|
|
self.assertEqual(len(seeds), num_workers)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
|
|
def test_worker_init_fn(self):
|
|
dataset = SeedDataset(4)
|
|
dataloader = DataLoader(dataset, batch_size=2, num_workers=2,
|
|
worker_init_fn=init_fn)
|
|
for batch in dataloader:
|
|
self.assertEqual(12345, batch[0])
|
|
self.assertEqual(12345, batch[1])
|
|
|
|
def test_shuffle(self):
|
|
self._test_shuffle(DataLoader(self.dataset, shuffle=True))
|
|
|
|
def test_shuffle_batch(self):
|
|
self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
|
|
def test_sequential_workers(self):
|
|
self._test_sequential(DataLoader(self.dataset, num_workers=4))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
|
|
def test_seqential_batch_workers(self):
|
|
self._test_sequential(DataLoader(self.dataset, batch_size=2, num_workers=4))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
|
|
def test_shuffle_workers(self):
|
|
self._test_shuffle(DataLoader(self.dataset, shuffle=True, num_workers=4))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
|
|
def test_shuffle_batch_workers(self):
|
|
self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4))
|
|
|
|
def _test_batch_sampler(self, **kwargs):
|
|
# [(0, 1), (2, 3, 4), (5, 6), (7, 8, 9), ...]
|
|
batches = []
|
|
for i in range(0, 100, 5):
|
|
batches.append(tuple(range(i, i + 2)))
|
|
batches.append(tuple(range(i + 2, i + 5)))
|
|
|
|
dl = DataLoader(self.dataset, batch_sampler=batches, **kwargs)
|
|
self.assertEqual(len(dl), 40)
|
|
for i, (input, _target) in enumerate(dl):
|
|
if i % 2 == 0:
|
|
offset = i * 5 // 2
|
|
self.assertEqual(len(input), 2)
|
|
self.assertEqual(input, self.data[offset:offset + 2])
|
|
else:
|
|
offset = i * 5 // 2
|
|
self.assertEqual(len(input), 3)
|
|
self.assertEqual(input, self.data[offset:offset + 3])
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
|
|
def test_batch_sampler(self):
|
|
self._test_batch_sampler()
|
|
self._test_batch_sampler(num_workers=4)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_shuffle_pin_memory(self):
|
|
loader = DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
|
|
for input, target in loader:
|
|
self.assertTrue(input.is_pinned())
|
|
self.assertTrue(target.is_pinned())
|
|
|
|
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
|
|
def test_numpy(self):
|
|
import numpy as np
|
|
|
|
class TestDataset(torch.utils.data.Dataset):
|
|
def __getitem__(self, i):
|
|
return np.ones((2, 3, 4)) * i
|
|
|
|
def __len__(self):
|
|
return 1000
|
|
|
|
loader = DataLoader(TestDataset(), batch_size=12)
|
|
batch = next(iter(loader))
|
|
self.assertIsInstance(batch, torch.DoubleTensor)
|
|
self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
|
|
|
|
def test_error(self):
|
|
self._test_error(DataLoader(ErrorDataset(100), batch_size=2, shuffle=True))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
|
|
def test_error_workers(self):
|
|
self._test_error(DataLoader(ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_partial_workers(self):
|
|
"check that workers exit even if the iterator is not exhausted"
|
|
loader = iter(DataLoader(self.dataset, batch_size=2, num_workers=4, pin_memory=True))
|
|
workers = loader.workers
|
|
worker_manager_thread = loader.worker_manager_thread
|
|
for i, sample in enumerate(loader):
|
|
if i == 3:
|
|
break
|
|
del loader
|
|
for w in workers:
|
|
w.join(JOIN_TIMEOUT)
|
|
self.assertFalse(w.is_alive(), 'subprocess not terminated')
|
|
self.assertEqual(w.exitcode, 0)
|
|
worker_manager_thread.join(JOIN_TIMEOUT)
|
|
self.assertFalse(worker_manager_thread.is_alive())
|
|
|
|
def test_len(self):
|
|
def check_len(dl, expected):
|
|
self.assertEqual(len(dl), expected)
|
|
n = 0
|
|
for sample in dl:
|
|
n += 1
|
|
self.assertEqual(n, expected)
|
|
check_len(self.dataset, 100)
|
|
check_len(DataLoader(self.dataset, batch_size=2), 50)
|
|
check_len(DataLoader(self.dataset, batch_size=3), 34)
|
|
|
|
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
|
|
def test_numpy_scalars(self):
|
|
import numpy as np
|
|
|
|
class ScalarDataset(torch.utils.data.Dataset):
|
|
def __init__(self, dtype):
|
|
self.dtype = dtype
|
|
|
|
def __getitem__(self, i):
|
|
return self.dtype()
|
|
|
|
def __len__(self):
|
|
return 4
|
|
|
|
dtypes = {
|
|
np.float64: torch.DoubleTensor,
|
|
np.float32: torch.FloatTensor,
|
|
np.float16: torch.HalfTensor,
|
|
np.int64: torch.LongTensor,
|
|
np.int32: torch.IntTensor,
|
|
np.int16: torch.ShortTensor,
|
|
np.int8: torch.CharTensor,
|
|
np.uint8: torch.ByteTensor,
|
|
}
|
|
for dt, tt in dtypes.items():
|
|
dset = ScalarDataset(dt)
|
|
loader = DataLoader(dset, batch_size=2)
|
|
batch = next(iter(loader))
|
|
self.assertIsInstance(batch, tt)
|
|
|
|
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
|
|
def test_default_colate_bad_numpy_types(self):
|
|
import numpy as np
|
|
|
|
# Should be a no-op
|
|
arr = np.array(['a', 'b', 'c'])
|
|
default_collate(arr)
|
|
|
|
arr = np.array([[['a', 'b', 'c']]])
|
|
self.assertRaises(TypeError, lambda: default_collate(arr))
|
|
|
|
arr = np.array([object(), object(), object()])
|
|
self.assertRaises(TypeError, lambda: default_collate(arr))
|
|
|
|
arr = np.array([[[object(), object(), object()]]])
|
|
self.assertRaises(TypeError, lambda: default_collate(arr))
|
|
|
|
|
|
class StringDataset(Dataset):
|
|
def __init__(self):
|
|
self.s = '12345'
|
|
|
|
def __len__(self):
|
|
return len(self.s)
|
|
|
|
def __getitem__(self, ndx):
|
|
return (self.s[ndx], ndx)
|
|
|
|
|
|
class TestStringDataLoader(TestCase):
|
|
def setUp(self):
|
|
self.dataset = StringDataset()
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_shuffle_pin_memory(self):
|
|
loader = DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
|
|
for batch_ndx, (s, n) in enumerate(loader):
|
|
self.assertIsInstance(s[0], str)
|
|
self.assertTrue(n.is_pinned())
|
|
|
|
|
|
class DictDataset(Dataset):
|
|
def __len__(self):
|
|
return 4
|
|
|
|
def __getitem__(self, ndx):
|
|
return {
|
|
'a_tensor': torch.Tensor(4, 2).fill_(ndx),
|
|
'another_dict': {
|
|
'a_number': ndx,
|
|
},
|
|
}
|
|
|
|
|
|
class TestDictDataLoader(TestCase):
|
|
def setUp(self):
|
|
self.dataset = DictDataset()
|
|
|
|
def test_sequential_batch(self):
|
|
loader = DataLoader(self.dataset, batch_size=2, shuffle=False)
|
|
batch_size = loader.batch_size
|
|
for i, sample in enumerate(loader):
|
|
idx = i * batch_size
|
|
self.assertEqual(set(sample.keys()), {'a_tensor', 'another_dict'})
|
|
self.assertEqual(set(sample['another_dict'].keys()), {'a_number'})
|
|
|
|
t = sample['a_tensor']
|
|
self.assertEqual(t.size(), torch.Size([batch_size, 4, 2]))
|
|
self.assertTrue((t[0] == idx).all())
|
|
self.assertTrue((t[1] == idx + 1).all())
|
|
|
|
n = sample['another_dict']['a_number']
|
|
self.assertEqual(n.size(), torch.Size([batch_size]))
|
|
self.assertEqual(n[0], idx)
|
|
self.assertEqual(n[1], idx + 1)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_pin_memory(self):
|
|
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
|
|
for batch_ndx, sample in enumerate(loader):
|
|
self.assertTrue(sample['a_tensor'].is_pinned())
|
|
self.assertTrue(sample['another_dict']['a_number'].is_pinned())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|