mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DataPipe] adding SequenceWrapperMapDataPipe (#66275)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66275 Once this is added to Core, TorchData's PR will not need a custom class and can use this wrapper instead. cc VitalyFedyunin ejguan NivekT Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D31485822 Pulled By: NivekT fbshipit-source-id: 790de27629c89c0ca7163a8ee5a09ee8b8233340
This commit is contained in:
parent
a7cc07f109
commit
e808e3d3d6
|
|
@ -112,7 +112,8 @@ def create_temp_dir_and_files():
|
||||||
# Then, reset the DataPipe and return a tuple of two lists
|
# Then, reset the DataPipe and return a tuple of two lists
|
||||||
# 1. A list of elements yielded before the reset
|
# 1. A list of elements yielded before the reset
|
||||||
# 2. A list of all elements of the DataPipe after the reset
|
# 2. A list of all elements of the DataPipe after the reset
|
||||||
def reset_after_n_next_calls(datapipe: IterDataPipe[T_co], n: int) -> Tuple[List[T_co], List[T_co]]:
|
def reset_after_n_next_calls(datapipe: Union[IterDataPipe[T_co], MapDataPipe[T_co]],
|
||||||
|
n: int) -> Tuple[List[T_co], List[T_co]]:
|
||||||
it = iter(datapipe)
|
it = iter(datapipe)
|
||||||
res_before_reset = []
|
res_before_reset = []
|
||||||
for _ in range(n):
|
for _ in range(n):
|
||||||
|
|
@ -644,19 +645,6 @@ class IDP_NoLen(IterDataPipe):
|
||||||
yield i
|
yield i
|
||||||
|
|
||||||
|
|
||||||
class MDP(MapDataPipe):
|
|
||||||
def __init__(self, input_dp):
|
|
||||||
super().__init__()
|
|
||||||
self.input_dp = input_dp
|
|
||||||
self.length = len(input_dp)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
return self.input_dp[index]
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return self.length
|
|
||||||
|
|
||||||
|
|
||||||
def _fake_fn(data, *args, **kwargs):
|
def _fake_fn(data, *args, **kwargs):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
@ -1443,8 +1431,8 @@ class TestFunctionalMapDataPipe(TestCase):
|
||||||
picklable_datapipes: List[
|
picklable_datapipes: List[
|
||||||
Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]]
|
Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]]
|
||||||
] = [
|
] = [
|
||||||
(dp.map.Mapper, MDP(arr), (), {}),
|
(dp.map.Mapper, dp.map.SequenceWrapper(arr), (), {}),
|
||||||
(dp.map.Mapper, MDP(arr), (_fake_fn, (0,), {'test': True}), {}),
|
(dp.map.Mapper, dp.map.SequenceWrapper(arr), (_fake_fn, (0,), {'test': True}), {}),
|
||||||
]
|
]
|
||||||
for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes:
|
for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes:
|
||||||
p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
|
p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
|
||||||
|
|
@ -1452,7 +1440,7 @@ class TestFunctionalMapDataPipe(TestCase):
|
||||||
unpicklable_datapipes: List[
|
unpicklable_datapipes: List[
|
||||||
Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]]
|
Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]]
|
||||||
] = [
|
] = [
|
||||||
(dp.map.Mapper, MDP(arr), (lambda x: x,), {}),
|
(dp.map.Mapper, dp.map.SequenceWrapper(arr), (lambda x: x,), {}),
|
||||||
]
|
]
|
||||||
for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes:
|
for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes:
|
||||||
with warnings.catch_warnings(record=True) as wa:
|
with warnings.catch_warnings(record=True) as wa:
|
||||||
|
|
@ -1464,9 +1452,36 @@ class TestFunctionalMapDataPipe(TestCase):
|
||||||
with self.assertRaises(AttributeError):
|
with self.assertRaises(AttributeError):
|
||||||
p = pickle.dumps(datapipe)
|
p = pickle.dumps(datapipe)
|
||||||
|
|
||||||
|
def test_sequence_wrapper_datapipe(self):
|
||||||
|
seq = list(range(10))
|
||||||
|
input_dp = dp.map.SequenceWrapper(seq)
|
||||||
|
|
||||||
|
# Functional Test: all elements are equal in the same order
|
||||||
|
self.assertEqual(seq, list(input_dp))
|
||||||
|
|
||||||
|
# Functional Test: confirm deepcopy works by default
|
||||||
|
seq.append(11)
|
||||||
|
self.assertEqual(list(range(10)), list(input_dp)) # input_dp shouldn't have 11
|
||||||
|
|
||||||
|
# Functional Test: non-deepcopy version is working
|
||||||
|
seq2 = [1, 2, 3]
|
||||||
|
input_dp_non_deep = dp.map.SequenceWrapper(seq2, deepcopy=False)
|
||||||
|
seq2.append(4)
|
||||||
|
self.assertEqual(list(seq2), list(input_dp_non_deep)) # should have 4
|
||||||
|
|
||||||
|
# Reset Test: reset the DataPipe
|
||||||
|
seq = list(range(10))
|
||||||
|
n_elements_before_reset = 5
|
||||||
|
res_before_reset, res_after_reset = reset_after_n_next_calls(input_dp, n_elements_before_reset)
|
||||||
|
self.assertEqual(list(range(5)), res_before_reset)
|
||||||
|
self.assertEqual(seq, res_after_reset)
|
||||||
|
|
||||||
|
# __len__ Test: inherits length from sequence
|
||||||
|
self.assertEqual(len(seq), len(input_dp))
|
||||||
|
|
||||||
def test_concat_datapipe(self):
|
def test_concat_datapipe(self):
|
||||||
input_dp1 = MDP(range(10))
|
input_dp1 = dp.map.SequenceWrapper(range(10))
|
||||||
input_dp2 = MDP(range(5))
|
input_dp2 = dp.map.SequenceWrapper(range(5))
|
||||||
|
|
||||||
with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"):
|
with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"):
|
||||||
dp.map.Concater()
|
dp.map.Concater()
|
||||||
|
|
@ -1482,7 +1497,7 @@ class TestFunctionalMapDataPipe(TestCase):
|
||||||
|
|
||||||
def test_map_datapipe(self):
|
def test_map_datapipe(self):
|
||||||
arr = range(10)
|
arr = range(10)
|
||||||
input_dp = MDP(arr)
|
input_dp = dp.map.SequenceWrapper(arr)
|
||||||
|
|
||||||
def fn(item, dtype=torch.float, *, sum=False):
|
def fn(item, dtype=torch.float, *, sum=False):
|
||||||
data = torch.tensor(item, dtype=dtype)
|
data = torch.tensor(item, dtype=dtype)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# Functional DataPipe
|
# Functional DataPipe
|
||||||
from torch.utils.data.datapipes.map.callable import MapperMapDataPipe as Mapper
|
from torch.utils.data.datapipes.map.callable import MapperMapDataPipe as Mapper
|
||||||
from torch.utils.data.datapipes.map.combining import ConcaterMapDataPipe as Concater
|
from torch.utils.data.datapipes.map.combining import ConcaterMapDataPipe as Concater
|
||||||
|
from torch.utils.data.datapipes.map.utils import SequenceWrapperMapDataPipe as SequenceWrapper
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['Concater', 'Mapper']
|
__all__ = ['Concater', 'Mapper', 'SequenceWrapper']
|
||||||
|
|
|
||||||
38
torch/utils/data/datapipes/map/utils.py
Normal file
38
torch/utils/data/datapipes/map/utils.py
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
import copy
|
||||||
|
import warnings
|
||||||
|
from torch.utils.data import MapDataPipe
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceWrapperMapDataPipe(MapDataPipe):
|
||||||
|
r""":class:`SequenceWrapperMapDataPipe`.
|
||||||
|
|
||||||
|
Map DataPipe that wraps a sequence object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequence: Sequence object to be wrapped into an IterDataPipe
|
||||||
|
deepcopy: Option to deepcopy input sequence object
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If `deepcopy` is set to False explicitly, users should ensure
|
||||||
|
that data pipeline doesn't contain any in-place operations over
|
||||||
|
the iterable instance, in order to prevent data inconsistency
|
||||||
|
across iterations.
|
||||||
|
"""
|
||||||
|
def __init__(self, sequence, deepcopy=True):
|
||||||
|
if deepcopy:
|
||||||
|
try:
|
||||||
|
self.sequence = copy.deepcopy(sequence)
|
||||||
|
except TypeError:
|
||||||
|
warnings.warn(
|
||||||
|
"The input sequence can not be deepcopied, "
|
||||||
|
"please be aware of in-place modification would affect source data"
|
||||||
|
)
|
||||||
|
self.sequence = sequence
|
||||||
|
else:
|
||||||
|
self.sequence = sequence
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.sequence[index]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.sequence)
|
||||||
Loading…
Reference in New Issue
Block a user