[DataPipe] adding SequenceWrapperMapDataPipe (#66275)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66275

Once this is added to Core, TorchData's PR will not need a custom class and can use this wrapper instead.

cc VitalyFedyunin ejguan NivekT

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D31485822

Pulled By: NivekT

fbshipit-source-id: 790de27629c89c0ca7163a8ee5a09ee8b8233340
This commit is contained in:
Kevin Tse 2021-10-08 08:29:25 -07:00 committed by Facebook GitHub Bot
parent a7cc07f109
commit e808e3d3d6
3 changed files with 75 additions and 21 deletions

View File

@ -112,7 +112,8 @@ def create_temp_dir_and_files():
# Then, reset the DataPipe and return a tuple of two lists # Then, reset the DataPipe and return a tuple of two lists
# 1. A list of elements yielded before the reset # 1. A list of elements yielded before the reset
# 2. A list of all elements of the DataPipe after the reset # 2. A list of all elements of the DataPipe after the reset
def reset_after_n_next_calls(datapipe: IterDataPipe[T_co], n: int) -> Tuple[List[T_co], List[T_co]]: def reset_after_n_next_calls(datapipe: Union[IterDataPipe[T_co], MapDataPipe[T_co]],
n: int) -> Tuple[List[T_co], List[T_co]]:
it = iter(datapipe) it = iter(datapipe)
res_before_reset = [] res_before_reset = []
for _ in range(n): for _ in range(n):
@ -644,19 +645,6 @@ class IDP_NoLen(IterDataPipe):
yield i yield i
class MDP(MapDataPipe):
def __init__(self, input_dp):
super().__init__()
self.input_dp = input_dp
self.length = len(input_dp)
def __getitem__(self, index):
return self.input_dp[index]
def __len__(self) -> int:
return self.length
def _fake_fn(data, *args, **kwargs): def _fake_fn(data, *args, **kwargs):
return data return data
@ -1443,8 +1431,8 @@ class TestFunctionalMapDataPipe(TestCase):
picklable_datapipes: List[ picklable_datapipes: List[
Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]] Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]]
] = [ ] = [
(dp.map.Mapper, MDP(arr), (), {}), (dp.map.Mapper, dp.map.SequenceWrapper(arr), (), {}),
(dp.map.Mapper, MDP(arr), (_fake_fn, (0,), {'test': True}), {}), (dp.map.Mapper, dp.map.SequenceWrapper(arr), (_fake_fn, (0,), {'test': True}), {}),
] ]
for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes: for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes:
p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg] p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
@ -1452,7 +1440,7 @@ class TestFunctionalMapDataPipe(TestCase):
unpicklable_datapipes: List[ unpicklable_datapipes: List[
Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]] Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]]
] = [ ] = [
(dp.map.Mapper, MDP(arr), (lambda x: x,), {}), (dp.map.Mapper, dp.map.SequenceWrapper(arr), (lambda x: x,), {}),
] ]
for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes: for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes:
with warnings.catch_warnings(record=True) as wa: with warnings.catch_warnings(record=True) as wa:
@ -1464,9 +1452,36 @@ class TestFunctionalMapDataPipe(TestCase):
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
p = pickle.dumps(datapipe) p = pickle.dumps(datapipe)
def test_sequence_wrapper_datapipe(self):
seq = list(range(10))
input_dp = dp.map.SequenceWrapper(seq)
# Functional Test: all elements are equal in the same order
self.assertEqual(seq, list(input_dp))
# Functional Test: confirm deepcopy works by default
seq.append(11)
self.assertEqual(list(range(10)), list(input_dp)) # input_dp shouldn't have 11
# Functional Test: non-deepcopy version is working
seq2 = [1, 2, 3]
input_dp_non_deep = dp.map.SequenceWrapper(seq2, deepcopy=False)
seq2.append(4)
self.assertEqual(list(seq2), list(input_dp_non_deep)) # should have 4
# Reset Test: reset the DataPipe
seq = list(range(10))
n_elements_before_reset = 5
res_before_reset, res_after_reset = reset_after_n_next_calls(input_dp, n_elements_before_reset)
self.assertEqual(list(range(5)), res_before_reset)
self.assertEqual(seq, res_after_reset)
# __len__ Test: inherits length from sequence
self.assertEqual(len(seq), len(input_dp))
def test_concat_datapipe(self): def test_concat_datapipe(self):
input_dp1 = MDP(range(10)) input_dp1 = dp.map.SequenceWrapper(range(10))
input_dp2 = MDP(range(5)) input_dp2 = dp.map.SequenceWrapper(range(5))
with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"): with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"):
dp.map.Concater() dp.map.Concater()
@ -1482,7 +1497,7 @@ class TestFunctionalMapDataPipe(TestCase):
def test_map_datapipe(self): def test_map_datapipe(self):
arr = range(10) arr = range(10)
input_dp = MDP(arr) input_dp = dp.map.SequenceWrapper(arr)
def fn(item, dtype=torch.float, *, sum=False): def fn(item, dtype=torch.float, *, sum=False):
data = torch.tensor(item, dtype=dtype) data = torch.tensor(item, dtype=dtype)

View File

@ -1,6 +1,7 @@
# Functional DataPipe # Functional DataPipe
from torch.utils.data.datapipes.map.callable import MapperMapDataPipe as Mapper from torch.utils.data.datapipes.map.callable import MapperMapDataPipe as Mapper
from torch.utils.data.datapipes.map.combining import ConcaterMapDataPipe as Concater from torch.utils.data.datapipes.map.combining import ConcaterMapDataPipe as Concater
from torch.utils.data.datapipes.map.utils import SequenceWrapperMapDataPipe as SequenceWrapper
__all__ = ['Concater', 'Mapper'] __all__ = ['Concater', 'Mapper', 'SequenceWrapper']

View File

@ -0,0 +1,38 @@
import copy
import warnings
from torch.utils.data import MapDataPipe
class SequenceWrapperMapDataPipe(MapDataPipe):
r""":class:`SequenceWrapperMapDataPipe`.
Map DataPipe that wraps a sequence object.
Args:
sequence: Sequence object to be wrapped into an IterDataPipe
deepcopy: Option to deepcopy input sequence object
.. note::
If `deepcopy` is set to False explicitly, users should ensure
that data pipeline doesn't contain any in-place operations over
the iterable instance, in order to prevent data inconsistency
across iterations.
"""
def __init__(self, sequence, deepcopy=True):
if deepcopy:
try:
self.sequence = copy.deepcopy(sequence)
except TypeError:
warnings.warn(
"The input sequence can not be deepcopied, "
"please be aware of in-place modification would affect source data"
)
self.sequence = sequence
else:
self.sequence = sequence
def __getitem__(self, index):
return self.sequence[index]
def __len__(self):
return len(self.sequence)