mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[package] make torch.package produce unified format (#51826)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51826 Looks like this: ``` resnet.pt ├── .data # Data folder named so it can't clash with torch.package codemodules. │ │ # Names/extensions automatically added to avoid namingconflicts. │ ├── 94286146172688.storage # tensor data │ ├── 94286146172784.storage │ ├── extern_modules # torch.package metadata │ ├── version # version metadata │ └── ... ├── model # package pickled model created w/ │ │ # exporter.save_pickel('model','model.pkl', resnet_model) │ └── model.pkl └── torchvision # all code dependencies for packaged picked └── models # models are captured as source files ├── resnet.py └── utils.py ``` Since `version` is hardcoded in our zip reader/writer implementation, add it as an option that defaults to "version" but accepts other locations for putting the version metadata. Test Plan: Imported from OSS Reviewed By: zdevito Differential Revision: D26295649 Pulled By: suo fbshipit-source-id: 2d75feeb7de0f78196b4d0b6e2b814a7d58bd1dd
This commit is contained in:
parent
85b25257ff
commit
c357f8b826
|
|
@ -115,7 +115,12 @@ void PyTorchStreamReader::init() {
|
|||
// version check
|
||||
at::DataPtr version_ptr;
|
||||
size_t version_size;
|
||||
std::tie(version_ptr, version_size) = getRecord("version");
|
||||
if (hasRecord(".data/version")) {
|
||||
std::tie(version_ptr, version_size) = getRecord(".data/version");
|
||||
} else {
|
||||
TORCH_CHECK(hasRecord("version"))
|
||||
std::tie(version_ptr, version_size) = getRecord("version");
|
||||
}
|
||||
std::string version(static_cast<const char*>(version_ptr.get()), version_size);
|
||||
version_ = caffe2::stoull(version);
|
||||
AT_ASSERTM(
|
||||
|
|
@ -357,7 +362,11 @@ void PyTorchStreamWriter::writeEndOfFile() {
|
|||
// Rewrites version info
|
||||
std::string version = c10::to_string(version_);
|
||||
version.push_back('\n');
|
||||
writeRecord("version", version.c_str(), version.size());
|
||||
if (version_ >= 0x6L) {
|
||||
writeRecord(".data/version", version.c_str(), version.size());
|
||||
} else {
|
||||
writeRecord("version", version.c_str(), version.size());
|
||||
}
|
||||
|
||||
AT_ASSERT(!finalized_);
|
||||
finalized_ = true;
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ namespace caffe2 {
|
|||
namespace serialize {
|
||||
|
||||
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
|
||||
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x5L;
|
||||
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
|
||||
|
||||
// Versions (i.e. why was the version number bumped?)
|
||||
|
||||
|
|
@ -46,6 +46,7 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x5L;
|
|||
// (a versioned symbol preserves the historic behavior of versions 1--3)
|
||||
// 5. (Dynamic) Stops torch.full inferring a floating point dtype
|
||||
// when given bool or integer fill values.
|
||||
// 6. Write version string to `./data/version` instead of `version`.
|
||||
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
|
||||
|
||||
// The version we write when the archive contains bytecode.
|
||||
|
|
|
|||
|
|
@ -591,6 +591,7 @@ class PyTorchFileWriter(object):
|
|||
def __init__(self, buffer: BinaryIO) -> None: ...
|
||||
def write_record(self, name: str, data: bytes, size: _int) -> None: ...
|
||||
def write_end_of_file(self) -> None: ...
|
||||
def set_min_version(self, version: _int) -> None: ...
|
||||
...
|
||||
|
||||
def _jit_get_inline_everything_mode() -> _bool: ...
|
||||
|
|
|
|||
|
|
@ -868,6 +868,7 @@ void initJITBindings(PyObject* module) {
|
|||
const char* data,
|
||||
size_t size) { return self.writeRecord(name, data, size); })
|
||||
.def("write_end_of_file", &PyTorchStreamWriter::writeEndOfFile)
|
||||
.def("set_min_version", &PyTorchStreamWriter::setMinVersion)
|
||||
.def(
|
||||
"write_record",
|
||||
[](PyTorchStreamWriter& self,
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ class PackageExporter:
|
|||
self.buffer = f
|
||||
|
||||
self.zip_file = torch._C.PyTorchFileWriter(f)
|
||||
self.zip_file.set_min_version(6)
|
||||
self.serialized_storages : Dict[str, Any] = {}
|
||||
self.external : List[str] = []
|
||||
self.provided : Dict[str, bool] = {}
|
||||
|
|
@ -427,7 +428,7 @@ node [shape=box];
|
|||
|
||||
# Write each tensor to a file named tensor/the_tensor_key in the zip archive
|
||||
for key in sorted(self.serialized_storages.keys()):
|
||||
name = 'data/{}'.format(key)
|
||||
name = f'.data/{key}.storage'
|
||||
storage = self.serialized_storages[key]
|
||||
# location information is saved in python, but to actually
|
||||
# get the data from non cpu tensors we need to move them over first
|
||||
|
|
@ -436,7 +437,7 @@ node [shape=box];
|
|||
num_bytes = storage.size() * storage.element_size()
|
||||
self.zip_file.write_record(name, storage.data_ptr(), num_bytes)
|
||||
contents = ('\n'.join(self.external) + '\n')
|
||||
self._write('extern_modules', contents)
|
||||
self._write('.data/extern_modules', contents)
|
||||
del self.zip_file
|
||||
if self.buffer:
|
||||
self.buffer.flush()
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ class PackageImporter:
|
|||
loaded_storages = {}
|
||||
|
||||
def load_tensor(data_type, size, key, location, restore_location):
|
||||
name = f'data/{key}'
|
||||
name = f'.data/{key}.storage'
|
||||
dtype = data_type(0).dtype
|
||||
|
||||
storage = self.zip_reader.get_storage_from_record(name, size, dtype).storage()
|
||||
|
|
@ -191,7 +191,7 @@ class PackageImporter:
|
|||
return self._mangler.parent_name()
|
||||
|
||||
def _read_extern(self):
|
||||
return self.zip_reader.get_record('extern_modules').decode('utf-8').splitlines(keepends=False)
|
||||
return self.zip_reader.get_record('.data/extern_modules').decode('utf-8').splitlines(keepends=False)
|
||||
|
||||
def _make_module(self, name: str, filename: Optional[str], is_package: bool, parent: str):
|
||||
mangled_filename = self._mangler.mangle(filename) if filename else None
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user