mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72237 add a generic zip file reader/writer to torch.package in order to get rid of dependency on torch for non torchscript / tensor related usages of package. This also enables users to create a derived class from the zip file reader/writer classes to have their own serialization/deserialization if it's desired for performance needs. https://www.internalfb.com/intern/diff/D35423079/ was reverted due to this refactor changing the name of where most of the implementation components of PackageExporter/PackageImporter come from like ModuleActionType_ etc. This diff also changes the import paths where these components come from to point to the correct file compared to D35423079 Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D35423079 Pulled By: PaliC fbshipit-source-id: 31abc4364d5fd007911cfb67cf36ebfac5d786f4 (cherry picked from commit 023b0d1445e0b1e1bb7a03c660cd62eb9d26d2a6)
103 lines
3.7 KiB
Python
103 lines
3.7 KiB
Python
import io
|
|
import torch
|
|
from torch.package._package_pickler import create_pickler
|
|
from torch.package._package_unpickler import PackageUnpickler
|
|
from torch.package import sys_importer, OrderedImporter, PackageImporter, Importer
|
|
from torch.package._zip_file_torchscript import TorchScriptPackageZipFileReader
|
|
from torch.serialization import _maybe_decode_ascii
|
|
|
|
def _save_storages(importer, obj):
|
|
serialized_storages = []
|
|
serialized_dtypes = []
|
|
|
|
importer = importer if isinstance(importer, torch.package.PackageImporter) else None
|
|
importers: Importer
|
|
if importer is not None:
|
|
importers = OrderedImporter(importer, sys_importer)
|
|
else:
|
|
importers = sys_importer
|
|
|
|
def persistent_id(obj):
|
|
if torch.is_storage(obj) or isinstance(obj, torch.storage._TypedStorage):
|
|
if isinstance(obj, torch.storage._TypedStorage):
|
|
# TODO: Once we decide to break serialization FC, we can
|
|
# remove this case
|
|
storage = obj._storage
|
|
dtype = obj.dtype
|
|
else:
|
|
storage = obj
|
|
dtype = torch.uint8
|
|
|
|
serialized_storages.append(obj)
|
|
serialized_dtypes.append(dtype)
|
|
return ('storage', len(serialized_storages) - 1)
|
|
|
|
if hasattr(obj, "__reduce_deploy__"):
|
|
if _serialized_reduces.get(id(obj)) is None:
|
|
_serialized_reduces[id(obj)] = (
|
|
"reduce_deploy",
|
|
id(obj),
|
|
*obj.__reduce_deploy__(importers),
|
|
)
|
|
return _serialized_reduces[id(obj)]
|
|
|
|
return None
|
|
|
|
# Write the pickle data for `obj`
|
|
data_buf = io.BytesIO()
|
|
pickler = create_pickler(data_buf, importers)
|
|
pickler.persistent_id = persistent_id
|
|
pickler.dump(obj)
|
|
data_value = data_buf.getvalue()
|
|
|
|
assert (not importer or isinstance(importer.zip_reader, TorchScriptPackageZipFileReader)), \
|
|
f'importer {importer}\'s zip reader is of type {type(importer.zip_reader)} not TorchScriptPackageZipFileReader'
|
|
return data_value, serialized_storages, serialized_dtypes, importer.zip_reader.zip_reader if importer else None
|
|
|
|
def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
|
|
|
|
def persistent_load(saved_id):
|
|
assert isinstance(saved_id, tuple)
|
|
typename = _maybe_decode_ascii(saved_id[0])
|
|
data = saved_id[1:]
|
|
|
|
if typename == 'storage':
|
|
# TODO: Once we decide to break serialization FC, we can
|
|
# stop wrapping with _TypedStorage
|
|
storage = serialized_storages[data[0]]
|
|
dtype = serialized_dtypes[data[0]]
|
|
return torch.storage._TypedStorage(
|
|
wrap_storage=storage._untyped(),
|
|
dtype=dtype)
|
|
|
|
if typename == 'reduce_deploy':
|
|
reduce_id, func, args = data
|
|
if reduce_id not in _loaded_reduces:
|
|
_loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
|
|
return _loaded_reduces[reduce_id]
|
|
|
|
return None
|
|
|
|
|
|
importer: Importer
|
|
if zip_reader is not None:
|
|
importer = OrderedImporter(_get_package(zip_reader), sys_importer)
|
|
else:
|
|
importer = sys_importer
|
|
|
|
unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
|
|
unpickler.persistent_load = persistent_load
|
|
result = _deploy_objects[id] = unpickler.load()
|
|
return result
|
|
|
|
def _get_package(zip_reader):
|
|
if zip_reader not in _raw_packages:
|
|
_raw_packages[zip_reader] = PackageImporter(zip_reader)
|
|
return _raw_packages[zip_reader]
|
|
|
|
|
|
_raw_packages: dict = {}
|
|
_deploy_objects: dict = {}
|
|
_serialized_reduces: dict = {}
|
|
_loaded_reduces: dict = {}
|