import io import torch from ._utils import _type, _cuda from torch.types import Storage from typing import Any, TypeVar, Type, Union, cast import copy import collections from functools import lru_cache T = TypeVar('T', bound='Union[_StorageBase, _TypedStorage]') class _StorageBase(object): _cdata: Any is_cuda: bool = False is_sparse: bool = False is_sparse_csr: bool = False device: torch.device def __init__(self, *args, **kwargs): ... # noqa: E704 def __len__(self) -> int: ... # noqa: E704 def __getitem__(self, idx): ... # noqa: E704 def copy_(self, source: T) -> T: ... # noqa: E704 def nbytes(self) -> int: ... # noqa: E704 def size(self) -> int: return self.nbytes() def type(self, dtype: str = None, non_blocking: bool = False) -> T: ... # noqa: E704 def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704 def element_size(self) -> int: ... # noqa: E704 def get_device(self) -> int: ... # noqa: E704 def data_ptr(self) -> int: ... # noqa: E704 # Defined in torch/csrc/generic/StorageSharing.cpp def _share_filename_(self): ... # noqa: E704 def _share_fd_(self): ... # noqa: E704 @classmethod def _new_using_filename(cls: Type[T], size: int) -> T: ... # noqa: E704 @classmethod def _new_using_fd(cls: Type[T], size: int) -> T: ... # noqa: E704 def __str__(self): content = ' ' + '\n '.join(str(self[i]) for i in range(len(self))) return content + f'\n[{torch.typename(self)} of size {len(self)}]' def __repr__(self): return str(self) def __iter__(self): return iter(map(lambda i: self[i], range(self.size()))) def __copy__(self): return self.clone() def __deepcopy__(self, memo): memo = memo.setdefault('torch', {}) if self._cdata in memo: return memo[self._cdata] new_storage = self.clone() memo[self._cdata] = new_storage return new_storage def __reduce__(self): b = io.BytesIO() torch.save(self, b, _use_new_zipfile_serialization=False) return (_load_from_bytes, (b.getvalue(),)) def __sizeof__(self): return super(_StorageBase, self).__sizeof__() + self.size() def clone(self): """Returns a copy of this storage""" device = self.get_device() if self.is_cuda else -1 with torch.cuda.device(device): return type(self)(self.nbytes()).copy_(self) def tolist(self): """Returns a list containing the elements of this storage""" return list(self) def cpu(self): """Returns a CPU copy of this storage if it's not already on the CPU""" return _type(self, getattr(torch, self.__class__.__name__)) def _to(self, dtype): storage = torch.tensor([], dtype=torch.uint8, device=self.device).set_(cast(Storage, self)).to(dtype).storage() if storage.data_ptr() == self.data_ptr(): storage = storage.clone() return storage def double(self): """Casts this storage to double type""" return self._to(torch.double) def float(self): """Casts this storage to float type""" return self._to(torch.float) def half(self): """Casts this storage to half type""" return self._to(torch.half) def long(self): """Casts this storage to long type""" return self._to(torch.long) def int(self): """Casts this storage to int type""" return self._to(torch.int) def short(self): """Casts this storage to short type""" return self._to(torch.short) def char(self): """Casts this storage to char type""" return self._to(torch.int8) def byte(self): """Casts this storage to byte type""" return self._to(torch.uint8) def bool(self): """Casts this storage to bool type""" return self._to(torch.bool) def bfloat16(self): """Casts this storage to bfloat16 type""" return self._to(torch.bfloat16) def complex_double(self): """Casts this storage to complex double type""" return self._to(torch.cdouble) def complex_float(self): """Casts this storage to complex float type""" return self._to(torch.cfloat) def pin_memory(self): """Copies the storage to pinned memory, if it's not already pinned.""" if self.is_cuda: raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned") import torch.cuda allocator = torch.cuda._host_allocator() # type: ignore[attr-defined] return type(self)(self.size(), allocator=allocator).copy_(self) def share_memory_(self): """Moves the storage to shared memory. This is a no-op for storages already in shared memory and for CUDA storages, which do not need to be moved for sharing across processes. Storages in shared memory cannot be resized. Returns: self """ from torch.multiprocessing import get_sharing_strategy if self.is_cuda: pass # CUDA doesn't use POSIX shared memory elif get_sharing_strategy() == 'file_system': self._share_filename_() else: self._share_fd_() return self @classmethod def _new_shared(cls, size): """Creates a new storage in shared memory with the same data type""" from torch.multiprocessing import get_sharing_strategy if cls.is_cuda: return cls(size) elif get_sharing_strategy() == 'file_system': return cls._new_using_filename(size) else: return cls._new_using_fd(size) def _untyped(self): return self def _load_from_bytes(b): return torch.load(io.BytesIO(b)) _StorageBase.type = _type # type: ignore[assignment] _StorageBase.cuda = _cuda # type: ignore[assignment] @lru_cache(maxsize=None) def _dtype_to_storage_type_map(): return { torch.double: 'DoubleStorage', torch.float: 'FloatStorage', torch.half: 'HalfStorage', torch.long: 'LongStorage', torch.int: 'IntStorage', torch.int16: 'ShortStorage', torch.int8: 'CharStorage', torch.uint8: 'ByteStorage', torch.bool: 'BoolStorage', torch.bfloat16: 'BFloat16Storage', torch.cdouble: 'ComplexDoubleStorage', torch.cfloat: 'ComplexFloatStorage', torch.qint8: 'QInt8Storage', torch.qint32: 'QInt32Storage', torch.quint8: 'QUInt8Storage', torch.quint4x2: 'QUInt4x2Storage', torch.quint2x4: 'QUInt2x4Storage', } @lru_cache(maxsize=None) def _storage_type_to_dtype_map(): dtype_map = { val: key for key, val in _dtype_to_storage_type_map().items()} return dtype_map class _TypedStorage: is_sparse = False def fill_(self, value): self[0:len(self)] = value return self def __init__(self, *args, **kwargs): arg_error_msg = ( f'{type(self)} constructor received an invalid combination ' f'of arguments - got args={tuple(type(arg) for arg in args)}, ' f'kwargs={ {key: type(val) for key, val in kwargs.items()} }, but ' 'expected one of:\n' ' * no arguments\n' ' * (int size)\n' ' * (Sequence data)\n') if type(self) == _TypedStorage: arg_error_msg += ' * (wrap_storage=<_UntypedStorage>, dtype=)' else: arg_error_msg += ' * (wrap_storage=<_UntypedStorage>)' if 'wrap_storage' in kwargs: assert len(args) == 0, ( "No positional arguments should be given when using " "'wrap_storage'") if type(self) == _TypedStorage: assert 'dtype' in kwargs, ( "When using 'wrap_storage', 'dtype' also must be specified") assert len(kwargs) == 2, ( "Only 'wrap_storage' and 'dtype' should be given, but got: " f"{kwargs}") dtype = kwargs['dtype'] assert isinstance(dtype, torch.dtype) self.dtype = dtype else: assert hasattr(self, 'dtype') assert len(kwargs) == 1, ( f"Only 'wrap_storage' should be given, but got: {kwargs.keys()}") dtype = self.dtype storage = kwargs['wrap_storage'] if not isinstance(storage, (torch._UntypedStorage, torch.cuda._UntypedStorage)): raise TypeError(arg_error_msg) if type(self) != _TypedStorage and storage.__module__ != self.__module__: raise TypeError(( arg_error_msg + f'\n`storage` `module {storage.__module__}` does not match ' f'module of {type(self)}')) self._storage = storage else: assert type(self) != _TypedStorage, ( "Calling __init__ this way is only supported in _TypedStorage's " "child classes. _TypedStorage can only be directly instantiated " "when kwargs 'wrap_storage' and 'dtype' are given.") assert len(kwargs) == 0, "invalid keyword arguments" def isint(x): try: int(x) except TypeError: return False return True if len(args) == 0: self._storage = eval(self.__module__)._UntypedStorage() elif len(args) == 1 and isint(args[0]): self._storage = eval(self.__module__)._UntypedStorage(int(args[0]) * self.element_size()) elif len(args) == 1 and isinstance(args[0], collections.abc.Sequence): if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: interpret_dtypes = { torch.quint8: torch.uint8, torch.quint4x2: torch.uint8, torch.quint2x4: torch.uint8, torch.qint32: torch.int32, torch.qint8: torch.int8 } tmp_tensor = torch.tensor( args[0], dtype=interpret_dtypes[self.dtype], device='cuda' if eval(self.__module__) is torch.cuda else 'cpu') else: tmp_tensor = torch.tensor( args[0], dtype=self.dtype, device='cuda' if eval(self.__module__) is torch.cuda else 'cpu') self._storage = tmp_tensor.storage()._untyped() else: raise TypeError(arg_error_msg) @property def is_cuda(self): return self._storage.device.type == 'cuda' def _untyped(self): return self._storage def _new_wrapped_storage(self, untyped_storage): module = eval(untyped_storage.__module__) assert type(untyped_storage) == module._UntypedStorage if type(self) == _TypedStorage: return _TypedStorage(wrap_storage=untyped_storage, dtype=self.dtype) else: # NOTE: We need to use the module of untyped_storage in case self's # module is different, e.g. if self is on CPU and untyped_storage # is on CUDA, and vice versa return getattr(module, type(self).__name__)(wrap_storage=untyped_storage) def __len__(self): return self._storage.nbytes() // self.element_size() def _maybe_wrap_index(self, idx, is_stop=False): if idx is None: if is_stop: return self.size() else: return 0 else: if type(idx) != int: raise TypeError( f"can't index a {type(self)} with {type(idx)}") if is_stop: if (idx > self.size()) or (idx < -self.size()): raise IndexError( f'index {idx} out of range for storage of size {self.size()}') if idx > 0: return idx else: return idx % self.size() else: if (idx >= self.size()) or (idx < -self.size()): raise IndexError( f'index {idx} out of range for storage of size {self.size()}') return idx % self.size() def __setitem__(self, idx, value): if not isinstance(idx, (int, slice)): raise RuntimeError(f"can't index a {type(self)} with {type(idx)}") if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: interpret_dtypes = { torch.quint8: torch.uint8, torch.quint4x2: torch.uint8, torch.quint2x4: torch.uint8, torch.qint32: torch.int32, torch.qint8: torch.int8 } tmp_dtype = interpret_dtypes[self.dtype] tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self.device).set_(_TypedStorage( wrap_storage=self._storage, dtype=tmp_dtype)) else: tmp_tensor = torch.tensor([], dtype=self.dtype, device=self.device).set_(self) tmp_tensor[idx] = value def __getitem__(self, idx): # NOTE: Before _TypedStorage existed, indexing with a slice used to be # possible for Storage objects. However, it would return # a storage view, which would be a hassle to implement in _TypedStorage, # so it was disabled if isinstance(idx, slice): raise RuntimeError('slices are only supported in _UntypedStorage.__getitem__') elif not isinstance(idx, int): raise RuntimeError(f"can't index a {type(self)} with {type(idx)}") if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: interpret_dtypes = { torch.quint8: torch.uint8, torch.quint4x2: torch.uint8, torch.quint2x4: torch.uint8, torch.qint32: torch.int32, torch.qint8: torch.int8 } return _TypedStorage( wrap_storage=self._storage, dtype=interpret_dtypes[self.dtype])[idx] idx_wrapped = self._maybe_wrap_index(idx) tmp_tensor = torch.tensor([], dtype=self.dtype, device=self.device).set_(self) return tmp_tensor[idx_wrapped].item() def copy_(self, source: T, non_blocking=None): self._storage.copy_(source._untyped(), non_blocking) return self def nbytes(self): return self._storage.nbytes() def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]: if dtype is None: return '.'.join([self.__module__, type(self).__name__]) else: return self._storage.type(dtype, non_blocking) def cuda(self, device=None, non_blocking=False, **kwargs) -> T: cuda_storage = self._storage.cuda(device, non_blocking, **kwargs) return self._new_wrapped_storage(cuda_storage) def element_size(self): return torch._utils._element_size(self.dtype) def get_device(self) -> int: return self._storage.get_device() def __str__(self): data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size())) if type(self) == _TypedStorage: return data_str + ( f'\n[{torch.typename(self)} with dtype {self.dtype} ' f'of size {len(self)}]') else: return data_str + f'\n[{torch.typename(self)} of size {len(self)}]' def __repr__(self): return str(self) def __iter__(self): return iter(map(lambda i: self[i], range(self.size()))) def __copy__(self): return self._new_wrapped_storage(copy.copy(self._storage)) def __deepcopy__(self, memo): return self._new_wrapped_storage(copy.deepcopy(self._storage, memo)) def __sizeof__(self): return super(_TypedStorage, self).__sizeof__() + self.nbytes() def clone(self): """Returns a copy of this storage""" return self._new_wrapped_storage(self._storage.clone()) def tolist(self): """Returns a list containing the elements of this storage""" return list(self) def cpu(self): """Returns a CPU copy of this storage if it's not already on the CPU""" return self._new_wrapped_storage(self._storage.cpu()) def pin_memory(self): """Coppies the storage to pinned memory, if it's not already pinned.""" return self._new_wrapped_storage(self._storage.pin_memory()) def share_memory_(self): """Moves the storage to shared memory. This is a no-op for storages already in shared memory and for CUDA storages, which do not need to be moved for sharing across processes. Storages in shared memory cannot be resized. Returns: self """ self._storage.share_memory_() return self @classmethod def _new_shared(cls, size): """Creates a new storage in shared memory with the same data type""" module = eval(cls.__module__) untyped_storage = module._UntypedStorage._new_shared(size * cls().element_size()) return cls(wrap_storage=untyped_storage) @property def _cdata(self): return self._storage._cdata @property def device(self): return self._storage.device def size(self): return len(self) def pickle_storage_type(self): try: return _dtype_to_storage_type_map()[self.dtype] except KeyError: raise KeyError(f'dtype {self.dtype} is not recognized') def __reduce__(self): b = io.BytesIO() torch.save(self, b, _use_new_zipfile_serialization=False) return (_load_from_bytes, (b.getvalue(),)) def data_ptr(self): return self._storage.data_ptr() def resize_(self, size): self._storage.resize_(size * self.element_size()) @classmethod def _free_weak_ref(cls, *args, **kwargs): return eval(cls.__module__)._UntypedStorage._free_weak_ref(*args, **kwargs) def _weak_ref(self, *args, **kwargs): return self._storage._weak_ref(*args, **kwargs) @classmethod def from_buffer(cls, *args, **kwargs): if cls == _TypedStorage: raise RuntimeError( 'from_buffer: only supported for subclasses of _TypedStorage') if 'dtype' in kwargs or len(args) == 5: raise RuntimeError(( "from_buffer: 'dtype' can only be specified in " "_UntypedStorage.from_buffer")) kwargs['dtype'] = cls().dtype untyped_storage = eval(cls.__module__)._UntypedStorage.from_buffer(*args, **kwargs) return cls(wrap_storage=untyped_storage) def _to(self, dtype): storage = torch.tensor([], dtype=self.dtype, device=self.device).set_(self).to(dtype).storage() if storage.data_ptr() == self.data_ptr(): storage = storage.clone() return storage def double(self): """Casts this storage to double type""" return self._to(torch.double) def float(self): """Casts this storage to float type""" return self._to(torch.float) def half(self): """Casts this storage to half type""" return self._to(torch.half) def long(self): """Casts this storage to long type""" return self._to(torch.long) def int(self): """Casts this storage to int type""" return self._to(torch.int) def short(self): """Casts this storage to short type""" return self._to(torch.short) def char(self): """Casts this storage to char type""" return self._to(torch.int8) def byte(self): """Casts this storage to byte type""" return self._to(torch.uint8) def bool(self): """Casts this storage to bool type""" return self._to(torch.bool) def bfloat16(self): """Casts this storage to bfloat16 type""" return self._to(torch.bfloat16) def complex_double(self): """Casts this storage to complex double type""" return self._to(torch.cdouble) def complex_float(self): """Casts this storage to complex float type""" return self._to(torch.cfloat) @classmethod def from_file(cls, filename, shared, size): if cls == _TypedStorage: raise RuntimeError('from_file can only be called on derived classes') untyped_storage = eval(cls.__module__)._UntypedStorage.from_file( filename, shared, size * torch._utils._element_size(cls.dtype)) storage = cls(wrap_storage=untyped_storage) return storage @classmethod def _expired(cls, *args, **kwargs): return eval(cls.__module__)._UntypedStorage._expired(*args, **kwargs) def is_pinned(self): return self._storage.is_pinned() def _write_file(self, *args, **kwargs): return self._storage._write_file(*args, **kwargs) def _set_from_file(self, *args, **kwargs): return self._storage._set_from_file(*args, **kwargs) def _set_cdata(self, *args, **kwargs): return self._storage._set_cdata(*args, **kwargs) def _share_cuda_(self, *args, **kwargs): return self._storage._share_cuda_(*args, **kwargs) def is_shared(self): return self._storage.is_shared() @classmethod def _new_shared_cuda(cls, *args, **kwargs): return eval(cls.__module__)._UntypedStorage._new_shared_cuda(*args, **kwargs) @classmethod def _new_with_weak_ptr(cls, *args, **kwargs): return eval(cls.__module__)._UntypedStorage._new_with_weak_ptr(*args, **kwargs) def _share_filename_(self, *args, **kwargs): manager_handle, storage_handle, size = self._storage._share_filename_(*args, **kwargs) return manager_handle, storage_handle, size // self.element_size() @classmethod def _new_shared_filename(cls, manager, obj, size): bytes_size = size * torch._utils._element_size(cls.dtype) return cls(wrap_storage=eval(cls.__module__)._UntypedStorage._new_shared_filename(manager, obj, bytes_size)) def _shared_decref(self): self._storage._shared_decref() return self @classmethod def _release_ipc_counter(cls, *args, **kwargs): return eval(cls.__module__)._UntypedStorage._release_ipc_counter(*args, **kwargs) def _shared_incref(self, *args, **kwargs): return self._storage._shared_incref(*args, **kwargs) def _share_fd_(self, *args, **kwargs): fd, size = self._storage._share_fd_(*args, **kwargs) return fd, size // self.element_size() def _get_dtype_from_pickle_storage_type(pickle_storage_type: str): try: return _storage_type_to_dtype_map()[pickle_storage_type] except KeyError: raise KeyError( f'pickle storage type "{pickle_storage_type}" is not recognized')