mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
In order to better track models after serialization, this change writes a serialization_id as a UUID to inline container. Having this ID enables traceability of model in saving and loading events.
serialization_id is generated as a new UUID everytime serialization takes place. It can be thought of as a model snapshot identifier at the time of serialization.
Test Plan:
```
buck2 test @//mode/dev //caffe2/caffe2/serialize:inline_container_test
```
Local tests:
```
buck2 run @//mode/opt //scripts/atannous:example_pytorch_package
buck2 run @//mode/opt //scripts/atannous:example_pytorch
buck2 run @//mode/opt //scripts/atannous:example_pytorch_script
```
```
$ unzip -l output.pt
Archive: output.pt
Length Date Time Name
--------- ---------- ----- ----
36 00-00-1980 00:00 output/.data/serialization_id
358 00-00-1980 00:00 output/extra/producer_info.json
58 00-00-1980 00:00 output/data.pkl
261 00-00-1980 00:00 output/code/__torch__.py
326 00-00-1980 00:00 output/code/__torch__.py.debug_pkl
4 00-00-1980 00:00 output/constants.pkl
2 00-00-1980 00:00 output/version
--------- -------
1045 7 files
```
```
unzip -p output.pt "output/.data/serialization_id"
a9f903df-cbf6-40e3-8068-68086167ec60
```
Differential Revision: D45683657
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100994
Approved by: https://github.com/davidberard98
64 lines
1.8 KiB
Python
64 lines
1.8 KiB
Python
import os.path
|
|
from glob import glob
|
|
from typing import cast
|
|
|
|
import torch
|
|
from torch.types import Storage
|
|
|
|
__serialization_id_record_name__ = ".data/serialization_id"
|
|
|
|
|
|
# because get_storage_from_record returns a tensor!?
|
|
class _HasStorage:
|
|
def __init__(self, storage):
|
|
self._storage = storage
|
|
|
|
def storage(self):
|
|
return self._storage
|
|
|
|
|
|
class DirectoryReader:
|
|
"""
|
|
Class to allow PackageImporter to operate on unzipped packages. Methods
|
|
copy the behavior of the internal PyTorchFileReader class (which is used for
|
|
accessing packages in all other cases).
|
|
|
|
N.B.: ScriptObjects are not depickleable or accessible via this DirectoryReader
|
|
class due to ScriptObjects requiring an actual PyTorchFileReader instance.
|
|
"""
|
|
|
|
def __init__(self, directory):
|
|
self.directory = directory
|
|
|
|
def get_record(self, name):
|
|
filename = f"{self.directory}/{name}"
|
|
with open(filename, "rb") as f:
|
|
return f.read()
|
|
|
|
def get_storage_from_record(self, name, numel, dtype):
|
|
filename = f"{self.directory}/{name}"
|
|
nbytes = torch._utils._element_size(dtype) * numel
|
|
storage = cast(Storage, torch.UntypedStorage)
|
|
return _HasStorage(storage.from_file(filename=filename, nbytes=nbytes))
|
|
|
|
def has_record(self, path):
|
|
full_path = os.path.join(self.directory, path)
|
|
return os.path.isfile(full_path)
|
|
|
|
def get_all_records(
|
|
self,
|
|
):
|
|
files = []
|
|
for filename in glob(f"{self.directory}/**", recursive=True):
|
|
if not os.path.isdir(filename):
|
|
files.append(filename[len(self.directory) + 1 :])
|
|
return files
|
|
|
|
def serialization_id(
|
|
self,
|
|
):
|
|
if self.has_record(__serialization_id_record_name__):
|
|
return self.get_record(__serialization_id_record_name__)
|
|
else:
|
|
return ""
|