pytorch/test/test_datapipe.py
lixinyu 015cabf82a move GroupByFilename Dataset to DataPipe (#51709)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51709

Move GroupByFilename Dataset to DataPipe

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D26263585

Pulled By: glaringlee

fbshipit-source-id: 00e3e13b47b89117f1ccfc4cd6239940a40d071e
2021-02-09 03:34:56 -08:00

406 lines
16 KiB
Python

import os
import pickle
import random
import tempfile
import warnings
import tarfile
import zipfile
import numpy as np
from PIL import Image
import torch
from torch.testing._internal.common_utils import (TestCase, run_tests)
from torch.utils.data import IterDataPipe, RandomSampler
from typing import List, Tuple, Dict, Any, Type
import torch.utils.data.datapipes as dp
from torch.utils.data.datapipes.utils.decoder import (
basichandlers as decoder_basichandlers,
imagehandler as decoder_imagehandler)
def create_temp_dir_and_files():
# The temp dir and files within it will be released and deleted in tearDown().
# Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function.
temp_dir = tempfile.TemporaryDirectory() # noqa: P201
temp_dir_path = temp_dir.name
with tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False, suffix='.txt') as f:
temp_file1_name = f.name
with tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False, suffix='.byte') as f:
temp_file2_name = f.name
with tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False, suffix='.empty') as f:
temp_file3_name = f.name
with open(temp_file1_name, 'w') as f1:
f1.write('0123456789abcdef')
with open(temp_file2_name, 'wb') as f2:
f2.write(b"0123456789abcdef")
temp_sub_dir = tempfile.TemporaryDirectory(dir=temp_dir_path) # noqa: P201
temp_sub_dir_path = temp_sub_dir.name
with tempfile.NamedTemporaryFile(dir=temp_sub_dir_path, delete=False, suffix='.txt') as f:
temp_sub_file1_name = f.name
with tempfile.NamedTemporaryFile(dir=temp_sub_dir_path, delete=False, suffix='.byte') as f:
temp_sub_file2_name = f.name
with open(temp_sub_file1_name, 'w') as f1:
f1.write('0123456789abcdef')
with open(temp_sub_file2_name, 'wb') as f2:
f2.write(b"0123456789abcdef")
return [(temp_dir, temp_file1_name, temp_file2_name, temp_file3_name),
(temp_sub_dir, temp_sub_file1_name, temp_sub_file2_name)]
class TestIterableDataPipeBasic(TestCase):
def setUp(self):
ret = create_temp_dir_and_files()
self.temp_dir = ret[0][0]
self.temp_files = ret[0][1:]
self.temp_sub_dir = ret[1][0]
self.temp_sub_files = ret[1][1:]
def tearDown(self):
try:
self.temp_sub_dir.cleanup()
self.temp_dir.cleanup()
except Exception as e:
warnings.warn("TestIterableDatasetBasic was not able to cleanup temp dir due to {}".format(str(e)))
def test_listdirfiles_iterable_datapipe(self):
temp_dir = self.temp_dir.name
datapipe = dp.iter.ListDirFiles(temp_dir, '')
count = 0
for pathname in datapipe:
count = count + 1
self.assertTrue(pathname in self.temp_files)
self.assertEqual(count, len(self.temp_files))
count = 0
datapipe = dp.iter.ListDirFiles(temp_dir, '', recursive=True)
for pathname in datapipe:
count = count + 1
self.assertTrue((pathname in self.temp_files) or (pathname in self.temp_sub_files))
self.assertEqual(count, len(self.temp_files) + len(self.temp_sub_files))
def test_loadfilesfromdisk_iterable_datapipe(self):
# test import datapipe class directly
from torch.utils.data.datapipes.iter import ListDirFiles, LoadFilesFromDisk
temp_dir = self.temp_dir.name
datapipe1 = ListDirFiles(temp_dir, '')
datapipe2 = LoadFilesFromDisk(datapipe1)
count = 0
for rec in datapipe2:
count = count + 1
self.assertTrue(rec[0] in self.temp_files)
self.assertTrue(rec[1].read() == open(rec[0], 'rb').read())
self.assertEqual(count, len(self.temp_files))
def test_readfilesfromtar_iterable_datapipe(self):
temp_dir = self.temp_dir.name
temp_tarfile_pathname = os.path.join(temp_dir, "test_tar.tar")
with tarfile.open(temp_tarfile_pathname, "w:gz") as tar:
tar.add(self.temp_files[0])
tar.add(self.temp_files[1])
tar.add(self.temp_files[2])
datapipe1 = dp.iter.ListDirFiles(temp_dir, '*.tar')
datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1)
datapipe3 = dp.iter.ReadFilesFromTar(datapipe2)
# read extracted files before reaching the end of the tarfile
count = 0
for rec, temp_file in zip(datapipe3, self.temp_files):
count = count + 1
self.assertEqual(os.path.basename(rec[0]), os.path.basename(temp_file))
self.assertEqual(rec[1].read(), open(temp_file, 'rb').read())
self.assertEqual(count, len(self.temp_files))
# read extracted files after reaching the end of the tarfile
count = 0
data_refs = []
for rec in datapipe3:
count = count + 1
data_refs.append(rec)
self.assertEqual(count, len(self.temp_files))
for i in range(0, count):
self.assertEqual(os.path.basename(data_refs[i][0]), os.path.basename(self.temp_files[i]))
self.assertEqual(data_refs[i][1].read(), open(self.temp_files[i], 'rb').read())
def test_readfilesfromzip_iterable_datapipe(self):
temp_dir = self.temp_dir.name
temp_zipfile_pathname = os.path.join(temp_dir, "test_zip.zip")
with zipfile.ZipFile(temp_zipfile_pathname, 'w') as myzip:
myzip.write(self.temp_files[0])
myzip.write(self.temp_files[1])
myzip.write(self.temp_files[2])
datapipe1 = dp.iter.ListDirFiles(temp_dir, '*.zip')
datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1)
datapipe3 = dp.iter.ReadFilesFromZip(datapipe2)
# read extracted files before reaching the end of the zipfile
count = 0
for rec, temp_file in zip(datapipe3, self.temp_files):
count = count + 1
self.assertEqual(os.path.basename(rec[0]), os.path.basename(temp_file))
self.assertEqual(rec[1].read(), open(temp_file, 'rb').read())
self.assertEqual(count, len(self.temp_files))
# read extracted files before reaching the end of the zipile
count = 0
data_refs = []
for rec in datapipe3:
count = count + 1
data_refs.append(rec)
self.assertEqual(count, len(self.temp_files))
for i in range(0, count):
self.assertEqual(os.path.basename(data_refs[i][0]), os.path.basename(self.temp_files[i]))
self.assertEqual(data_refs[i][1].read(), open(self.temp_files[i], 'rb').read())
def test_routeddecoder_iterable_datapipe(self):
temp_dir = self.temp_dir.name
temp_pngfile_pathname = os.path.join(temp_dir, "test_png.png")
img = Image.new('RGB', (2, 2), color='red')
img.save(temp_pngfile_pathname)
datapipe1 = dp.iter.ListDirFiles(temp_dir, ['*.png', '*.txt'])
datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1)
datapipe3 = dp.iter.RoutedDecoder(datapipe2, handlers=[decoder_imagehandler('rgb')])
datapipe3.add_handler(decoder_basichandlers)
for rec in datapipe3:
ext = os.path.splitext(rec[0])[1]
if ext == '.png':
expected = np.array([[[1., 0., 0.], [1., 0., 0.]], [[1., 0., 0.], [1., 0., 0.]]], dtype=np.single)
self.assertTrue(np.array_equal(rec[1], expected))
else:
self.assertTrue(rec[1] == open(rec[0], 'rb').read().decode('utf-8'))
def test_groupbykey_iterable_datapipe(self):
temp_dir = self.temp_dir.name
temp_tarfile_pathname = os.path.join(temp_dir, "test_tar.tar")
file_list = [
"a.png", "b.png", "c.json", "a.json", "c.png", "b.json", "d.png",
"d.json", "e.png", "f.json", "g.png", "f.png", "g.json", "e.json",
"h.txt", "h.json"]
with tarfile.open(temp_tarfile_pathname, "w:gz") as tar:
for file_name in file_list:
file_pathname = os.path.join(temp_dir, file_name)
with open(file_pathname, 'w') as f:
f.write('12345abcde')
tar.add(file_pathname)
datapipe1 = dp.iter.ListDirFiles(temp_dir, '*.tar')
datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1)
datapipe3 = dp.iter.ReadFilesFromTar(datapipe2)
datapipe4 = dp.iter.GroupByKey(datapipe3, group_size=2)
expected_result = [("a.png", "a.json"), ("c.png", "c.json"), ("b.png", "b.json"), ("d.png", "d.json"), (
"f.png", "f.json"), ("g.png", "g.json"), ("e.png", "e.json"), ("h.json", "h.txt")]
count = 0
for rec, expected in zip(datapipe4, expected_result):
count = count + 1
self.assertEqual(os.path.basename(rec[0][0]), expected[0])
self.assertEqual(os.path.basename(rec[1][0]), expected[1])
self.assertEqual(rec[0][1].read(), b'12345abcde')
self.assertEqual(rec[1][1].read(), b'12345abcde')
self.assertEqual(count, 8)
class IDP_NoLen(IterDataPipe):
def __init__(self, input_dp):
super().__init__()
self.input_dp = input_dp
def __iter__(self):
for i in self.input_dp:
yield i
class IDP(IterDataPipe):
def __init__(self, input_dp):
super().__init__()
self.input_dp = input_dp
self.length = len(input_dp)
def __iter__(self):
for i in self.input_dp:
yield i
def __len__(self):
return self.length
def _fake_fn(self, data, *args, **kwargs):
return data
class TestFunctionalIterDataPipe(TestCase):
def test_picklable(self):
arr = range(10)
picklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, List, Dict[str, Any]]] = [
(dp.iter.Callable, IDP(arr), [], {}),
(dp.iter.Callable, IDP(arr), [0], {'fn': _fake_fn, 'test': True}),
(dp.iter.Collate, IDP(arr), [], {}),
(dp.iter.Collate, IDP(arr), [0], {'collate_fn': _fake_fn, 'test': True}),
]
for dpipe, input_dp, args, kargs in picklable_datapipes:
p = pickle.dumps(dpipe(input_dp, *args, **kargs)) # type: ignore
unpicklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, List, Dict[str, Any]]] = [
(dp.iter.Callable, IDP(arr), [], {'fn': lambda x: x}),
(dp.iter.Collate, IDP(arr), [], {'collate_fn': lambda x: x}),
]
for dpipe, input_dp, args, kargs in unpicklable_datapipes:
with self.assertRaises(AttributeError):
p = pickle.dumps(dpipe(input_dp, *args, **kargs)) # type: ignore
def test_callable_datapipe(self):
arr = range(10)
input_dp = IDP(arr)
input_dp_nl = IDP_NoLen(arr)
def fn(item, dtype=torch.float, *, sum=False):
data = torch.tensor(item, dtype=dtype)
return data if not sum else data.sum()
callable_dp = dp.iter.Callable(input_dp, fn=fn) # type: ignore
self.assertEqual(len(input_dp), len(callable_dp))
for x, y in zip(callable_dp, input_dp):
self.assertEqual(x, torch.tensor(y, dtype=torch.float))
callable_dp = dp.iter.Callable(input_dp, torch.int, fn=fn, sum=True) # type: ignore
self.assertEqual(len(input_dp), len(callable_dp))
for x, y in zip(callable_dp, input_dp):
self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())
callable_dp_nl = dp.iter.Callable(input_dp_nl) # type: ignore
with self.assertRaises(NotImplementedError):
len(callable_dp_nl)
for x, y in zip(callable_dp_nl, input_dp_nl):
self.assertEqual(x, torch.tensor(y, dtype=torch.float))
def test_collate_datapipe(self):
arrs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
input_dp = IDP(arrs)
input_dp_nl = IDP_NoLen(arrs)
def _collate_fn(batch):
return torch.tensor(sum(batch), dtype=torch.float)
collate_dp = dp.iter.Collate(input_dp, collate_fn=_collate_fn)
self.assertEqual(len(input_dp), len(collate_dp))
for x, y in zip(collate_dp, input_dp):
self.assertEqual(x, torch.tensor(sum(y), dtype=torch.float))
collate_dp_nl = dp.iter.Collate(input_dp_nl) # type: ignore
with self.assertRaises(NotImplementedError):
len(collate_dp_nl)
for x, y in zip(collate_dp_nl, input_dp_nl):
self.assertEqual(x, torch.tensor(y))
def test_batch_datapipe(self):
arrs = list(range(10))
input_dp = IDP(arrs)
with self.assertRaises(AssertionError):
batch_dp0 = dp.iter.Batch(input_dp, batch_size=0)
# Default not drop the last batch
bs = 3
batch_dp1 = dp.iter.Batch(input_dp, batch_size=bs)
self.assertEqual(len(batch_dp1), 4)
for i, batch in enumerate(batch_dp1):
self.assertEqual(len(batch), 1 if i == 3 else bs)
self.assertEqual(batch, arrs[i * bs: i * bs + len(batch)])
# Drop the last batch
bs = 4
batch_dp2 = dp.iter.Batch(input_dp, batch_size=bs, drop_last=True)
self.assertEqual(len(batch_dp2), 2)
for i, batch in enumerate(batch_dp2):
self.assertEqual(len(batch), bs)
self.assertEqual(batch, arrs[i * bs: i * bs + len(batch)])
input_dp_nl = IDP_NoLen(range(10))
batch_dp_nl = dp.iter.Batch(input_dp_nl, batch_size=2)
with self.assertRaises(NotImplementedError):
len(batch_dp_nl)
def test_bucket_batch_datapipe(self):
input_dp = IDP(range(20))
with self.assertRaises(AssertionError):
dp.iter.BucketBatch(input_dp, batch_size=0)
input_dp_nl = IDP_NoLen(range(20))
bucket_dp_nl = dp.iter.BucketBatch(input_dp_nl, batch_size=7)
with self.assertRaises(NotImplementedError):
len(bucket_dp_nl)
# Test Bucket Batch without sort_key
def _helper(**kwargs):
arrs = list(range(100))
random.shuffle(arrs)
input_dp = IDP(arrs)
bucket_dp = dp.iter.BucketBatch(input_dp, **kwargs)
if kwargs["sort_key"] is None:
# BatchDataset as reference
ref_dp = dp.iter.Batch(input_dp, batch_size=kwargs['batch_size'], drop_last=kwargs['drop_last'])
for batch, rbatch in zip(bucket_dp, ref_dp):
self.assertEqual(batch, rbatch)
else:
bucket_size = bucket_dp.bucket_size
bucket_num = (len(input_dp) - 1) // bucket_size + 1
it = iter(bucket_dp)
for i in range(bucket_num):
ref = sorted(arrs[i * bucket_size: (i + 1) * bucket_size])
bucket: List = []
while len(bucket) < len(ref):
try:
batch = next(it)
bucket += batch
# If drop last, stop in advance
except StopIteration:
break
if len(bucket) != len(ref):
ref = ref[:len(bucket)]
# Sorted bucket
self.assertEqual(bucket, ref)
_helper(batch_size=7, drop_last=False, sort_key=None)
_helper(batch_size=7, drop_last=True, bucket_size_mul=5, sort_key=None)
# Test Bucket Batch with sort_key
def _sort_fn(data):
return data
_helper(batch_size=7, drop_last=False, bucket_size_mul=5, sort_key=_sort_fn)
_helper(batch_size=7, drop_last=True, bucket_size_mul=5, sort_key=_sort_fn)
def test_sampler_datapipe(self):
arrs = range(10)
input_dp = IDP(arrs)
# Default SequentialSampler
sampled_dp = dp.iter.Sampler(input_dp) # type: ignore
self.assertEqual(len(sampled_dp), 10)
i = 0
for x in sampled_dp:
self.assertEqual(x, i)
i += 1
# RandomSampler
random_sampled_dp = dp.iter.Sampler(input_dp, sampler=RandomSampler, replacement=True) # type: ignore
# Requires `__len__` to build SamplerDataset
input_dp_nolen = IDP_NoLen(arrs)
with self.assertRaises(AssertionError):
sampled_dp = dp.iter.Sampler(input_dp_nolen)
if __name__ == '__main__':
run_tests()