mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67499 Since https://github.com/pytorch/pytorch/pull/62030 was landed, storages being produced when loading from a pickle are of type TypedStorage. We weren't catching this in our deploy serialization, leading tensors to actually get pickled instead of the storages getting shared across interpreters. Since this is technically correct still, it wasn't caught by any of our tests, until someone tried to pass a really big tensor and started ooming. ghstack-source-id: 141869521 Test Plan: added unit test Reviewed By: shunting314 Differential Revision: D32004075 fbshipit-source-id: ef5a80cd3cb1dff0b6b4c1b6c95923e4faab7d50
99 lines
3.3 KiB
Python
99 lines
3.3 KiB
Python
import io
|
|
import torch
|
|
from torch.package._package_pickler import create_pickler
|
|
from torch.package._package_unpickler import PackageUnpickler
|
|
from torch.package import sys_importer, OrderedImporter, PackageImporter, Importer
|
|
from torch.serialization import _maybe_decode_ascii
|
|
|
|
def _save_storages(importer, obj):
|
|
serialized_storages = []
|
|
serialized_dtypes = []
|
|
|
|
importer = importer if isinstance(importer, torch.package.PackageImporter) else None
|
|
importers: Importer
|
|
if importer is not None:
|
|
importers = OrderedImporter(importer, sys_importer)
|
|
else:
|
|
importers = sys_importer
|
|
|
|
def persistent_id(obj):
|
|
if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
|
|
if isinstance(obj, torch.storage.TypedStorage):
|
|
# TODO: Once we decide to break serialization FC, we can
|
|
# remove this case
|
|
storage = obj._storage
|
|
dtype = obj.dtype
|
|
else:
|
|
storage = obj
|
|
dtype = torch.uint8
|
|
|
|
serialized_storages.append(obj)
|
|
serialized_dtypes.append(dtype)
|
|
return ('storage', len(serialized_storages) - 1)
|
|
|
|
if hasattr(obj, "__reduce_deploy__"):
|
|
if _serialized_reduces.get(id(obj)) is None:
|
|
_serialized_reduces[id(obj)] = (
|
|
"reduce_deploy",
|
|
id(obj),
|
|
*obj.__reduce_deploy__(importers),
|
|
)
|
|
return _serialized_reduces[id(obj)]
|
|
|
|
return None
|
|
|
|
# Write the pickle data for `obj`
|
|
data_buf = io.BytesIO()
|
|
pickler = create_pickler(data_buf, importers)
|
|
pickler.persistent_id = persistent_id
|
|
pickler.dump(obj)
|
|
data_value = data_buf.getvalue()
|
|
return data_value, serialized_storages, serialized_dtypes, importer.zip_reader if importer else None
|
|
|
|
def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
|
|
|
|
def persistent_load(saved_id):
|
|
assert isinstance(saved_id, tuple)
|
|
typename = _maybe_decode_ascii(saved_id[0])
|
|
data = saved_id[1:]
|
|
|
|
if typename == 'storage':
|
|
# TODO: Once we decide to break serialization FC, we can
|
|
# stop wrapping with TypedStorage
|
|
storage = serialized_storages[data[0]]
|
|
dtype = serialized_dtypes[data[0]]
|
|
return torch.storage.TypedStorage(
|
|
wrap_storage=storage._untyped(),
|
|
dtype=dtype)
|
|
|
|
if typename == 'reduce_deploy':
|
|
reduce_id, func, args = data
|
|
if reduce_id not in _loaded_reduces:
|
|
_loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
|
|
return _loaded_reduces[reduce_id]
|
|
|
|
return None
|
|
|
|
|
|
importer: Importer
|
|
if zip_reader is not None:
|
|
importer = OrderedImporter(_get_package(zip_reader), sys_importer)
|
|
else:
|
|
importer = sys_importer
|
|
|
|
unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
|
|
unpickler.persistent_load = persistent_load
|
|
result = _deploy_objects[id] = unpickler.load()
|
|
return result
|
|
|
|
def _get_package(zip_reader):
|
|
if zip_reader not in _raw_packages:
|
|
_raw_packages[zip_reader] = PackageImporter(zip_reader)
|
|
return _raw_packages[zip_reader]
|
|
|
|
|
|
_raw_packages: dict = {}
|
|
_deploy_objects: dict = {}
|
|
_serialized_reduces: dict = {}
|
|
_loaded_reduces: dict = {}
|