pytorch/torch/utils/show_pickle.py
Aaron Gokaslan c5fafe9f48 [BE]: TRY002 - Ban raising vanilla exceptions (#124570)
Adds a ruff lint rule to ban raising raw exceptions. Most of these should at the very least be runtime exception, value errors, type errors or some other errors. There are hundreds of instance of these bad exception types already in the codebase, so I have noqa'd most of them. Hopefully this error code will get commiters to rethink what exception type they should raise when they submit a PR.

I also encourage people to gradually go and fix all the existing noqas that have been added so they can be removed overtime and our exception typing can be improved.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124570
Approved by: https://github.com/ezyang
2024-04-21 22:26:40 +00:00

151 lines
5.3 KiB
Python

#!/usr/bin/env python3
import sys
import pickle
import struct
import pprint
import zipfile
import fnmatch
from typing import Any, IO, BinaryIO, Union
__all__ = ["FakeObject", "FakeClass", "DumpUnpickler", "main"]
class FakeObject:
def __init__(self, module, name, args):
self.module = module
self.name = name
self.args = args
# NOTE: We don't distinguish between state never set and state set to None.
self.state = None
def __repr__(self):
state_str = "" if self.state is None else f"(state={self.state!r})"
return f"{self.module}.{self.name}{self.args!r}{state_str}"
def __setstate__(self, state):
self.state = state
@staticmethod
def pp_format(printer, obj, stream, indent, allowance, context, level):
if not obj.args and obj.state is None:
stream.write(repr(obj))
return
if obj.state is None:
stream.write(f"{obj.module}.{obj.name}")
printer._format(obj.args, stream, indent + 1, allowance + 1, context, level)
return
if not obj.args:
stream.write(f"{obj.module}.{obj.name}()(state=\n")
indent += printer._indent_per_level
stream.write(" " * indent)
printer._format(obj.state, stream, indent, allowance + 1, context, level + 1)
stream.write(")")
return
raise Exception("Need to implement") # noqa: TRY002
class FakeClass:
def __init__(self, module, name):
self.module = module
self.name = name
self.__new__ = self.fake_new # type: ignore[assignment]
def __repr__(self):
return f"{self.module}.{self.name}"
def __call__(self, *args):
return FakeObject(self.module, self.name, args)
def fake_new(self, *args):
return FakeObject(self.module, self.name, args[1:])
class DumpUnpickler(pickle._Unpickler): # type: ignore[name-defined]
def __init__(
self,
file,
*,
catch_invalid_utf8=False,
**kwargs):
super().__init__(file, **kwargs)
self.catch_invalid_utf8 = catch_invalid_utf8
def find_class(self, module, name):
return FakeClass(module, name)
def persistent_load(self, pid):
return FakeObject("pers", "obj", (pid,))
dispatch = dict(pickle._Unpickler.dispatch) # type: ignore[attr-defined]
# Custom objects in TorchScript are able to return invalid UTF-8 strings
# from their pickle (__getstate__) functions. Install a custom loader
# for strings that catches the decode exception and replaces it with
# a sentinel object.
def load_binunicode(self):
strlen, = struct.unpack("<I", self.read(4)) # type: ignore[attr-defined]
if strlen > sys.maxsize:
raise Exception("String too long.") # noqa: TRY002
str_bytes = self.read(strlen) # type: ignore[attr-defined]
obj: Any
try:
obj = str(str_bytes, "utf-8", "surrogatepass")
except UnicodeDecodeError as exn:
if not self.catch_invalid_utf8:
raise
obj = FakeObject("builtin", "UnicodeDecodeError", (str(exn),))
self.append(obj) # type: ignore[attr-defined]
dispatch[pickle.BINUNICODE[0]] = load_binunicode # type: ignore[assignment]
@classmethod
def dump(cls, in_stream, out_stream):
value = cls(in_stream).load()
pprint.pprint(value, stream=out_stream)
return value
def main(argv, output_stream=None):
if len(argv) != 2:
# Don't spam stderr if not using stdout.
if output_stream is not None:
raise Exception("Pass argv of length 2.") # noqa: TRY002
sys.stderr.write("usage: show_pickle PICKLE_FILE\n")
sys.stderr.write(" PICKLE_FILE can be any of:\n")
sys.stderr.write(" path to a pickle file\n")
sys.stderr.write(" file.zip@member.pkl\n")
sys.stderr.write(" file.zip@*/pattern.*\n")
sys.stderr.write(" (shell glob pattern for members)\n")
sys.stderr.write(" (only first match will be shown)\n")
return 2
fname = argv[1]
handle: Union[IO[bytes], BinaryIO]
if "@" not in fname:
with open(fname, "rb") as handle:
DumpUnpickler.dump(handle, output_stream)
else:
zfname, mname = fname.split("@", 1)
with zipfile.ZipFile(zfname) as zf:
if "*" not in mname:
with zf.open(mname) as handle:
DumpUnpickler.dump(handle, output_stream)
else:
found = False
for info in zf.infolist():
if fnmatch.fnmatch(info.filename, mname):
with zf.open(info) as handle:
DumpUnpickler.dump(handle, output_stream)
found = True
break
if not found:
raise Exception(f"Could not find member matching {mname} in {zfname}") # noqa: TRY002
if __name__ == "__main__":
# This hack works on every version of Python I've tested.
# I've tested on the following versions:
# 3.7.4
if True:
pprint.PrettyPrinter._dispatch[FakeObject.__repr__] = FakeObject.pp_format # type: ignore[attr-defined]
sys.exit(main(sys.argv))