mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
type annotations for dataloader, dataset, sampler (#39392)
Summary: Fixes https://github.com/pytorch/pytorch/issues/38913 Pull Request resolved: https://github.com/pytorch/pytorch/pull/39392 Reviewed By: anjali411 Differential Revision: D22102489 Pulled By: zou3519 fbshipit-source-id: acb68d9521145f0b047214d62b5bdc5a0d1b9be4
This commit is contained in:
parent
a6b703cc89
commit
0e09511af9
3
mypy.ini
3
mypy.ini
|
|
@ -302,9 +302,6 @@ ignore_errors = True
|
|||
[mypy-torch.utils.data._utils.worker]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.data.dataset]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.data.distributed]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -821,6 +821,16 @@ class TestDataLoader(TestCase):
|
|||
with self.assertRaisesRegex(RuntimeError, 'Error in worker_init_fn'):
|
||||
list(iter(loader))
|
||||
|
||||
def test_typing(self):
|
||||
from typing import List
|
||||
# Make sure there is no TypeError
|
||||
|
||||
class SomeDatasetClass(Dataset[List[torch.Tensor]]):
|
||||
pass
|
||||
|
||||
def _create_dataloader(is_train: bool) -> DataLoader[List[torch.Tensor]]:
|
||||
pass
|
||||
|
||||
@unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
|
||||
@unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows")
|
||||
def test_fd_limit_exceeded(self):
|
||||
|
|
@ -2019,5 +2029,6 @@ class TestSetAffinity(TestCase):
|
|||
self.assertEqual(sample, [2])
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ in `./_utils/worker.py`.
|
|||
import threading
|
||||
import itertools
|
||||
import warnings
|
||||
from typing import Any, Callable, TypeVar, Generic, Sequence, List, Optional
|
||||
|
||||
import multiprocessing as python_multiprocessing
|
||||
import torch
|
||||
|
|
@ -15,19 +16,27 @@ import torch.multiprocessing as multiprocessing
|
|||
from torch._utils import ExceptionWrapper
|
||||
from torch._six import queue, string_classes
|
||||
|
||||
from . import IterableDataset, Sampler, SequentialSampler, RandomSampler, BatchSampler
|
||||
from . import IterableDataset, Sampler, SequentialSampler, RandomSampler, BatchSampler, Dataset
|
||||
from . import _utils
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
T = TypeVar('T')
|
||||
_worker_init_fn_t = Callable[[int], None]
|
||||
|
||||
# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
|
||||
# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
|
||||
# See https://github.com/python/mypy/issues/3737.
|
||||
_collate_fn_t = Callable[[List[T]], Any]
|
||||
|
||||
get_worker_info = _utils.worker.get_worker_info
|
||||
|
||||
# This function used to be defined in this file. However, it was moved to
|
||||
# _utils/collate.py. Although it is rather hard to access this from user land
|
||||
# (one has to explicitly directly `import torch.utils.data.dataloader`), there
|
||||
# probably is user code out there using it. This aliasing maintains BC in this
|
||||
# aspect.
|
||||
default_collate = _utils.collate.default_collate
|
||||
default_collate: _collate_fn_t = _utils.collate.default_collate
|
||||
|
||||
get_worker_info = _utils.worker.get_worker_info
|
||||
|
||||
class _DatasetKind(object):
|
||||
Map = 0
|
||||
|
|
@ -57,7 +66,7 @@ class _InfiniteConstantSampler(Sampler):
|
|||
yield None
|
||||
|
||||
|
||||
class DataLoader(object):
|
||||
class DataLoader(Generic[T_co]):
|
||||
r"""
|
||||
Data loader. Combines a dataset and a sampler, and provides an iterable over
|
||||
the given dataset.
|
||||
|
|
@ -116,15 +125,24 @@ class DataLoader(object):
|
|||
details on these two types of datasets and how
|
||||
:class:`~torch.utils.data.IterableDataset` interacts with `Multi-process data loading`_.
|
||||
"""
|
||||
dataset: Dataset[T_co]
|
||||
batch_size: Optional[int]
|
||||
num_workers: int
|
||||
pin_memory: bool
|
||||
drop_last: bool
|
||||
timeout: float
|
||||
sampler: Sampler
|
||||
|
||||
__initialized = False
|
||||
|
||||
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
|
||||
batch_sampler=None, num_workers=0, collate_fn=None,
|
||||
pin_memory=False, drop_last=False, timeout=0,
|
||||
worker_init_fn=None, multiprocessing_context=None,
|
||||
generator=None):
|
||||
torch._C._log_api_usage_once("python.data_loader")
|
||||
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
|
||||
shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
|
||||
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
|
||||
num_workers: int = 0, collate_fn: _collate_fn_t = None,
|
||||
pin_memory: bool = False, drop_last: bool = False,
|
||||
timeout: float = 0, worker_init_fn: _worker_init_fn_t = None,
|
||||
multiprocessing_context=None, generator=None):
|
||||
torch._C._log_api_usage_once("python.data_loader") # type: ignore
|
||||
|
||||
if num_workers < 0:
|
||||
raise ValueError('num_workers option should be non-negative; '
|
||||
|
|
@ -146,7 +164,7 @@ class DataLoader(object):
|
|||
# after spending time fixing the custom sampler errors.
|
||||
if isinstance(dataset, IterableDataset):
|
||||
self._dataset_kind = _DatasetKind.Iterable
|
||||
# NOTE [ Custom Samplers and `IterableDataset` ]
|
||||
# NOTE [ Custom Samplers and IterableDataset ]
|
||||
#
|
||||
# `IterableDataset` does not support custom `batch_sampler` or
|
||||
# `sampler` since the key is irrelevant (unless we support
|
||||
|
|
@ -212,7 +230,9 @@ class DataLoader(object):
|
|||
sampler = _InfiniteConstantSampler()
|
||||
else: # map-style
|
||||
if shuffle:
|
||||
sampler = RandomSampler(dataset, generator=generator)
|
||||
# Cannot statically verify that dataset is Sized
|
||||
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
|
||||
sampler = RandomSampler(dataset, generator=generator) # type: ignore
|
||||
else:
|
||||
sampler = SequentialSampler(dataset)
|
||||
|
||||
|
|
@ -253,9 +273,10 @@ class DataLoader(object):
|
|||
if multiprocessing_context not in valid_start_methods:
|
||||
raise ValueError(
|
||||
('multiprocessing_context option '
|
||||
'should specify a valid start method in {}, but got '
|
||||
'multiprocessing_context={}').format(valid_start_methods, multiprocessing_context))
|
||||
multiprocessing_context = multiprocessing.get_context(multiprocessing_context)
|
||||
'should specify a valid start method in {!r}, but got '
|
||||
'multiprocessing_context={!r}').format(valid_start_methods, multiprocessing_context))
|
||||
# error: Argument 1 to "get_context" has incompatible type "Union[str, bytes]"; expected "str" [arg-type]
|
||||
multiprocessing_context = multiprocessing.get_context(multiprocessing_context) # type: ignore
|
||||
|
||||
if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext):
|
||||
raise TypeError(('multiprocessing_context option should be a valid context '
|
||||
|
|
@ -275,7 +296,9 @@ class DataLoader(object):
|
|||
|
||||
super(DataLoader, self).__setattr__(attr, val)
|
||||
|
||||
def __iter__(self):
|
||||
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
|
||||
# since '_BaseDataLoaderIter' references 'DataLoader'.
|
||||
def __iter__(self) -> '_BaseDataLoaderIter':
|
||||
if self.num_workers == 0:
|
||||
return _SingleProcessDataLoaderIter(self)
|
||||
else:
|
||||
|
|
@ -297,7 +320,7 @@ class DataLoader(object):
|
|||
else:
|
||||
return self.sampler
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
if self._dataset_kind == _DatasetKind.Iterable:
|
||||
# NOTE [ IterableDataset and __len__ ]
|
||||
#
|
||||
|
|
@ -313,7 +336,9 @@ class DataLoader(object):
|
|||
# To provide a further warning, we track if `__len__` was called on the
|
||||
# `DataLoader`, save the returned value in `self._len_called`, and warn
|
||||
# if the iterator ends up yielding more than this number of samples.
|
||||
length = self._IterableDataset_len_called = len(self.dataset)
|
||||
|
||||
# Cannot statically verify that dataset is Sized
|
||||
length = self._IterableDataset_len_called = len(self.dataset) # type: ignore
|
||||
if self.batch_size is not None:
|
||||
from math import ceil
|
||||
if self.drop_last:
|
||||
|
|
@ -326,7 +351,7 @@ class DataLoader(object):
|
|||
|
||||
|
||||
class _BaseDataLoaderIter(object):
|
||||
def __init__(self, loader):
|
||||
def __init__(self, loader: DataLoader) -> None:
|
||||
self._dataset = loader.dataset
|
||||
self._dataset_kind = loader._dataset_kind
|
||||
self._IterableDataset_len_called = loader._IterableDataset_len_called
|
||||
|
|
@ -341,7 +366,7 @@ class _BaseDataLoaderIter(object):
|
|||
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
|
||||
self._num_yielded = 0
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> '_BaseDataLoaderIter':
|
||||
return self
|
||||
|
||||
def _next_index(self):
|
||||
|
|
@ -350,7 +375,7 @@ class _BaseDataLoaderIter(object):
|
|||
def _next_data(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __next__(self):
|
||||
def __next__(self) -> Any:
|
||||
data = self._next_data()
|
||||
self._num_yielded += 1
|
||||
if self._dataset_kind == _DatasetKind.Iterable and \
|
||||
|
|
@ -368,7 +393,7 @@ class _BaseDataLoaderIter(object):
|
|||
|
||||
next = __next__ # Python 2 compatibility
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self._index_sampler)
|
||||
|
||||
def __getstate__(self):
|
||||
|
|
@ -690,7 +715,8 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
|||
|
||||
self._worker_init_fn = loader.worker_init_fn
|
||||
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
|
||||
self._worker_result_queue = multiprocessing_context.Queue()
|
||||
# No certainty which module multiprocessing_context is
|
||||
self._worker_result_queue = multiprocessing_context.Queue() # type: ignore
|
||||
self._worker_pids_set = False
|
||||
self._shutdown = False
|
||||
self._send_idx = 0 # idx of the next task to be sent to workers
|
||||
|
|
@ -710,7 +736,8 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
|||
# (i.e., if kind != Iterable).
|
||||
self._workers_status = []
|
||||
for i in range(self._num_workers):
|
||||
index_queue = multiprocessing_context.Queue()
|
||||
# No certainty which module multiprocessing_context is
|
||||
index_queue = multiprocessing_context.Queue() # type: ignore
|
||||
# index_queue.cancel_join_thread()
|
||||
w = multiprocessing_context.Process(
|
||||
target=_utils.worker._worker_loop,
|
||||
|
|
@ -732,7 +759,9 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
|||
|
||||
if self._pin_memory:
|
||||
self._pin_memory_thread_done_event = threading.Event()
|
||||
self._data_queue = queue.Queue()
|
||||
|
||||
# Queue is not type-annotated
|
||||
self._data_queue = queue.Queue() # type: ignore
|
||||
pin_memory_thread = threading.Thread(
|
||||
target=_utils.pin_memory._pin_memory_loop,
|
||||
args=(self._worker_result_queue, self._data_queue,
|
||||
|
|
|
|||
|
|
@ -1,46 +0,0 @@
|
|||
from typing import Any, Callable, TypeVar, Generic, overload, Sequence, List, Optional
|
||||
from . import Dataset, Sampler
|
||||
|
||||
from torch.utils.data._utils.worker import get_worker_info as get_worker_info
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
T = TypeVar('T')
|
||||
_worker_init_fn_t = Callable[[int], None]
|
||||
|
||||
# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
|
||||
# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
|
||||
# See https://github.com/python/mypy/issues/3737.
|
||||
_collate_fn_t = Callable[[List[T]], Any]
|
||||
|
||||
def default_collate(batch: List[T]) -> Any: ...
|
||||
|
||||
class DataLoader(Generic[T_co]):
|
||||
dataset: Dataset[T_co]
|
||||
batch_size: int
|
||||
num_workers: int
|
||||
pin_memory: bool
|
||||
drop_last: bool
|
||||
timeout: float
|
||||
|
||||
@overload
|
||||
def __init__(self, dataset: Dataset[T_co], batch_size: int=..., shuffle: bool=...,
|
||||
sampler: Optional[Sampler[int]]=..., num_workers: int=..., collate_fn: _collate_fn_t=...,
|
||||
pin_memory: bool=..., drop_last: bool=..., timeout: float=...,
|
||||
worker_init_fn: _worker_init_fn_t=...) -> None: ...
|
||||
@overload
|
||||
def __init__(self, dataset: Dataset[T_co], batch_sampler: Optional[Sampler[Sequence[int]]]=...,
|
||||
num_workers: int=..., collate_fn: _collate_fn_t=..., pin_memory: bool=..., timeout: float=...,
|
||||
worker_init_fn: _worker_init_fn_t=...) -> None: ...
|
||||
|
||||
def __len__(self) -> int: ...
|
||||
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
|
||||
# since '_BaseDataLoaderIter' references 'DataLoader'. In mypy 0.720 and newer a new semantic
|
||||
# analyzer is used that obviates the need for this but we leave the quoting in to support older
|
||||
# versions of mypy
|
||||
def __iter__(self) -> '_BaseDataLoaderIter':...
|
||||
|
||||
class _BaseDataLoaderIter:
|
||||
def __init__(self, loader: DataLoader) -> None:...
|
||||
def __len__(self) -> int: ...
|
||||
def __iter__(self) -> _BaseDataLoaderIter: ...
|
||||
def __next__(self) -> Any: ...
|
||||
|
|
@ -2,10 +2,17 @@ import bisect
|
|||
import warnings
|
||||
|
||||
from torch._utils import _accumulate
|
||||
from torch import randperm, default_generator
|
||||
from torch import randperm
|
||||
# No 'default_generator' in torch/__init__.pyi
|
||||
from torch import default_generator # type: ignore
|
||||
from typing import TypeVar, Generic, Iterable, Iterator, Sequence, List, Optional, Tuple
|
||||
from ... import Tensor, Generator
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
class Dataset(Generic[T_co]):
|
||||
r"""An abstract class representing a :class:`Dataset`.
|
||||
|
||||
All datasets that represent a map from keys to data samples should subclass
|
||||
|
|
@ -21,10 +28,10 @@ class Dataset(object):
|
|||
dataset with non-integral indices/keys, a custom sampler must be provided.
|
||||
"""
|
||||
|
||||
def __getitem__(self, index):
|
||||
def __getitem__(self, index) -> T_co:
|
||||
raise NotImplementedError
|
||||
|
||||
def __add__(self, other):
|
||||
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
|
||||
return ConcatDataset([self, other])
|
||||
|
||||
# No `def __len__(self)` default?
|
||||
|
|
@ -32,7 +39,7 @@ class Dataset(object):
|
|||
# in pytorch/torch/utils/data/sampler.py
|
||||
|
||||
|
||||
class IterableDataset(Dataset):
|
||||
class IterableDataset(Dataset[T_co]):
|
||||
r"""An iterable Dataset.
|
||||
|
||||
All datasets that represent an iterable of data samples should subclass it.
|
||||
|
|
@ -135,17 +142,17 @@ class IterableDataset(Dataset):
|
|||
[3, 4, 5, 6]
|
||||
"""
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[T_co]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __add__(self, other):
|
||||
def __add__(self, other: Dataset[T_co]):
|
||||
return ChainDataset([self, other])
|
||||
|
||||
# No `def __len__(self)` default?
|
||||
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
|
||||
|
||||
|
||||
class TensorDataset(Dataset):
|
||||
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
|
||||
r"""Dataset wrapping tensors.
|
||||
|
||||
Each sample will be retrieved by indexing tensors along the first dimension.
|
||||
|
|
@ -153,8 +160,9 @@ class TensorDataset(Dataset):
|
|||
Arguments:
|
||||
*tensors (Tensor): tensors that have the same size of the first dimension.
|
||||
"""
|
||||
tensors: Tuple[Tensor, ...]
|
||||
|
||||
def __init__(self, *tensors):
|
||||
def __init__(self, *tensors: Tensor) -> None:
|
||||
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
|
||||
self.tensors = tensors
|
||||
|
||||
|
|
@ -165,7 +173,7 @@ class TensorDataset(Dataset):
|
|||
return self.tensors[0].size(0)
|
||||
|
||||
|
||||
class ConcatDataset(Dataset):
|
||||
class ConcatDataset(Dataset[T_co]):
|
||||
r"""Dataset as a concatenation of multiple datasets.
|
||||
|
||||
This class is useful to assemble different existing datasets.
|
||||
|
|
@ -173,6 +181,8 @@ class ConcatDataset(Dataset):
|
|||
Arguments:
|
||||
datasets (sequence): List of datasets to be concatenated
|
||||
"""
|
||||
datasets: List[Dataset[T_co]]
|
||||
cumulative_sizes: List[int]
|
||||
|
||||
@staticmethod
|
||||
def cumsum(sequence):
|
||||
|
|
@ -183,9 +193,10 @@ class ConcatDataset(Dataset):
|
|||
s += l
|
||||
return r
|
||||
|
||||
def __init__(self, datasets):
|
||||
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
||||
super(ConcatDataset, self).__init__()
|
||||
assert len(datasets) > 0, 'datasets should not be an empty iterable'
|
||||
# Cannot verify that datasets is Sized
|
||||
assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore
|
||||
self.datasets = list(datasets)
|
||||
for d in self.datasets:
|
||||
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
|
||||
|
|
@ -223,7 +234,7 @@ class ChainDataset(IterableDataset):
|
|||
Arguments:
|
||||
datasets (iterable of IterableDataset): datasets to be chained together
|
||||
"""
|
||||
def __init__(self, datasets):
|
||||
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
||||
super(ChainDataset, self).__init__()
|
||||
self.datasets = datasets
|
||||
|
||||
|
|
@ -237,11 +248,12 @@ class ChainDataset(IterableDataset):
|
|||
total = 0
|
||||
for d in self.datasets:
|
||||
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
|
||||
total += len(d)
|
||||
# Cannot verify that all self.datasets are Sized
|
||||
total += len(d) # type: ignore
|
||||
return total
|
||||
|
||||
|
||||
class Subset(Dataset):
|
||||
class Subset(Dataset[T_co]):
|
||||
r"""
|
||||
Subset of a dataset at specified indices.
|
||||
|
||||
|
|
@ -249,7 +261,10 @@ class Subset(Dataset):
|
|||
dataset (Dataset): The whole Dataset
|
||||
indices (sequence): Indices in the whole set selected for subset
|
||||
"""
|
||||
def __init__(self, dataset, indices):
|
||||
dataset: Dataset[T_co]
|
||||
indices: Sequence[int]
|
||||
|
||||
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
|
||||
self.dataset = dataset
|
||||
self.indices = indices
|
||||
|
||||
|
|
@ -260,7 +275,8 @@ class Subset(Dataset):
|
|||
return len(self.indices)
|
||||
|
||||
|
||||
def random_split(dataset, lengths, generator=default_generator):
|
||||
def random_split(dataset: Dataset[T], lengths: Sequence[int],
|
||||
generator: Optional[Generator] = default_generator) -> List[Subset[T]]:
|
||||
r"""
|
||||
Randomly split a dataset into non-overlapping new datasets of given lengths.
|
||||
Optionally fix the generator for reproducible results, e.g.:
|
||||
|
|
@ -272,7 +288,8 @@ def random_split(dataset, lengths, generator=default_generator):
|
|||
lengths (sequence): lengths of splits to be produced
|
||||
generator (Generator): Generator used for the random permutation.
|
||||
"""
|
||||
if sum(lengths) != len(dataset):
|
||||
# Cannot verify that dataset is Sized
|
||||
if sum(lengths) != len(dataset): # type: ignore
|
||||
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
|
||||
|
||||
indices = randperm(sum(lengths), generator=generator).tolist()
|
||||
|
|
|
|||
|
|
@ -1,35 +0,0 @@
|
|||
from typing import TypeVar, Generic, Iterable, Sequence, List, Optional, Tuple
|
||||
from ... import Tensor, Generator
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
T = TypeVar('T')
|
||||
class Dataset(Generic[T_co]):
|
||||
def __getitem__(self, index: int) -> T_co: ...
|
||||
def __len__(self) -> int: ...
|
||||
# error: Cannot use a covariant type variable as a parameter
|
||||
def __add__(self, other: T_co) -> 'ConcatDataset[T_co]': ... # type: ignore
|
||||
|
||||
class IterableDataset(Dataset[T_co]):
|
||||
def __iter__(self) -> Iterable[T_co]: ...
|
||||
|
||||
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
|
||||
tensors: List[Tensor]
|
||||
|
||||
def __init__(self, *tensors: Tensor) -> None: ...
|
||||
|
||||
class ConcatDataset(Dataset[T_co]):
|
||||
datasets: List[Dataset[T_co]]
|
||||
cumulative_sizes: List[int]
|
||||
|
||||
def __init__(self, datasets: Iterable[Dataset]) -> None: ...
|
||||
|
||||
class ChainDataset(Dataset[T_co]):
|
||||
def __init__(self, datasets: Iterable[Dataset]) -> None: ...
|
||||
|
||||
class Subset(Dataset[T_co]):
|
||||
dataset: Dataset[T_co]
|
||||
indices: Sequence[int]
|
||||
|
||||
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None: ...
|
||||
|
||||
def random_split(dataset: Dataset[T], lengths: Sequence[int], generator: Optional[Generator]) -> List[Subset[T]]: ...
|
||||
|
|
@ -1,8 +1,12 @@
|
|||
import torch
|
||||
from torch._six import int_classes as _int_classes
|
||||
from torch import Tensor
|
||||
|
||||
from typing import Iterator, Optional, Sequence, List, TypeVar, Generic, Sized
|
||||
|
||||
class Sampler(object):
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
|
||||
class Sampler(Generic[T_co]):
|
||||
r"""Base class for all Samplers.
|
||||
|
||||
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
|
||||
|
|
@ -14,10 +18,10 @@ class Sampler(object):
|
|||
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
|
||||
"""
|
||||
|
||||
def __init__(self, data_source):
|
||||
def __init__(self, data_source: Optional[Sized]) -> None:
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[T_co]:
|
||||
raise NotImplementedError
|
||||
|
||||
# NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
|
||||
|
|
@ -48,12 +52,13 @@ class Sampler(object):
|
|||
# (@ssnl verifies that this works on at least Python 3.7.)
|
||||
|
||||
|
||||
class SequentialSampler(Sampler):
|
||||
class SequentialSampler(Sampler[int]):
|
||||
r"""Samples elements sequentially, always in the same order.
|
||||
|
||||
Arguments:
|
||||
data_source (Dataset): dataset to sample from
|
||||
"""
|
||||
data_source: Sized
|
||||
|
||||
def __init__(self, data_source):
|
||||
self.data_source = data_source
|
||||
|
|
@ -61,11 +66,11 @@ class SequentialSampler(Sampler):
|
|||
def __iter__(self):
|
||||
return iter(range(len(self.data_source)))
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.data_source)
|
||||
|
||||
|
||||
class RandomSampler(Sampler):
|
||||
class RandomSampler(Sampler[int]):
|
||||
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
|
||||
If with replacement, then user can specify :attr:`num_samples` to draw.
|
||||
|
||||
|
|
@ -76,8 +81,11 @@ class RandomSampler(Sampler):
|
|||
is supposed to be specified only when `replacement` is ``True``.
|
||||
generator (Generator): Generator used in sampling.
|
||||
"""
|
||||
data_source: Sized
|
||||
replacement: bool
|
||||
|
||||
def __init__(self, data_source, replacement=False, num_samples=None, generator=None):
|
||||
def __init__(self, data_source: Sized, replacement: bool = False,
|
||||
num_samples: Optional[int] = None, generator=None) -> None:
|
||||
self.data_source = data_source
|
||||
self.replacement = replacement
|
||||
self._num_samples = num_samples
|
||||
|
|
@ -96,7 +104,7 @@ class RandomSampler(Sampler):
|
|||
"value, but got num_samples={}".format(self.num_samples))
|
||||
|
||||
@property
|
||||
def num_samples(self):
|
||||
def num_samples(self) -> int:
|
||||
# dataset size might change at runtime
|
||||
if self._num_samples is None:
|
||||
return len(self.data_source)
|
||||
|
|
@ -113,15 +121,16 @@ class RandomSampler(Sampler):
|
|||
return self.num_samples
|
||||
|
||||
|
||||
class SubsetRandomSampler(Sampler):
|
||||
class SubsetRandomSampler(Sampler[int]):
|
||||
r"""Samples elements randomly from a given list of indices, without replacement.
|
||||
|
||||
Arguments:
|
||||
indices (sequence): a sequence of indices
|
||||
generator (Generator): Generator used in sampling.
|
||||
"""
|
||||
indices: Sequence[int]
|
||||
|
||||
def __init__(self, indices, generator=None):
|
||||
def __init__(self, indices: Sequence[int], generator=None) -> None:
|
||||
self.indices = indices
|
||||
self.generator = generator
|
||||
|
||||
|
|
@ -132,7 +141,7 @@ class SubsetRandomSampler(Sampler):
|
|||
return len(self.indices)
|
||||
|
||||
|
||||
class WeightedRandomSampler(Sampler):
|
||||
class WeightedRandomSampler(Sampler[int]):
|
||||
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
|
||||
|
||||
Args:
|
||||
|
|
@ -149,8 +158,12 @@ class WeightedRandomSampler(Sampler):
|
|||
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
|
||||
[0, 1, 4, 3, 2]
|
||||
"""
|
||||
weights: Tensor
|
||||
num_samples: int
|
||||
replacement: bool
|
||||
|
||||
def __init__(self, weights, num_samples, replacement=True, generator=None):
|
||||
def __init__(self, weights: Sequence[float], num_samples: int,
|
||||
replacement: bool = True, generator=None) -> None:
|
||||
if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
|
||||
num_samples <= 0:
|
||||
raise ValueError("num_samples should be a positive integer "
|
||||
|
|
@ -171,12 +184,11 @@ class WeightedRandomSampler(Sampler):
|
|||
return self.num_samples
|
||||
|
||||
|
||||
class BatchSampler(Sampler):
|
||||
class BatchSampler(Sampler[List[int]]):
|
||||
r"""Wraps another sampler to yield a mini-batch of indices.
|
||||
|
||||
Args:
|
||||
sampler (Sampler or Iterable): Base sampler. Can be any iterable object
|
||||
with ``__len__`` implemented.
|
||||
batch_size (int): Size of mini-batch.
|
||||
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
||||
its size would be less than ``batch_size``
|
||||
|
|
@ -188,7 +200,7 @@ class BatchSampler(Sampler):
|
|||
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
|
||||
"""
|
||||
|
||||
def __init__(self, sampler, batch_size, drop_last):
|
||||
def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None:
|
||||
# Since collections.abc.Iterable does not check for `__getitem__`, which
|
||||
# is one way for an object to be an iterable, we don't do an `isinstance`
|
||||
# check here.
|
||||
|
|
@ -214,7 +226,11 @@ class BatchSampler(Sampler):
|
|||
yield batch
|
||||
|
||||
def __len__(self):
|
||||
# Can only be called if self.sampler has __len__ implemented
|
||||
# We cannot enforce this condition, so we turn off typechecking for the
|
||||
# implementation below.
|
||||
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
|
||||
if self.drop_last:
|
||||
return len(self.sampler) // self.batch_size
|
||||
return len(self.sampler) // self.batch_size # type: ignore
|
||||
else:
|
||||
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
|
||||
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore
|
||||
|
|
|
|||
|
|
@ -1,38 +0,0 @@
|
|||
from typing import Iterator, Optional, Sequence, List, TypeVar, Generic, Sized
|
||||
from ... import Tensor
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
class Sampler(Generic[T_co]):
|
||||
def __init__(self, data_source: Sized) -> None: ...
|
||||
def __iter__(self) -> Iterator[T_co]: ...
|
||||
def __len__(self) -> int: ...
|
||||
|
||||
class SequentialSampler(Sampler[int]):
|
||||
data_source: Sized
|
||||
pass
|
||||
|
||||
class RandomSampler(Sampler[int]):
|
||||
data_source: Sized
|
||||
replacement: bool
|
||||
num_samples: int
|
||||
|
||||
def __init__(self, data_source: Sized, replacement: bool=..., num_samples: Optional[int]=...) -> None: ...
|
||||
|
||||
class SubsetRandomSampler(Sampler[int]):
|
||||
indices: Sequence[int]
|
||||
|
||||
def __init__(self, indices: Sequence[int]) -> None: ...
|
||||
|
||||
class WeightedRandomSampler(Sampler[int]):
|
||||
weights: Tensor
|
||||
num_samples: int
|
||||
replacement: bool
|
||||
|
||||
def __init__(self, weights: Sequence[float], num_samples: int, replacement: bool=...) -> None: ...
|
||||
|
||||
class BatchSampler(Sampler[List[int]]):
|
||||
sampler: Sampler[int]
|
||||
batch_size: int
|
||||
drop_last: bool
|
||||
|
||||
def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None: ...
|
||||
Loading…
Reference in New Issue
Block a user