typing: storage (#130669)

This isn't a full typing of the file - it just fixes some uses of unbound 'T' (if you use a TypeVar as an output it also needs to be an input).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130669
Approved by: https://github.com/oulgen, https://github.com/Skylion007
This commit is contained in:
Aaron Orenstein 2024-07-15 09:58:09 -07:00 committed by PyTorch MergeBot
parent 8390843eba
commit ea25febfab

View File

@ -1,4 +1,7 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import collections
import copy
import functools
@ -6,6 +9,7 @@ import io
import threading
import warnings
from typing import Any, cast, Dict as _Dict, Optional as _Optional, Type, TypeVar, Union
from typing_extensions import Self
import torch
from torch._utils import _to, _type
@ -48,7 +52,7 @@ class _StorageBase:
def copy_(self, source: T, non_blocking: _Optional[_bool] = None) -> T:
raise NotImplementedError
def new(self) -> T: # type: ignore[type-var]
def new(self) -> Union[_StorageBase, TypedStorage]:
raise NotImplementedError
def nbytes(self) -> _int:
@ -57,10 +61,14 @@ class _StorageBase:
def size(self) -> _int:
return self.nbytes()
def type(self, dtype: _Optional[str] = None, non_blocking: _bool = False) -> T: # type: ignore[type-var]
def type(
self, dtype: _Optional[str] = None, non_blocking: _bool = False
) -> Union[_StorageBase, TypedStorage]:
return _type(self, dtype, non_blocking)
def cuda(self, device=None, non_blocking=False) -> T: # type: ignore[type-var, misc] # noqa: E704
def cuda(
self, device=None, non_blocking=False
) -> Union[_StorageBase, TypedStorage]:
"""Returns a copy of this object in CUDA memory.
If this object is already in CUDA memory and on the correct device, then
@ -75,7 +83,7 @@ class _StorageBase:
device2 = torch.device("cuda", device) if device else torch.device("cuda")
return self.to(device=device2, non_blocking=non_blocking)
def hpu(self, device=None, non_blocking=False) -> T: # type: ignore[type-var, misc] # noqa: E704
def hpu(self, device=None, non_blocking=False) -> Union[_StorageBase, TypedStorage]:
"""Returns a copy of this object in HPU memory.
If this object is already in HPU memory and on the correct device, then
@ -141,7 +149,7 @@ class _StorageBase:
def _new_with_weak_ptr(cls: Type[T], *args, **kwargs) -> T:
raise NotImplementedError
def _shared_decref(self) -> T: # type: ignore[type-var]
def _shared_decref(self) -> Union[_StorageBase, TypedStorage]:
raise NotImplementedError
def _write_file(self, *args, **kwargs):
@ -150,7 +158,7 @@ class _StorageBase:
def resize_(self, size: _int):
raise NotImplementedError
def _weak_ref(self, *args, **kwargs) -> T: # type: ignore[type-var]
def _weak_ref(self, *args, **kwargs) -> Union[_StorageBase, TypedStorage]:
raise NotImplementedError
def _set_from_file(self, *args, **kwargs):
@ -185,11 +193,11 @@ class _StorageBase:
raise NotImplementedError
@classmethod
def from_file(cls, filename, shared, nbytes) -> T: # type: ignore[type-var]
def from_file(cls, filename, shared, nbytes) -> Union[_StorageBase, TypedStorage]:
raise NotImplementedError
@classmethod
def _expired(cls, *args, **kwargs) -> T: # type: ignore[type-var]
def _expired(cls, *args, **kwargs) -> Union[_StorageBase, TypedStorage]:
raise NotImplementedError
def _byteswap(self, *args, **kwargs):
@ -260,7 +268,9 @@ class _StorageBase:
storage = storage.clone()
return storage
def to(self, *, device: torch.device, non_blocking: _bool = False) -> T: # type: ignore[type-var, misc] # noqa: E704
def to(
self, *, device: torch.device, non_blocking: _bool = False
) -> Union[_StorageBase, TypedStorage]:
return _to(self, device, non_blocking)
def double(self):
@ -852,12 +862,15 @@ class TypedStorage:
_warn_typed_storage_removal()
return self._untyped_storage
def _new_wrapped_storage(self, untyped_storage):
def _new_wrapped_storage(self, untyped_storage) -> Self:
assert type(untyped_storage) == torch.UntypedStorage
if type(self) == TypedStorage:
return TypedStorage(
wrap_storage=untyped_storage, dtype=self.dtype, _internal=True
return cast(
Self,
TypedStorage(
wrap_storage=untyped_storage, dtype=self.dtype, _internal=True
),
)
else:
return type(self)(wrap_storage=untyped_storage)
@ -982,9 +995,9 @@ class TypedStorage:
def copy_(self, source: T, non_blocking: _Optional[bool] = None):
_warn_typed_storage_removal()
if isinstance(source, TypedStorage):
self._untyped_storage.copy_(source._untyped_storage, non_blocking) # type: ignore[arg-type]
self._untyped_storage.copy_(source._untyped_storage, non_blocking)
else:
self._untyped_storage.copy_(source, non_blocking) # type: ignore[arg-type]
self._untyped_storage.copy_(source, non_blocking)
return self
def nbytes(self):
@ -999,7 +1012,7 @@ class TypedStorage:
self,
dtype: _Optional[str] = None,
non_blocking: bool = False,
) -> Union[T, str]:
) -> Union[_StorageBase, TypedStorage, str]:
_warn_typed_storage_removal()
if dtype is None:
legacy_class = self._get_legacy_storage_class()
@ -1012,7 +1025,7 @@ class TypedStorage:
else:
return self._untyped_storage.type(dtype, non_blocking)
def cuda(self, device=None, non_blocking=False) -> T: # type: ignore[misc,type-var]
def cuda(self, device=None, non_blocking=False) -> Self:
_warn_typed_storage_removal()
if self.dtype in [
torch.quint8,
@ -1022,12 +1035,10 @@ class TypedStorage:
torch.qint8,
]:
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(
device, non_blocking
)
cuda_storage = self._untyped_storage.cuda(device, non_blocking)
return self._new_wrapped_storage(cuda_storage)
def hpu(self, device=None, non_blocking=False) -> T: # type: ignore[misc,type-var]
def hpu(self, device=None, non_blocking=False) -> Self:
_warn_typed_storage_removal()
if self.dtype in [
torch.quint8,
@ -1037,12 +1048,10 @@ class TypedStorage:
torch.qint8,
]:
raise RuntimeError("Cannot create HPU storage with quantized dtype")
hpu_storage: torch.UntypedStorage = self._untyped_storage.hpu(
device, non_blocking
)
hpu_storage = self._untyped_storage.hpu(device, non_blocking)
return self._new_wrapped_storage(hpu_storage)
def to(self, *, device: torch.device, non_blocking: bool = False) -> T: # type: ignore[type-var, misc]
def to(self, *, device: torch.device, non_blocking: bool = False) -> Self:
_warn_typed_storage_removal()
if self.dtype in [
torch.quint8,
@ -1054,9 +1063,7 @@ class TypedStorage:
raise RuntimeError(
f"Cannot create {device.type.upper()} storage with quantized dtype"
)
to_storage: torch.UntypedStorage = self._untyped_storage.to(
device=device, non_blocking=non_blocking
)
to_storage = self._untyped_storage.to(device=device, non_blocking=non_blocking)
return self._new_wrapped_storage(to_storage)
def element_size(self):
@ -1385,7 +1392,7 @@ class TypedStorage:
_warn_typed_storage_removal()
if cls == TypedStorage:
raise RuntimeError("from_file can only be called on derived classes")
untyped_storage: UntypedStorage = UntypedStorage.from_file(
untyped_storage = UntypedStorage.from_file(
filename, shared, size * torch._utils._element_size(cls.dtype)
)
storage = cls(wrap_storage=untyped_storage)