diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index bf65f1f65be..4da8850db99 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -375,7 +375,7 @@ def import_ir_module_from_buffer( def _import_ir_module_from_package( cu: CompilationUnit, reader: PyTorchFileReader, - storage_context: StorageContext, + storage_context: DeserializationStorageContext, map_location: Union[_device, str, None], ts_id: str ) -> ScriptModule: ... @@ -436,14 +436,22 @@ class ScriptModuleSerializer(object): def __init__(self, export_writer: PyTorchFileWriter) -> None: ... def serialize(self, model: ScriptModule, script_module_id: _int) -> None: ... def write_files(self) -> None: ... + def storage_context(self) -> SerializationStorageContext: ... ... # Defined in torch/csrc/jit/python/script_init.cpp -class StorageContext(object): +class SerializationStorageContext(object): + def __init__(self) -> None: ... + def has_storage(self, storage: Storage) -> _bool: ... + def get_or_add_storage(self, storage: Storage) -> _int: ... + ... + +# Defined in torch/csrc/jit/python/script_init.cpp +class DeserializationStorageContext(object): def __init__(self) -> None: ... def get_storage(self, name: str, dtype: _dtype) -> Tensor: ... def has_storage(self, name: str) -> _bool: ... - def add_storage(self, name: str, tensor: Tensor) -> None: ... + def add_storage(self, name: str, tensor: Tensor) -> _int: ... ... # Defined in torch/csrc/jit/python/script_init.cpp diff --git a/torch/csrc/jit/mobile/backport_manager.cpp b/torch/csrc/jit/mobile/backport_manager.cpp index 87233ff1e56..91c8548ee7d 100644 --- a/torch/csrc/jit/mobile/backport_manager.cpp +++ b/torch/csrc/jit/mobile/backport_manager.cpp @@ -210,8 +210,8 @@ void writeArchiveV5( const std::string& archive_name, const std::string& archive_dir, const std::string& tensor_dir, - bool tensor_cdata_naming_scheme, - StorageContext& storage_context) { + bool use_storage_context, + SerializationStorageContext& storage_context) { std::vector data; // Vector to capture the run-time class types during pickling the IValues std::vector memoizedClassTypes; @@ -225,12 +225,12 @@ void writeArchiveV5( &memoizedClassTypes, [&](const at::Tensor& tensor) { // returns a string to use in picker.cpp as storage obj key - if (tensor_cdata_naming_scheme) { + if (use_storage_context) { std::string string_id = std::to_string(reinterpret_cast( tensor.storage().unsafeGetStorageImpl())); tensor_names.push_back(string_id + ".storage"); - storage_context.addStorage(string_id, tensor.storage()); + storage_context.getOrAddStorage(tensor.storage()); } else { tensor_names.push_back(std::to_string(tensor_names.size())); } @@ -250,7 +250,7 @@ void writeArchiveV5( for (const auto& td : data_pickle.tensorData()) { WriteableTensorData writable_td = getWriteableTensorData(td); std::string fname = tensor_dir + tensor_names[i++]; - if (tensor_cdata_naming_scheme && + if (use_storage_context && std::find( pre_serialized_files.begin(), pre_serialized_files.end(), fname) != pre_serialized_files.end()) { @@ -329,14 +329,14 @@ std::stringstream backport_v6_to_v5(std::stringstream& input_model_stream) { update_bytecode_version(bytecode_values, kBytecodeVersionV5); auto bytecode_tuple = c10::ivalue::Tuple::create(std::move(bytecode_values)); - StorageContext storage_context; + SerializationStorageContext storage_context; writeArchiveV5( writer_bytecode, c10::ivalue::Tuple::create(constants_values), /*archive_name=*/"constants", /*archive_dir=*/"", /*tensor_dir=*/"constants/", - /*tensor_cdata_naming_scheme=*/true, + /*use_storage_context=*/true, storage_context); writeArchiveV5( writer_bytecode, @@ -344,7 +344,7 @@ std::stringstream backport_v6_to_v5(std::stringstream& input_model_stream) { /*archive_name=*/"bytecode", /*archive_dir=*/"", /*tensor_dir=*/"constants/", - /*tensor_cdata_naming_scheme=*/true, + /*use_storage_context=*/true, storage_context); return ouput_model_stream; diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index f6b0ee74fda..509662f208b 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1126,12 +1126,14 @@ void initJITBindings(PyObject* module) { // Used by torch.Package to coordinate deserialization of storages across // ScriptModules and eager modules - py::class_>( - m, "StorageContext") + py::class_< + DeserializationStorageContext, + std::shared_ptr>( + m, "DeserializationStorageContext") .def(py::init<>()) .def( "get_storage", - [](StorageContext& self, + [](DeserializationStorageContext& self, const std::string& name, py::object data_type_obj) { c10::Storage storage = self.getStorage(name); @@ -1147,12 +1149,12 @@ void initJITBindings(PyObject* module) { }) .def( "add_storage", - [](StorageContext& self, + [](DeserializationStorageContext& self, const std::string& name, const at::Tensor& tensor) { - self.addStorage(name, tensor.storage()); + return self.addStorage(name, tensor.storage()); }) - .def("has_storage", &StorageContext::hasStorage); + .def("has_storage", &DeserializationStorageContext::hasStorage); m.def( "_jit_get_operation", diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 4152f643e54..8d203f3035c 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -987,14 +988,25 @@ void initJitScriptBindings(PyObject* module) { pyIValueDeepcopy(IValue(self._ivalue()), memo).toObject()); }); - // Used by torch.Package to save TS objects in unified format + // Used by torch.package to save ScriptModule objects in unified format. py::class_(m, "ScriptModuleSerializer") .def(py::init()) .def("serialize", &ScriptModuleSerializer::serialize_unified_format) .def( "write_files", &ScriptModuleSerializer::writeFiles, - py::arg("code_dir") = ".data/ts_code/code/"); + py::arg("code_dir") = ".data/ts_code/code/") + .def("storage_context", &ScriptModuleSerializer::storage_context); + + // Used by torch.package to coordinate sharing of storages between eager + // and ScriptModules. + py::class_< + SerializationStorageContext, + std::shared_ptr>( + m, "SerializationStorageContext") + .def(py::init()) + .def("has_storage", &SerializationStorageContext::hasStorage) + .def("get_or_add_storage", &SerializationStorageContext::getOrAddStorage); // torch.jit.ScriptModule is a subclass of this C++ object. // Methods here are prefixed with _ since they should not be @@ -1674,7 +1686,8 @@ void initJitScriptBindings(PyObject* module) { "_import_ir_module_from_package", [](std::shared_ptr cu, std::shared_ptr reader, - std::shared_ptr storage_context, + std::shared_ptr + storage_context, py::object map_location, std::string ts_id) { c10::optional optional_device; diff --git a/torch/csrc/jit/serialization/export.h b/torch/csrc/jit/serialization/export.h index f4f4da3b4dd..12f3e275adf 100644 --- a/torch/csrc/jit/serialization/export.h +++ b/torch/csrc/jit/serialization/export.h @@ -3,9 +3,9 @@ #include #include #include -#include #include #include +#include #include #include @@ -70,6 +70,7 @@ class TORCH_API ScriptModuleSerializer { bool bytecode_format, bool save_mobile_debug_info); void serialize_unified_format(Module& module, uint64_t script_module_id); + SerializationStorageContext& storage_context(); ~ScriptModuleSerializer() = default; @@ -86,7 +87,7 @@ class TORCH_API ScriptModuleSerializer { const std::string& archive_name, const std::string& archive_dir, const std::string& tensor_dir, - bool tensor_cdata_naming_scheme = false); + bool use_storage_context = false); void updateSourceRangeTags(const SourceRangeRecords& ranges); caffe2::serialize::PyTorchStreamWriter& writer_; @@ -100,8 +101,9 @@ class TORCH_API ScriptModuleSerializer { OrderedDict file_streams_; // Used to keep references of storages around during serialization to solve // for ABA memory reuse problem hit when storages are created/destroyed - // during serializaiton process. - StorageContext storage_context_; + // during serialization process. Also used to coordinate sharing of storages + // between Script and eager modules in torch.package. + SerializationStorageContext storage_context_; // Uniquely identifies a SourceRange in a model. // SourceRanges are associated with Nodes of Graphs. diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index fa36e772699..e9c717b0306 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -441,7 +441,7 @@ void ScriptModuleSerializer::serialize( /*archive_name=*/"constants", /*archive_dir=*/"", /*tensor_dir=*/"constants/", - /*tensor_cdata_naming_scheme=*/true); + /*use_storage_context=*/true); writeByteCode(module, save_mobile_debug_info); writeMobileMetadata(module, extra_files); @@ -463,11 +463,13 @@ void ScriptModuleSerializer::writeArchive( const std::string& archive_name, const std::string& archive_dir, const std::string& tensor_dir, - bool tensor_cdata_naming_scheme) { + bool use_storage_context) { std::vector data; // Vector to capture the run-time class types during pickling the IValues std::vector memoizedClassTypes; std::vector tensor_names; + // tensors that are already serialized in use_storage_context + std::unordered_set serialized_tensors; Pickler data_pickle( [&](const char* buf, size_t size) { data.insert(data.end(), buf, buf + size); @@ -479,12 +481,19 @@ void ScriptModuleSerializer::writeArchive( &memoizedClassTypes, [&](const at::Tensor& tensor) { // returns a string to use in picker.cpp as storage obj key - if (tensor_cdata_naming_scheme) { - std::string string_id = - std::to_string(reinterpret_cast( - tensor.storage().unsafeGetStorageImpl())); - tensor_names.push_back(string_id + ".storage"); - storage_context_.addStorage(string_id, tensor.storage()); + if (use_storage_context) { + bool already_serialized = + storage_context_.hasStorage(tensor.storage()); + std::string tensor_name = + std::to_string( + storage_context_.getOrAddStorage(tensor.storage())) + + ".storage"; + if (already_serialized) { + // this case is hit when storage has been serialized already + // from a torch.package context + serialized_tensors.insert(tensor_name); + } + tensor_names.push_back(tensor_name); } else { tensor_names.push_back(std::to_string(tensor_names.size())); } @@ -498,20 +507,18 @@ void ScriptModuleSerializer::writeArchive( std::string prefix = archive_name + "/"; TORCH_INTERNAL_ASSERT(tensor_names.size() == data_pickle.tensorData().size()); - const std::vector& pre_serialized_files = - writer_.getAllWrittenRecords(); for (const auto& td : data_pickle.tensorData()) { WriteableTensorData writable_td = getWriteableTensorData(td); - std::string fname = tensor_dir + tensor_names[i++]; - if (tensor_cdata_naming_scheme && - std::find( - pre_serialized_files.begin(), pre_serialized_files.end(), fname) != - pre_serialized_files.end()) { + std::string tensor_name = tensor_names[i++]; + if (use_storage_context && serialized_tensors.count(tensor_name)) { // storage has been serialzed already, skip continue; } - writer_.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes()); + writer_.writeRecord( + tensor_dir + tensor_name, + writable_td.data(), + writable_td.sizeInBytes()); } std::string fname = archive_dir + archive_name + ".pkl"; @@ -650,7 +657,7 @@ void ScriptModuleSerializer::writeByteCode( /*archive_name=*/"bytecode", /*archive_dir=*/"", /*tensor_dir=*/"constants/", - /*tensor_cdata_naming_scheme=*/true); + /*use_storage_context=*/true); auto debug_info_telements = Tup(std::move(debug_info_elements)); @@ -760,7 +767,7 @@ void ScriptModuleSerializer::serialize_unified_format( "data", archive_dir, /*tensor_dir=*/".data/", - /*tensor_cdata_naming_scheme=*/true); + /*use_storage_context=*/true); // Then we serialize all code info. convertTypes(module.type()); // The tensor constants from the code are written to a separate archive @@ -772,12 +779,16 @@ void ScriptModuleSerializer::serialize_unified_format( "constants", archive_dir, /*tensor_dir=*/".data/", - /*tensor_cdata_naming_scheme=*/true); + /*use_storage_context=*/true); // Note: writeFiles() call needs to be made in addition to calling this // function to have the code actually saved (tensors are saved) } +SerializationStorageContext& ScriptModuleSerializer::storage_context() { + return storage_context_; +} + void ExportModule( const Module& module, std::ostream& out, diff --git a/torch/csrc/jit/serialization/import.cpp b/torch/csrc/jit/serialization/import.cpp index 8bbd9f8a995..8380013d8ce 100644 --- a/torch/csrc/jit/serialization/import.cpp +++ b/torch/csrc/jit/serialization/import.cpp @@ -91,7 +91,7 @@ class ScriptModuleDeserializer final { std::shared_ptr reader, std::string pickle_dir_prefix, std::string tensor_dir_prefix, - std::shared_ptr storage_context) + std::shared_ptr storage_context) : compilation_unit_(std::move(cu)), reader_(std::move(reader)), storage_context_(std::move(storage_context)), @@ -116,7 +116,7 @@ class ScriptModuleDeserializer final { std::shared_ptr compilation_unit_; std::shared_ptr reader_; - std::shared_ptr storage_context_; + std::shared_ptr storage_context_; c10::optional device_; std::vector constants_table_; std::string code_prefix_; @@ -291,7 +291,7 @@ Module import_ir_module( Module import_ir_module( std::shared_ptr cu, std::shared_ptr reader, - std::shared_ptr storage_context, + std::shared_ptr storage_context, c10::optional device, std::string ts_id) { ScriptModuleDeserializer deserializer( diff --git a/torch/csrc/jit/serialization/import.h b/torch/csrc/jit/serialization/import.h index ce223099bc6..fe2e1129c27 100644 --- a/torch/csrc/jit/serialization/import.h +++ b/torch/csrc/jit/serialization/import.h @@ -17,29 +17,6 @@ class ReadAdapterInterface; namespace torch { namespace jit { -// used in torch.package deserialization -class TORCH_API StorageContext { - public: - explicit StorageContext() = default; - - void addStorage(const std::string& name, c10::Storage storage) { - storage_map_.insert({name, storage}); - } - - bool hasStorage(const std::string& name) { - return storage_map_.find(name) != storage_map_.end(); - } - - c10::Storage getStorage(const std::string& name) { - TORCH_INTERNAL_ASSERT(hasStorage(name)); - return storage_map_.find(name)->second; - } - ~StorageContext() = default; - - private: - std::map storage_map_; -}; - TORCH_API Module import_ir_module( std::shared_ptr cu, const std::string& filename, @@ -65,7 +42,7 @@ TORCH_API Module import_ir_module( TORCH_API Module import_ir_module( std::shared_ptr cu, std::shared_ptr reader, - std::shared_ptr storage_context, + std::shared_ptr storage_context, c10::optional device, std::string ts_id /* torchscript identifier inside package */); diff --git a/torch/csrc/jit/serialization/import_read.cpp b/torch/csrc/jit/serialization/import_read.cpp index edc335edf27..1ea16fc8c06 100644 --- a/torch/csrc/jit/serialization/import_read.cpp +++ b/torch/csrc/jit/serialization/import_read.cpp @@ -12,7 +12,7 @@ IValue readArchiveAndTensors( c10::optional obj_loader, c10::optional device, caffe2::serialize::PyTorchStreamReader& stream_reader, - std::shared_ptr storage_context) { + std::shared_ptr storage_context) { std::string picklename = pickle_prefix + archive_name + ".pkl"; at::DataPtr pickle_ptr; size_t pickle_size = 0; diff --git a/torch/csrc/jit/serialization/import_read.h b/torch/csrc/jit/serialization/import_read.h index 73b27242272..504f5dc1293 100644 --- a/torch/csrc/jit/serialization/import_read.h +++ b/torch/csrc/jit/serialization/import_read.h @@ -1,6 +1,5 @@ #pragma once -#include #include #include @@ -21,7 +20,7 @@ TORCH_API IValue readArchiveAndTensors( c10::optional obj_loader, c10::optional device, caffe2::serialize::PyTorchStreamReader& stream_reader, - std::shared_ptr storage_context = nullptr); + std::shared_ptr storage_context = nullptr); bool check_zip_file( std::shared_ptr rai); diff --git a/torch/csrc/jit/serialization/storage_context.h b/torch/csrc/jit/serialization/storage_context.h new file mode 100644 index 00000000000..d94aa479f17 --- /dev/null +++ b/torch/csrc/jit/serialization/storage_context.h @@ -0,0 +1,79 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +// Used in torch.package and TorchScript serialization to coordinate +// sharing of storages between models. Also used to create deterministic +// naming for storages. +class TORCH_API SerializationStorageContext { + public: + explicit SerializationStorageContext() = default; + + uint64_t getOrAddStorage(c10::Storage storage) { + if (!hasStorage(storage)) { + uint64_t size = storage_id_map_.size(); + storage_id_map_[storage] = size; + } + return storage_id_map_[storage]; + } + + bool hasStorage(c10::Storage storage) { + return storage_id_map_.find(storage) != storage_id_map_.end(); + } + + ~SerializationStorageContext() = default; + + private: + class StorageSerializationHash { + public: + size_t operator()(const c10::Storage& storage) const { + return std::hash()( + reinterpret_cast(storage.unsafeGetStorageImpl())); + } + }; + + class StorageSerializationEqual { + public: + bool operator()(const c10::Storage& lhs, const c10::Storage& rhs) const { + return lhs.unsafeGetStorageImpl() == rhs.unsafeGetStorageImpl(); + } + }; + + std::unordered_map< + c10::Storage, + uint64_t, + StorageSerializationHash, + StorageSerializationEqual> + storage_id_map_; +}; + +// Used in torch.package and TorchScript deserialization to coordinate +// sharing of storages between models. +class TORCH_API DeserializationStorageContext { + public: + explicit DeserializationStorageContext() = default; + + void addStorage(const std::string& name, c10::Storage storage) { + TORCH_INTERNAL_ASSERT(!hasStorage(name)); + name_storage_map_.insert({name, storage}); + } + + bool hasStorage(const std::string& name) { + return name_storage_map_.find(name) != name_storage_map_.end(); + } + + c10::Storage getStorage(const std::string& name) { + TORCH_INTERNAL_ASSERT(hasStorage(name)); + return name_storage_map_.find(name)->second; + } + ~DeserializationStorageContext() = default; + + private: + std::unordered_map name_storage_map_; +}; + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index 5c4a17fe83f..581b94978c4 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -5,8 +5,8 @@ #endif #include #include -#include #include +#include #include #include diff --git a/torch/csrc/jit/serialization/unpickler.h b/torch/csrc/jit/serialization/unpickler.h index da6d4cec41e..f404deee848 100644 --- a/torch/csrc/jit/serialization/unpickler.h +++ b/torch/csrc/jit/serialization/unpickler.h @@ -15,7 +15,7 @@ using TypeResolver = using ObjLoader = std::function< c10::intrusive_ptr(at::StrongTypePtr, IValue)>; -class StorageContext; +class DeserializationStorageContext; // [unpickler refactor] there is some cruft around PickleOpCode::BUILD, // PickleOpCode::NEWOBJ, and the last_opcode_ member below that should be @@ -51,7 +51,7 @@ class TORCH_API Unpickler { std::function read_record, c10::optional device, bool use_storage_device = false, - std::shared_ptr storage_context = nullptr) + std::shared_ptr storage_context = nullptr) : reader_(std::move(reader)), tensor_table_(), type_resolver_(std::move(type_resolver)), @@ -157,7 +157,7 @@ class TORCH_API Unpickler { // Used for torch.package to enable sharing of storages across // ScriptModules and eager modules - std::shared_ptr storage_context_; + std::shared_ptr storage_context_; // See [type tag serialization] uint64_t version_; diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index d3640ba9f43..786ae652f88 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -16,7 +16,6 @@ from typing import ( List, Optional, Sequence, - Set, Union, ) from urllib.parse import quote @@ -192,7 +191,6 @@ class PackageExporter: self.zip_file = torch._C.PyTorchFileWriter(f) self.zip_file.set_min_version(6) self.serialized_reduces: Dict[int, Any] = {} - self.serialized_storages: Set[str] = set() # A graph tracking all the modules and pickle objects added to this # package and the dependencies between them. @@ -202,6 +200,7 @@ class PackageExporter: self.dependency_graph = DiGraph() self.verbose = verbose self.script_module_serializer = torch._C.ScriptModuleSerializer(self.zip_file) + self.storage_context = self.script_module_serializer.storage_context() # These are OrderedDicts for compatibility with RemovableHandle. # Generic OrderedDict type annotations are not present until 3.7. @@ -776,20 +775,19 @@ node [shape=box]; def _persistent_id(self, obj): if torch.is_storage(obj): storage_type = normalize_storage_type(type(obj)) - obj_key = str(obj._cdata) location = location_tag(obj) - name = f".data/{obj_key}.storage" - if name not in self.serialized_storages: - # check to see if storage was previously serialized - serialized_files = self.zip_file.get_all_written_records() - if name not in serialized_files: - if obj.device.type != "cpu": - obj = obj.cpu() - num_bytes = obj.size() * obj.element_size() - self.zip_file.write_record(name, obj.data_ptr(), num_bytes) - self.serialized_storages.add(name) - return ("storage", storage_type, obj_key, location, obj.size()) + # serialize storage if not already written + storage_present = self.storage_context.has_storage(obj) + storage_id = self.storage_context.get_or_add_storage(obj) + if not storage_present: + if obj.device.type != "cpu": + obj = obj.cpu() + num_bytes = obj.size() * obj.element_size() + self.zip_file.write_record( + f".data/{storage_id}.storage", obj.data_ptr(), num_bytes + ) + return ("storage", storage_type, storage_id, location, obj.size()) if hasattr(obj, "__reduce_package__"): if _gate_torchscript_serialization and isinstance( diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index d40601709bd..797aa28cc11 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -173,10 +173,10 @@ class PackageImporter(Importer): restore_location = _get_restore_location(map_location) loaded_storages = {} loaded_reduces = {} - storage_context = torch._C.StorageContext() + storage_context = torch._C.DeserializationStorageContext() def load_tensor(data_type, size, key, location, restore_location): - name = f"{key}.storage" + name = f"{int(key)}.storage" dtype = data_type(0).dtype if storage_context.has_storage(name):