mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45290 Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D24001084 Pulled By: erjia-guan fbshipit-source-id: d8a7455cf3f18e1f8c1edc53c42c1a99c8573c51
355 lines
14 KiB
Python
355 lines
14 KiB
Python
import bisect
|
|
import random
|
|
import warnings
|
|
|
|
from torch._utils import _accumulate
|
|
from torch import randperm
|
|
# No 'default_generator' in torch/__init__.pyi
|
|
from torch import default_generator # type: ignore
|
|
from typing import TypeVar, Generic, Iterable, Iterator, Sequence, List, Optional, Tuple
|
|
from ... import Tensor, Generator
|
|
|
|
T_co = TypeVar('T_co', covariant=True)
|
|
T = TypeVar('T')
|
|
|
|
|
|
class Dataset(Generic[T_co]):
|
|
r"""An abstract class representing a :class:`Dataset`.
|
|
|
|
All datasets that represent a map from keys to data samples should subclass
|
|
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
|
|
data sample for a given key. Subclasses could also optionally overwrite
|
|
:meth:`__len__`, which is expected to return the size of the dataset by many
|
|
:class:`~torch.utils.data.Sampler` implementations and the default options
|
|
of :class:`~torch.utils.data.DataLoader`.
|
|
|
|
.. note::
|
|
:class:`~torch.utils.data.DataLoader` by default constructs a index
|
|
sampler that yields integral indices. To make it work with a map-style
|
|
dataset with non-integral indices/keys, a custom sampler must be provided.
|
|
"""
|
|
|
|
def __getitem__(self, index) -> T_co:
|
|
raise NotImplementedError
|
|
|
|
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
|
|
return ConcatDataset([self, other])
|
|
|
|
# No `def __len__(self)` default?
|
|
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
|
|
# in pytorch/torch/utils/data/sampler.py
|
|
|
|
|
|
class IterableDataset(Dataset[T_co]):
|
|
r"""An iterable Dataset.
|
|
|
|
All datasets that represent an iterable of data samples should subclass it.
|
|
Such form of datasets is particularly useful when data come from a stream.
|
|
|
|
All subclasses should overwrite :meth:`__iter__`, which would return an
|
|
iterator of samples in this dataset.
|
|
|
|
When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
|
|
item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
|
|
iterator. When :attr:`num_workers > 0`, each worker process will have a
|
|
different copy of the dataset object, so it is often desired to configure
|
|
each copy independently to avoid having duplicate data returned from the
|
|
workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
|
|
process, returns information about the worker. It can be used in either the
|
|
dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
|
|
:attr:`worker_init_fn` option to modify each copy's behavior.
|
|
|
|
Example 1: splitting workload across all workers in :meth:`__iter__`::
|
|
|
|
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
|
|
... def __init__(self, start, end):
|
|
... super(MyIterableDataset).__init__()
|
|
... assert end > start, "this example code only works with end >= start"
|
|
... self.start = start
|
|
... self.end = end
|
|
...
|
|
... def __iter__(self):
|
|
... worker_info = torch.utils.data.get_worker_info()
|
|
... if worker_info is None: # single-process data loading, return the full iterator
|
|
... iter_start = self.start
|
|
... iter_end = self.end
|
|
... else: # in a worker process
|
|
... # split workload
|
|
... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
|
|
... worker_id = worker_info.id
|
|
... iter_start = self.start + worker_id * per_worker
|
|
... iter_end = min(iter_start + per_worker, self.end)
|
|
... return iter(range(iter_start, iter_end))
|
|
...
|
|
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
|
|
>>> ds = MyIterableDataset(start=3, end=7)
|
|
|
|
>>> # Single-process loading
|
|
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
|
|
[3, 4, 5, 6]
|
|
|
|
>>> # Mult-process loading with two worker processes
|
|
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
|
|
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
|
|
[3, 5, 4, 6]
|
|
|
|
>>> # With even more workers
|
|
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
|
|
[3, 4, 5, 6]
|
|
|
|
Example 2: splitting workload across all workers using :attr:`worker_init_fn`::
|
|
|
|
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
|
|
... def __init__(self, start, end):
|
|
... super(MyIterableDataset).__init__()
|
|
... assert end > start, "this example code only works with end >= start"
|
|
... self.start = start
|
|
... self.end = end
|
|
...
|
|
... def __iter__(self):
|
|
... return iter(range(self.start, self.end))
|
|
...
|
|
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
|
|
>>> ds = MyIterableDataset(start=3, end=7)
|
|
|
|
>>> # Single-process loading
|
|
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
|
|
[3, 4, 5, 6]
|
|
>>>
|
|
>>> # Directly doing multi-process loading yields duplicate data
|
|
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
|
|
[3, 3, 4, 4, 5, 5, 6, 6]
|
|
|
|
>>> # Define a `worker_init_fn` that configures each dataset copy differently
|
|
>>> def worker_init_fn(worker_id):
|
|
... worker_info = torch.utils.data.get_worker_info()
|
|
... dataset = worker_info.dataset # the dataset copy in this worker process
|
|
... overall_start = dataset.start
|
|
... overall_end = dataset.end
|
|
... # configure the dataset to only process the split workload
|
|
... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
|
|
... worker_id = worker_info.id
|
|
... dataset.start = overall_start + worker_id * per_worker
|
|
... dataset.end = min(dataset.start + per_worker, overall_end)
|
|
...
|
|
|
|
>>> # Mult-process loading with the custom `worker_init_fn`
|
|
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
|
|
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
|
|
[3, 5, 4, 6]
|
|
|
|
>>> # With even more workers
|
|
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
|
|
[3, 4, 5, 6]
|
|
"""
|
|
|
|
def __iter__(self) -> Iterator[T_co]:
|
|
raise NotImplementedError
|
|
|
|
def __add__(self, other: Dataset[T_co]):
|
|
return ChainDataset([self, other])
|
|
|
|
# No `def __len__(self)` default?
|
|
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
|
|
|
|
|
|
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
|
|
r"""Dataset wrapping tensors.
|
|
|
|
Each sample will be retrieved by indexing tensors along the first dimension.
|
|
|
|
Arguments:
|
|
*tensors (Tensor): tensors that have the same size of the first dimension.
|
|
"""
|
|
tensors: Tuple[Tensor, ...]
|
|
|
|
def __init__(self, *tensors: Tensor) -> None:
|
|
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
|
|
self.tensors = tensors
|
|
|
|
def __getitem__(self, index):
|
|
return tuple(tensor[index] for tensor in self.tensors)
|
|
|
|
def __len__(self):
|
|
return self.tensors[0].size(0)
|
|
|
|
|
|
class ConcatDataset(Dataset[T_co]):
|
|
r"""Dataset as a concatenation of multiple datasets.
|
|
|
|
This class is useful to assemble different existing datasets.
|
|
|
|
Arguments:
|
|
datasets (sequence): List of datasets to be concatenated
|
|
"""
|
|
datasets: List[Dataset[T_co]]
|
|
cumulative_sizes: List[int]
|
|
|
|
@staticmethod
|
|
def cumsum(sequence):
|
|
r, s = [], 0
|
|
for e in sequence:
|
|
l = len(e)
|
|
r.append(l + s)
|
|
s += l
|
|
return r
|
|
|
|
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
|
super(ConcatDataset, self).__init__()
|
|
# Cannot verify that datasets is Sized
|
|
assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore
|
|
self.datasets = list(datasets)
|
|
for d in self.datasets:
|
|
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
|
|
self.cumulative_sizes = self.cumsum(self.datasets)
|
|
|
|
def __len__(self):
|
|
return self.cumulative_sizes[-1]
|
|
|
|
def __getitem__(self, idx):
|
|
if idx < 0:
|
|
if -idx > len(self):
|
|
raise ValueError("absolute value of index should not exceed dataset length")
|
|
idx = len(self) + idx
|
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
|
if dataset_idx == 0:
|
|
sample_idx = idx
|
|
else:
|
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
|
return self.datasets[dataset_idx][sample_idx]
|
|
|
|
@property
|
|
def cummulative_sizes(self):
|
|
warnings.warn("cummulative_sizes attribute is renamed to "
|
|
"cumulative_sizes", DeprecationWarning, stacklevel=2)
|
|
return self.cumulative_sizes
|
|
|
|
|
|
class ChainDataset(IterableDataset):
|
|
r"""Dataset for chainning multiple :class:`IterableDataset` s.
|
|
|
|
This class is useful to assemble different existing dataset streams. The
|
|
chainning operation is done on-the-fly, so concatenating large-scale
|
|
datasets with this class will be efficient.
|
|
|
|
Arguments:
|
|
datasets (iterable of IterableDataset): datasets to be chained together
|
|
"""
|
|
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
|
super(ChainDataset, self).__init__()
|
|
self.datasets = datasets
|
|
|
|
def __iter__(self):
|
|
for d in self.datasets:
|
|
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
|
|
for x in d:
|
|
yield x
|
|
|
|
def __len__(self):
|
|
total = 0
|
|
for d in self.datasets:
|
|
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
|
|
# Cannot verify that all self.datasets are Sized
|
|
total += len(d) # type: ignore
|
|
return total
|
|
|
|
|
|
class BufferedShuffleDataset(IterableDataset[T_co]):
|
|
r"""Dataset shuffled from the original dataset.
|
|
|
|
This class is useful to shuffle an existing instance of an IterableDataset.
|
|
The buffer with `buffer_size` is filled with the items from the dataset first. Then,
|
|
each item will be yielded from the buffer by reservoir sampling via iterator.
|
|
|
|
`buffer_size` is required to be larger than 0. For `buffer_size == 1`, the
|
|
dataset is not shuffled. In order to fully shuffle the whole dataset, `buffer_size`
|
|
is required to be greater than or equal to the size of dataset.
|
|
|
|
When it is used with :class:`~torch.utils.data.DataLoader`, each item in the
|
|
dataset will be yielded from the :class:`~torch.utils.data.DataLoader` iterator.
|
|
And, the method to set up a random seed is different based on :attr:`num_workers`.
|
|
|
|
For single-process mode (:attr:`num_workers == 0`), the random seed is required to
|
|
be set before the :class:`~torch.utils.data.DataLoader` in the main process.
|
|
|
|
>>> ds = BufferedShuffleDataset(dataset)
|
|
>>> random.seed(...)
|
|
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
|
|
|
|
For multi-process mode (:attr:`num_workers > 0`), the random seed is set by a callable
|
|
function in each worker.
|
|
|
|
>>> ds = BufferedShuffleDataset(dataset)
|
|
>>> def init_fn(worker_id):
|
|
... random.seed(...)
|
|
>>> print(list(torch.utils.data.DataLoader(ds, ..., num_workers=n, worker_init_fn=init_fn)))
|
|
|
|
Arguments:
|
|
dataset (IterableDataset): The original IterableDataset.
|
|
buffer_size (int): The buffer size for shuffling.
|
|
"""
|
|
dataset: IterableDataset[T_co]
|
|
buffer_size: int
|
|
|
|
def __init__(self, dataset: IterableDataset[T_co], buffer_size: int) -> None:
|
|
super(BufferedShuffleDataset, self).__init__()
|
|
assert buffer_size > 0, "buffer_size should be larger than 0"
|
|
self.dataset = dataset
|
|
self.buffer_size = buffer_size
|
|
|
|
def __iter__(self) -> Iterator[T_co]:
|
|
buf: List[T_co] = []
|
|
for x in self.dataset:
|
|
if len(buf) == self.buffer_size:
|
|
idx = random.randint(0, self.buffer_size - 1)
|
|
yield buf[idx]
|
|
buf[idx] = x
|
|
else:
|
|
buf.append(x)
|
|
random.shuffle(buf)
|
|
while buf:
|
|
yield buf.pop()
|
|
|
|
|
|
class Subset(Dataset[T_co]):
|
|
r"""
|
|
Subset of a dataset at specified indices.
|
|
|
|
Arguments:
|
|
dataset (Dataset): The whole Dataset
|
|
indices (sequence): Indices in the whole set selected for subset
|
|
"""
|
|
dataset: Dataset[T_co]
|
|
indices: Sequence[int]
|
|
|
|
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
|
|
self.dataset = dataset
|
|
self.indices = indices
|
|
|
|
def __getitem__(self, idx):
|
|
return self.dataset[self.indices[idx]]
|
|
|
|
def __len__(self):
|
|
return len(self.indices)
|
|
|
|
|
|
def random_split(dataset: Dataset[T], lengths: Sequence[int],
|
|
generator: Optional[Generator] = default_generator) -> List[Subset[T]]:
|
|
r"""
|
|
Randomly split a dataset into non-overlapping new datasets of given lengths.
|
|
Optionally fix the generator for reproducible results, e.g.:
|
|
|
|
>>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
|
|
|
|
Arguments:
|
|
dataset (Dataset): Dataset to be split
|
|
lengths (sequence): lengths of splits to be produced
|
|
generator (Generator): Generator used for the random permutation.
|
|
"""
|
|
# Cannot verify that dataset is Sized
|
|
if sum(lengths) != len(dataset): # type: ignore
|
|
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
|
|
|
|
indices = randperm(sum(lengths), generator=generator).tolist()
|
|
return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]
|