mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DataPipe] Adding usage examples for IterDataPipes (#73033)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73033 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D34313793 Pulled By: NivekT fbshipit-source-id: 51125be2f79d73d02658b2b1c2691f96be8d4769
This commit is contained in:
parent
722edfe676
commit
3e3c2df7c6
|
|
@ -33,6 +33,20 @@ class MapperIterDataPipe(IterDataPipe[T_co]):
|
|||
multiple indices, the left-most one is used, and other indices will be removed.
|
||||
- Integer is used for list/tuple. ``-1`` represents to append result at the end.
|
||||
- Key is used for dict. New key is acceptable.
|
||||
|
||||
Example:
|
||||
>>> from torchdata.datapipes.iter import IterableWrapper, Mapper
|
||||
>>> def add_one(x):
|
||||
... return x + 1
|
||||
>>> dp = IterableWrapper(range(10))
|
||||
>>> map_dp_1 = dp.map(add_one) # Invocation via functional form is preferred
|
||||
>>> list(map_dp_1)
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
>>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle`
|
||||
>>> # Use `functools.partial` or explicitly define the function instead
|
||||
>>> map_dp_2 = Mapper(dp, lambda x: x + 1)
|
||||
>>> list(map_dp_2)
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
"""
|
||||
datapipe: IterDataPipe
|
||||
fn: Callable
|
||||
|
|
@ -166,7 +180,6 @@ class CollatorIterDataPipe(MapperIterDataPipe):
|
|||
>>> ds = MyIterDataPipe(start=3, end=7)
|
||||
>>> print(list(ds))
|
||||
[3, 4, 5, 6]
|
||||
|
||||
>>> def collate_fn(batch):
|
||||
... return torch.tensor(batch, dtype=torch.float)
|
||||
...
|
||||
|
|
|
|||
|
|
@ -67,6 +67,13 @@ class ShufflerIterDataPipe(IterDataPipe[T_co]):
|
|||
buffer_size: The buffer size for shuffling (default to ``10000``)
|
||||
unbatch_level: Specifies if it is necessary to unbatch source data before
|
||||
applying the shuffle
|
||||
|
||||
Example:
|
||||
>>> from torchdata.datapipes.iter import IterableWrapper
|
||||
>>> dp = IterableWrapper(range(10))
|
||||
>>> shuffle_dp = dp.shuffle()
|
||||
[0, 4, 1, 6, 3, 2, 9, 5, 7, 8]
|
||||
>>> list(shuffle_dp)
|
||||
"""
|
||||
datapipe: IterDataPipe[T_co]
|
||||
buffer_size: int
|
||||
|
|
|
|||
|
|
@ -21,6 +21,14 @@ class ConcaterIterDataPipe(IterDataPipe):
|
|||
|
||||
Args:
|
||||
datapipes: Iterable DataPipes being concatenated
|
||||
|
||||
Example:
|
||||
>>> import random
|
||||
>>> from torchdata.datapipes.iter import IterableWrapper
|
||||
>>> dp1 = IterableWrapper(range(3))
|
||||
>>> dp2 = IterableWrapper(range(5))
|
||||
>>> list(dp1.concat(dp2))
|
||||
[0, 1, 2, 0, 1, 2, 3, 4]
|
||||
"""
|
||||
datapipes: Tuple[IterDataPipe]
|
||||
length: Optional[int]
|
||||
|
|
@ -61,6 +69,15 @@ class ForkerIterDataPipe(IterDataPipe):
|
|||
buffer_size: this restricts how far ahead the leading child DataPipe
|
||||
can read relative to the slowest child DataPipe.
|
||||
Defaults to ``1000``. Use ``-1`` for the unlimited buffer.
|
||||
|
||||
Example:
|
||||
>>> from torchdata.datapipes.iter import IterableWrapper
|
||||
>>> source_dp = IterableWrapper(range(5))
|
||||
>>> dp1, dp2 = source_dp.fork(num_instances=2)
|
||||
>>> list(dp1)
|
||||
[0, 1, 2, 3, 4]
|
||||
>>> list(dp2)
|
||||
[0, 1, 2, 3, 4]
|
||||
"""
|
||||
def __new__(cls, datapipe: IterDataPipe, num_instances: int, buffer_size: int = 1000):
|
||||
if num_instances < 1:
|
||||
|
|
@ -187,6 +204,25 @@ class DemultiplexerIterDataPipe(IterDataPipe):
|
|||
buffer_size: this defines the maximum number of inputs that the buffer can hold across all child
|
||||
DataPipes while waiting for their values to be yielded.
|
||||
Defaults to ``1000``. Use ``-1`` for the unlimited buffer.
|
||||
|
||||
Examples:
|
||||
>>> from torchdata.datapipes.iter import IterableWrapper
|
||||
>>> def odd_or_even(n):
|
||||
... return n % 2
|
||||
>>> source_dp = IterableWrapper(range(5))
|
||||
>>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even)
|
||||
>>> list(dp1)
|
||||
[0, 2, 4]
|
||||
>>> list(dp2)
|
||||
[1, 3]
|
||||
>>> # It can also filter out any element that gets `None` from the `classifier_fn`
|
||||
>>> def odd_or_even_no_zero(n):
|
||||
... return n % 2 if n != 0 else None
|
||||
>>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True)
|
||||
>>> list(dp1)
|
||||
[2, 4]
|
||||
>>> list(dp2)
|
||||
[1, 3]
|
||||
"""
|
||||
def __new__(cls, datapipe: IterDataPipe, num_instances: int,
|
||||
classifier_fn: Callable[[T_co], Optional[int]], drop_none: bool = False, buffer_size: int = 1000):
|
||||
|
|
@ -326,6 +362,12 @@ class MultiplexerIterDataPipe(IterDataPipe):
|
|||
|
||||
Args:
|
||||
datapipes: Iterable DataPipes that will take turn to yield their elements, until they are all exhausted
|
||||
|
||||
Example:
|
||||
>>> from torchdata.datapipes.iter import IterableWrapper
|
||||
>>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
|
||||
>>> list(dp1.mux(dp2, dp3))
|
||||
[0, 10, 20, 1, 11, 21, 2, 12, 22, 3, 13, 23, 4, 14, 24]
|
||||
"""
|
||||
def __init__(self, *datapipes):
|
||||
self.datapipes = datapipes
|
||||
|
|
@ -363,6 +405,12 @@ class ZipperIterDataPipe(IterDataPipe[Tuple[T_co]]):
|
|||
|
||||
Args:
|
||||
*datapipes: Iterable DataPipes being aggregated
|
||||
|
||||
Example:
|
||||
>>> from torchdata.datapipes.iter import IterableWrapper
|
||||
>>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
|
||||
>>> list(dp1.zip(dp2, dp3))
|
||||
[(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)]
|
||||
"""
|
||||
datapipes: Tuple[IterDataPipe]
|
||||
length: Optional[int]
|
||||
|
|
|
|||
|
|
@ -17,6 +17,12 @@ class FileListerIterDataPipe(IterDataPipe[str]):
|
|||
non_deterministic: Whether to return pathname in sorted order or not.
|
||||
If ``False``, the results yielded from each root directory will be sorted
|
||||
length: Nominal length of the datapipe
|
||||
|
||||
Example:
|
||||
>>> from torchdata.datapipes.iter import FileLister
|
||||
>>> dp = FileLister(root=".", recursive=True)
|
||||
>>> list(dp)
|
||||
['example.py', './data/data.tar']
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -22,6 +22,14 @@ class FileOpenerIterDataPipe(IterDataPipe[Tuple[str, IOBase]]):
|
|||
Note:
|
||||
The opened file handles will be closed by Python's GC periodically. Users can choose
|
||||
to close them explicitly.
|
||||
|
||||
Example:
|
||||
>>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader
|
||||
>>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt'))
|
||||
>>> dp = FileOpener(dp)
|
||||
>>> dp = StreamReader(dp)
|
||||
>>> list(dp)
|
||||
[('./abc.txt', 'abc')]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -58,6 +58,13 @@ class BatcherIterDataPipe(IterDataPipe[DataChunk]):
|
|||
drop_last: Option to drop the last batch if it's not full
|
||||
wrapper_class: wrapper to apply onto each batch (type ``List``) before yielding,
|
||||
defaults to ``DataChunk``
|
||||
|
||||
Example:
|
||||
>>> from torchdata.datapipes.iter import IterableWrapper
|
||||
>>> dp = IterableWrapper(range(10))
|
||||
>>> dp = dp.batch(batch_size=3, drop_last=True)
|
||||
>>> list(dp)
|
||||
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
|
||||
"""
|
||||
datapipe: IterDataPipe
|
||||
batch_size: int
|
||||
|
|
@ -111,6 +118,16 @@ class UnBatcherIterDataPipe(IterDataPipe):
|
|||
datapipe: Iterable DataPipe being un-batched
|
||||
unbatch_level: Defaults to ``1`` (only flattening the top level). If set to ``2``,
|
||||
it will flatten the top two levels, and ``-1`` will flatten the entire DataPipe.
|
||||
|
||||
Example:
|
||||
>>> from torchdata.datapipes.iter import IterableWrapper
|
||||
>>> source_dp = IterableWrapper([[[0, 1], [2]], [[3, 4], [5]], [[6]]])
|
||||
>>> dp1 = source_dp.unbatch()
|
||||
>>> list(dp1)
|
||||
[[0, 1], [2], [3, 4], [5], [6]]
|
||||
>>> dp2 = source_dp.unbatch(unbatch_level=2)
|
||||
>>> list(dp2)
|
||||
[0, 1, 2, 3, 4, 5, 6]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
|
@ -149,16 +166,42 @@ class UnBatcherIterDataPipe(IterDataPipe):
|
|||
class GrouperIterDataPipe(IterDataPipe[DataChunk]):
|
||||
r"""
|
||||
Groups data from input IterDataPipe by keys which are generated from ``group_key_fn``,
|
||||
and yields a ``DataChunk`` with size ranging from ``guaranteed_group_size``
|
||||
to ``group_size`` (functional name: ``groupby``).
|
||||
and yields a ``DataChunk`` with batch size up to ``group_size`` if defined (functional name: ``groupby``).
|
||||
|
||||
The samples are read sequentially from the source ``datapipe``, and a batch of samples belonging to the same group
|
||||
will be yielded as soon as the size of the batch reaches ``group_size``. When the buffer is full,
|
||||
the DataPipe will yield the largest batch with the same key, provided that its size is larger
|
||||
than ``guaranteed_group_size``. If its size is smaller, it will be dropped if ``drop_remaining=True``.
|
||||
|
||||
After iterating through the entirety of source ``datapipe``, everything not dropped due to the buffer capacity
|
||||
will be yielded from the buffer, even if the group sizes are smaller than ``guaranteed_group_size``.
|
||||
|
||||
Args:
|
||||
datapipe: Iterable datapipe to be grouped
|
||||
group_key_fn: Function used to generate group key from the data of the source datapipe
|
||||
buffer_size: The size of buffer for ungrouped data
|
||||
group_size: The size of each group
|
||||
guaranteed_group_size: The guaranteed minimum group size
|
||||
drop_remaining: Specifies if the group smaller than `guaranteed_group_size` will be dropped from buffer
|
||||
group_size: The max size of each group, a batch is yielded as soon as it reaches this size
|
||||
guaranteed_group_size: The guaranteed minimum group size to be yielded in case the buffer is full
|
||||
drop_remaining: Specifies if the group smaller than ``guaranteed_group_size`` will be dropped from buffer
|
||||
when the buffer is full
|
||||
|
||||
Example:
|
||||
>>> import os
|
||||
>>> from torchdata.datapipes.iter import IterableWrapper
|
||||
>>> def group_fn(file):
|
||||
... return os.path.basename(file).split(".")[0]
|
||||
>>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"])
|
||||
>>> dp0 = source_dp.groupby(group_key_fn=group_fn)
|
||||
>>> list(dp0)
|
||||
[['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']]
|
||||
>>> # A group is yielded as soon as its size equals to `group_size`
|
||||
>>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2)
|
||||
>>> list(dp1)
|
||||
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
|
||||
>>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size`
|
||||
>>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2)
|
||||
>>> list(dp2)
|
||||
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
|
||||
"""
|
||||
def __init__(self,
|
||||
datapipe: IterDataPipe[T_co],
|
||||
|
|
|
|||
|
|
@ -28,6 +28,15 @@ class FilterIterDataPipe(IterDataPipe[T_co]):
|
|||
datapipe: Iterable DataPipe being filtered
|
||||
filter_fn: Customized function mapping an element to a boolean.
|
||||
drop_empty_batches: By default, drops a batch if it is empty after filtering instead of keeping an empty list
|
||||
|
||||
Example:
|
||||
>>> from torchdata.datapipes.iter import IterableWrapper
|
||||
>>> def is_even(n):
|
||||
... return n % 2 == 0
|
||||
>>> dp = IterableWrapper(range(5))
|
||||
>>> filter_dp = dp.filter(filter_fn=is_even)
|
||||
>>> list(filter_dp)
|
||||
[0, 2, 4]
|
||||
"""
|
||||
datapipe: IterDataPipe
|
||||
filter_fn: Callable
|
||||
|
|
|
|||
|
|
@ -10,6 +10,13 @@ class StreamReaderIterDataPipe(IterDataPipe[Tuple[str, bytes]]):
|
|||
datapipe: Iterable DataPipe provides label/URL and byte stream
|
||||
chunk: Number of bytes to be read from stream per iteration.
|
||||
If ``None``, all bytes will be read util the EOF.
|
||||
|
||||
Example:
|
||||
>>> from torchdata.datapipes.iter import IterableWrapper, StreamReader
|
||||
>>> from io import StringIO
|
||||
>>> dp = IterableWrapper([("alphabet", StringIO("abcde"))])
|
||||
>>> list(StreamReader(dp, chunk=1))
|
||||
[('alphabet', 'a'), ('alphabet', 'b'), ('alphabet', 'c'), ('alphabet', 'd'), ('alphabet', 'e')]
|
||||
"""
|
||||
def __init__(self, datapipe, chunk=None):
|
||||
self.datapipe = datapipe
|
||||
|
|
|
|||
|
|
@ -13,9 +13,13 @@ class IterableWrapperIterDataPipe(IterDataPipe):
|
|||
iterator. The copy is made when the first element is read in ``iter()``.
|
||||
|
||||
.. note::
|
||||
If ``deepcopy`` is explicitly set to ``False``, users should ensure
|
||||
that the data pipeline doesn't contain any in-place operations over
|
||||
the iterable instance to prevent data inconsistency across iterations.
|
||||
If ``deepcopy`` is explicitly set to ``False``, users should ensure
|
||||
that the data pipeline doesn't contain any in-place operations over
|
||||
the iterable instance to prevent data inconsistency across iterations.
|
||||
|
||||
Example:
|
||||
>>> from torchdata.datapipes.iter import IterableWrapper
|
||||
>>> dp = IterableWrapper(range(10))
|
||||
"""
|
||||
def __init__(self, iterable, deepcopy=True):
|
||||
self.iterable = iterable
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ class MapDataPipe(Dataset[T_co], metaclass=_DataPipeMeta):
|
|||
of :class:`~torch.utils.data.DataLoader`.
|
||||
|
||||
These DataPipes can be invoked in two ways, using the class constructor or applying their
|
||||
functional form onto an existing `MapDataPipe` (available to most but not all DataPipes).
|
||||
functional form onto an existing `MapDataPipe` (recommend, available to most but not all DataPipes).
|
||||
|
||||
Note:
|
||||
:class:`~torch.utils.data.DataLoader` by default constructs an index
|
||||
|
|
@ -97,12 +97,15 @@ class MapDataPipe(Dataset[T_co], metaclass=_DataPipeMeta):
|
|||
Example:
|
||||
>>> from torchdata.datapipes.map import SequenceWrapper, Mapper
|
||||
>>> dp = SequenceWrapper(range(10))
|
||||
>>> map_dp_1 = dp.map(lambda x: x + 1) # Using functional form
|
||||
>>> list(map_dp_1) # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
>>> map_dp_1 = dp.map(lambda x: x + 1) # Using functional form (recommended)
|
||||
>>> list(map_dp_1)
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
>>> map_dp_2 = Mapper(dp, lambda x: x + 1) # Using class constructor
|
||||
>>> list(map_dp_2) # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
>>> list(map_dp_2)
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
>>> batch_dp = map_dp_1.batch(batch_size=2)
|
||||
>>> list(batch_dp) # [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
|
||||
>>> list(batch_dp)
|
||||
[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
|
||||
"""
|
||||
functions: Dict[str, Callable] = {}
|
||||
|
||||
|
|
@ -257,7 +260,7 @@ class IterDataPipe(IterableDataset[T_co], metaclass=_DataPipeMeta):
|
|||
on its iterator.
|
||||
|
||||
These DataPipes can be invoked in two ways, using the class constructor or applying their
|
||||
functional form onto an existing `IterDataPipe` (available to most but not all DataPipes).
|
||||
functional form onto an existing `IterDataPipe` (recommended, available to most but not all DataPipes).
|
||||
You can chain multiple `IterDataPipe` together to form a pipeline that will perform multiple
|
||||
operations in succession.
|
||||
|
||||
|
|
@ -276,11 +279,14 @@ class IterDataPipe(IterableDataset[T_co], metaclass=_DataPipeMeta):
|
|||
>>> from torchdata.datapipes.iter import IterableWrapper, Mapper
|
||||
>>> dp = IterableWrapper(range(10))
|
||||
>>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor
|
||||
>>> map_dp_2 = dp.map(lambda x: x + 1) # Using functional form
|
||||
>>> list(map_dp_1) # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
>>> list(map_dp_2) # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
>>> map_dp_2 = dp.map(lambda x: x + 1) # Using functional form (recommended)
|
||||
>>> list(map_dp_1)
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
>>> list(map_dp_2)
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
>>> filter_dp = map_dp_1.filter(lambda x: x % 2 == 0)
|
||||
>>> list(filter_dp) # [2, 4, 6, 8, 10]
|
||||
>>> list(filter_dp)
|
||||
[2, 4, 6, 8, 10]
|
||||
"""
|
||||
functions: Dict[str, Callable] = {}
|
||||
reduce_ex_hook : Optional[Callable] = None
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user