mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DataPipe] adding SequenceWrapperMapDataPipe (#66275)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66275 Once this is added to Core, TorchData's PR will not need a custom class and can use this wrapper instead. cc VitalyFedyunin ejguan NivekT Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D31485822 Pulled By: NivekT fbshipit-source-id: 790de27629c89c0ca7163a8ee5a09ee8b8233340
This commit is contained in:
parent
a7cc07f109
commit
e808e3d3d6
|
|
@ -112,7 +112,8 @@ def create_temp_dir_and_files():
|
|||
# Then, reset the DataPipe and return a tuple of two lists
|
||||
# 1. A list of elements yielded before the reset
|
||||
# 2. A list of all elements of the DataPipe after the reset
|
||||
def reset_after_n_next_calls(datapipe: IterDataPipe[T_co], n: int) -> Tuple[List[T_co], List[T_co]]:
|
||||
def reset_after_n_next_calls(datapipe: Union[IterDataPipe[T_co], MapDataPipe[T_co]],
|
||||
n: int) -> Tuple[List[T_co], List[T_co]]:
|
||||
it = iter(datapipe)
|
||||
res_before_reset = []
|
||||
for _ in range(n):
|
||||
|
|
@ -644,19 +645,6 @@ class IDP_NoLen(IterDataPipe):
|
|||
yield i
|
||||
|
||||
|
||||
class MDP(MapDataPipe):
|
||||
def __init__(self, input_dp):
|
||||
super().__init__()
|
||||
self.input_dp = input_dp
|
||||
self.length = len(input_dp)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.input_dp[index]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.length
|
||||
|
||||
|
||||
def _fake_fn(data, *args, **kwargs):
|
||||
return data
|
||||
|
||||
|
|
@ -1443,8 +1431,8 @@ class TestFunctionalMapDataPipe(TestCase):
|
|||
picklable_datapipes: List[
|
||||
Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]]
|
||||
] = [
|
||||
(dp.map.Mapper, MDP(arr), (), {}),
|
||||
(dp.map.Mapper, MDP(arr), (_fake_fn, (0,), {'test': True}), {}),
|
||||
(dp.map.Mapper, dp.map.SequenceWrapper(arr), (), {}),
|
||||
(dp.map.Mapper, dp.map.SequenceWrapper(arr), (_fake_fn, (0,), {'test': True}), {}),
|
||||
]
|
||||
for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes:
|
||||
p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
|
||||
|
|
@ -1452,7 +1440,7 @@ class TestFunctionalMapDataPipe(TestCase):
|
|||
unpicklable_datapipes: List[
|
||||
Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]]
|
||||
] = [
|
||||
(dp.map.Mapper, MDP(arr), (lambda x: x,), {}),
|
||||
(dp.map.Mapper, dp.map.SequenceWrapper(arr), (lambda x: x,), {}),
|
||||
]
|
||||
for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes:
|
||||
with warnings.catch_warnings(record=True) as wa:
|
||||
|
|
@ -1464,9 +1452,36 @@ class TestFunctionalMapDataPipe(TestCase):
|
|||
with self.assertRaises(AttributeError):
|
||||
p = pickle.dumps(datapipe)
|
||||
|
||||
def test_sequence_wrapper_datapipe(self):
|
||||
seq = list(range(10))
|
||||
input_dp = dp.map.SequenceWrapper(seq)
|
||||
|
||||
# Functional Test: all elements are equal in the same order
|
||||
self.assertEqual(seq, list(input_dp))
|
||||
|
||||
# Functional Test: confirm deepcopy works by default
|
||||
seq.append(11)
|
||||
self.assertEqual(list(range(10)), list(input_dp)) # input_dp shouldn't have 11
|
||||
|
||||
# Functional Test: non-deepcopy version is working
|
||||
seq2 = [1, 2, 3]
|
||||
input_dp_non_deep = dp.map.SequenceWrapper(seq2, deepcopy=False)
|
||||
seq2.append(4)
|
||||
self.assertEqual(list(seq2), list(input_dp_non_deep)) # should have 4
|
||||
|
||||
# Reset Test: reset the DataPipe
|
||||
seq = list(range(10))
|
||||
n_elements_before_reset = 5
|
||||
res_before_reset, res_after_reset = reset_after_n_next_calls(input_dp, n_elements_before_reset)
|
||||
self.assertEqual(list(range(5)), res_before_reset)
|
||||
self.assertEqual(seq, res_after_reset)
|
||||
|
||||
# __len__ Test: inherits length from sequence
|
||||
self.assertEqual(len(seq), len(input_dp))
|
||||
|
||||
def test_concat_datapipe(self):
|
||||
input_dp1 = MDP(range(10))
|
||||
input_dp2 = MDP(range(5))
|
||||
input_dp1 = dp.map.SequenceWrapper(range(10))
|
||||
input_dp2 = dp.map.SequenceWrapper(range(5))
|
||||
|
||||
with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"):
|
||||
dp.map.Concater()
|
||||
|
|
@ -1482,7 +1497,7 @@ class TestFunctionalMapDataPipe(TestCase):
|
|||
|
||||
def test_map_datapipe(self):
|
||||
arr = range(10)
|
||||
input_dp = MDP(arr)
|
||||
input_dp = dp.map.SequenceWrapper(arr)
|
||||
|
||||
def fn(item, dtype=torch.float, *, sum=False):
|
||||
data = torch.tensor(item, dtype=dtype)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# Functional DataPipe
|
||||
from torch.utils.data.datapipes.map.callable import MapperMapDataPipe as Mapper
|
||||
from torch.utils.data.datapipes.map.combining import ConcaterMapDataPipe as Concater
|
||||
from torch.utils.data.datapipes.map.utils import SequenceWrapperMapDataPipe as SequenceWrapper
|
||||
|
||||
|
||||
__all__ = ['Concater', 'Mapper']
|
||||
__all__ = ['Concater', 'Mapper', 'SequenceWrapper']
|
||||
|
|
|
|||
38
torch/utils/data/datapipes/map/utils.py
Normal file
38
torch/utils/data/datapipes/map/utils.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
import copy
|
||||
import warnings
|
||||
from torch.utils.data import MapDataPipe
|
||||
|
||||
|
||||
class SequenceWrapperMapDataPipe(MapDataPipe):
|
||||
r""":class:`SequenceWrapperMapDataPipe`.
|
||||
|
||||
Map DataPipe that wraps a sequence object.
|
||||
|
||||
Args:
|
||||
sequence: Sequence object to be wrapped into an IterDataPipe
|
||||
deepcopy: Option to deepcopy input sequence object
|
||||
|
||||
.. note::
|
||||
If `deepcopy` is set to False explicitly, users should ensure
|
||||
that data pipeline doesn't contain any in-place operations over
|
||||
the iterable instance, in order to prevent data inconsistency
|
||||
across iterations.
|
||||
"""
|
||||
def __init__(self, sequence, deepcopy=True):
|
||||
if deepcopy:
|
||||
try:
|
||||
self.sequence = copy.deepcopy(sequence)
|
||||
except TypeError:
|
||||
warnings.warn(
|
||||
"The input sequence can not be deepcopied, "
|
||||
"please be aware of in-place modification would affect source data"
|
||||
)
|
||||
self.sequence = sequence
|
||||
else:
|
||||
self.sequence = sequence
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.sequence[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence)
|
||||
Loading…
Reference in New Issue
Block a user