Type-annotate serialization.py (#40862)

Summary:
Move Storage class from __init__.pyi.in to types.py and make it a protocol, since this is not a real class
Expose `PyTorchFileReader` and `PyTorchFileWriter` native classes

Ignore function attributes, as there are yet no good way to type annotate those, see https://github.com/python/mypy/issues/2087
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40862

Differential Revision: D22344743

Pulled By: malfet

fbshipit-source-id: 95cdb6f980ee79383960f306223e170c63df3232
This commit is contained in:
Nikita Shulga 2020-07-02 07:09:21 -07:00 committed by Facebook GitHub Bot
parent 9fa1f27968
commit 591fffc524
5 changed files with 64 additions and 34 deletions

View File

@ -317,9 +317,6 @@ ignore_errors = True
[mypy-torch.nn.functional] [mypy-torch.nn.functional]
ignore_errors = True ignore_errors = True
[mypy-torch.serialization]
ignore_errors = True
[mypy-torch.utils] [mypy-torch.utils]
ignore_errors = True ignore_errors = True

View File

@ -2,10 +2,11 @@
import torch import torch
from torch import Tensor from torch import Tensor
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar, Type from typing import (Any, BinaryIO, Callable, ContextManager, Iterator, List, NamedTuple,
Optional, overload, Sequence, Tuple, TypeVar, Type, Union)
from torch._six import inf from torch._six import inf
from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Number, Device from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage
import builtins import builtins
@ -90,9 +91,6 @@ class qscheme: ...
# Defined in torch/csrc/utils/tensor_qschemes.cpp # Defined in torch/csrc/utils/tensor_qschemes.cpp
per_tensor_affine: qscheme = ... per_tensor_affine: qscheme = ...
# Defined in torch/csrc/generic/Storage.cpp
class Storage: ...
# Defined in torch/csrc/autograd/python_function.cpp # Defined in torch/csrc/autograd/python_function.cpp
class _FunctionBase(object): class _FunctionBase(object):
... ...
@ -169,6 +167,24 @@ class FileCheck(object):
# TODO # TODO
... ...
# Defined in torch/csrc/jit/python/init.cpp
class PyTorchFileReader(object):
@overload
def __init__(self, name: str) -> None: ...
@overload
def __init__(self, buffer: BinaryIO) -> None: ...
def get_record(self, name: str) -> bytes: ...
...
class PyTorchFileWriter(object):
@overload
def __init__(self, name: str) -> None: ...
@overload
def __init__(self, buffer: BinaryIO) -> None: ...
def write_record(self, name: str, data: bytes, size: _int) -> None: ...
def write_end_of_file(self) -> None: ...
...
# Defined in torch/csrc/Generator.cpp # Defined in torch/csrc/Generator.cpp
class Generator(object): class Generator(object):
device: _device device: _device

View File

@ -311,11 +311,6 @@ def is_deterministic():
""" """
return _C._get_deterministic() return _C._get_deterministic()
# If you edit these imports, please update torch/__init__.py.in as well
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
from .serialization import save, load
from ._tensor_str import set_printoptions
################################################################################ ################################################################################
# Define Storage and Tensor classes # Define Storage and Tensor classes
################################################################################ ################################################################################
@ -388,6 +383,10 @@ _storage_classes = {
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings() # The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
_tensor_classes: Set[Type] = set() _tensor_classes: Set[Type] = set()
# If you edit these imports, please update torch/__init__.py.in as well
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
from .serialization import save, load
from ._tensor_str import set_printoptions
################################################################################ ################################################################################
# Initialize extension # Initialize extension

View File

@ -12,6 +12,8 @@ from contextlib import closing, contextmanager
from ._utils import _import_dotted_name from ._utils import _import_dotted_name
from ._six import string_classes as _string_classes from ._six import string_classes as _string_classes
from torch._utils_internal import get_source_lines_and_file from torch._utils_internal import get_source_lines_and_file
from torch.types import Storage
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union
import copyreg import copyreg
import pickle import pickle
import pathlib import pathlib
@ -26,7 +28,6 @@ MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
PROTOCOL_VERSION = 1001 PROTOCOL_VERSION = 1001
STORAGE_KEY_SEPARATOR = ',' STORAGE_KEY_SEPARATOR = ','
class SourceChangeWarning(Warning): class SourceChangeWarning(Warning):
pass pass
@ -41,7 +42,7 @@ def mkdtemp():
_package_registry = [] _package_registry = []
def _is_zipfile(f): def _is_zipfile(f) -> bool:
# This is a stricter implementation than zipfile.is_zipfile(). # This is a stricter implementation than zipfile.is_zipfile().
# zipfile.is_zipfile() is True if the magic number appears anywhere in the # zipfile.is_zipfile() is True if the magic number appears anywhere in the
# binary. Since we expect the files here to be generated by torch.save or # binary. Since we expect the files here to be generated by torch.save or
@ -160,7 +161,7 @@ register_package(10, _cpu_tag, _cpu_deserialize)
register_package(20, _cuda_tag, _cuda_deserialize) register_package(20, _cuda_tag, _cuda_deserialize)
def location_tag(storage): def location_tag(storage: Storage):
for _, tagger, _ in _package_registry: for _, tagger, _ in _package_registry:
location = tagger(storage) location = tagger(storage)
if location: if location:
@ -237,29 +238,30 @@ def _open_file_like(name_or_buffer, mode):
class _open_zipfile_reader(_opener): class _open_zipfile_reader(_opener):
def __init__(self, name_or_buffer): def __init__(self, name_or_buffer) -> None:
super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer)) super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
class _open_zipfile_writer_file(_opener): class _open_zipfile_writer_file(_opener):
def __init__(self, name): def __init__(self, name) -> None:
super(_open_zipfile_writer_file, self).__init__(torch._C.PyTorchFileWriter(str(name))) super(_open_zipfile_writer_file, self).__init__(torch._C.PyTorchFileWriter(str(name)))
def __exit__(self, *args): def __exit__(self, *args) -> None:
self.file_like.write_end_of_file() self.file_like.write_end_of_file()
class _open_zipfile_writer_buffer(_opener): class _open_zipfile_writer_buffer(_opener):
def __init__(self, buffer): def __init__(self, buffer) -> None:
self.buffer = buffer self.buffer = buffer
super(_open_zipfile_writer_buffer, self).__init__(torch._C.PyTorchFileWriter(buffer)) super(_open_zipfile_writer_buffer, self).__init__(torch._C.PyTorchFileWriter(buffer))
def __exit__(self, *args): def __exit__(self, *args) -> None:
self.file_like.write_end_of_file() self.file_like.write_end_of_file()
self.buffer.flush() self.buffer.flush()
def _open_zipfile_writer(name_or_buffer): def _open_zipfile_writer(name_or_buffer):
container: Type[_opener]
if _is_path(name_or_buffer): if _is_path(name_or_buffer):
container = _open_zipfile_writer_file container = _open_zipfile_writer_file
else: else:
@ -267,7 +269,7 @@ def _open_zipfile_writer(name_or_buffer):
return container(name_or_buffer) return container(name_or_buffer)
def _is_compressed_file(f): def _is_compressed_file(f) -> bool:
compress_modules = ['gzip'] compress_modules = ['gzip']
try: try:
return f.__module__ in compress_modules return f.__module__ in compress_modules
@ -291,7 +293,7 @@ def _should_read_directly(f):
return False return False
def _check_seekable(f): def _check_seekable(f) -> bool:
def raise_err_msg(patterns, e): def raise_err_msg(patterns, e):
for p in patterns: for p in patterns:
@ -307,8 +309,9 @@ def _check_seekable(f):
return True return True
except (io.UnsupportedOperation, AttributeError) as e: except (io.UnsupportedOperation, AttributeError) as e:
raise_err_msg(["seek", "tell"], e) raise_err_msg(["seek", "tell"], e)
return False
def _check_dill_version(pickle_module): def _check_dill_version(pickle_module) -> None:
'''Checks if using dill as the pickle module, and if so, checks if it is the correct version. '''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
If dill version is lower than 0.3.1, a ValueError is raised. If dill version is lower than 0.3.1, a ValueError is raised.
@ -327,7 +330,8 @@ def _check_dill_version(pickle_module):
pickle_module.__version__ pickle_module.__version__
)) ))
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True): def save(obj, f: Union[str, os.PathLike, BinaryIO],
pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:
"""Saves an object to a disk file. """Saves an object to a disk file.
See also: :ref:`recommend-saving-models` See also: :ref:`recommend-saving-models`
@ -370,12 +374,12 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_ne
_legacy_save(obj, opened_file, pickle_module, pickle_protocol) _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
def _legacy_save(obj, f, pickle_module, pickle_protocol): def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
import torch.nn as nn import torch.nn as nn
serialized_container_types = {} serialized_container_types = {}
serialized_storages = {} serialized_storages = {}
def persistent_id(obj): def persistent_id(obj: Any) -> Optional[Tuple]:
# FIXME: the docs say that persistent_id should only return a string # FIXME: the docs say that persistent_id should only return a string
# but torch store returns tuples. This works only in the binary protocol # but torch store returns tuples. This works only in the binary protocol
# see # see
@ -396,6 +400,8 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol):
return ('module', obj, source_file, source) return ('module', obj, source_file, source)
elif torch.is_storage(obj): elif torch.is_storage(obj):
view_metadata: Optional[Tuple[str, int, int]]
obj = cast(Storage, obj)
storage_type = normalize_storage_type(type(obj)) storage_type = normalize_storage_type(type(obj))
# Offset is always 0, but we keep it for backwards compatibility # Offset is always 0, but we keep it for backwards compatibility
# with the old serialization format (which supported storage views) # with the old serialization format (which supported storage views)
@ -589,20 +595,20 @@ def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
def _get_layout(name): def _get_layout(name):
"""Get layout extension object from its string representation. """Get layout extension object from its string representation.
""" """
cache = _get_layout.cache cache = _get_layout.cache # type: ignore[attr-defined]
if not cache: if not cache:
for v in torch.__dict__.values(): for v in torch.__dict__.values():
if isinstance(v, torch.layout): if isinstance(v, torch.layout):
cache[str(v)] = v cache[str(v)] = v
return cache[name] return cache[name]
# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
_get_layout.cache = {} _get_layout.cache = {} # type: ignore[attr-defined]
copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),))) copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
def _legacy_load(f, map_location, pickle_module, **pickle_load_args): def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
deserialized_objects = {} deserialized_objects: Dict[int, Any] = {}
restore_location = _get_restore_location(map_location) restore_location = _get_restore_location(map_location)
@ -648,7 +654,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
warnings.warn(msg, SourceChangeWarning) warnings.warn(msg, SourceChangeWarning)
def legacy_load(f): def legacy_load(f):
deserialized_objects = {} deserialized_objects: Dict[int, Any] = {}
def persistent_load(saved_id): def persistent_load(saved_id):
if isinstance(saved_id, tuple): if isinstance(saved_id, tuple):
@ -777,7 +783,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
return result return result
def _maybe_decode_ascii(bytes_str): def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
# When using encoding='bytes' in Py3, some **internal** keys stored as # When using encoding='bytes' in Py3, some **internal** keys stored as
# strings in Py2 are loaded as bytes. This function decodes them with # strings in Py2 are loaded as bytes. This function decodes them with
# ascii encoding, one that Py3 uses by default. # ascii encoding, one that Py3 uses by default.

View File

@ -1,5 +1,5 @@
import torch import torch
from typing import Union, Sequence, List, Tuple from typing import Any, List, Sequence, Tuple, Union
import builtins import builtins
@ -29,3 +29,15 @@ Number = Union[builtins.int, builtins.float, builtins.bool]
# literal device object). This nomenclature is consistent with PythonArgParser. # literal device object). This nomenclature is consistent with PythonArgParser.
# None means use the default device (typically CPU) # None means use the default device (typically CPU)
Device = Union[_device, str, None] Device = Union[_device, str, None]
# Storage protocol implemented by ${Type}StorageBase classes
class Storage(object):
_cdata: int
def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool) -> None:
...
def size(self) -> int:
...
...