pytorch/test/test_show_pickle.py
David Reiss a054d05707 Add torch.utils.show_pickle for showing pickle contents in saved models (#35168)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35168

Sometimes when a saved model isn't working, it's nice to be able to look
at the contents of the pickle files.  Unfortunately, pickletools output
isn't particularly readable, and unpickling is often either not possible
or runs so much post-processing code that it's not possible to tell
exactly what is present in the pickled data.

This script uses a custom Unpickler to unpickle (almost) any data into
stub objects that have no dependency on torch or any other runtime types
and suppress (almost) any postprocessing code.

As a convenience, the wrapper can search through zip files, supporting
command lines like

`python -m torch.utils.show_pickle /path/to/model.pt1@*/data.pkl`

When the module is invoked as main, we also install a hack in pprint to
allow semi-resonable formatting of our stub objects.

Test Plan: Ran it on a data.pkl, constants.pkl, and a debug pkl

Differential Revision: D20842550

Pulled By: dreiss

fbshipit-source-id: ef662d8915fc5795039054d1f8fef2e1c51cf40a
2020-04-03 15:11:20 -07:00

35 lines
1012 B
Python

import unittest
import io
import tempfile
import torch
import torch.utils.show_pickle
from torch.testing._internal.common_utils import IS_WINDOWS
class TestShowPickle(unittest.TestCase):
@unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows")
def test_scripted_model(self):
class MyCoolModule(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight
def forward(self, x):
return x * self.weight
m = torch.jit.script(MyCoolModule(torch.tensor([2.0])))
with tempfile.NamedTemporaryFile() as tmp:
torch.jit.save(m, tmp)
tmp.flush()
buf = io.StringIO()
torch.utils.show_pickle.main(["", tmp.name + "@*/data.pkl"], output_stream=buf)
output = buf.getvalue()
self.assertRegex(output, "MyCoolModule")
self.assertRegex(output, "weight")
if __name__ == '__main__':
unittest.main()