Type-annotate serialization.py (#40862)

Summary:
Move Storage class from __init__.pyi.in to types.py and make it a protocol, since this is not a real class
Expose `PyTorchFileReader` and `PyTorchFileWriter` native classes

Ignore function attributes, as there are yet no good way to type annotate those, see https://github.com/python/mypy/issues/2087
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40862

Differential Revision: D22344743

Pulled By: malfet

fbshipit-source-id: 95cdb6f980ee79383960f306223e170c63df3232
This commit is contained in:
Nikita Shulga 2020-07-02 07:09:21 -07:00 committed by Facebook GitHub Bot
parent 9fa1f27968
commit 591fffc524
5 changed files with 64 additions and 34 deletions

View File

@ -317,9 +317,6 @@ ignore_errors = True
[mypy-torch.nn.functional]
ignore_errors = True
[mypy-torch.serialization]
ignore_errors = True
[mypy-torch.utils]
ignore_errors = True

View File

@ -2,10 +2,11 @@
import torch
from torch import Tensor
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar, Type
from typing import (Any, BinaryIO, Callable, ContextManager, Iterator, List, NamedTuple,
Optional, overload, Sequence, Tuple, TypeVar, Type, Union)
from torch._six import inf
from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Number, Device
from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage
import builtins
@ -90,9 +91,6 @@ class qscheme: ...
# Defined in torch/csrc/utils/tensor_qschemes.cpp
per_tensor_affine: qscheme = ...
# Defined in torch/csrc/generic/Storage.cpp
class Storage: ...
# Defined in torch/csrc/autograd/python_function.cpp
class _FunctionBase(object):
...
@ -169,6 +167,24 @@ class FileCheck(object):
# TODO
...
# Defined in torch/csrc/jit/python/init.cpp
class PyTorchFileReader(object):
@overload
def __init__(self, name: str) -> None: ...
@overload
def __init__(self, buffer: BinaryIO) -> None: ...
def get_record(self, name: str) -> bytes: ...
...
class PyTorchFileWriter(object):
@overload
def __init__(self, name: str) -> None: ...
@overload
def __init__(self, buffer: BinaryIO) -> None: ...
def write_record(self, name: str, data: bytes, size: _int) -> None: ...
def write_end_of_file(self) -> None: ...
...
# Defined in torch/csrc/Generator.cpp
class Generator(object):
device: _device

View File

@ -311,11 +311,6 @@ def is_deterministic():
"""
return _C._get_deterministic()
# If you edit these imports, please update torch/__init__.py.in as well
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
from .serialization import save, load
from ._tensor_str import set_printoptions
################################################################################
# Define Storage and Tensor classes
################################################################################
@ -388,6 +383,10 @@ _storage_classes = {
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
_tensor_classes: Set[Type] = set()
# If you edit these imports, please update torch/__init__.py.in as well
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
from .serialization import save, load
from ._tensor_str import set_printoptions
################################################################################
# Initialize extension

View File

@ -12,6 +12,8 @@ from contextlib import closing, contextmanager
from ._utils import _import_dotted_name
from ._six import string_classes as _string_classes
from torch._utils_internal import get_source_lines_and_file
from torch.types import Storage
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union
import copyreg
import pickle
import pathlib
@ -26,7 +28,6 @@ MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
PROTOCOL_VERSION = 1001
STORAGE_KEY_SEPARATOR = ','
class SourceChangeWarning(Warning):
pass
@ -41,7 +42,7 @@ def mkdtemp():
_package_registry = []
def _is_zipfile(f):
def _is_zipfile(f) -> bool:
# This is a stricter implementation than zipfile.is_zipfile().
# zipfile.is_zipfile() is True if the magic number appears anywhere in the
# binary. Since we expect the files here to be generated by torch.save or
@ -160,7 +161,7 @@ register_package(10, _cpu_tag, _cpu_deserialize)
register_package(20, _cuda_tag, _cuda_deserialize)
def location_tag(storage):
def location_tag(storage: Storage):
for _, tagger, _ in _package_registry:
location = tagger(storage)
if location:
@ -237,29 +238,30 @@ def _open_file_like(name_or_buffer, mode):
class _open_zipfile_reader(_opener):
def __init__(self, name_or_buffer):
def __init__(self, name_or_buffer) -> None:
super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
class _open_zipfile_writer_file(_opener):
def __init__(self, name):
def __init__(self, name) -> None:
super(_open_zipfile_writer_file, self).__init__(torch._C.PyTorchFileWriter(str(name)))
def __exit__(self, *args):
def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()
class _open_zipfile_writer_buffer(_opener):
def __init__(self, buffer):
def __init__(self, buffer) -> None:
self.buffer = buffer
super(_open_zipfile_writer_buffer, self).__init__(torch._C.PyTorchFileWriter(buffer))
def __exit__(self, *args):
def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()
self.buffer.flush()
def _open_zipfile_writer(name_or_buffer):
container: Type[_opener]
if _is_path(name_or_buffer):
container = _open_zipfile_writer_file
else:
@ -267,7 +269,7 @@ def _open_zipfile_writer(name_or_buffer):
return container(name_or_buffer)
def _is_compressed_file(f):
def _is_compressed_file(f) -> bool:
compress_modules = ['gzip']
try:
return f.__module__ in compress_modules
@ -291,7 +293,7 @@ def _should_read_directly(f):
return False
def _check_seekable(f):
def _check_seekable(f) -> bool:
def raise_err_msg(patterns, e):
for p in patterns:
@ -307,8 +309,9 @@ def _check_seekable(f):
return True
except (io.UnsupportedOperation, AttributeError) as e:
raise_err_msg(["seek", "tell"], e)
return False
def _check_dill_version(pickle_module):
def _check_dill_version(pickle_module) -> None:
'''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
If dill version is lower than 0.3.1, a ValueError is raised.
@ -327,7 +330,8 @@ def _check_dill_version(pickle_module):
pickle_module.__version__
))
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True):
def save(obj, f: Union[str, os.PathLike, BinaryIO],
pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:
"""Saves an object to a disk file.
See also: :ref:`recommend-saving-models`
@ -370,12 +374,12 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_ne
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
def _legacy_save(obj, f, pickle_module, pickle_protocol):
def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
import torch.nn as nn
serialized_container_types = {}
serialized_storages = {}
def persistent_id(obj):
def persistent_id(obj: Any) -> Optional[Tuple]:
# FIXME: the docs say that persistent_id should only return a string
# but torch store returns tuples. This works only in the binary protocol
# see
@ -396,6 +400,8 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol):
return ('module', obj, source_file, source)
elif torch.is_storage(obj):
view_metadata: Optional[Tuple[str, int, int]]
obj = cast(Storage, obj)
storage_type = normalize_storage_type(type(obj))
# Offset is always 0, but we keep it for backwards compatibility
# with the old serialization format (which supported storage views)
@ -589,20 +595,20 @@ def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
def _get_layout(name):
"""Get layout extension object from its string representation.
"""
cache = _get_layout.cache
cache = _get_layout.cache # type: ignore[attr-defined]
if not cache:
for v in torch.__dict__.values():
if isinstance(v, torch.layout):
cache[str(v)] = v
return cache[name]
_get_layout.cache = {}
# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
_get_layout.cache = {} # type: ignore[attr-defined]
copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
deserialized_objects = {}
deserialized_objects: Dict[int, Any] = {}
restore_location = _get_restore_location(map_location)
@ -648,7 +654,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
warnings.warn(msg, SourceChangeWarning)
def legacy_load(f):
deserialized_objects = {}
deserialized_objects: Dict[int, Any] = {}
def persistent_load(saved_id):
if isinstance(saved_id, tuple):
@ -777,7 +783,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
return result
def _maybe_decode_ascii(bytes_str):
def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
# When using encoding='bytes' in Py3, some **internal** keys stored as
# strings in Py2 are loaded as bytes. This function decodes them with
# ascii encoding, one that Py3 uses by default.

View File

@ -1,5 +1,5 @@
import torch
from typing import Union, Sequence, List, Tuple
from typing import Any, List, Sequence, Tuple, Union
import builtins
@ -29,3 +29,15 @@ Number = Union[builtins.int, builtins.float, builtins.bool]
# literal device object). This nomenclature is consistent with PythonArgParser.
# None means use the default device (typically CPU)
Device = Union[_device, str, None]
# Storage protocol implemented by ${Type}StorageBase classes
class Storage(object):
_cdata: int
def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool) -> None:
...
def size(self) -> int:
...
...