mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38157 This removes the error prone process of assembling `torch/__init__.pyi` (and frequently forgetting to expose things), since now we can simply rely on the true source file to get things done. Most of the old codegen in gen_pyi.py is now rerouted to various files: - `torch/_C/__init__.pyi` (the dumping pile of all misc bindings) - `torch/_C/_nn.pyi` (NN function bindings) - `torch/_C/_VariableFunctions.pyi` (torch function bindings) `torch.types` grew a bunch more definitions that previously where defined in `torch/__init__.pyi` Some miscellaneous changes - Fixed a bug where we treat single TensorList argument as implying varargs are accepted. This is actually only supported on IntList. This means we can correctly generate a stub for dequantize. - Add missing manual stub for nonzero - Switched torch/onnx/operators.py to directly refer to _C module, since apparently mypy doesn't think that methods prefixed with underscores get reexported. This may be a recurring theme; maybe we need to find a better way to solve it. Because I was really lazy, I dumped namedtuple definitions in both `torch._C` and `torch._C._VariableFunctions`. This is definitely wrong. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D21497400 Pulled By: ezyang fbshipit-source-id: 07b126141c82efaca37be27c07255cb2b9b3f064
47 lines
2.0 KiB
Python
47 lines
2.0 KiB
Python
from typing import Any, Callable, TypeVar, Generic, overload, Sequence, List, Optional
|
|
from . import Dataset, Sampler
|
|
|
|
from torch.utils.data._utils.worker import get_worker_info as get_worker_info
|
|
|
|
T_co = TypeVar('T_co', covariant=True)
|
|
T = TypeVar('T')
|
|
_worker_init_fn_t = Callable[[int], None]
|
|
|
|
# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
|
|
# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
|
|
# See https://github.com/python/mypy/issues/3737.
|
|
_collate_fn_t = Callable[[List[T]], Any]
|
|
|
|
def default_collate(batch: List[T]) -> Any: ...
|
|
|
|
class DataLoader(Generic[T_co]):
|
|
dataset: Dataset[T_co]
|
|
batch_size: int
|
|
num_workers: int
|
|
pin_memory: bool
|
|
drop_last: bool
|
|
timeout: float
|
|
|
|
@overload
|
|
def __init__(self, dataset: Dataset[T_co], batch_size: int=..., shuffle: bool=...,
|
|
sampler: Optional[Sampler[int]]=..., num_workers: int=..., collate_fn: _collate_fn_t=...,
|
|
pin_memory: bool=..., drop_last: bool=..., timeout: float=...,
|
|
worker_init_fn: _worker_init_fn_t=...) -> None: ...
|
|
@overload
|
|
def __init__(self, dataset: Dataset[T_co], batch_sampler: Optional[Sampler[Sequence[int]]]=...,
|
|
num_workers: int=..., collate_fn: _collate_fn_t=..., pin_memory: bool=..., timeout: float=...,
|
|
worker_init_fn: _worker_init_fn_t=...) -> None: ...
|
|
|
|
def __len__(self) -> int: ...
|
|
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
|
|
# since '_BaseDataLoaderIter' references 'DataLoader'. In mypy 0.720 and newer a new semantic
|
|
# analyzer is used that obviates the need for this but we leave the quoting in to support older
|
|
# versions of mypy
|
|
def __iter__(self) -> '_BaseDataLoaderIter':...
|
|
|
|
class _BaseDataLoaderIter:
|
|
def __init__(self, loader: DataLoader) -> None:...
|
|
def __len__(self) -> int: ...
|
|
def __iter__(self) -> _BaseDataLoaderIter: ...
|
|
def __next__(self) -> Any: ...
|