mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
typing: storage (#130669)
This isn't a full typing of the file - it just fixes some uses of unbound 'T' (if you use a TypeVar as an output it also needs to be an input). Pull Request resolved: https://github.com/pytorch/pytorch/pull/130669 Approved by: https://github.com/oulgen, https://github.com/Skylion007
This commit is contained in:
parent
8390843eba
commit
ea25febfab
|
|
@ -1,4 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import functools
|
||||
|
|
@ -6,6 +9,7 @@ import io
|
|||
import threading
|
||||
import warnings
|
||||
from typing import Any, cast, Dict as _Dict, Optional as _Optional, Type, TypeVar, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
from torch._utils import _to, _type
|
||||
|
|
@ -48,7 +52,7 @@ class _StorageBase:
|
|||
def copy_(self, source: T, non_blocking: _Optional[_bool] = None) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
def new(self) -> T: # type: ignore[type-var]
|
||||
def new(self) -> Union[_StorageBase, TypedStorage]:
|
||||
raise NotImplementedError
|
||||
|
||||
def nbytes(self) -> _int:
|
||||
|
|
@ -57,10 +61,14 @@ class _StorageBase:
|
|||
def size(self) -> _int:
|
||||
return self.nbytes()
|
||||
|
||||
def type(self, dtype: _Optional[str] = None, non_blocking: _bool = False) -> T: # type: ignore[type-var]
|
||||
def type(
|
||||
self, dtype: _Optional[str] = None, non_blocking: _bool = False
|
||||
) -> Union[_StorageBase, TypedStorage]:
|
||||
return _type(self, dtype, non_blocking)
|
||||
|
||||
def cuda(self, device=None, non_blocking=False) -> T: # type: ignore[type-var, misc] # noqa: E704
|
||||
def cuda(
|
||||
self, device=None, non_blocking=False
|
||||
) -> Union[_StorageBase, TypedStorage]:
|
||||
"""Returns a copy of this object in CUDA memory.
|
||||
|
||||
If this object is already in CUDA memory and on the correct device, then
|
||||
|
|
@ -75,7 +83,7 @@ class _StorageBase:
|
|||
device2 = torch.device("cuda", device) if device else torch.device("cuda")
|
||||
return self.to(device=device2, non_blocking=non_blocking)
|
||||
|
||||
def hpu(self, device=None, non_blocking=False) -> T: # type: ignore[type-var, misc] # noqa: E704
|
||||
def hpu(self, device=None, non_blocking=False) -> Union[_StorageBase, TypedStorage]:
|
||||
"""Returns a copy of this object in HPU memory.
|
||||
|
||||
If this object is already in HPU memory and on the correct device, then
|
||||
|
|
@ -141,7 +149,7 @@ class _StorageBase:
|
|||
def _new_with_weak_ptr(cls: Type[T], *args, **kwargs) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
def _shared_decref(self) -> T: # type: ignore[type-var]
|
||||
def _shared_decref(self) -> Union[_StorageBase, TypedStorage]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _write_file(self, *args, **kwargs):
|
||||
|
|
@ -150,7 +158,7 @@ class _StorageBase:
|
|||
def resize_(self, size: _int):
|
||||
raise NotImplementedError
|
||||
|
||||
def _weak_ref(self, *args, **kwargs) -> T: # type: ignore[type-var]
|
||||
def _weak_ref(self, *args, **kwargs) -> Union[_StorageBase, TypedStorage]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _set_from_file(self, *args, **kwargs):
|
||||
|
|
@ -185,11 +193,11 @@ class _StorageBase:
|
|||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, filename, shared, nbytes) -> T: # type: ignore[type-var]
|
||||
def from_file(cls, filename, shared, nbytes) -> Union[_StorageBase, TypedStorage]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _expired(cls, *args, **kwargs) -> T: # type: ignore[type-var]
|
||||
def _expired(cls, *args, **kwargs) -> Union[_StorageBase, TypedStorage]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _byteswap(self, *args, **kwargs):
|
||||
|
|
@ -260,7 +268,9 @@ class _StorageBase:
|
|||
storage = storage.clone()
|
||||
return storage
|
||||
|
||||
def to(self, *, device: torch.device, non_blocking: _bool = False) -> T: # type: ignore[type-var, misc] # noqa: E704
|
||||
def to(
|
||||
self, *, device: torch.device, non_blocking: _bool = False
|
||||
) -> Union[_StorageBase, TypedStorage]:
|
||||
return _to(self, device, non_blocking)
|
||||
|
||||
def double(self):
|
||||
|
|
@ -852,12 +862,15 @@ class TypedStorage:
|
|||
_warn_typed_storage_removal()
|
||||
return self._untyped_storage
|
||||
|
||||
def _new_wrapped_storage(self, untyped_storage):
|
||||
def _new_wrapped_storage(self, untyped_storage) -> Self:
|
||||
assert type(untyped_storage) == torch.UntypedStorage
|
||||
|
||||
if type(self) == TypedStorage:
|
||||
return TypedStorage(
|
||||
wrap_storage=untyped_storage, dtype=self.dtype, _internal=True
|
||||
return cast(
|
||||
Self,
|
||||
TypedStorage(
|
||||
wrap_storage=untyped_storage, dtype=self.dtype, _internal=True
|
||||
),
|
||||
)
|
||||
else:
|
||||
return type(self)(wrap_storage=untyped_storage)
|
||||
|
|
@ -982,9 +995,9 @@ class TypedStorage:
|
|||
def copy_(self, source: T, non_blocking: _Optional[bool] = None):
|
||||
_warn_typed_storage_removal()
|
||||
if isinstance(source, TypedStorage):
|
||||
self._untyped_storage.copy_(source._untyped_storage, non_blocking) # type: ignore[arg-type]
|
||||
self._untyped_storage.copy_(source._untyped_storage, non_blocking)
|
||||
else:
|
||||
self._untyped_storage.copy_(source, non_blocking) # type: ignore[arg-type]
|
||||
self._untyped_storage.copy_(source, non_blocking)
|
||||
return self
|
||||
|
||||
def nbytes(self):
|
||||
|
|
@ -999,7 +1012,7 @@ class TypedStorage:
|
|||
self,
|
||||
dtype: _Optional[str] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> Union[T, str]:
|
||||
) -> Union[_StorageBase, TypedStorage, str]:
|
||||
_warn_typed_storage_removal()
|
||||
if dtype is None:
|
||||
legacy_class = self._get_legacy_storage_class()
|
||||
|
|
@ -1012,7 +1025,7 @@ class TypedStorage:
|
|||
else:
|
||||
return self._untyped_storage.type(dtype, non_blocking)
|
||||
|
||||
def cuda(self, device=None, non_blocking=False) -> T: # type: ignore[misc,type-var]
|
||||
def cuda(self, device=None, non_blocking=False) -> Self:
|
||||
_warn_typed_storage_removal()
|
||||
if self.dtype in [
|
||||
torch.quint8,
|
||||
|
|
@ -1022,12 +1035,10 @@ class TypedStorage:
|
|||
torch.qint8,
|
||||
]:
|
||||
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
|
||||
cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(
|
||||
device, non_blocking
|
||||
)
|
||||
cuda_storage = self._untyped_storage.cuda(device, non_blocking)
|
||||
return self._new_wrapped_storage(cuda_storage)
|
||||
|
||||
def hpu(self, device=None, non_blocking=False) -> T: # type: ignore[misc,type-var]
|
||||
def hpu(self, device=None, non_blocking=False) -> Self:
|
||||
_warn_typed_storage_removal()
|
||||
if self.dtype in [
|
||||
torch.quint8,
|
||||
|
|
@ -1037,12 +1048,10 @@ class TypedStorage:
|
|||
torch.qint8,
|
||||
]:
|
||||
raise RuntimeError("Cannot create HPU storage with quantized dtype")
|
||||
hpu_storage: torch.UntypedStorage = self._untyped_storage.hpu(
|
||||
device, non_blocking
|
||||
)
|
||||
hpu_storage = self._untyped_storage.hpu(device, non_blocking)
|
||||
return self._new_wrapped_storage(hpu_storage)
|
||||
|
||||
def to(self, *, device: torch.device, non_blocking: bool = False) -> T: # type: ignore[type-var, misc]
|
||||
def to(self, *, device: torch.device, non_blocking: bool = False) -> Self:
|
||||
_warn_typed_storage_removal()
|
||||
if self.dtype in [
|
||||
torch.quint8,
|
||||
|
|
@ -1054,9 +1063,7 @@ class TypedStorage:
|
|||
raise RuntimeError(
|
||||
f"Cannot create {device.type.upper()} storage with quantized dtype"
|
||||
)
|
||||
to_storage: torch.UntypedStorage = self._untyped_storage.to(
|
||||
device=device, non_blocking=non_blocking
|
||||
)
|
||||
to_storage = self._untyped_storage.to(device=device, non_blocking=non_blocking)
|
||||
return self._new_wrapped_storage(to_storage)
|
||||
|
||||
def element_size(self):
|
||||
|
|
@ -1385,7 +1392,7 @@ class TypedStorage:
|
|||
_warn_typed_storage_removal()
|
||||
if cls == TypedStorage:
|
||||
raise RuntimeError("from_file can only be called on derived classes")
|
||||
untyped_storage: UntypedStorage = UntypedStorage.from_file(
|
||||
untyped_storage = UntypedStorage.from_file(
|
||||
filename, shared, size * torch._utils._element_size(cls.dtype)
|
||||
)
|
||||
storage = cls(wrap_storage=untyped_storage)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user