mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64111 Reviewed By: mruberry Differential Revision: D30639383 Pulled By: ezyang fbshipit-source-id: 96b243307413c99a67d55d862a71937e1ef210f4
374 lines
14 KiB
Python
374 lines
14 KiB
Python
import bisect
|
|
import functools
|
|
import warnings
|
|
from typing import (
|
|
Callable,
|
|
Dict,
|
|
Generic,
|
|
Iterable,
|
|
Iterator,
|
|
List,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
TypeVar,
|
|
)
|
|
|
|
# No 'default_generator' in torch/__init__.pyi
|
|
from torch import default_generator, randperm
|
|
from torch._utils import _accumulate
|
|
from torch.utils.data._typing import _DataPipeMeta
|
|
|
|
from ... import Generator, Tensor
|
|
|
|
T_co = TypeVar('T_co', covariant=True)
|
|
T = TypeVar('T')
|
|
|
|
|
|
class DataChunk(list, Generic[T]):
|
|
def __init__(self, items):
|
|
super().__init__(items)
|
|
self.items = items
|
|
|
|
def as_str(self, indent=''):
|
|
res = indent + "[" + ", ".join(str(i) for i in iter(self)) + "]"
|
|
return res
|
|
|
|
def __iter__(self) -> Iterator[T]:
|
|
for i in super().__iter__():
|
|
yield i
|
|
|
|
def raw_iterator(self):
|
|
for i in self.items:
|
|
yield i
|
|
|
|
|
|
class Dataset(Generic[T_co]):
|
|
r"""An abstract class representing a :class:`Dataset`.
|
|
|
|
All datasets that represent a map from keys to data samples should subclass
|
|
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
|
|
data sample for a given key. Subclasses could also optionally overwrite
|
|
:meth:`__len__`, which is expected to return the size of the dataset by many
|
|
:class:`~torch.utils.data.Sampler` implementations and the default options
|
|
of :class:`~torch.utils.data.DataLoader`.
|
|
|
|
.. note::
|
|
:class:`~torch.utils.data.DataLoader` by default constructs a index
|
|
sampler that yields integral indices. To make it work with a map-style
|
|
dataset with non-integral indices/keys, a custom sampler must be provided.
|
|
"""
|
|
functions: Dict[str, Callable] = {}
|
|
|
|
def __getitem__(self, index) -> T_co:
|
|
raise NotImplementedError
|
|
|
|
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
|
|
return ConcatDataset([self, other])
|
|
|
|
# No `def __len__(self)` default?
|
|
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
|
|
# in pytorch/torch/utils/data/sampler.py
|
|
|
|
def __getattr__(self, attribute_name):
|
|
if attribute_name in Dataset.functions:
|
|
function = functools.partial(Dataset.functions[attribute_name], self)
|
|
return function
|
|
else:
|
|
raise AttributeError
|
|
|
|
@classmethod
|
|
def register_function(cls, function_name, function):
|
|
cls.functions[function_name] = function
|
|
|
|
@classmethod
|
|
def register_datapipe_as_function(cls, function_name, cls_to_register):
|
|
if function_name in cls.functions:
|
|
raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name))
|
|
|
|
def class_function(cls, source_dp, *args, **kwargs):
|
|
return cls(source_dp, *args, **kwargs)
|
|
function = functools.partial(class_function, cls_to_register)
|
|
cls.functions[function_name] = function
|
|
|
|
|
|
class IterableDataset(Dataset[T_co], metaclass=_DataPipeMeta):
|
|
r"""An iterable Dataset.
|
|
|
|
All datasets that represent an iterable of data samples should subclass it.
|
|
Such form of datasets is particularly useful when data come from a stream.
|
|
|
|
All subclasses should overwrite :meth:`__iter__`, which would return an
|
|
iterator of samples in this dataset.
|
|
|
|
When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
|
|
item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
|
|
iterator. When :attr:`num_workers > 0`, each worker process will have a
|
|
different copy of the dataset object, so it is often desired to configure
|
|
each copy independently to avoid having duplicate data returned from the
|
|
workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
|
|
process, returns information about the worker. It can be used in either the
|
|
dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
|
|
:attr:`worker_init_fn` option to modify each copy's behavior.
|
|
|
|
Example 1: splitting workload across all workers in :meth:`__iter__`::
|
|
|
|
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
|
|
... def __init__(self, start, end):
|
|
... super(MyIterableDataset).__init__()
|
|
... assert end > start, "this example code only works with end >= start"
|
|
... self.start = start
|
|
... self.end = end
|
|
...
|
|
... def __iter__(self):
|
|
... worker_info = torch.utils.data.get_worker_info()
|
|
... if worker_info is None: # single-process data loading, return the full iterator
|
|
... iter_start = self.start
|
|
... iter_end = self.end
|
|
... else: # in a worker process
|
|
... # split workload
|
|
... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
|
|
... worker_id = worker_info.id
|
|
... iter_start = self.start + worker_id * per_worker
|
|
... iter_end = min(iter_start + per_worker, self.end)
|
|
... return iter(range(iter_start, iter_end))
|
|
...
|
|
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
|
|
>>> ds = MyIterableDataset(start=3, end=7)
|
|
|
|
>>> # Single-process loading
|
|
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
|
|
[3, 4, 5, 6]
|
|
|
|
>>> # Mult-process loading with two worker processes
|
|
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
|
|
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
|
|
[3, 5, 4, 6]
|
|
|
|
>>> # With even more workers
|
|
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
|
|
[3, 4, 5, 6]
|
|
|
|
Example 2: splitting workload across all workers using :attr:`worker_init_fn`::
|
|
|
|
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
|
|
... def __init__(self, start, end):
|
|
... super(MyIterableDataset).__init__()
|
|
... assert end > start, "this example code only works with end >= start"
|
|
... self.start = start
|
|
... self.end = end
|
|
...
|
|
... def __iter__(self):
|
|
... return iter(range(self.start, self.end))
|
|
...
|
|
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
|
|
>>> ds = MyIterableDataset(start=3, end=7)
|
|
|
|
>>> # Single-process loading
|
|
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
|
|
[3, 4, 5, 6]
|
|
>>>
|
|
>>> # Directly doing multi-process loading yields duplicate data
|
|
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
|
|
[3, 3, 4, 4, 5, 5, 6, 6]
|
|
|
|
>>> # Define a `worker_init_fn` that configures each dataset copy differently
|
|
>>> def worker_init_fn(worker_id):
|
|
... worker_info = torch.utils.data.get_worker_info()
|
|
... dataset = worker_info.dataset # the dataset copy in this worker process
|
|
... overall_start = dataset.start
|
|
... overall_end = dataset.end
|
|
... # configure the dataset to only process the split workload
|
|
... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
|
|
... worker_id = worker_info.id
|
|
... dataset.start = overall_start + worker_id * per_worker
|
|
... dataset.end = min(dataset.start + per_worker, overall_end)
|
|
...
|
|
|
|
>>> # Mult-process loading with the custom `worker_init_fn`
|
|
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
|
|
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
|
|
[3, 5, 4, 6]
|
|
|
|
>>> # With even more workers
|
|
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
|
|
[3, 4, 5, 6]
|
|
"""
|
|
functions: Dict[str, Callable] = {}
|
|
reduce_ex_hook : Optional[Callable] = None
|
|
|
|
def __iter__(self) -> Iterator[T_co]:
|
|
raise NotImplementedError
|
|
|
|
def __add__(self, other: Dataset[T_co]):
|
|
return ChainDataset([self, other])
|
|
|
|
# No `def __len__(self)` default? Subclasses raise `TypeError` when needed.
|
|
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
|
|
|
|
def __getattr__(self, attribute_name):
|
|
if attribute_name in IterableDataset.functions:
|
|
function = functools.partial(IterableDataset.functions[attribute_name], self)
|
|
return function
|
|
else:
|
|
raise AttributeError
|
|
|
|
def __reduce_ex__(self, *args, **kwargs):
|
|
if IterableDataset.reduce_ex_hook is not None:
|
|
try:
|
|
return IterableDataset.reduce_ex_hook(self)
|
|
except NotImplementedError:
|
|
pass
|
|
return super().__reduce_ex__(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def set_reduce_ex_hook(cls, hook_fn):
|
|
if IterableDataset.reduce_ex_hook is not None and hook_fn is not None:
|
|
raise Exception("Attempt to override existing reduce_ex_hook")
|
|
IterableDataset.reduce_ex_hook = hook_fn
|
|
|
|
|
|
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
|
|
r"""Dataset wrapping tensors.
|
|
|
|
Each sample will be retrieved by indexing tensors along the first dimension.
|
|
|
|
Args:
|
|
*tensors (Tensor): tensors that have the same size of the first dimension.
|
|
"""
|
|
tensors: Tuple[Tensor, ...]
|
|
|
|
def __init__(self, *tensors: Tensor) -> None:
|
|
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
|
|
self.tensors = tensors
|
|
|
|
def __getitem__(self, index):
|
|
return tuple(tensor[index] for tensor in self.tensors)
|
|
|
|
def __len__(self):
|
|
return self.tensors[0].size(0)
|
|
|
|
|
|
class ConcatDataset(Dataset[T_co]):
|
|
r"""Dataset as a concatenation of multiple datasets.
|
|
|
|
This class is useful to assemble different existing datasets.
|
|
|
|
Args:
|
|
datasets (sequence): List of datasets to be concatenated
|
|
"""
|
|
datasets: List[Dataset[T_co]]
|
|
cumulative_sizes: List[int]
|
|
|
|
@staticmethod
|
|
def cumsum(sequence):
|
|
r, s = [], 0
|
|
for e in sequence:
|
|
l = len(e)
|
|
r.append(l + s)
|
|
s += l
|
|
return r
|
|
|
|
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
|
super(ConcatDataset, self).__init__()
|
|
# Cannot verify that datasets is Sized
|
|
assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
|
|
self.datasets = list(datasets)
|
|
for d in self.datasets:
|
|
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
|
|
self.cumulative_sizes = self.cumsum(self.datasets)
|
|
|
|
def __len__(self):
|
|
return self.cumulative_sizes[-1]
|
|
|
|
def __getitem__(self, idx):
|
|
if idx < 0:
|
|
if -idx > len(self):
|
|
raise ValueError("absolute value of index should not exceed dataset length")
|
|
idx = len(self) + idx
|
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
|
if dataset_idx == 0:
|
|
sample_idx = idx
|
|
else:
|
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
|
return self.datasets[dataset_idx][sample_idx]
|
|
|
|
@property
|
|
def cummulative_sizes(self):
|
|
warnings.warn("cummulative_sizes attribute is renamed to "
|
|
"cumulative_sizes", DeprecationWarning, stacklevel=2)
|
|
return self.cumulative_sizes
|
|
|
|
|
|
class ChainDataset(IterableDataset):
|
|
r"""Dataset for chaining multiple :class:`IterableDataset` s.
|
|
|
|
This class is useful to assemble different existing dataset streams. The
|
|
chaining operation is done on-the-fly, so concatenating large-scale
|
|
datasets with this class will be efficient.
|
|
|
|
Args:
|
|
datasets (iterable of IterableDataset): datasets to be chained together
|
|
"""
|
|
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
|
super(ChainDataset, self).__init__()
|
|
self.datasets = datasets
|
|
|
|
def __iter__(self):
|
|
for d in self.datasets:
|
|
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
|
|
for x in d:
|
|
yield x
|
|
|
|
def __len__(self):
|
|
total = 0
|
|
for d in self.datasets:
|
|
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
|
|
total += len(d)
|
|
return total
|
|
|
|
|
|
class Subset(Dataset[T_co]):
|
|
r"""
|
|
Subset of a dataset at specified indices.
|
|
|
|
Args:
|
|
dataset (Dataset): The whole Dataset
|
|
indices (sequence): Indices in the whole set selected for subset
|
|
"""
|
|
dataset: Dataset[T_co]
|
|
indices: Sequence[int]
|
|
|
|
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
|
|
self.dataset = dataset
|
|
self.indices = indices
|
|
|
|
def __getitem__(self, idx):
|
|
if isinstance(idx, list):
|
|
return self.dataset[[self.indices[i] for i in idx]]
|
|
return self.dataset[self.indices[idx]]
|
|
|
|
def __len__(self):
|
|
return len(self.indices)
|
|
|
|
|
|
def random_split(dataset: Dataset[T], lengths: Sequence[int],
|
|
generator: Optional[Generator] = default_generator) -> List[Subset[T]]:
|
|
r"""
|
|
Randomly split a dataset into non-overlapping new datasets of given lengths.
|
|
Optionally fix the generator for reproducible results, e.g.:
|
|
|
|
>>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
|
|
|
|
Args:
|
|
dataset (Dataset): Dataset to be split
|
|
lengths (sequence): lengths of splits to be produced
|
|
generator (Generator): Generator used for the random permutation.
|
|
"""
|
|
# Cannot verify that dataset is Sized
|
|
if sum(lengths) != len(dataset):
|
|
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
|
|
|
|
indices = randperm(sum(lengths), generator=generator).tolist()
|
|
return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]
|