[package] make torch.package produce unified format (#51826)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51826

Looks like this:
```
resnet.pt
├── .data  # Data folder named so it can't clash with torch.package codemodules.
│   │      # Names/extensions automatically added to avoid namingconflicts.
│   ├── 94286146172688.storage   # tensor data
│   ├── 94286146172784.storage
│   ├── extern_modules           # torch.package metadata
│   ├── version                  # version metadata
│   └── ...
├── model  # package pickled model created w/
│   │      # exporter.save_pickel('model','model.pkl', resnet_model)
│   └── model.pkl
└── torchvision  # all code dependencies for packaged picked
    └── models   # models are captured as source files
            ├── resnet.py
                    └── utils.py
```

Since `version` is hardcoded in our zip reader/writer implementation,
add it as an option that defaults to "version" but accepts other
locations for putting the version metadata.

Test Plan: Imported from OSS

Reviewed By: zdevito

Differential Revision: D26295649

Pulled By: suo

fbshipit-source-id: 2d75feeb7de0f78196b4d0b6e2b814a7d58bd1dd
This commit is contained in:
Michael Suo 2021-02-09 07:34:05 -08:00 committed by Facebook GitHub Bot
parent 85b25257ff
commit c357f8b826
6 changed files with 20 additions and 7 deletions

View File

@ -115,7 +115,12 @@ void PyTorchStreamReader::init() {
// version check
at::DataPtr version_ptr;
size_t version_size;
std::tie(version_ptr, version_size) = getRecord("version");
if (hasRecord(".data/version")) {
std::tie(version_ptr, version_size) = getRecord(".data/version");
} else {
TORCH_CHECK(hasRecord("version"))
std::tie(version_ptr, version_size) = getRecord("version");
}
std::string version(static_cast<const char*>(version_ptr.get()), version_size);
version_ = caffe2::stoull(version);
AT_ASSERTM(
@ -357,7 +362,11 @@ void PyTorchStreamWriter::writeEndOfFile() {
// Rewrites version info
std::string version = c10::to_string(version_);
version.push_back('\n');
writeRecord("version", version.c_str(), version.size());
if (version_ >= 0x6L) {
writeRecord(".data/version", version.c_str(), version.size());
} else {
writeRecord("version", version.c_str(), version.size());
}
AT_ASSERT(!finalized_);
finalized_ = true;

View File

@ -6,7 +6,7 @@ namespace caffe2 {
namespace serialize {
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x5L;
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
// Versions (i.e. why was the version number bumped?)
@ -46,6 +46,7 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x5L;
// (a versioned symbol preserves the historic behavior of versions 1--3)
// 5. (Dynamic) Stops torch.full inferring a floating point dtype
// when given bool or integer fill values.
// 6. Write version string to `./data/version` instead of `version`.
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
// The version we write when the archive contains bytecode.

View File

@ -591,6 +591,7 @@ class PyTorchFileWriter(object):
def __init__(self, buffer: BinaryIO) -> None: ...
def write_record(self, name: str, data: bytes, size: _int) -> None: ...
def write_end_of_file(self) -> None: ...
def set_min_version(self, version: _int) -> None: ...
...
def _jit_get_inline_everything_mode() -> _bool: ...

View File

@ -868,6 +868,7 @@ void initJITBindings(PyObject* module) {
const char* data,
size_t size) { return self.writeRecord(name, data, size); })
.def("write_end_of_file", &PyTorchStreamWriter::writeEndOfFile)
.def("set_min_version", &PyTorchStreamWriter::setMinVersion)
.def(
"write_record",
[](PyTorchStreamWriter& self,

View File

@ -74,6 +74,7 @@ class PackageExporter:
self.buffer = f
self.zip_file = torch._C.PyTorchFileWriter(f)
self.zip_file.set_min_version(6)
self.serialized_storages : Dict[str, Any] = {}
self.external : List[str] = []
self.provided : Dict[str, bool] = {}
@ -427,7 +428,7 @@ node [shape=box];
# Write each tensor to a file named tensor/the_tensor_key in the zip archive
for key in sorted(self.serialized_storages.keys()):
name = 'data/{}'.format(key)
name = f'.data/{key}.storage'
storage = self.serialized_storages[key]
# location information is saved in python, but to actually
# get the data from non cpu tensors we need to move them over first
@ -436,7 +437,7 @@ node [shape=box];
num_bytes = storage.size() * storage.element_size()
self.zip_file.write_record(name, storage.data_ptr(), num_bytes)
contents = ('\n'.join(self.external) + '\n')
self._write('extern_modules', contents)
self._write('.data/extern_modules', contents)
del self.zip_file
if self.buffer:
self.buffer.flush()

View File

@ -149,7 +149,7 @@ class PackageImporter:
loaded_storages = {}
def load_tensor(data_type, size, key, location, restore_location):
name = f'data/{key}'
name = f'.data/{key}.storage'
dtype = data_type(0).dtype
storage = self.zip_reader.get_storage_from_record(name, size, dtype).storage()
@ -191,7 +191,7 @@ class PackageImporter:
return self._mangler.parent_name()
def _read_extern(self):
return self.zip_reader.get_record('extern_modules').decode('utf-8').splitlines(keepends=False)
return self.zip_reader.get_record('.data/extern_modules').decode('utf-8').splitlines(keepends=False)
def _make_module(self, name: str, filename: Optional[str], is_package: bool, parent: str):
mangled_filename = self._mangler.mangle(filename) if filename else None