mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add type informations to torch/storage.py (#46876)
Summary: Fixes https://github.com/pytorch/pytorch/issues/46875 Pull Request resolved: https://github.com/pytorch/pytorch/pull/46876 Reviewed By: glaringlee Differential Revision: D24758448 Pulled By: ezyang fbshipit-source-id: afbc19637fbfaa1b0276cdd707043111aee3abc3
This commit is contained in:
parent
d0d673b043
commit
90a90ab1d6
3
mypy.ini
3
mypy.ini
|
|
@ -137,9 +137,6 @@ ignore_errors = True
|
|||
[mypy-torch._appdirs]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.storage]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch._utils]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,31 @@ import io
|
|||
|
||||
import torch
|
||||
from ._utils import _type, _cuda
|
||||
from typing import Any, TypeVar, Type
|
||||
|
||||
|
||||
T = TypeVar('T', bound='_StorageBase')
|
||||
class _StorageBase(object):
|
||||
is_cuda = False
|
||||
is_sparse = False
|
||||
_cdata: Any
|
||||
is_cuda: bool = False
|
||||
is_sparse: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs): ... # noqa: E704
|
||||
def __len__(self) -> int: ... # noqa: E704
|
||||
def __getitem__(self, idx): ... # noqa: E704
|
||||
def copy_(self, source: T) -> T: ... # noqa: E704
|
||||
def size(self) -> int: ... # noqa: E704
|
||||
def type(self, dtype: str = None, non_blocking: bool = False) -> T: ... # noqa: E704
|
||||
def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704
|
||||
def element_size(self) -> int: ... # noqa: E704
|
||||
def get_device(self) -> int: ... # noqa: E704
|
||||
|
||||
# Defined in torch/csrc/generic/StorageSharing.cpp
|
||||
def _share_filename_(self): ... # noqa: E704
|
||||
def _share_fd_(self): ... # noqa: E704
|
||||
@classmethod
|
||||
def _new_using_filename(cls: Type[T], size: int) -> T: ... # noqa: E704
|
||||
@classmethod
|
||||
def _new_using_fd(cls: Type[T], size: int) -> T: ... # noqa: E704
|
||||
|
||||
def __str__(self):
|
||||
content = ' ' + '\n '.join(str(self[i]) for i in range(len(self)))
|
||||
|
|
@ -104,7 +124,7 @@ class _StorageBase(object):
|
|||
if self.is_cuda:
|
||||
raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned")
|
||||
import torch.cuda
|
||||
allocator = torch.cuda._host_allocator()
|
||||
allocator = torch.cuda._host_allocator() # type: ignore[attr-defined]
|
||||
return type(self)(self.size(), allocator=allocator).copy_(self)
|
||||
|
||||
def share_memory_(self):
|
||||
|
|
@ -141,5 +161,5 @@ def _load_from_bytes(b):
|
|||
return torch.load(io.BytesIO(b))
|
||||
|
||||
|
||||
_StorageBase.type = _type
|
||||
_StorageBase.cuda = _cuda
|
||||
_StorageBase.type = _type # type: ignore[assignment]
|
||||
_StorageBase.cuda = _cuda # type: ignore[assignment]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user