[DataPipe] adding SequenceWrapperMapDataPipe (#66275)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66275

Once this is added to Core, TorchData's PR will not need a custom class and can use this wrapper instead.

cc VitalyFedyunin ejguan NivekT

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D31485822

Pulled By: NivekT

fbshipit-source-id: 790de27629c89c0ca7163a8ee5a09ee8b8233340
This commit is contained in:
Kevin Tse 2021-10-08 08:29:25 -07:00 committed by Facebook GitHub Bot
parent a7cc07f109
commit e808e3d3d6
3 changed files with 75 additions and 21 deletions

View File

@ -112,7 +112,8 @@ def create_temp_dir_and_files():
# Then, reset the DataPipe and return a tuple of two lists
# 1. A list of elements yielded before the reset
# 2. A list of all elements of the DataPipe after the reset
def reset_after_n_next_calls(datapipe: IterDataPipe[T_co], n: int) -> Tuple[List[T_co], List[T_co]]:
def reset_after_n_next_calls(datapipe: Union[IterDataPipe[T_co], MapDataPipe[T_co]],
n: int) -> Tuple[List[T_co], List[T_co]]:
it = iter(datapipe)
res_before_reset = []
for _ in range(n):
@ -644,19 +645,6 @@ class IDP_NoLen(IterDataPipe):
yield i
class MDP(MapDataPipe):
def __init__(self, input_dp):
super().__init__()
self.input_dp = input_dp
self.length = len(input_dp)
def __getitem__(self, index):
return self.input_dp[index]
def __len__(self) -> int:
return self.length
def _fake_fn(data, *args, **kwargs):
return data
@ -1443,8 +1431,8 @@ class TestFunctionalMapDataPipe(TestCase):
picklable_datapipes: List[
Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]]
] = [
(dp.map.Mapper, MDP(arr), (), {}),
(dp.map.Mapper, MDP(arr), (_fake_fn, (0,), {'test': True}), {}),
(dp.map.Mapper, dp.map.SequenceWrapper(arr), (), {}),
(dp.map.Mapper, dp.map.SequenceWrapper(arr), (_fake_fn, (0,), {'test': True}), {}),
]
for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes:
p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
@ -1452,7 +1440,7 @@ class TestFunctionalMapDataPipe(TestCase):
unpicklable_datapipes: List[
Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]]
] = [
(dp.map.Mapper, MDP(arr), (lambda x: x,), {}),
(dp.map.Mapper, dp.map.SequenceWrapper(arr), (lambda x: x,), {}),
]
for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes:
with warnings.catch_warnings(record=True) as wa:
@ -1464,9 +1452,36 @@ class TestFunctionalMapDataPipe(TestCase):
with self.assertRaises(AttributeError):
p = pickle.dumps(datapipe)
def test_sequence_wrapper_datapipe(self):
seq = list(range(10))
input_dp = dp.map.SequenceWrapper(seq)
# Functional Test: all elements are equal in the same order
self.assertEqual(seq, list(input_dp))
# Functional Test: confirm deepcopy works by default
seq.append(11)
self.assertEqual(list(range(10)), list(input_dp)) # input_dp shouldn't have 11
# Functional Test: non-deepcopy version is working
seq2 = [1, 2, 3]
input_dp_non_deep = dp.map.SequenceWrapper(seq2, deepcopy=False)
seq2.append(4)
self.assertEqual(list(seq2), list(input_dp_non_deep)) # should have 4
# Reset Test: reset the DataPipe
seq = list(range(10))
n_elements_before_reset = 5
res_before_reset, res_after_reset = reset_after_n_next_calls(input_dp, n_elements_before_reset)
self.assertEqual(list(range(5)), res_before_reset)
self.assertEqual(seq, res_after_reset)
# __len__ Test: inherits length from sequence
self.assertEqual(len(seq), len(input_dp))
def test_concat_datapipe(self):
input_dp1 = MDP(range(10))
input_dp2 = MDP(range(5))
input_dp1 = dp.map.SequenceWrapper(range(10))
input_dp2 = dp.map.SequenceWrapper(range(5))
with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"):
dp.map.Concater()
@ -1482,7 +1497,7 @@ class TestFunctionalMapDataPipe(TestCase):
def test_map_datapipe(self):
arr = range(10)
input_dp = MDP(arr)
input_dp = dp.map.SequenceWrapper(arr)
def fn(item, dtype=torch.float, *, sum=False):
data = torch.tensor(item, dtype=dtype)

View File

@ -1,6 +1,7 @@
# Functional DataPipe
from torch.utils.data.datapipes.map.callable import MapperMapDataPipe as Mapper
from torch.utils.data.datapipes.map.combining import ConcaterMapDataPipe as Concater
from torch.utils.data.datapipes.map.utils import SequenceWrapperMapDataPipe as SequenceWrapper
__all__ = ['Concater', 'Mapper']
__all__ = ['Concater', 'Mapper', 'SequenceWrapper']

View File

@ -0,0 +1,38 @@
import copy
import warnings
from torch.utils.data import MapDataPipe
class SequenceWrapperMapDataPipe(MapDataPipe):
r""":class:`SequenceWrapperMapDataPipe`.
Map DataPipe that wraps a sequence object.
Args:
sequence: Sequence object to be wrapped into an IterDataPipe
deepcopy: Option to deepcopy input sequence object
.. note::
If `deepcopy` is set to False explicitly, users should ensure
that data pipeline doesn't contain any in-place operations over
the iterable instance, in order to prevent data inconsistency
across iterations.
"""
def __init__(self, sequence, deepcopy=True):
if deepcopy:
try:
self.sequence = copy.deepcopy(sequence)
except TypeError:
warnings.warn(
"The input sequence can not be deepcopied, "
"please be aware of in-place modification would affect source data"
)
self.sequence = sequence
else:
self.sequence = sequence
def __getitem__(self, index):
return self.sequence[index]
def __len__(self):
return len(self.sequence)