Rename Typed/UntypedStorage to _Typed/_UntypedStorage (#72540)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72540

Reviewed By: jbschlosser

Differential Revision: D34216823

Pulled By: bdhirsh

fbshipit-source-id: 1bc9930ab582771ebf02308e035576cd1a0dbe47
This commit is contained in:
Kurt Mohler 2022-02-15 15:43:57 -08:00 committed by Facebook GitHub Bot
parent bf90af704f
commit 329238f612
21 changed files with 138 additions and 139 deletions

View File

@ -153,7 +153,6 @@ coverage_ignore_classes = [
"LongTensor",
"ShortStorage",
"ShortTensor",
"UntypedStorage",
"cudaStatus",
# torch.distributed.elastic.multiprocessing.errors
"ChildFailedError",

View File

@ -568,8 +568,8 @@ class TestCuda(TestCase):
self.assertTrue(isinstance(q_copy[0], torch.cuda.FloatTensor))
self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor))
self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor))
self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
self.assertTrue(isinstance(q_copy[3]._storage, torch.cuda.UntypedStorage))
self.assertTrue(isinstance(q_copy[3], torch.storage._TypedStorage))
self.assertTrue(isinstance(q_copy[3]._storage, torch.cuda._UntypedStorage))
q_copy[1].fill_(10)
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))

View File

@ -97,7 +97,7 @@ class SerializationMixin(object):
self.assertTrue(isinstance(c[1], torch.FloatTensor))
self.assertTrue(isinstance(c[2], torch.FloatTensor))
self.assertTrue(isinstance(c[3], torch.FloatTensor))
self.assertTrue(isinstance(c[4], torch.storage.TypedStorage))
self.assertTrue(isinstance(c[4], torch.storage._TypedStorage))
self.assertEqual(c[4].dtype, torch.float)
c[0].fill_(10)
self.assertEqual(c[0], c[2], atol=0, rtol=0)
@ -370,7 +370,7 @@ class SerializationMixin(object):
self.assertTrue(isinstance(c[1], torch.FloatTensor))
self.assertTrue(isinstance(c[2], torch.FloatTensor))
self.assertTrue(isinstance(c[3], torch.FloatTensor))
self.assertTrue(isinstance(c[4], torch.storage.TypedStorage))
self.assertTrue(isinstance(c[4], torch.storage._TypedStorage))
self.assertEqual(c[4].dtype, torch.float32)
c[0].fill_(10)
self.assertEqual(c[0], c[2], atol=0, rtol=0)
@ -620,7 +620,7 @@ class SerializationMixin(object):
a = torch.tensor([], dtype=dtype, device=device)
for other_dtype in get_all_dtypes():
s = torch.TypedStorage(
s = torch._TypedStorage(
wrap_storage=a.storage()._untyped(),
dtype=other_dtype)
save_load_check(a, s)
@ -652,7 +652,7 @@ class SerializationMixin(object):
torch.save([a.storage(), a.imag.storage()], f)
a = torch.randn(10, device=device)
s_bytes = torch.TypedStorage(
s_bytes = torch._TypedStorage(
wrap_storage=a.storage()._untyped(),
dtype=torch.uint8)

View File

@ -1114,7 +1114,7 @@ static PyObject* THPVariable_set_(
at::Storage storage = _r.storage(0, storage_scalar_type, is_typed_storage);
TORCH_CHECK(storage_scalar_type == self.dtype() || !is_typed_storage,
"Expected a Storage of type ", self.dtype(),
" or an UntypedStorage, but got type ", storage_scalar_type,
" or an _UntypedStorage, but got type ", storage_scalar_type,
" for argument 1 'storage'");
auto dispatch_set_ = [](const Tensor& self, Storage source) -> Tensor {
pybind11::gil_scoped_release no_gil;
@ -1130,7 +1130,7 @@ static PyObject* THPVariable_set_(
at::Storage storage = _r.storage(0, storage_scalar_type, is_typed_storage);
TORCH_CHECK(storage_scalar_type == self.dtype() || !is_typed_storage,
"Expected a Storage of type ", self.dtype(),
" or an UntypedStorage, but got type ", storage_scalar_type,
" or an _UntypedStorage, but got type ", storage_scalar_type,
" for argument 1 'storage'");
auto dispatch_set_ = [](const Tensor& self,
Storage source,

View File

@ -482,8 +482,8 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
],
'item': ["def item(self) -> Number: ..."],
'copy_': ["def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..."],
'set_': ['def set_(self, storage: Union[Storage, TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...',
'def set_(self, storage: Union[Storage, TypedStorage]) -> Tensor: ...'],
'set_': ['def set_(self, storage: Union[Storage, _TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...',
'def set_(self, storage: Union[Storage, _TypedStorage]) -> Tensor: ...'],
'split': ['def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...',
'def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...'],
'div': ['def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ...'],

View File

@ -13,7 +13,7 @@ from typing_extensions import Literal
from torch._six import inf
from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage
from torch.storage import TypedStorage
from torch.storage import _TypedStorage
import builtins

View File

@ -594,101 +594,101 @@ __all__.extend(['e', 'pi', 'nan', 'inf'])
################################################################################
from ._tensor import Tensor
from .storage import _StorageBase, TypedStorage
from .storage import _StorageBase, _TypedStorage
# NOTE: New <type>Storage classes should never be added. When adding a new
# dtype, use torch.storage.TypedStorage directly.
# dtype, use torch.storage._TypedStorage directly.
class UntypedStorage(_C.ByteStorageBase, _StorageBase):
class _UntypedStorage(_C.ByteStorageBase, _StorageBase):
pass
class ByteStorage(TypedStorage):
class ByteStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.uint8
class DoubleStorage(TypedStorage):
class DoubleStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.double
class FloatStorage(TypedStorage):
class FloatStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.float
class HalfStorage(TypedStorage):
class HalfStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.half
class LongStorage(TypedStorage):
class LongStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.long
class IntStorage(TypedStorage):
class IntStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.int
class ShortStorage(TypedStorage):
class ShortStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.short
class CharStorage(TypedStorage):
class CharStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.int8
class BoolStorage(TypedStorage):
class BoolStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.bool
class BFloat16Storage(TypedStorage):
class BFloat16Storage(_TypedStorage):
@classproperty
def dtype(self):
return torch.bfloat16
class ComplexDoubleStorage(TypedStorage):
class ComplexDoubleStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.cdouble
class ComplexFloatStorage(TypedStorage):
class ComplexFloatStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.cfloat
class QUInt8Storage(TypedStorage):
class QUInt8Storage(_TypedStorage):
@classproperty
def dtype(self):
return torch.quint8
class QInt8Storage(TypedStorage):
class QInt8Storage(_TypedStorage):
@classproperty
def dtype(self):
return torch.qint8
class QInt32Storage(TypedStorage):
class QInt32Storage(_TypedStorage):
@classproperty
def dtype(self):
return torch.qint32
class QUInt4x2Storage(TypedStorage):
class QUInt4x2Storage(_TypedStorage):
@classproperty
def dtype(self):
return torch.quint4x2
class QUInt2x4Storage(TypedStorage):
class QUInt2x4Storage(_TypedStorage):
@classproperty
def dtype(self):
return torch.quint2x4
_storage_classes = {
UntypedStorage, DoubleStorage, FloatStorage, LongStorage, IntStorage,
_UntypedStorage, DoubleStorage, FloatStorage, LongStorage, IntStorage,
ShortStorage, CharStorage, ByteStorage, HalfStorage, BoolStorage,
QUInt8Storage, QInt8Storage, QInt32Storage, BFloat16Storage,
ComplexFloatStorage, ComplexDoubleStorage, QUInt4x2Storage, QUInt2x4Storage,

View File

@ -17,8 +17,8 @@ def _save_storages(importer, obj):
importers = sys_importer
def persistent_id(obj):
if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
if isinstance(obj, torch.storage.TypedStorage):
if torch.is_storage(obj) or isinstance(obj, torch.storage._TypedStorage):
if isinstance(obj, torch.storage._TypedStorage):
# TODO: Once we decide to break serialization FC, we can
# remove this case
storage = obj._storage
@ -59,10 +59,10 @@ def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dt
if typename == 'storage':
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
# stop wrapping with _TypedStorage
storage = serialized_storages[data[0]]
dtype = serialized_dtypes[data[0]]
return torch.storage.TypedStorage(
return torch.storage._TypedStorage(
wrap_storage=storage._untyped(),
dtype=dtype)

View File

@ -109,9 +109,9 @@ class Tensor(torch._C._TensorBase):
else:
raise RuntimeError(f"Unsupported qscheme {self.qscheme()} in deepcopy")
# TODO: Once we decide to break serialization FC, no longer
# need to wrap with TypedStorage
# need to wrap with _TypedStorage
new_tensor = torch._utils._rebuild_qtensor(
torch.storage.TypedStorage(
torch.storage._TypedStorage(
wrap_storage=new_storage._untyped(),
dtype=self.dtype),
self.storage_offset(),
@ -232,9 +232,9 @@ class Tensor(torch._C._TensorBase):
else:
raise RuntimeError(f"Serialization is not supported for tensors of type {self.qscheme()}")
# TODO: Once we decide to break serialization FC, no longer
# need to wrap with TypedStorage
# need to wrap with _TypedStorage
args_qtensor = (
torch.storage.TypedStorage(
torch.storage._TypedStorage(
wrap_storage=self.storage()._untyped(),
dtype=self.dtype),
self.storage_offset(),
@ -267,9 +267,9 @@ class Tensor(torch._C._TensorBase):
return (torch._utils._rebuild_sparse_csr_tensor, args_sparse_csr)
else:
# TODO: Once we decide to break serialization FC, no longer
# need to wrap with TypedStorage
# need to wrap with _TypedStorage
args = (
torch.storage.TypedStorage(
torch.storage._TypedStorage(
wrap_storage=self.storage()._untyped(),
dtype=self.dtype),
self.storage_offset(),
@ -830,9 +830,9 @@ class Tensor(torch._C._TensorBase):
Returns the type of the underlying storage.
"""
# NB: this returns old fashioned TypedStorage, e.g., FloatStorage, as it
# NB: this returns old fashioned _TypedStorage, e.g., FloatStorage, as it
# would be pretty pointless otherwise (it would always return
# UntypedStorage)
# _UntypedStorage)
return type(self.storage())
def refine_names(self, *names):

View File

@ -128,7 +128,7 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
# TODO: Once we decide to break serialization FC, `storage` no longer needs to
# be a TypedStorage
# be a _TypedStorage
def _rebuild_tensor(storage, storage_offset, size, stride):
# first construct a tensor with the correct dtype/device
t = torch.tensor([], dtype=storage.dtype, device=storage._untyped().device)
@ -210,7 +210,7 @@ def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
# TODO: Once we decide to break serialization FC, `storage` no longer needs to
# be a TypedStorage
# be a _TypedStorage
def _rebuild_qtensor(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks):
qscheme = quantizer_params[0]
if qscheme == torch.per_tensor_affine:

View File

@ -66,7 +66,7 @@ PyTypeObject* getPyTypeObject(const at::Storage& storage) {
scalarType);
auto it = attype_to_py_storage_type.find(attype);
TORCH_INTERNAL_ASSERT(it != attype_to_py_storage_type.end(),
"Failed to get the Python type of `UntypedStorage`.");
"Failed to get the Python type of `_UntypedStorage`.");
return it->second;
}
} // namespace
@ -115,10 +115,10 @@ PyTypeObject* loadTypedStorageTypeObject() {
PyObject* storage_module = PyImport_ImportModule("torch.storage");
TORCH_INTERNAL_ASSERT(storage_module && PyModule_Check(storage_module));
PyObject* typed_storage_obj = PyObject_GetAttrString(storage_module, "TypedStorage");
PyObject* typed_storage_obj = PyObject_GetAttrString(storage_module, "_TypedStorage");
TORCH_INTERNAL_ASSERT(typed_storage_obj && PyType_Check(typed_storage_obj));
return reinterpret_cast<PyTypeObject*>(
PyObject_GetAttrString(storage_module, "TypedStorage"));
PyObject_GetAttrString(storage_module, "_TypedStorage"));
}
PyTypeObject* getTypedStorageTypeObject() {
@ -169,7 +169,7 @@ at::Storage createStorageGetType(PyObject* obj, at::ScalarType& scalar_type, boo
}
if (obj_type == storage_type) {
auto& type = *item.second;
// UntypedStorage should always be interpreted with byte dtype
// _UntypedStorage should always be interpreted with byte dtype
scalar_type = at::kByte;
return type.unsafeStorageFromTH(((THPVoidStorage*)obj)->cdata, true);
}

View File

@ -344,7 +344,7 @@ bool THPStorage_(init)(PyObject *module)
void THPStorage_(postInit)(PyObject *module)
{
THPStorageClass = PyObject_GetAttrString(module, "UntypedStorage");
THPStorageClass = PyObject_GetAttrString(module, "_UntypedStorage");
if (!THPStorageClass) throw python_error();
at::Backend backend = at::Backend::CPU;

View File

@ -297,7 +297,7 @@ Tensor internal_new_from_data(
Storage storage = createStorageGetType(data, storage_scalar_type, is_typed_storage);
TORCH_CHECK(!is_typed_storage || storage_scalar_type == scalar_type,
"Expected a Storage of type ", scalar_type,
" or an UntypedStorage, but got ", storage_scalar_type);
" or an _UntypedStorage, but got ", storage_scalar_type);
tensor = at::empty(sizes, at::initialTensorOptions().dtype(is_typed_storage ? storage_scalar_type : inferred_scalar_type).pinned_memory(pin_memory).device(storage.device()));
tensor.set_(storage);
@ -534,7 +534,7 @@ Tensor legacy_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_t
TORCH_CHECK(
storage_scalar_type == scalar_type,
"Expected a Storage of type ", scalar_type,
" or an UntypedStorage, but got type ", storage_scalar_type,
" or an _UntypedStorage, but got type ", storage_scalar_type,
" for argument 1 'storage'");
}
return new_with_storage(options, scalar_type, storage);
@ -596,7 +596,7 @@ Tensor legacy_tensor_new(c10::DispatchKey dispatch_key, at::ScalarType scalar_ty
TORCH_CHECK(
storage_scalar_type == scalar_type,
"Expected a Storage of type ", scalar_type,
" or an UntypedStorage, but got type ", storage_scalar_type,
" or an _UntypedStorage, but got type ", storage_scalar_type,
" for argument 1 'storage'");
}
return new_with_storage(options, scalar_type, storage);

View File

@ -674,72 +674,72 @@ class _CudaBase(object):
__new__ = _lazy_new
from torch.storage import TypedStorage
from torch.storage import _TypedStorage
class UntypedStorage(_CudaBase, torch._C.CudaByteStorageBase, _StorageBase):
class _UntypedStorage(_CudaBase, torch._C.CudaByteStorageBase, _StorageBase):
pass
class ByteStorage(TypedStorage):
class ByteStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.uint8
class DoubleStorage(TypedStorage):
class DoubleStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.double
class FloatStorage(TypedStorage):
class FloatStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.float
class HalfStorage(TypedStorage):
class HalfStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.half
class LongStorage(TypedStorage):
class LongStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.long
class IntStorage(TypedStorage):
class IntStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.int
class ShortStorage(TypedStorage):
class ShortStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.short
class CharStorage(TypedStorage):
class CharStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.int8
class BoolStorage(TypedStorage):
class BoolStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.bool
class BFloat16Storage(TypedStorage):
class BFloat16Storage(_TypedStorage):
@classproperty
def dtype(self):
return torch.bfloat16
class ComplexDoubleStorage(TypedStorage):
class ComplexDoubleStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.cdouble
class ComplexFloatStorage(TypedStorage):
class ComplexFloatStorage(_TypedStorage):
@classproperty
def dtype(self):
return torch.cfloat
torch._storage_classes.add(UntypedStorage)
torch._storage_classes.add(_UntypedStorage)
torch._storage_classes.add(DoubleStorage)
torch._storage_classes.add(FloatStorage)
torch._storage_classes.add(LongStorage)

View File

@ -123,7 +123,7 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset)
t = torch._utils._rebuild_tensor(
torch.storage.TypedStorage(wrap_storage=storage._untyped(), dtype=dtype),
torch.storage._TypedStorage(wrap_storage=storage._untyped(), dtype=dtype),
tensor_offset, tensor_size, tensor_stride)
if tensor_cls == torch.nn.parameter.Parameter:
@ -317,16 +317,16 @@ def rebuild_storage_empty(cls):
return cls()
def rebuild_typed_storage(storage, dtype):
return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype)
return torch.storage._TypedStorage(wrap_storage=storage, dtype=dtype)
# Use for torch.storage.TypedStorage
# Use for torch.storage._TypedStorage
def reduce_typed_storage(storage):
return (rebuild_typed_storage, (storage._storage, storage.dtype))
def rebuild_typed_storage_child(storage, storage_type):
return storage_type(wrap_storage=storage)
# Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage
# Use for child classes of torch.storage._TypedStorage, like torch.FloatStorage
def reduce_typed_storage_child(storage):
return (rebuild_typed_storage_child, (storage._storage, type(storage)))
@ -358,12 +358,12 @@ def init_reductions():
ForkingPickler.register(torch.cuda.Event, reduce_event)
for t in torch._storage_classes:
if t.__name__ == 'UntypedStorage':
if t.__name__ == '_UntypedStorage':
ForkingPickler.register(t, reduce_storage)
else:
ForkingPickler.register(t, reduce_typed_storage_child)
ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage)
ForkingPickler.register(torch.storage._TypedStorage, reduce_typed_storage)
for t in torch._tensor_classes:
ForkingPickler.register(t, reduce_tensor)

View File

@ -35,7 +35,7 @@ class DirectoryReader(object):
def get_storage_from_record(self, name, numel, dtype):
filename = f"{self.directory}/{name}"
nbytes = torch._utils._element_size(dtype) * numel
storage = cast(Storage, torch.UntypedStorage)
storage = cast(Storage, torch._UntypedStorage)
return _HasStorage(storage.from_file(filename=filename, nbytes=nbytes))
def has_record(self, path):

View File

@ -849,8 +849,8 @@ class PackageExporter:
)
def _persistent_id(self, obj):
if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
if isinstance(obj, torch.storage.TypedStorage):
if torch.is_storage(obj) or isinstance(obj, torch.storage._TypedStorage):
if isinstance(obj, torch.storage._TypedStorage):
# TODO: Once we decide to break serialization FC, we can
# remove this case
storage = obj._storage

View File

@ -217,8 +217,8 @@ class PackageImporter(Importer):
)
storage = loaded_storages[key]
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
return torch.storage.TypedStorage(
# stop wrapping with _TypedStorage
return torch.storage._TypedStorage(
wrap_storage=storage._untyped(), dtype=dtype
)
elif typename == "reduce_package":

View File

@ -162,7 +162,7 @@ register_package(10, _cpu_tag, _cpu_deserialize)
register_package(20, _cuda_tag, _cuda_deserialize)
def location_tag(storage: Union[Storage, torch.storage.TypedStorage]):
def location_tag(storage: Union[Storage, torch.storage._TypedStorage]):
for _, tagger, _ in _package_registry:
location = tagger(storage)
if location:
@ -413,8 +413,8 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
"for correctness upon loading.")
return ('module', obj, source_file, source)
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
if isinstance(obj, torch.storage.TypedStorage):
if isinstance(obj, torch.storage._TypedStorage) or torch.is_storage(obj):
if isinstance(obj, torch.storage._TypedStorage):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
storage = obj._storage
@ -463,8 +463,8 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
# effectively saving nbytes in this case. We'll be able to load it
# and the tensor back up with no problems in _this_ and future
# versions of pytorch, but in older versions, here's the problem:
# the storage will be loaded up as a UntypedStorage, and then the
# FloatTensor will loaded and the UntypedStorage will be assigned to
# the storage will be loaded up as a _UntypedStorage, and then the
# FloatTensor will loaded and the _UntypedStorage will be assigned to
# it. Since the storage dtype does not match the tensor dtype, this
# will cause an error. If we reverse the list, like `[tensor,
# storage]`, then we will save the `tensor.storage()` as a faked
@ -472,7 +472,7 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
# dtype-specific numel count that old versions expect. `tensor`
# will be able to load up properly in old versions, pointing to
# a FloatStorage. However, `storage` is still being translated to
# a UntypedStorage, and it will try to resolve to the same
# a _UntypedStorage, and it will try to resolve to the same
# FloatStorage that `tensor` contains. This will also cause an
# error. It doesn't seem like there's any way around this.
# Probably, we just cannot maintain FC for the legacy format if the
@ -539,9 +539,9 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
# see
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
if isinstance(obj, torch.storage._TypedStorage) or torch.is_storage(obj):
if isinstance(obj, torch.storage.TypedStorage):
if isinstance(obj, torch.storage._TypedStorage):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
storage = obj._storage
@ -806,11 +806,11 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
args = pickle_module.load(f, **pickle_load_args)
key, location, storage_type = args
dtype = storage_type.dtype
obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
obj = cast(Storage, torch._UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
obj = restore_location(obj, location)
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
deserialized_objects[key] = torch.storage.TypedStorage(
# stop wrapping with _TypedStorage
deserialized_objects[key] = torch.storage._TypedStorage(
wrap_storage=obj,
dtype=dtype)
@ -820,8 +820,8 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
element_size = torch._utils._element_size(root.dtype)
offset_bytes = offset * element_size
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
deserialized_objects[target_cdata] = torch.storage.TypedStorage(
# stop wrapping with _TypedStorage
deserialized_objects[target_cdata] = torch.storage._TypedStorage(
wrap_storage=root._storage[offset_bytes:offset_bytes + numel * element_size],
dtype=root.dtype)
@ -868,11 +868,11 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
nbytes = numel * torch._utils._element_size(dtype)
if root_key not in deserialized_objects:
obj = cast(Storage, torch.UntypedStorage(nbytes))
obj = cast(Storage, torch._UntypedStorage(nbytes))
obj._torch_load_uninitialized = True
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
deserialized_objects[root_key] = torch.storage.TypedStorage(
# stop wrapping with _TypedStorage
deserialized_objects[root_key] = torch.storage._TypedStorage(
wrap_storage=restore_location(obj, location),
dtype=dtype)
@ -883,8 +883,8 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
view_size_bytes = view_size * torch._utils._element_size(dtype)
if view_key not in deserialized_objects:
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
deserialized_objects[view_key] = torch.storage.TypedStorage(
# stop wrapping with _TypedStorage
deserialized_objects[view_key] = torch.storage._TypedStorage(
wrap_storage=typed_storage._storage[offset_bytes:offset_bytes + view_size_bytes],
dtype=dtype)
res = deserialized_objects[view_key]
@ -994,10 +994,10 @@ def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickl
def load_tensor(dtype, numel, key, location):
name = f'data/{key}'
storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage).storage()._untyped()
storage = zip_file.get_storage_from_record(name, numel, torch._UntypedStorage).storage()._untyped()
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
loaded_storages[key] = torch.storage.TypedStorage(
# stop wrapping with _TypedStorage
loaded_storages[key] = torch.storage._TypedStorage(
wrap_storage=restore_location(storage, location),
dtype=dtype)

View File

@ -8,7 +8,7 @@ import copy
import collections
from functools import lru_cache
T = TypeVar('T', bound='Union[_StorageBase, TypedStorage]')
T = TypeVar('T', bound='Union[_StorageBase, _TypedStorage]')
class _StorageBase(object):
_cdata: Any
is_cuda: bool = False
@ -213,7 +213,7 @@ def _storage_type_to_dtype_map():
val: key for key, val in _dtype_to_storage_type_map().items()}
return dtype_map
class TypedStorage:
class _TypedStorage:
is_sparse = False
def fill_(self, value):
@ -229,17 +229,17 @@ class TypedStorage:
' * no arguments\n'
' * (int size)\n'
' * (Sequence data)\n')
if type(self) == TypedStorage:
arg_error_msg += ' * (wrap_storage=<UntypedStorage>, dtype=<torch.dtype>)'
if type(self) == _TypedStorage:
arg_error_msg += ' * (wrap_storage=<_UntypedStorage>, dtype=<torch.dtype>)'
else:
arg_error_msg += ' * (wrap_storage=<UntypedStorage>)'
arg_error_msg += ' * (wrap_storage=<_UntypedStorage>)'
if 'wrap_storage' in kwargs:
assert len(args) == 0, (
"No positional arguments should be given when using "
"'wrap_storage'")
if type(self) == TypedStorage:
if type(self) == _TypedStorage:
assert 'dtype' in kwargs, (
"When using 'wrap_storage', 'dtype' also must be specified")
assert len(kwargs) == 2, (
@ -257,9 +257,9 @@ class TypedStorage:
storage = kwargs['wrap_storage']
if not isinstance(storage, (torch.UntypedStorage, torch.cuda.UntypedStorage)):
if not isinstance(storage, (torch._UntypedStorage, torch.cuda._UntypedStorage)):
raise TypeError(arg_error_msg)
if type(self) != TypedStorage and storage.__module__ != self.__module__:
if type(self) != _TypedStorage and storage.__module__ != self.__module__:
raise TypeError((
arg_error_msg +
f'\n`storage` `module {storage.__module__}` does not match '
@ -267,9 +267,9 @@ class TypedStorage:
self._storage = storage
else:
assert type(self) != TypedStorage, (
"Calling __init__ this way is only supported in TypedStorage's "
"child classes. TypedStorage can only be directly instantiated "
assert type(self) != _TypedStorage, (
"Calling __init__ this way is only supported in _TypedStorage's "
"child classes. _TypedStorage can only be directly instantiated "
"when kwargs 'wrap_storage' and 'dtype' are given.")
assert len(kwargs) == 0, "invalid keyword arguments"
@ -282,10 +282,10 @@ class TypedStorage:
return True
if len(args) == 0:
self._storage = eval(self.__module__).UntypedStorage()
self._storage = eval(self.__module__)._UntypedStorage()
elif len(args) == 1 and isint(args[0]):
self._storage = eval(self.__module__).UntypedStorage(int(args[0]) * self.element_size())
self._storage = eval(self.__module__)._UntypedStorage(int(args[0]) * self.element_size())
elif len(args) == 1 and isinstance(args[0], collections.abc.Sequence):
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
@ -321,10 +321,10 @@ class TypedStorage:
def _new_wrapped_storage(self, untyped_storage):
module = eval(untyped_storage.__module__)
assert type(untyped_storage) == module.UntypedStorage
assert type(untyped_storage) == module._UntypedStorage
if type(self) == TypedStorage:
return TypedStorage(wrap_storage=untyped_storage, dtype=self.dtype)
if type(self) == _TypedStorage:
return _TypedStorage(wrap_storage=untyped_storage, dtype=self.dtype)
else:
# NOTE: We need to use the module of untyped_storage in case self's
# module is different, e.g. if self is on CPU and untyped_storage
@ -371,7 +371,7 @@ class TypedStorage:
torch.qint8: torch.int8
}
tmp_dtype = interpret_dtypes[self.dtype]
tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self.device).set_(TypedStorage(
tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self.device).set_(_TypedStorage(
wrap_storage=self._storage,
dtype=tmp_dtype))
else:
@ -380,12 +380,12 @@ class TypedStorage:
tmp_tensor[idx] = value
def __getitem__(self, idx):
# NOTE: Before TypedStorage existed, indexing with a slice used to be
# NOTE: Before _TypedStorage existed, indexing with a slice used to be
# possible for <type>Storage objects. However, it would return
# a storage view, which would be a hassle to implement in TypedStorage,
# a storage view, which would be a hassle to implement in _TypedStorage,
# so it was disabled
if isinstance(idx, slice):
raise RuntimeError('slices are only supported in UntypedStorage.__getitem__')
raise RuntimeError('slices are only supported in _UntypedStorage.__getitem__')
elif not isinstance(idx, int):
raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
@ -397,7 +397,7 @@ class TypedStorage:
torch.qint32: torch.int32,
torch.qint8: torch.int8
}
return TypedStorage(
return _TypedStorage(
wrap_storage=self._storage,
dtype=interpret_dtypes[self.dtype])[idx]
@ -430,7 +430,7 @@ class TypedStorage:
def __str__(self):
data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size()))
if type(self) == TypedStorage:
if type(self) == _TypedStorage:
return data_str + (
f'\n[{torch.typename(self)} with dtype {self.dtype} '
f'of size {len(self)}]')
@ -450,7 +450,7 @@ class TypedStorage:
return self._new_wrapped_storage(copy.deepcopy(self._storage, memo))
def __sizeof__(self):
return super(TypedStorage, self).__sizeof__() + self.nbytes()
return super(_TypedStorage, self).__sizeof__() + self.nbytes()
def clone(self):
"""Returns a copy of this storage"""
@ -484,7 +484,7 @@ class TypedStorage:
def _new_shared(cls, size):
"""Creates a new storage in shared memory with the same data type"""
module = eval(cls.__module__)
untyped_storage = module.UntypedStorage._new_shared(size * cls().element_size())
untyped_storage = module._UntypedStorage._new_shared(size * cls().element_size())
return cls(wrap_storage=untyped_storage)
@property
@ -517,25 +517,25 @@ class TypedStorage:
@classmethod
def _free_weak_ref(cls, *args, **kwargs):
return eval(cls.__module__).UntypedStorage._free_weak_ref(*args, **kwargs)
return eval(cls.__module__)._UntypedStorage._free_weak_ref(*args, **kwargs)
def _weak_ref(self, *args, **kwargs):
return self._storage._weak_ref(*args, **kwargs)
@classmethod
def from_buffer(cls, *args, **kwargs):
if cls == TypedStorage:
if cls == _TypedStorage:
raise RuntimeError(
'from_buffer: only supported for subclasses of TypedStorage')
'from_buffer: only supported for subclasses of _TypedStorage')
if 'dtype' in kwargs or len(args) == 5:
raise RuntimeError((
"from_buffer: 'dtype' can only be specified in "
"UntypedStorage.from_buffer"))
"_UntypedStorage.from_buffer"))
kwargs['dtype'] = cls().dtype
untyped_storage = eval(cls.__module__).UntypedStorage.from_buffer(*args, **kwargs)
untyped_storage = eval(cls.__module__)._UntypedStorage.from_buffer(*args, **kwargs)
return cls(wrap_storage=untyped_storage)
def _to(self, dtype):
@ -594,9 +594,9 @@ class TypedStorage:
@classmethod
def from_file(cls, filename, shared, size):
if cls == TypedStorage:
if cls == _TypedStorage:
raise RuntimeError('from_file can only be called on derived classes')
untyped_storage = eval(cls.__module__).UntypedStorage.from_file(
untyped_storage = eval(cls.__module__)._UntypedStorage.from_file(
filename,
shared,
size * torch._utils._element_size(cls.dtype))
@ -605,7 +605,7 @@ class TypedStorage:
@classmethod
def _expired(cls, *args, **kwargs):
return eval(cls.__module__).UntypedStorage._expired(*args, **kwargs)
return eval(cls.__module__)._UntypedStorage._expired(*args, **kwargs)
def is_pinned(self):
return self._storage.is_pinned()
@ -627,11 +627,11 @@ class TypedStorage:
@classmethod
def _new_shared_cuda(cls, *args, **kwargs):
return eval(cls.__module__).UntypedStorage._new_shared_cuda(*args, **kwargs)
return eval(cls.__module__)._UntypedStorage._new_shared_cuda(*args, **kwargs)
@classmethod
def _new_with_weak_ptr(cls, *args, **kwargs):
return eval(cls.__module__).UntypedStorage._new_with_weak_ptr(*args, **kwargs)
return eval(cls.__module__)._UntypedStorage._new_with_weak_ptr(*args, **kwargs)
def _share_filename_(self, *args, **kwargs):
manager_handle, storage_handle, size = self._storage._share_filename_(*args, **kwargs)
@ -640,7 +640,7 @@ class TypedStorage:
@classmethod
def _new_shared_filename(cls, manager, obj, size):
bytes_size = size * torch._utils._element_size(cls.dtype)
return cls(wrap_storage=eval(cls.__module__).UntypedStorage._new_shared_filename(manager, obj, bytes_size))
return cls(wrap_storage=eval(cls.__module__)._UntypedStorage._new_shared_filename(manager, obj, bytes_size))
def _shared_decref(self):
self._storage._shared_decref()
@ -648,7 +648,7 @@ class TypedStorage:
@classmethod
def _release_ipc_counter(cls, *args, **kwargs):
return eval(cls.__module__).UntypedStorage._release_ipc_counter(*args, **kwargs)
return eval(cls.__module__)._UntypedStorage._release_ipc_counter(*args, **kwargs)
def _shared_incref(self, *args, **kwargs):
return self._storage._shared_incref(*args, **kwargs)

View File

@ -2133,7 +2133,7 @@ class TestCase(expecttest.TestCase):
),
sequence_types=(
Sequence,
torch.storage.TypedStorage,
torch.storage._TypedStorage,
Sequential,
ModuleList,
ParameterList,