mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50108 Test Plan: Imported from OSS Reviewed By: glaringlee Differential Revision: D25789184 Pulled By: ejguan fbshipit-source-id: 0eeeeeda62533e7137d56f313b7bf11406b32611
165 lines
5.6 KiB
Python
165 lines
5.6 KiB
Python
import tempfile
|
|
import warnings
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import (TestCase, run_tests)
|
|
from torch.utils.data import IterableDataset, RandomSampler
|
|
from torch.utils.data.datasets import \
|
|
(CollateIterableDataset, BatchIterableDataset, ListDirFilesIterableDataset,
|
|
LoadFilesFromDiskIterableDataset, SamplerIterableDataset)
|
|
|
|
|
|
def create_temp_dir_and_files():
|
|
# The temp dir and files within it will be released and deleted in tearDown().
|
|
# Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function.
|
|
temp_dir = tempfile.TemporaryDirectory() # noqa: P201
|
|
temp_dir_path = temp_dir.name
|
|
temp_file1 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False) # noqa: P201
|
|
temp_file2 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False) # noqa: P201
|
|
temp_file3 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False) # noqa: P201
|
|
|
|
return (temp_dir, temp_file1.name, temp_file2.name, temp_file3.name)
|
|
|
|
|
|
class TestIterableDatasetBasic(TestCase):
|
|
|
|
def setUp(self):
|
|
ret = create_temp_dir_and_files()
|
|
self.temp_dir = ret[0]
|
|
self.temp_files = ret[1:]
|
|
|
|
def tearDown(self):
|
|
try:
|
|
self.temp_dir.cleanup()
|
|
except Exception as e:
|
|
warnings.warn("TestIterableDatasetBasic was not able to cleanup temp dir due to {}".format(str(e)))
|
|
|
|
def test_listdirfiles_iterable_dataset(self):
|
|
temp_dir = self.temp_dir.name
|
|
dataset = ListDirFilesIterableDataset(temp_dir, '')
|
|
for pathname in dataset:
|
|
self.assertTrue(pathname in self.temp_files)
|
|
|
|
def test_loadfilesfromdisk_iterable_dataset(self):
|
|
temp_dir = self.temp_dir.name
|
|
dataset1 = ListDirFilesIterableDataset(temp_dir, '')
|
|
dataset2 = LoadFilesFromDiskIterableDataset(dataset1)
|
|
|
|
for rec in dataset2:
|
|
self.assertTrue(rec[0] in self.temp_files)
|
|
self.assertTrue(rec[1].read() == open(rec[0], 'rb').read())
|
|
|
|
|
|
class IterDatasetWithoutLen(IterableDataset):
|
|
def __init__(self, ds):
|
|
super().__init__()
|
|
self.ds = ds
|
|
|
|
def __iter__(self):
|
|
for i in self.ds:
|
|
yield i
|
|
|
|
|
|
class IterDatasetWithLen(IterableDataset):
|
|
def __init__(self, ds):
|
|
super().__init__()
|
|
self.ds = ds
|
|
self.length = len(ds)
|
|
|
|
def __iter__(self):
|
|
for i in self.ds:
|
|
yield i
|
|
|
|
def __len__(self):
|
|
return self.length
|
|
|
|
|
|
class TestFunctionalIterableDataset(TestCase):
|
|
def test_collate_dataset(self):
|
|
arrs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
|
ds_len = IterDatasetWithLen(arrs)
|
|
ds_nolen = IterDatasetWithoutLen(arrs)
|
|
|
|
def _collate_fn(batch):
|
|
return torch.tensor(sum(batch), dtype=torch.float)
|
|
|
|
collate_ds = CollateIterableDataset(ds_len, collate_fn=_collate_fn)
|
|
self.assertEqual(len(ds_len), len(collate_ds))
|
|
ds_iter = iter(ds_len)
|
|
for x in collate_ds:
|
|
y = next(ds_iter)
|
|
self.assertEqual(x, torch.tensor(sum(y), dtype=torch.float))
|
|
|
|
collate_ds_nolen = CollateIterableDataset(ds_nolen) # type: ignore
|
|
with self.assertRaises(NotImplementedError):
|
|
len(collate_ds_nolen)
|
|
ds_nolen_iter = iter(ds_nolen)
|
|
for x in collate_ds_nolen:
|
|
y = next(ds_nolen_iter)
|
|
self.assertEqual(x, torch.tensor(y))
|
|
|
|
def test_batch_dataset(self):
|
|
arrs = range(10)
|
|
ds = IterDatasetWithLen(arrs)
|
|
with self.assertRaises(AssertionError):
|
|
batch_ds0 = BatchIterableDataset(ds, batch_size=0)
|
|
|
|
# Default not drop the last batch
|
|
batch_ds1 = BatchIterableDataset(ds, batch_size=3)
|
|
self.assertEqual(len(batch_ds1), 4)
|
|
batch_iter = iter(batch_ds1)
|
|
value = 0
|
|
for i in range(len(batch_ds1)):
|
|
batch = next(batch_iter)
|
|
if i == 3:
|
|
self.assertEqual(len(batch), 1)
|
|
self.assertEqual(batch, [9])
|
|
else:
|
|
self.assertEqual(len(batch), 3)
|
|
for x in batch:
|
|
self.assertEqual(x, value)
|
|
value += 1
|
|
|
|
# Drop the last batch
|
|
batch_ds2 = BatchIterableDataset(ds, batch_size=3, drop_last=True)
|
|
self.assertEqual(len(batch_ds2), 3)
|
|
value = 0
|
|
for batch in batch_ds2:
|
|
self.assertEqual(len(batch), 3)
|
|
for x in batch:
|
|
self.assertEqual(x, value)
|
|
value += 1
|
|
|
|
batch_ds3 = BatchIterableDataset(ds, batch_size=2)
|
|
self.assertEqual(len(batch_ds3), 5)
|
|
batch_ds4 = BatchIterableDataset(ds, batch_size=2, drop_last=True)
|
|
self.assertEqual(len(batch_ds4), 5)
|
|
|
|
ds_nolen = IterDatasetWithoutLen(arrs)
|
|
batch_ds_nolen = BatchIterableDataset(ds_nolen, batch_size=5)
|
|
with self.assertRaises(NotImplementedError):
|
|
len(batch_ds_nolen)
|
|
|
|
def test_sampler_dataset(self):
|
|
arrs = range(10)
|
|
ds = IterDatasetWithLen(arrs)
|
|
# Default SequentialSampler
|
|
sampled_ds = SamplerIterableDataset(ds) # type: ignore
|
|
self.assertEqual(len(sampled_ds), 10)
|
|
i = 0
|
|
for x in sampled_ds:
|
|
self.assertEqual(x, i)
|
|
i += 1
|
|
|
|
# RandomSampler
|
|
random_sampled_ds = SamplerIterableDataset(ds, sampler=RandomSampler, replacement=True) # type: ignore
|
|
|
|
# Requires `__len__` to build SamplerDataset
|
|
ds_nolen = IterDatasetWithoutLen(arrs)
|
|
with self.assertRaises(AssertionError):
|
|
sampled_ds = SamplerIterableDataset(ds_nolen)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|