diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index 3d9701274ba..0859e2d1695 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -115,7 +115,12 @@ void PyTorchStreamReader::init() { // version check at::DataPtr version_ptr; size_t version_size; - std::tie(version_ptr, version_size) = getRecord("version"); + if (hasRecord(".data/version")) { + std::tie(version_ptr, version_size) = getRecord(".data/version"); + } else { + TORCH_CHECK(hasRecord("version")) + std::tie(version_ptr, version_size) = getRecord("version"); + } std::string version(static_cast(version_ptr.get()), version_size); version_ = caffe2::stoull(version); AT_ASSERTM( @@ -357,7 +362,11 @@ void PyTorchStreamWriter::writeEndOfFile() { // Rewrites version info std::string version = c10::to_string(version_); version.push_back('\n'); - writeRecord("version", version.c_str(), version.size()); + if (version_ >= 0x6L) { + writeRecord(".data/version", version.c_str(), version.size()); + } else { + writeRecord("version", version.c_str(), version.size()); + } AT_ASSERT(!finalized_); finalized_ = true; diff --git a/caffe2/serialize/versions.h b/caffe2/serialize/versions.h index 7466bc0cc24..7e00b6adeba 100644 --- a/caffe2/serialize/versions.h +++ b/caffe2/serialize/versions.h @@ -6,7 +6,7 @@ namespace caffe2 { namespace serialize { constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; -constexpr uint64_t kMaxSupportedFileFormatVersion = 0x5L; +constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L; // Versions (i.e. why was the version number bumped?) @@ -46,6 +46,7 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x5L; // (a versioned symbol preserves the historic behavior of versions 1--3) // 5. (Dynamic) Stops torch.full inferring a floating point dtype // when given bool or integer fill values. +// 6. Write version string to `./data/version` instead of `version`. constexpr uint64_t kProducedFileFormatVersion = 0x3L; // The version we write when the archive contains bytecode. diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index adb4809adf9..1916289cc36 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -591,6 +591,7 @@ class PyTorchFileWriter(object): def __init__(self, buffer: BinaryIO) -> None: ... def write_record(self, name: str, data: bytes, size: _int) -> None: ... def write_end_of_file(self) -> None: ... + def set_min_version(self, version: _int) -> None: ... ... def _jit_get_inline_everything_mode() -> _bool: ... diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 23aa92f2229..0a773fd72ce 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -868,6 +868,7 @@ void initJITBindings(PyObject* module) { const char* data, size_t size) { return self.writeRecord(name, data, size); }) .def("write_end_of_file", &PyTorchStreamWriter::writeEndOfFile) + .def("set_min_version", &PyTorchStreamWriter::setMinVersion) .def( "write_record", [](PyTorchStreamWriter& self, diff --git a/torch/package/exporter.py b/torch/package/exporter.py index 4cb4c5dbac1..62150f9bcc5 100644 --- a/torch/package/exporter.py +++ b/torch/package/exporter.py @@ -74,6 +74,7 @@ class PackageExporter: self.buffer = f self.zip_file = torch._C.PyTorchFileWriter(f) + self.zip_file.set_min_version(6) self.serialized_storages : Dict[str, Any] = {} self.external : List[str] = [] self.provided : Dict[str, bool] = {} @@ -427,7 +428,7 @@ node [shape=box]; # Write each tensor to a file named tensor/the_tensor_key in the zip archive for key in sorted(self.serialized_storages.keys()): - name = 'data/{}'.format(key) + name = f'.data/{key}.storage' storage = self.serialized_storages[key] # location information is saved in python, but to actually # get the data from non cpu tensors we need to move them over first @@ -436,7 +437,7 @@ node [shape=box]; num_bytes = storage.size() * storage.element_size() self.zip_file.write_record(name, storage.data_ptr(), num_bytes) contents = ('\n'.join(self.external) + '\n') - self._write('extern_modules', contents) + self._write('.data/extern_modules', contents) del self.zip_file if self.buffer: self.buffer.flush() diff --git a/torch/package/importer.py b/torch/package/importer.py index 074ef987d02..f142fa87edf 100644 --- a/torch/package/importer.py +++ b/torch/package/importer.py @@ -149,7 +149,7 @@ class PackageImporter: loaded_storages = {} def load_tensor(data_type, size, key, location, restore_location): - name = f'data/{key}' + name = f'.data/{key}.storage' dtype = data_type(0).dtype storage = self.zip_reader.get_storage_from_record(name, size, dtype).storage() @@ -191,7 +191,7 @@ class PackageImporter: return self._mangler.parent_name() def _read_extern(self): - return self.zip_reader.get_record('extern_modules').decode('utf-8').splitlines(keepends=False) + return self.zip_reader.get_record('.data/extern_modules').decode('utf-8').splitlines(keepends=False) def _make_module(self, name: str, filename: Optional[str], is_package: bool, parent: str): mangled_filename = self._mangler.mangle(filename) if filename else None