mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes https://github.com/pytorch/pytorch/issues/75927. Had to fix some bugs and add some ignores. To check if clean: ``` lintrunner --paths-cmd='git grep -Il .' --take MYPY,MYPYSTRICT ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/76753 Approved by: https://github.com/malfet
27 lines
965 B
Python
27 lines
965 B
Python
import _compat_pickle
|
|
import pickle
|
|
|
|
from .importer import Importer
|
|
|
|
|
|
class PackageUnpickler(pickle._Unpickler): # type: ignore[name-defined]
|
|
"""Package-aware unpickler.
|
|
|
|
This behaves the same as a normal unpickler, except it uses `importer` to
|
|
find any global names that it encounters while unpickling.
|
|
"""
|
|
|
|
def __init__(self, importer: Importer, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._importer = importer
|
|
|
|
def find_class(self, module, name):
|
|
# Subclasses may override this.
|
|
if self.proto < 3 and self.fix_imports: # type: ignore[attr-defined]
|
|
if (module, name) in _compat_pickle.NAME_MAPPING:
|
|
module, name = _compat_pickle.NAME_MAPPING[(module, name)]
|
|
elif module in _compat_pickle.IMPORT_MAPPING:
|
|
module = _compat_pickle.IMPORT_MAPPING[module]
|
|
mod = self._importer.import_module(module)
|
|
return getattr(mod, name)
|