mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70642 Review history on https://github.com/pytorch/pytorch/pull/70014 Test Plan: Imported from OSS Reviewed By: suo Differential Revision: D33414364 Pulled By: PaliC fbshipit-source-id: 7e7ed491c6f16d4fac3a03f7e403935823c03aa6
108 lines
4.1 KiB
Python
108 lines
4.1 KiB
Python
"""isort:skip_file"""
|
|
from pickle import EXT1, EXT2, EXT4, GLOBAL, STACK_GLOBAL, Pickler, PicklingError
|
|
from pickle import _compat_pickle, _extension_registry, _getattribute, _Pickler # type: ignore[attr-defined]
|
|
from struct import pack
|
|
from types import FunctionType
|
|
|
|
from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer
|
|
|
|
|
|
class PackagePickler(_Pickler):
|
|
"""Package-aware pickler.
|
|
|
|
This behaves the same as a normal pickler, except it uses an `Importer`
|
|
to find objects and modules to save.
|
|
"""
|
|
|
|
def __init__(self, importer: Importer, *args, **kwargs):
|
|
self.importer = importer
|
|
super().__init__(*args, **kwargs)
|
|
|
|
# Make sure the dispatch table copied from _Pickler is up-to-date.
|
|
# Previous issues have been encountered where a library (e.g. dill)
|
|
# mutate _Pickler.dispatch, PackagePickler makes a copy when this lib
|
|
# is imported, then the offending library removes its dispatch entries,
|
|
# leaving PackagePickler with a stale dispatch table that may cause
|
|
# unwanted behavior.
|
|
self.dispatch = _Pickler.dispatch.copy()
|
|
self.dispatch[FunctionType] = PackagePickler.save_global
|
|
|
|
def save_global(self, obj, name=None):
|
|
# unfortunately the pickler code is factored in a way that
|
|
# forces us to copy/paste this function. The only change is marked
|
|
# CHANGED below.
|
|
write = self.write
|
|
memo = self.memo
|
|
|
|
# CHANGED: import module from module environment instead of __import__
|
|
try:
|
|
module_name, name = self.importer.get_name(obj, name)
|
|
except (ObjNotFoundError, ObjMismatchError) as err:
|
|
raise PicklingError(f"Can't pickle {obj}: {str(err)}") from None
|
|
|
|
module = self.importer.import_module(module_name)
|
|
_, parent = _getattribute(module, name)
|
|
# END CHANGED
|
|
|
|
if self.proto >= 2:
|
|
code = _extension_registry.get((module_name, name))
|
|
if code:
|
|
assert code > 0
|
|
if code <= 0xFF:
|
|
write(EXT1 + pack("<B", code))
|
|
elif code <= 0xFFFF:
|
|
write(EXT2 + pack("<H", code))
|
|
else:
|
|
write(EXT4 + pack("<i", code))
|
|
return
|
|
lastname = name.rpartition(".")[2]
|
|
if parent is module:
|
|
name = lastname
|
|
# Non-ASCII identifiers are supported only with protocols >= 3.
|
|
if self.proto >= 4:
|
|
self.save(module_name)
|
|
self.save(name)
|
|
write(STACK_GLOBAL)
|
|
elif parent is not module:
|
|
self.save_reduce(getattr, (parent, lastname))
|
|
elif self.proto >= 3:
|
|
write(
|
|
GLOBAL
|
|
+ bytes(module_name, "utf-8")
|
|
+ b"\n"
|
|
+ bytes(name, "utf-8")
|
|
+ b"\n"
|
|
)
|
|
else:
|
|
if self.fix_imports:
|
|
r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
|
|
r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
|
|
if (module_name, name) in r_name_mapping:
|
|
module_name, name = r_name_mapping[(module_name, name)]
|
|
elif module_name in r_import_mapping:
|
|
module_name = r_import_mapping[module_name]
|
|
try:
|
|
write(
|
|
GLOBAL
|
|
+ bytes(module_name, "ascii")
|
|
+ b"\n"
|
|
+ bytes(name, "ascii")
|
|
+ b"\n"
|
|
)
|
|
except UnicodeEncodeError:
|
|
raise PicklingError(
|
|
"can't pickle global identifier '%s.%s' using "
|
|
"pickle protocol %i" % (module, name, self.proto)
|
|
) from None
|
|
|
|
self.memoize(obj)
|
|
|
|
|
|
def create_pickler(data_buf, importer, protocol=4):
|
|
if importer is sys_importer:
|
|
# if we are using the normal import library system, then
|
|
# we can use the C implementation of pickle which is faster
|
|
return Pickler(data_buf, protocol=protocol)
|
|
else:
|
|
return PackagePickler(importer, data_buf, protocol=protocol)
|