pytorch/torch/utils/data/__init__.py
erjia 365ce350cb Make ShufflerDataPipe deterministic for SP & MP DataLoader (#77741)
This is the first PR to make DataPipe deterministic.

Users should be able to use `torch.manual_seed(seed)` to control the shuffle order for the following cases:
- Directly over `DataPipe`
- For single-process DataLoader
- Multiprocessing DataLoader

Unfortunately, for distributed training, users have to run `apply_shuffle_seed` manually to make sure all distributed processes having the same order of shuffle.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77741
Approved by: https://github.com/VitalyFedyunin, https://github.com/NivekT
2022-05-18 23:32:07 +00:00

79 lines
2.0 KiB
Python

# TODO(VitalyFedyunin): Rearranging this imports leads to crash,
# need to cleanup dependencies and fix it
from torch.utils.data.sampler import (
BatchSampler,
RandomSampler,
Sampler,
SequentialSampler,
SubsetRandomSampler,
WeightedRandomSampler,
)
from torch.utils.data.dataset import (
ChainDataset,
ConcatDataset,
Dataset,
IterableDataset,
Subset,
TensorDataset,
random_split,
)
from torch.utils.data.datapipes.datapipe import (
DFIterDataPipe,
DataChunk,
IterDataPipe,
MapDataPipe,
)
from torch.utils.data.dataloader import (
DataLoader,
_DatasetKind,
get_worker_info,
default_collate,
default_convert,
)
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.datapipes._decorator import (
argument_validation,
functional_datapipe,
guaranteed_datapipes_determinism,
non_deterministic,
runtime_validation,
runtime_validation_disabled,
)
from torch.utils.data.dataloader_experimental import DataLoader2
from torch.utils.data import communication
__all__ = ['BatchSampler',
'ChainDataset',
'ConcatDataset',
'DFIterDataPipe',
'DataChunk',
'DataLoader',
'DataLoader2',
'Dataset',
'DistributedSampler',
'IterDataPipe',
'IterableDataset',
'MapDataPipe',
'RandomSampler',
'Sampler',
'SequentialSampler',
'Subset',
'SubsetRandomSampler',
'TensorDataset',
'WeightedRandomSampler',
'_DatasetKind',
'argument_validation',
'communication',
'default_collate',
'default_convert',
'functional_datapipe',
'get_worker_info',
'guaranteed_datapipes_determinism',
'non_deterministic',
'random_split',
'runtime_validation',
'runtime_validation_disabled']
# Please keep this list sorted
assert __all__ == sorted(__all__)