type annotations for dataloader, dataset, sampler (#39392)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/38913

Pull Request resolved: https://github.com/pytorch/pytorch/pull/39392

Reviewed By: anjali411

Differential Revision: D22102489

Pulled By: zou3519

fbshipit-source-id: acb68d9521145f0b047214d62b5bdc5a0d1b9be4
This commit is contained in:
Wojciech Baranowski 2020-07-07 07:14:33 -07:00 committed by Facebook GitHub Bot
parent a6b703cc89
commit 0e09511af9
8 changed files with 133 additions and 182 deletions

View File

@ -302,9 +302,6 @@ ignore_errors = True
[mypy-torch.utils.data._utils.worker]
ignore_errors = True
[mypy-torch.utils.data.dataset]
ignore_errors = True
[mypy-torch.utils.data.distributed]
ignore_errors = True

View File

@ -821,6 +821,16 @@ class TestDataLoader(TestCase):
with self.assertRaisesRegex(RuntimeError, 'Error in worker_init_fn'):
list(iter(loader))
def test_typing(self):
from typing import List
# Make sure there is no TypeError
class SomeDatasetClass(Dataset[List[torch.Tensor]]):
pass
def _create_dataloader(is_train: bool) -> DataLoader[List[torch.Tensor]]:
pass
@unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
@unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows")
def test_fd_limit_exceeded(self):
@ -2019,5 +2029,6 @@ class TestSetAffinity(TestCase):
self.assertEqual(sample, [2])
if __name__ == '__main__':
run_tests()

View File

@ -8,6 +8,7 @@ in `./_utils/worker.py`.
import threading
import itertools
import warnings
from typing import Any, Callable, TypeVar, Generic, Sequence, List, Optional
import multiprocessing as python_multiprocessing
import torch
@ -15,19 +16,27 @@ import torch.multiprocessing as multiprocessing
from torch._utils import ExceptionWrapper
from torch._six import queue, string_classes
from . import IterableDataset, Sampler, SequentialSampler, RandomSampler, BatchSampler
from . import IterableDataset, Sampler, SequentialSampler, RandomSampler, BatchSampler, Dataset
from . import _utils
T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
_worker_init_fn_t = Callable[[int], None]
# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
# See https://github.com/python/mypy/issues/3737.
_collate_fn_t = Callable[[List[T]], Any]
get_worker_info = _utils.worker.get_worker_info
# This function used to be defined in this file. However, it was moved to
# _utils/collate.py. Although it is rather hard to access this from user land
# (one has to explicitly directly `import torch.utils.data.dataloader`), there
# probably is user code out there using it. This aliasing maintains BC in this
# aspect.
default_collate = _utils.collate.default_collate
default_collate: _collate_fn_t = _utils.collate.default_collate
get_worker_info = _utils.worker.get_worker_info
class _DatasetKind(object):
Map = 0
@ -57,7 +66,7 @@ class _InfiniteConstantSampler(Sampler):
yield None
class DataLoader(object):
class DataLoader(Generic[T_co]):
r"""
Data loader. Combines a dataset and a sampler, and provides an iterable over
the given dataset.
@ -116,15 +125,24 @@ class DataLoader(object):
details on these two types of datasets and how
:class:`~torch.utils.data.IterableDataset` interacts with `Multi-process data loading`_.
"""
dataset: Dataset[T_co]
batch_size: Optional[int]
num_workers: int
pin_memory: bool
drop_last: bool
timeout: float
sampler: Sampler
__initialized = False
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None,
generator=None):
torch._C._log_api_usage_once("python.data_loader")
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
num_workers: int = 0, collate_fn: _collate_fn_t = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: _worker_init_fn_t = None,
multiprocessing_context=None, generator=None):
torch._C._log_api_usage_once("python.data_loader") # type: ignore
if num_workers < 0:
raise ValueError('num_workers option should be non-negative; '
@ -146,7 +164,7 @@ class DataLoader(object):
# after spending time fixing the custom sampler errors.
if isinstance(dataset, IterableDataset):
self._dataset_kind = _DatasetKind.Iterable
# NOTE [ Custom Samplers and `IterableDataset` ]
# NOTE [ Custom Samplers and IterableDataset ]
#
# `IterableDataset` does not support custom `batch_sampler` or
# `sampler` since the key is irrelevant (unless we support
@ -212,7 +230,9 @@ class DataLoader(object):
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset, generator=generator)
# Cannot statically verify that dataset is Sized
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
sampler = RandomSampler(dataset, generator=generator) # type: ignore
else:
sampler = SequentialSampler(dataset)
@ -253,9 +273,10 @@ class DataLoader(object):
if multiprocessing_context not in valid_start_methods:
raise ValueError(
('multiprocessing_context option '
'should specify a valid start method in {}, but got '
'multiprocessing_context={}').format(valid_start_methods, multiprocessing_context))
multiprocessing_context = multiprocessing.get_context(multiprocessing_context)
'should specify a valid start method in {!r}, but got '
'multiprocessing_context={!r}').format(valid_start_methods, multiprocessing_context))
# error: Argument 1 to "get_context" has incompatible type "Union[str, bytes]"; expected "str" [arg-type]
multiprocessing_context = multiprocessing.get_context(multiprocessing_context) # type: ignore
if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext):
raise TypeError(('multiprocessing_context option should be a valid context '
@ -275,7 +296,9 @@ class DataLoader(object):
super(DataLoader, self).__setattr__(attr, val)
def __iter__(self):
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
# since '_BaseDataLoaderIter' references 'DataLoader'.
def __iter__(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
@ -297,7 +320,7 @@ class DataLoader(object):
else:
return self.sampler
def __len__(self):
def __len__(self) -> int:
if self._dataset_kind == _DatasetKind.Iterable:
# NOTE [ IterableDataset and __len__ ]
#
@ -313,7 +336,9 @@ class DataLoader(object):
# To provide a further warning, we track if `__len__` was called on the
# `DataLoader`, save the returned value in `self._len_called`, and warn
# if the iterator ends up yielding more than this number of samples.
length = self._IterableDataset_len_called = len(self.dataset)
# Cannot statically verify that dataset is Sized
length = self._IterableDataset_len_called = len(self.dataset) # type: ignore
if self.batch_size is not None:
from math import ceil
if self.drop_last:
@ -326,7 +351,7 @@ class DataLoader(object):
class _BaseDataLoaderIter(object):
def __init__(self, loader):
def __init__(self, loader: DataLoader) -> None:
self._dataset = loader.dataset
self._dataset_kind = loader._dataset_kind
self._IterableDataset_len_called = loader._IterableDataset_len_called
@ -341,7 +366,7 @@ class _BaseDataLoaderIter(object):
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
self._num_yielded = 0
def __iter__(self):
def __iter__(self) -> '_BaseDataLoaderIter':
return self
def _next_index(self):
@ -350,7 +375,7 @@ class _BaseDataLoaderIter(object):
def _next_data(self):
raise NotImplementedError
def __next__(self):
def __next__(self) -> Any:
data = self._next_data()
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
@ -368,7 +393,7 @@ class _BaseDataLoaderIter(object):
next = __next__ # Python 2 compatibility
def __len__(self):
def __len__(self) -> int:
return len(self._index_sampler)
def __getstate__(self):
@ -690,7 +715,8 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
self._worker_init_fn = loader.worker_init_fn
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
self._worker_result_queue = multiprocessing_context.Queue()
# No certainty which module multiprocessing_context is
self._worker_result_queue = multiprocessing_context.Queue() # type: ignore
self._worker_pids_set = False
self._shutdown = False
self._send_idx = 0 # idx of the next task to be sent to workers
@ -710,7 +736,8 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
# (i.e., if kind != Iterable).
self._workers_status = []
for i in range(self._num_workers):
index_queue = multiprocessing_context.Queue()
# No certainty which module multiprocessing_context is
index_queue = multiprocessing_context.Queue() # type: ignore
# index_queue.cancel_join_thread()
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop,
@ -732,7 +759,9 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
if self._pin_memory:
self._pin_memory_thread_done_event = threading.Event()
self._data_queue = queue.Queue()
# Queue is not type-annotated
self._data_queue = queue.Queue() # type: ignore
pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop,
args=(self._worker_result_queue, self._data_queue,

View File

@ -1,46 +0,0 @@
from typing import Any, Callable, TypeVar, Generic, overload, Sequence, List, Optional
from . import Dataset, Sampler
from torch.utils.data._utils.worker import get_worker_info as get_worker_info
T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
_worker_init_fn_t = Callable[[int], None]
# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
# See https://github.com/python/mypy/issues/3737.
_collate_fn_t = Callable[[List[T]], Any]
def default_collate(batch: List[T]) -> Any: ...
class DataLoader(Generic[T_co]):
dataset: Dataset[T_co]
batch_size: int
num_workers: int
pin_memory: bool
drop_last: bool
timeout: float
@overload
def __init__(self, dataset: Dataset[T_co], batch_size: int=..., shuffle: bool=...,
sampler: Optional[Sampler[int]]=..., num_workers: int=..., collate_fn: _collate_fn_t=...,
pin_memory: bool=..., drop_last: bool=..., timeout: float=...,
worker_init_fn: _worker_init_fn_t=...) -> None: ...
@overload
def __init__(self, dataset: Dataset[T_co], batch_sampler: Optional[Sampler[Sequence[int]]]=...,
num_workers: int=..., collate_fn: _collate_fn_t=..., pin_memory: bool=..., timeout: float=...,
worker_init_fn: _worker_init_fn_t=...) -> None: ...
def __len__(self) -> int: ...
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
# since '_BaseDataLoaderIter' references 'DataLoader'. In mypy 0.720 and newer a new semantic
# analyzer is used that obviates the need for this but we leave the quoting in to support older
# versions of mypy
def __iter__(self) -> '_BaseDataLoaderIter':...
class _BaseDataLoaderIter:
def __init__(self, loader: DataLoader) -> None:...
def __len__(self) -> int: ...
def __iter__(self) -> _BaseDataLoaderIter: ...
def __next__(self) -> Any: ...

View File

@ -2,10 +2,17 @@ import bisect
import warnings
from torch._utils import _accumulate
from torch import randperm, default_generator
from torch import randperm
# No 'default_generator' in torch/__init__.pyi
from torch import default_generator # type: ignore
from typing import TypeVar, Generic, Iterable, Iterator, Sequence, List, Optional, Tuple
from ... import Tensor, Generator
T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
class Dataset(object):
class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
@ -21,10 +28,10 @@ class Dataset(object):
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index):
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other):
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
# No `def __len__(self)` default?
@ -32,7 +39,7 @@ class Dataset(object):
# in pytorch/torch/utils/data/sampler.py
class IterableDataset(Dataset):
class IterableDataset(Dataset[T_co]):
r"""An iterable Dataset.
All datasets that represent an iterable of data samples should subclass it.
@ -135,17 +142,17 @@ class IterableDataset(Dataset):
[3, 4, 5, 6]
"""
def __iter__(self):
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
def __add__(self, other):
def __add__(self, other: Dataset[T_co]):
return ChainDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
class TensorDataset(Dataset):
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
r"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
@ -153,8 +160,9 @@ class TensorDataset(Dataset):
Arguments:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
tensors: Tuple[Tensor, ...]
def __init__(self, *tensors):
def __init__(self, *tensors: Tensor) -> None:
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
@ -165,7 +173,7 @@ class TensorDataset(Dataset):
return self.tensors[0].size(0)
class ConcatDataset(Dataset):
class ConcatDataset(Dataset[T_co]):
r"""Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
@ -173,6 +181,8 @@ class ConcatDataset(Dataset):
Arguments:
datasets (sequence): List of datasets to be concatenated
"""
datasets: List[Dataset[T_co]]
cumulative_sizes: List[int]
@staticmethod
def cumsum(sequence):
@ -183,9 +193,10 @@ class ConcatDataset(Dataset):
s += l
return r
def __init__(self, datasets):
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(ConcatDataset, self).__init__()
assert len(datasets) > 0, 'datasets should not be an empty iterable'
# Cannot verify that datasets is Sized
assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore
self.datasets = list(datasets)
for d in self.datasets:
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
@ -223,7 +234,7 @@ class ChainDataset(IterableDataset):
Arguments:
datasets (iterable of IterableDataset): datasets to be chained together
"""
def __init__(self, datasets):
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(ChainDataset, self).__init__()
self.datasets = datasets
@ -237,11 +248,12 @@ class ChainDataset(IterableDataset):
total = 0
for d in self.datasets:
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
total += len(d)
# Cannot verify that all self.datasets are Sized
total += len(d) # type: ignore
return total
class Subset(Dataset):
class Subset(Dataset[T_co]):
r"""
Subset of a dataset at specified indices.
@ -249,7 +261,10 @@ class Subset(Dataset):
dataset (Dataset): The whole Dataset
indices (sequence): Indices in the whole set selected for subset
"""
def __init__(self, dataset, indices):
dataset: Dataset[T_co]
indices: Sequence[int]
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
self.dataset = dataset
self.indices = indices
@ -260,7 +275,8 @@ class Subset(Dataset):
return len(self.indices)
def random_split(dataset, lengths, generator=default_generator):
def random_split(dataset: Dataset[T], lengths: Sequence[int],
generator: Optional[Generator] = default_generator) -> List[Subset[T]]:
r"""
Randomly split a dataset into non-overlapping new datasets of given lengths.
Optionally fix the generator for reproducible results, e.g.:
@ -272,7 +288,8 @@ def random_split(dataset, lengths, generator=default_generator):
lengths (sequence): lengths of splits to be produced
generator (Generator): Generator used for the random permutation.
"""
if sum(lengths) != len(dataset):
# Cannot verify that dataset is Sized
if sum(lengths) != len(dataset): # type: ignore
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
indices = randperm(sum(lengths), generator=generator).tolist()

View File

@ -1,35 +0,0 @@
from typing import TypeVar, Generic, Iterable, Sequence, List, Optional, Tuple
from ... import Tensor, Generator
T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
class Dataset(Generic[T_co]):
def __getitem__(self, index: int) -> T_co: ...
def __len__(self) -> int: ...
# error: Cannot use a covariant type variable as a parameter
def __add__(self, other: T_co) -> 'ConcatDataset[T_co]': ... # type: ignore
class IterableDataset(Dataset[T_co]):
def __iter__(self) -> Iterable[T_co]: ...
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
tensors: List[Tensor]
def __init__(self, *tensors: Tensor) -> None: ...
class ConcatDataset(Dataset[T_co]):
datasets: List[Dataset[T_co]]
cumulative_sizes: List[int]
def __init__(self, datasets: Iterable[Dataset]) -> None: ...
class ChainDataset(Dataset[T_co]):
def __init__(self, datasets: Iterable[Dataset]) -> None: ...
class Subset(Dataset[T_co]):
dataset: Dataset[T_co]
indices: Sequence[int]
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None: ...
def random_split(dataset: Dataset[T], lengths: Sequence[int], generator: Optional[Generator]) -> List[Subset[T]]: ...

View File

@ -1,8 +1,12 @@
import torch
from torch._six import int_classes as _int_classes
from torch import Tensor
from typing import Iterator, Optional, Sequence, List, TypeVar, Generic, Sized
class Sampler(object):
T_co = TypeVar('T_co', covariant=True)
class Sampler(Generic[T_co]):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
@ -14,10 +18,10 @@ class Sampler(object):
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
"""
def __init__(self, data_source):
def __init__(self, data_source: Optional[Sized]) -> None:
pass
def __iter__(self):
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
# NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
@ -48,12 +52,13 @@ class Sampler(object):
# (@ssnl verifies that this works on at least Python 3.7.)
class SequentialSampler(Sampler):
class SequentialSampler(Sampler[int]):
r"""Samples elements sequentially, always in the same order.
Arguments:
data_source (Dataset): dataset to sample from
"""
data_source: Sized
def __init__(self, data_source):
self.data_source = data_source
@ -61,11 +66,11 @@ class SequentialSampler(Sampler):
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
def __len__(self) -> int:
return len(self.data_source)
class RandomSampler(Sampler):
class RandomSampler(Sampler[int]):
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify :attr:`num_samples` to draw.
@ -76,8 +81,11 @@ class RandomSampler(Sampler):
is supposed to be specified only when `replacement` is ``True``.
generator (Generator): Generator used in sampling.
"""
data_source: Sized
replacement: bool
def __init__(self, data_source, replacement=False, num_samples=None, generator=None):
def __init__(self, data_source: Sized, replacement: bool = False,
num_samples: Optional[int] = None, generator=None) -> None:
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
@ -96,7 +104,7 @@ class RandomSampler(Sampler):
"value, but got num_samples={}".format(self.num_samples))
@property
def num_samples(self):
def num_samples(self) -> int:
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
@ -113,15 +121,16 @@ class RandomSampler(Sampler):
return self.num_samples
class SubsetRandomSampler(Sampler):
class SubsetRandomSampler(Sampler[int]):
r"""Samples elements randomly from a given list of indices, without replacement.
Arguments:
indices (sequence): a sequence of indices
generator (Generator): Generator used in sampling.
"""
indices: Sequence[int]
def __init__(self, indices, generator=None):
def __init__(self, indices: Sequence[int], generator=None) -> None:
self.indices = indices
self.generator = generator
@ -132,7 +141,7 @@ class SubsetRandomSampler(Sampler):
return len(self.indices)
class WeightedRandomSampler(Sampler):
class WeightedRandomSampler(Sampler[int]):
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
Args:
@ -149,8 +158,12 @@ class WeightedRandomSampler(Sampler):
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
"""
weights: Tensor
num_samples: int
replacement: bool
def __init__(self, weights, num_samples, replacement=True, generator=None):
def __init__(self, weights: Sequence[float], num_samples: int,
replacement: bool = True, generator=None) -> None:
if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
@ -171,12 +184,11 @@ class WeightedRandomSampler(Sampler):
return self.num_samples
class BatchSampler(Sampler):
class BatchSampler(Sampler[List[int]]):
r"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler or Iterable): Base sampler. Can be any iterable object
with ``__len__`` implemented.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
@ -188,7 +200,7 @@ class BatchSampler(Sampler):
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(self, sampler, batch_size, drop_last):
def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None:
# Since collections.abc.Iterable does not check for `__getitem__`, which
# is one way for an object to be an iterable, we don't do an `isinstance`
# check here.
@ -214,7 +226,11 @@ class BatchSampler(Sampler):
yield batch
def __len__(self):
# Can only be called if self.sampler has __len__ implemented
# We cannot enforce this condition, so we turn off typechecking for the
# implementation below.
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
if self.drop_last:
return len(self.sampler) // self.batch_size
return len(self.sampler) // self.batch_size # type: ignore
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore

View File

@ -1,38 +0,0 @@
from typing import Iterator, Optional, Sequence, List, TypeVar, Generic, Sized
from ... import Tensor
T_co = TypeVar('T_co', covariant=True)
class Sampler(Generic[T_co]):
def __init__(self, data_source: Sized) -> None: ...
def __iter__(self) -> Iterator[T_co]: ...
def __len__(self) -> int: ...
class SequentialSampler(Sampler[int]):
data_source: Sized
pass
class RandomSampler(Sampler[int]):
data_source: Sized
replacement: bool
num_samples: int
def __init__(self, data_source: Sized, replacement: bool=..., num_samples: Optional[int]=...) -> None: ...
class SubsetRandomSampler(Sampler[int]):
indices: Sequence[int]
def __init__(self, indices: Sequence[int]) -> None: ...
class WeightedRandomSampler(Sampler[int]):
weights: Tensor
num_samples: int
replacement: bool
def __init__(self, weights: Sequence[float], num_samples: int, replacement: bool=...) -> None: ...
class BatchSampler(Sampler[List[int]]):
sampler: Sampler[int]
batch_size: int
drop_last: bool
def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None: ...