mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
122 lines
3.5 KiB
Python
122 lines
3.5 KiB
Python
import bisect
|
|
import warnings
|
|
|
|
from torch._utils import _accumulate
|
|
from torch import randperm
|
|
|
|
|
|
class Dataset(object):
|
|
"""An abstract class representing a Dataset.
|
|
|
|
All other datasets should subclass it. All subclasses should override
|
|
``__len__``, that provides the size of the dataset, and ``__getitem__``,
|
|
supporting integer indexing in range from 0 to len(self) exclusive.
|
|
"""
|
|
|
|
def __getitem__(self, index):
|
|
raise NotImplementedError
|
|
|
|
def __len__(self):
|
|
raise NotImplementedError
|
|
|
|
def __add__(self, other):
|
|
return ConcatDataset([self, other])
|
|
|
|
|
|
class TensorDataset(Dataset):
|
|
"""Dataset wrapping tensors.
|
|
|
|
Each sample will be retrieved by indexing tensors along the first dimension.
|
|
|
|
Arguments:
|
|
*tensors (Tensor): tensors that have the same size of the first dimension.
|
|
"""
|
|
|
|
def __init__(self, *tensors):
|
|
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
|
|
self.tensors = tensors
|
|
|
|
def __getitem__(self, index):
|
|
return tuple(tensor[index] for tensor in self.tensors)
|
|
|
|
def __len__(self):
|
|
return self.tensors[0].size(0)
|
|
|
|
|
|
class ConcatDataset(Dataset):
|
|
"""
|
|
Dataset to concatenate multiple datasets.
|
|
Purpose: useful to assemble different existing datasets, possibly
|
|
large-scale datasets as the concatenation operation is done in an
|
|
on-the-fly manner.
|
|
|
|
Arguments:
|
|
datasets (sequence): List of datasets to be concatenated
|
|
"""
|
|
|
|
@staticmethod
|
|
def cumsum(sequence):
|
|
r, s = [], 0
|
|
for e in sequence:
|
|
l = len(e)
|
|
r.append(l + s)
|
|
s += l
|
|
return r
|
|
|
|
def __init__(self, datasets):
|
|
super(ConcatDataset, self).__init__()
|
|
assert len(datasets) > 0, 'datasets should not be an empty iterable'
|
|
self.datasets = list(datasets)
|
|
self.cumulative_sizes = self.cumsum(self.datasets)
|
|
|
|
def __len__(self):
|
|
return self.cumulative_sizes[-1]
|
|
|
|
def __getitem__(self, idx):
|
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
|
if dataset_idx == 0:
|
|
sample_idx = idx
|
|
else:
|
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
|
return self.datasets[dataset_idx][sample_idx]
|
|
|
|
@property
|
|
def cummulative_sizes(self):
|
|
warnings.warn("cummulative_sizes attribute is renamed to "
|
|
"cumulative_sizes", DeprecationWarning, stacklevel=2)
|
|
return self.cumulative_sizes
|
|
|
|
|
|
class Subset(Dataset):
|
|
"""
|
|
Subset of a dataset at specified indices.
|
|
|
|
Arguments:
|
|
dataset (Dataset): The whole Dataset
|
|
indices (sequence): Indices in the whole set selected for subset
|
|
"""
|
|
def __init__(self, dataset, indices):
|
|
self.dataset = dataset
|
|
self.indices = indices
|
|
|
|
def __getitem__(self, idx):
|
|
return self.dataset[self.indices[idx]]
|
|
|
|
def __len__(self):
|
|
return len(self.indices)
|
|
|
|
|
|
def random_split(dataset, lengths):
|
|
"""
|
|
Randomly split a dataset into non-overlapping new datasets of given lengths.
|
|
|
|
Arguments:
|
|
dataset (Dataset): Dataset to be split
|
|
lengths (sequence): lengths of splits to be produced
|
|
"""
|
|
if sum(lengths) != len(dataset):
|
|
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
|
|
|
|
indices = randperm(sum(lengths))
|
|
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
|