mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Torch wrapping datasets list has: `TensorDataset` `ConcatDataset` `ChainDataset` `TensorDataset` is useful for stacking sets of tensors but can't work with objects without `.size()` method. This PR proposes `StackDataset`, similar to `TensorDataset` but for a general case like `ConcatDataset`. Possible usage of `StackDataset` is multimodal networks with different input like image+text or for staking non-tensor input and property to predict. Pull Request resolved: https://github.com/pytorch/pytorch/pull/101338 Approved by: https://github.com/ejguan, https://github.com/NivekT
77 lines
1.9 KiB
Python
77 lines
1.9 KiB
Python
# TODO(VitalyFedyunin): Rearranging this imports leads to crash,
|
|
# need to cleanup dependencies and fix it
|
|
from torch.utils.data.sampler import (
|
|
BatchSampler,
|
|
RandomSampler,
|
|
Sampler,
|
|
SequentialSampler,
|
|
SubsetRandomSampler,
|
|
WeightedRandomSampler,
|
|
)
|
|
from torch.utils.data.dataset import (
|
|
ChainDataset,
|
|
ConcatDataset,
|
|
Dataset,
|
|
IterableDataset,
|
|
StackDataset,
|
|
Subset,
|
|
TensorDataset,
|
|
random_split,
|
|
)
|
|
from torch.utils.data.datapipes.datapipe import (
|
|
DFIterDataPipe,
|
|
DataChunk,
|
|
IterDataPipe,
|
|
MapDataPipe,
|
|
)
|
|
from torch.utils.data.dataloader import (
|
|
DataLoader,
|
|
_DatasetKind,
|
|
get_worker_info,
|
|
default_collate,
|
|
default_convert,
|
|
)
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from torch.utils.data.datapipes._decorator import (
|
|
argument_validation,
|
|
functional_datapipe,
|
|
guaranteed_datapipes_determinism,
|
|
non_deterministic,
|
|
runtime_validation,
|
|
runtime_validation_disabled,
|
|
)
|
|
|
|
__all__ = ['BatchSampler',
|
|
'ChainDataset',
|
|
'ConcatDataset',
|
|
'DFIterDataPipe',
|
|
'DataChunk',
|
|
'DataLoader',
|
|
'Dataset',
|
|
'DistributedSampler',
|
|
'IterDataPipe',
|
|
'IterableDataset',
|
|
'MapDataPipe',
|
|
'RandomSampler',
|
|
'Sampler',
|
|
'SequentialSampler',
|
|
'StackDataset',
|
|
'Subset',
|
|
'SubsetRandomSampler',
|
|
'TensorDataset',
|
|
'WeightedRandomSampler',
|
|
'_DatasetKind',
|
|
'argument_validation',
|
|
'default_collate',
|
|
'default_convert',
|
|
'functional_datapipe',
|
|
'get_worker_info',
|
|
'guaranteed_datapipes_determinism',
|
|
'non_deterministic',
|
|
'random_split',
|
|
'runtime_validation',
|
|
'runtime_validation_disabled']
|
|
|
|
# Please keep this list sorted
|
|
assert __all__ == sorted(__all__)
|