mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Type-annotate serialization.py (#40862)
Summary: Move Storage class from __init__.pyi.in to types.py and make it a protocol, since this is not a real class Expose `PyTorchFileReader` and `PyTorchFileWriter` native classes Ignore function attributes, as there are yet no good way to type annotate those, see https://github.com/python/mypy/issues/2087 Pull Request resolved: https://github.com/pytorch/pytorch/pull/40862 Differential Revision: D22344743 Pulled By: malfet fbshipit-source-id: 95cdb6f980ee79383960f306223e170c63df3232
This commit is contained in:
parent
9fa1f27968
commit
591fffc524
3
mypy.ini
3
mypy.ini
|
|
@ -317,9 +317,6 @@ ignore_errors = True
|
|||
[mypy-torch.nn.functional]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.serialization]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -2,10 +2,11 @@
|
|||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar, Type
|
||||
from typing import (Any, BinaryIO, Callable, ContextManager, Iterator, List, NamedTuple,
|
||||
Optional, overload, Sequence, Tuple, TypeVar, Type, Union)
|
||||
from torch._six import inf
|
||||
|
||||
from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Number, Device
|
||||
from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage
|
||||
|
||||
import builtins
|
||||
|
||||
|
|
@ -90,9 +91,6 @@ class qscheme: ...
|
|||
# Defined in torch/csrc/utils/tensor_qschemes.cpp
|
||||
per_tensor_affine: qscheme = ...
|
||||
|
||||
# Defined in torch/csrc/generic/Storage.cpp
|
||||
class Storage: ...
|
||||
|
||||
# Defined in torch/csrc/autograd/python_function.cpp
|
||||
class _FunctionBase(object):
|
||||
...
|
||||
|
|
@ -169,6 +167,24 @@ class FileCheck(object):
|
|||
# TODO
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/jit/python/init.cpp
|
||||
class PyTorchFileReader(object):
|
||||
@overload
|
||||
def __init__(self, name: str) -> None: ...
|
||||
@overload
|
||||
def __init__(self, buffer: BinaryIO) -> None: ...
|
||||
def get_record(self, name: str) -> bytes: ...
|
||||
...
|
||||
|
||||
class PyTorchFileWriter(object):
|
||||
@overload
|
||||
def __init__(self, name: str) -> None: ...
|
||||
@overload
|
||||
def __init__(self, buffer: BinaryIO) -> None: ...
|
||||
def write_record(self, name: str, data: bytes, size: _int) -> None: ...
|
||||
def write_end_of_file(self) -> None: ...
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/Generator.cpp
|
||||
class Generator(object):
|
||||
device: _device
|
||||
|
|
|
|||
|
|
@ -311,11 +311,6 @@ def is_deterministic():
|
|||
"""
|
||||
return _C._get_deterministic()
|
||||
|
||||
# If you edit these imports, please update torch/__init__.py.in as well
|
||||
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
|
||||
from .serialization import save, load
|
||||
from ._tensor_str import set_printoptions
|
||||
|
||||
################################################################################
|
||||
# Define Storage and Tensor classes
|
||||
################################################################################
|
||||
|
|
@ -388,6 +383,10 @@ _storage_classes = {
|
|||
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
|
||||
_tensor_classes: Set[Type] = set()
|
||||
|
||||
# If you edit these imports, please update torch/__init__.py.in as well
|
||||
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
|
||||
from .serialization import save, load
|
||||
from ._tensor_str import set_printoptions
|
||||
|
||||
################################################################################
|
||||
# Initialize extension
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ from contextlib import closing, contextmanager
|
|||
from ._utils import _import_dotted_name
|
||||
from ._six import string_classes as _string_classes
|
||||
from torch._utils_internal import get_source_lines_and_file
|
||||
from torch.types import Storage
|
||||
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union
|
||||
import copyreg
|
||||
import pickle
|
||||
import pathlib
|
||||
|
|
@ -26,7 +28,6 @@ MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
|
|||
PROTOCOL_VERSION = 1001
|
||||
STORAGE_KEY_SEPARATOR = ','
|
||||
|
||||
|
||||
class SourceChangeWarning(Warning):
|
||||
pass
|
||||
|
||||
|
|
@ -41,7 +42,7 @@ def mkdtemp():
|
|||
_package_registry = []
|
||||
|
||||
|
||||
def _is_zipfile(f):
|
||||
def _is_zipfile(f) -> bool:
|
||||
# This is a stricter implementation than zipfile.is_zipfile().
|
||||
# zipfile.is_zipfile() is True if the magic number appears anywhere in the
|
||||
# binary. Since we expect the files here to be generated by torch.save or
|
||||
|
|
@ -160,7 +161,7 @@ register_package(10, _cpu_tag, _cpu_deserialize)
|
|||
register_package(20, _cuda_tag, _cuda_deserialize)
|
||||
|
||||
|
||||
def location_tag(storage):
|
||||
def location_tag(storage: Storage):
|
||||
for _, tagger, _ in _package_registry:
|
||||
location = tagger(storage)
|
||||
if location:
|
||||
|
|
@ -237,29 +238,30 @@ def _open_file_like(name_or_buffer, mode):
|
|||
|
||||
|
||||
class _open_zipfile_reader(_opener):
|
||||
def __init__(self, name_or_buffer):
|
||||
def __init__(self, name_or_buffer) -> None:
|
||||
super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
|
||||
|
||||
|
||||
class _open_zipfile_writer_file(_opener):
|
||||
def __init__(self, name):
|
||||
def __init__(self, name) -> None:
|
||||
super(_open_zipfile_writer_file, self).__init__(torch._C.PyTorchFileWriter(str(name)))
|
||||
|
||||
def __exit__(self, *args):
|
||||
def __exit__(self, *args) -> None:
|
||||
self.file_like.write_end_of_file()
|
||||
|
||||
|
||||
class _open_zipfile_writer_buffer(_opener):
|
||||
def __init__(self, buffer):
|
||||
def __init__(self, buffer) -> None:
|
||||
self.buffer = buffer
|
||||
super(_open_zipfile_writer_buffer, self).__init__(torch._C.PyTorchFileWriter(buffer))
|
||||
|
||||
def __exit__(self, *args):
|
||||
def __exit__(self, *args) -> None:
|
||||
self.file_like.write_end_of_file()
|
||||
self.buffer.flush()
|
||||
|
||||
|
||||
def _open_zipfile_writer(name_or_buffer):
|
||||
container: Type[_opener]
|
||||
if _is_path(name_or_buffer):
|
||||
container = _open_zipfile_writer_file
|
||||
else:
|
||||
|
|
@ -267,7 +269,7 @@ def _open_zipfile_writer(name_or_buffer):
|
|||
return container(name_or_buffer)
|
||||
|
||||
|
||||
def _is_compressed_file(f):
|
||||
def _is_compressed_file(f) -> bool:
|
||||
compress_modules = ['gzip']
|
||||
try:
|
||||
return f.__module__ in compress_modules
|
||||
|
|
@ -291,7 +293,7 @@ def _should_read_directly(f):
|
|||
return False
|
||||
|
||||
|
||||
def _check_seekable(f):
|
||||
def _check_seekable(f) -> bool:
|
||||
|
||||
def raise_err_msg(patterns, e):
|
||||
for p in patterns:
|
||||
|
|
@ -307,8 +309,9 @@ def _check_seekable(f):
|
|||
return True
|
||||
except (io.UnsupportedOperation, AttributeError) as e:
|
||||
raise_err_msg(["seek", "tell"], e)
|
||||
return False
|
||||
|
||||
def _check_dill_version(pickle_module):
|
||||
def _check_dill_version(pickle_module) -> None:
|
||||
'''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
|
||||
If dill version is lower than 0.3.1, a ValueError is raised.
|
||||
|
||||
|
|
@ -327,7 +330,8 @@ def _check_dill_version(pickle_module):
|
|||
pickle_module.__version__
|
||||
))
|
||||
|
||||
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True):
|
||||
def save(obj, f: Union[str, os.PathLike, BinaryIO],
|
||||
pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:
|
||||
"""Saves an object to a disk file.
|
||||
|
||||
See also: :ref:`recommend-saving-models`
|
||||
|
|
@ -370,12 +374,12 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_ne
|
|||
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
|
||||
|
||||
|
||||
def _legacy_save(obj, f, pickle_module, pickle_protocol):
|
||||
def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||
import torch.nn as nn
|
||||
serialized_container_types = {}
|
||||
serialized_storages = {}
|
||||
|
||||
def persistent_id(obj):
|
||||
def persistent_id(obj: Any) -> Optional[Tuple]:
|
||||
# FIXME: the docs say that persistent_id should only return a string
|
||||
# but torch store returns tuples. This works only in the binary protocol
|
||||
# see
|
||||
|
|
@ -396,6 +400,8 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol):
|
|||
return ('module', obj, source_file, source)
|
||||
|
||||
elif torch.is_storage(obj):
|
||||
view_metadata: Optional[Tuple[str, int, int]]
|
||||
obj = cast(Storage, obj)
|
||||
storage_type = normalize_storage_type(type(obj))
|
||||
# Offset is always 0, but we keep it for backwards compatibility
|
||||
# with the old serialization format (which supported storage views)
|
||||
|
|
@ -589,20 +595,20 @@ def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
|
|||
def _get_layout(name):
|
||||
"""Get layout extension object from its string representation.
|
||||
"""
|
||||
cache = _get_layout.cache
|
||||
cache = _get_layout.cache # type: ignore[attr-defined]
|
||||
if not cache:
|
||||
for v in torch.__dict__.values():
|
||||
if isinstance(v, torch.layout):
|
||||
cache[str(v)] = v
|
||||
return cache[name]
|
||||
|
||||
|
||||
_get_layout.cache = {}
|
||||
# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
|
||||
_get_layout.cache = {} # type: ignore[attr-defined]
|
||||
copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
|
||||
|
||||
|
||||
def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
||||
deserialized_objects = {}
|
||||
deserialized_objects: Dict[int, Any] = {}
|
||||
|
||||
restore_location = _get_restore_location(map_location)
|
||||
|
||||
|
|
@ -648,7 +654,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
|||
warnings.warn(msg, SourceChangeWarning)
|
||||
|
||||
def legacy_load(f):
|
||||
deserialized_objects = {}
|
||||
deserialized_objects: Dict[int, Any] = {}
|
||||
|
||||
def persistent_load(saved_id):
|
||||
if isinstance(saved_id, tuple):
|
||||
|
|
@ -777,7 +783,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
|||
return result
|
||||
|
||||
|
||||
def _maybe_decode_ascii(bytes_str):
|
||||
def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
|
||||
# When using encoding='bytes' in Py3, some **internal** keys stored as
|
||||
# strings in Py2 are loaded as bytes. This function decodes them with
|
||||
# ascii encoding, one that Py3 uses by default.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from typing import Union, Sequence, List, Tuple
|
||||
from typing import Any, List, Sequence, Tuple, Union
|
||||
|
||||
import builtins
|
||||
|
||||
|
|
@ -29,3 +29,15 @@ Number = Union[builtins.int, builtins.float, builtins.bool]
|
|||
# literal device object). This nomenclature is consistent with PythonArgParser.
|
||||
# None means use the default device (typically CPU)
|
||||
Device = Union[_device, str, None]
|
||||
|
||||
# Storage protocol implemented by ${Type}StorageBase classes
|
||||
class Storage(object):
|
||||
_cdata: int
|
||||
|
||||
def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool) -> None:
|
||||
...
|
||||
|
||||
def size(self) -> int:
|
||||
...
|
||||
|
||||
...
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user