diff --git a/mypy.ini b/mypy.ini index 4ef24b6013d..05831581ffc 100644 --- a/mypy.ini +++ b/mypy.ini @@ -317,9 +317,6 @@ ignore_errors = True [mypy-torch.nn.functional] ignore_errors = True -[mypy-torch.serialization] -ignore_errors = True - [mypy-torch.utils] ignore_errors = True diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index e04362163f1..1ddc9d79b33 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2,10 +2,11 @@ import torch from torch import Tensor -from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar, Type +from typing import (Any, BinaryIO, Callable, ContextManager, Iterator, List, NamedTuple, + Optional, overload, Sequence, Tuple, TypeVar, Type, Union) from torch._six import inf -from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Number, Device +from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage import builtins @@ -90,9 +91,6 @@ class qscheme: ... # Defined in torch/csrc/utils/tensor_qschemes.cpp per_tensor_affine: qscheme = ... -# Defined in torch/csrc/generic/Storage.cpp -class Storage: ... - # Defined in torch/csrc/autograd/python_function.cpp class _FunctionBase(object): ... @@ -169,6 +167,24 @@ class FileCheck(object): # TODO ... +# Defined in torch/csrc/jit/python/init.cpp +class PyTorchFileReader(object): + @overload + def __init__(self, name: str) -> None: ... + @overload + def __init__(self, buffer: BinaryIO) -> None: ... + def get_record(self, name: str) -> bytes: ... + ... + +class PyTorchFileWriter(object): + @overload + def __init__(self, name: str) -> None: ... + @overload + def __init__(self, buffer: BinaryIO) -> None: ... + def write_record(self, name: str, data: bytes, size: _int) -> None: ... + def write_end_of_file(self) -> None: ... + ... + # Defined in torch/csrc/Generator.cpp class Generator(object): device: _device diff --git a/torch/__init__.py b/torch/__init__.py index 4937fba4a45..fd381643363 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -311,11 +311,6 @@ def is_deterministic(): """ return _C._get_deterministic() -# If you edit these imports, please update torch/__init__.py.in as well -from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed -from .serialization import save, load -from ._tensor_str import set_printoptions - ################################################################################ # Define Storage and Tensor classes ################################################################################ @@ -388,6 +383,10 @@ _storage_classes = { # The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings() _tensor_classes: Set[Type] = set() +# If you edit these imports, please update torch/__init__.py.in as well +from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed +from .serialization import save, load +from ._tensor_str import set_printoptions ################################################################################ # Initialize extension diff --git a/torch/serialization.py b/torch/serialization.py index c19a148608e..62642a30b0e 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -12,6 +12,8 @@ from contextlib import closing, contextmanager from ._utils import _import_dotted_name from ._six import string_classes as _string_classes from torch._utils_internal import get_source_lines_and_file +from torch.types import Storage +from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union import copyreg import pickle import pathlib @@ -26,7 +28,6 @@ MAGIC_NUMBER = 0x1950a86a20f9469cfc6c PROTOCOL_VERSION = 1001 STORAGE_KEY_SEPARATOR = ',' - class SourceChangeWarning(Warning): pass @@ -41,7 +42,7 @@ def mkdtemp(): _package_registry = [] -def _is_zipfile(f): +def _is_zipfile(f) -> bool: # This is a stricter implementation than zipfile.is_zipfile(). # zipfile.is_zipfile() is True if the magic number appears anywhere in the # binary. Since we expect the files here to be generated by torch.save or @@ -160,7 +161,7 @@ register_package(10, _cpu_tag, _cpu_deserialize) register_package(20, _cuda_tag, _cuda_deserialize) -def location_tag(storage): +def location_tag(storage: Storage): for _, tagger, _ in _package_registry: location = tagger(storage) if location: @@ -237,29 +238,30 @@ def _open_file_like(name_or_buffer, mode): class _open_zipfile_reader(_opener): - def __init__(self, name_or_buffer): + def __init__(self, name_or_buffer) -> None: super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer)) class _open_zipfile_writer_file(_opener): - def __init__(self, name): + def __init__(self, name) -> None: super(_open_zipfile_writer_file, self).__init__(torch._C.PyTorchFileWriter(str(name))) - def __exit__(self, *args): + def __exit__(self, *args) -> None: self.file_like.write_end_of_file() class _open_zipfile_writer_buffer(_opener): - def __init__(self, buffer): + def __init__(self, buffer) -> None: self.buffer = buffer super(_open_zipfile_writer_buffer, self).__init__(torch._C.PyTorchFileWriter(buffer)) - def __exit__(self, *args): + def __exit__(self, *args) -> None: self.file_like.write_end_of_file() self.buffer.flush() def _open_zipfile_writer(name_or_buffer): + container: Type[_opener] if _is_path(name_or_buffer): container = _open_zipfile_writer_file else: @@ -267,7 +269,7 @@ def _open_zipfile_writer(name_or_buffer): return container(name_or_buffer) -def _is_compressed_file(f): +def _is_compressed_file(f) -> bool: compress_modules = ['gzip'] try: return f.__module__ in compress_modules @@ -291,7 +293,7 @@ def _should_read_directly(f): return False -def _check_seekable(f): +def _check_seekable(f) -> bool: def raise_err_msg(patterns, e): for p in patterns: @@ -307,8 +309,9 @@ def _check_seekable(f): return True except (io.UnsupportedOperation, AttributeError) as e: raise_err_msg(["seek", "tell"], e) + return False -def _check_dill_version(pickle_module): +def _check_dill_version(pickle_module) -> None: '''Checks if using dill as the pickle module, and if so, checks if it is the correct version. If dill version is lower than 0.3.1, a ValueError is raised. @@ -327,7 +330,8 @@ def _check_dill_version(pickle_module): pickle_module.__version__ )) -def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True): +def save(obj, f: Union[str, os.PathLike, BinaryIO], + pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None: """Saves an object to a disk file. See also: :ref:`recommend-saving-models` @@ -370,12 +374,12 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_ne _legacy_save(obj, opened_file, pickle_module, pickle_protocol) -def _legacy_save(obj, f, pickle_module, pickle_protocol): +def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: import torch.nn as nn serialized_container_types = {} serialized_storages = {} - def persistent_id(obj): + def persistent_id(obj: Any) -> Optional[Tuple]: # FIXME: the docs say that persistent_id should only return a string # but torch store returns tuples. This works only in the binary protocol # see @@ -396,6 +400,8 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol): return ('module', obj, source_file, source) elif torch.is_storage(obj): + view_metadata: Optional[Tuple[str, int, int]] + obj = cast(Storage, obj) storage_type = normalize_storage_type(type(obj)) # Offset is always 0, but we keep it for backwards compatibility # with the old serialization format (which supported storage views) @@ -589,20 +595,20 @@ def load(f, map_location=None, pickle_module=pickle, **pickle_load_args): def _get_layout(name): """Get layout extension object from its string representation. """ - cache = _get_layout.cache + cache = _get_layout.cache # type: ignore[attr-defined] if not cache: for v in torch.__dict__.values(): if isinstance(v, torch.layout): cache[str(v)] = v return cache[name] - -_get_layout.cache = {} +# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087 +_get_layout.cache = {} # type: ignore[attr-defined] copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),))) def _legacy_load(f, map_location, pickle_module, **pickle_load_args): - deserialized_objects = {} + deserialized_objects: Dict[int, Any] = {} restore_location = _get_restore_location(map_location) @@ -648,7 +654,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args): warnings.warn(msg, SourceChangeWarning) def legacy_load(f): - deserialized_objects = {} + deserialized_objects: Dict[int, Any] = {} def persistent_load(saved_id): if isinstance(saved_id, tuple): @@ -777,7 +783,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args): return result -def _maybe_decode_ascii(bytes_str): +def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str: # When using encoding='bytes' in Py3, some **internal** keys stored as # strings in Py2 are loaded as bytes. This function decodes them with # ascii encoding, one that Py3 uses by default. diff --git a/torch/types.py b/torch/types.py index be86dfd27f8..ef3c68e985c 100644 --- a/torch/types.py +++ b/torch/types.py @@ -1,5 +1,5 @@ import torch -from typing import Union, Sequence, List, Tuple +from typing import Any, List, Sequence, Tuple, Union import builtins @@ -29,3 +29,15 @@ Number = Union[builtins.int, builtins.float, builtins.bool] # literal device object). This nomenclature is consistent with PythonArgParser. # None means use the default device (typically CPU) Device = Union[_device, str, None] + +# Storage protocol implemented by ${Type}StorageBase classes +class Storage(object): + _cdata: int + + def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool) -> None: + ... + + def size(self) -> int: + ... + + ...