mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[package] track storages across lifetime of PackageExporter (#59735)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59735 1. Fixes ABA storage identity problem during serialization for `torch.package` by keeping reference of serialized storages through lifetime of `PackageExporter` to prevent reuse of memory address. Achieved by extending logic used in solution to mobile's same issue. 2. Adds determinism to naming scheme of serialized storages in export code paths which utilize `tensor_cdata_naming_scheme`(introduced 2nd mapping in `StorageContext`, now maps `storage cdata ptr` -> `unique id`, `unique id` -> `c10::Storage`) 3. Additionally uses presence of a storage in the `StorageContext` instance as marker for if a storage has been serialized or not, removing the need to scan the `PythonStreamWriter` for presence of the storage's serialization file Test Plan: Imported from OSS Reviewed By: suo Differential Revision: D29075276 Pulled By: Lilyjjo fbshipit-source-id: 15a5c30b1de99c5bd7079388f2db9b6ece2eca12
This commit is contained in:
parent
eb2f535689
commit
0dd90cceaf
|
|
@ -375,7 +375,7 @@ def import_ir_module_from_buffer(
|
|||
def _import_ir_module_from_package(
|
||||
cu: CompilationUnit,
|
||||
reader: PyTorchFileReader,
|
||||
storage_context: StorageContext,
|
||||
storage_context: DeserializationStorageContext,
|
||||
map_location: Union[_device, str, None],
|
||||
ts_id: str
|
||||
) -> ScriptModule: ...
|
||||
|
|
@ -436,14 +436,22 @@ class ScriptModuleSerializer(object):
|
|||
def __init__(self, export_writer: PyTorchFileWriter) -> None: ...
|
||||
def serialize(self, model: ScriptModule, script_module_id: _int) -> None: ...
|
||||
def write_files(self) -> None: ...
|
||||
def storage_context(self) -> SerializationStorageContext: ...
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||
class StorageContext(object):
|
||||
class SerializationStorageContext(object):
|
||||
def __init__(self) -> None: ...
|
||||
def has_storage(self, storage: Storage) -> _bool: ...
|
||||
def get_or_add_storage(self, storage: Storage) -> _int: ...
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||
class DeserializationStorageContext(object):
|
||||
def __init__(self) -> None: ...
|
||||
def get_storage(self, name: str, dtype: _dtype) -> Tensor: ...
|
||||
def has_storage(self, name: str) -> _bool: ...
|
||||
def add_storage(self, name: str, tensor: Tensor) -> None: ...
|
||||
def add_storage(self, name: str, tensor: Tensor) -> _int: ...
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||
|
|
|
|||
|
|
@ -210,8 +210,8 @@ void writeArchiveV5(
|
|||
const std::string& archive_name,
|
||||
const std::string& archive_dir,
|
||||
const std::string& tensor_dir,
|
||||
bool tensor_cdata_naming_scheme,
|
||||
StorageContext& storage_context) {
|
||||
bool use_storage_context,
|
||||
SerializationStorageContext& storage_context) {
|
||||
std::vector<char> data;
|
||||
// Vector to capture the run-time class types during pickling the IValues
|
||||
std::vector<c10::ClassTypePtr> memoizedClassTypes;
|
||||
|
|
@ -225,12 +225,12 @@ void writeArchiveV5(
|
|||
&memoizedClassTypes,
|
||||
[&](const at::Tensor& tensor) {
|
||||
// returns a string to use in picker.cpp as storage obj key
|
||||
if (tensor_cdata_naming_scheme) {
|
||||
if (use_storage_context) {
|
||||
std::string string_id =
|
||||
std::to_string(reinterpret_cast<std::intptr_t>(
|
||||
tensor.storage().unsafeGetStorageImpl()));
|
||||
tensor_names.push_back(string_id + ".storage");
|
||||
storage_context.addStorage(string_id, tensor.storage());
|
||||
storage_context.getOrAddStorage(tensor.storage());
|
||||
} else {
|
||||
tensor_names.push_back(std::to_string(tensor_names.size()));
|
||||
}
|
||||
|
|
@ -250,7 +250,7 @@ void writeArchiveV5(
|
|||
for (const auto& td : data_pickle.tensorData()) {
|
||||
WriteableTensorData writable_td = getWriteableTensorData(td);
|
||||
std::string fname = tensor_dir + tensor_names[i++];
|
||||
if (tensor_cdata_naming_scheme &&
|
||||
if (use_storage_context &&
|
||||
std::find(
|
||||
pre_serialized_files.begin(), pre_serialized_files.end(), fname) !=
|
||||
pre_serialized_files.end()) {
|
||||
|
|
@ -329,14 +329,14 @@ std::stringstream backport_v6_to_v5(std::stringstream& input_model_stream) {
|
|||
|
||||
update_bytecode_version(bytecode_values, kBytecodeVersionV5);
|
||||
auto bytecode_tuple = c10::ivalue::Tuple::create(std::move(bytecode_values));
|
||||
StorageContext storage_context;
|
||||
SerializationStorageContext storage_context;
|
||||
writeArchiveV5(
|
||||
writer_bytecode,
|
||||
c10::ivalue::Tuple::create(constants_values),
|
||||
/*archive_name=*/"constants",
|
||||
/*archive_dir=*/"",
|
||||
/*tensor_dir=*/"constants/",
|
||||
/*tensor_cdata_naming_scheme=*/true,
|
||||
/*use_storage_context=*/true,
|
||||
storage_context);
|
||||
writeArchiveV5(
|
||||
writer_bytecode,
|
||||
|
|
@ -344,7 +344,7 @@ std::stringstream backport_v6_to_v5(std::stringstream& input_model_stream) {
|
|||
/*archive_name=*/"bytecode",
|
||||
/*archive_dir=*/"",
|
||||
/*tensor_dir=*/"constants/",
|
||||
/*tensor_cdata_naming_scheme=*/true,
|
||||
/*use_storage_context=*/true,
|
||||
storage_context);
|
||||
|
||||
return ouput_model_stream;
|
||||
|
|
|
|||
|
|
@ -1126,12 +1126,14 @@ void initJITBindings(PyObject* module) {
|
|||
|
||||
// Used by torch.Package to coordinate deserialization of storages across
|
||||
// ScriptModules and eager modules
|
||||
py::class_<StorageContext, std::shared_ptr<StorageContext>>(
|
||||
m, "StorageContext")
|
||||
py::class_<
|
||||
DeserializationStorageContext,
|
||||
std::shared_ptr<DeserializationStorageContext>>(
|
||||
m, "DeserializationStorageContext")
|
||||
.def(py::init<>())
|
||||
.def(
|
||||
"get_storage",
|
||||
[](StorageContext& self,
|
||||
[](DeserializationStorageContext& self,
|
||||
const std::string& name,
|
||||
py::object data_type_obj) {
|
||||
c10::Storage storage = self.getStorage(name);
|
||||
|
|
@ -1147,12 +1149,12 @@ void initJITBindings(PyObject* module) {
|
|||
})
|
||||
.def(
|
||||
"add_storage",
|
||||
[](StorageContext& self,
|
||||
[](DeserializationStorageContext& self,
|
||||
const std::string& name,
|
||||
const at::Tensor& tensor) {
|
||||
self.addStorage(name, tensor.storage());
|
||||
return self.addStorage(name, tensor.storage());
|
||||
})
|
||||
.def("has_storage", &StorageContext::hasStorage);
|
||||
.def("has_storage", &DeserializationStorageContext::hasStorage);
|
||||
|
||||
m.def(
|
||||
"_jit_get_operation",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include <torch/csrc/jit/python/script_init.h>
|
||||
|
||||
#include <torch/csrc/Device.h>
|
||||
#include <torch/csrc/DynamicTypes.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
||||
#include <torch/csrc/jit/frontend/sugared_value.h>
|
||||
|
|
@ -987,14 +988,25 @@ void initJitScriptBindings(PyObject* module) {
|
|||
pyIValueDeepcopy(IValue(self._ivalue()), memo).toObject());
|
||||
});
|
||||
|
||||
// Used by torch.Package to save TS objects in unified format
|
||||
// Used by torch.package to save ScriptModule objects in unified format.
|
||||
py::class_<ScriptModuleSerializer>(m, "ScriptModuleSerializer")
|
||||
.def(py::init<caffe2::serialize::PyTorchStreamWriter&>())
|
||||
.def("serialize", &ScriptModuleSerializer::serialize_unified_format)
|
||||
.def(
|
||||
"write_files",
|
||||
&ScriptModuleSerializer::writeFiles,
|
||||
py::arg("code_dir") = ".data/ts_code/code/");
|
||||
py::arg("code_dir") = ".data/ts_code/code/")
|
||||
.def("storage_context", &ScriptModuleSerializer::storage_context);
|
||||
|
||||
// Used by torch.package to coordinate sharing of storages between eager
|
||||
// and ScriptModules.
|
||||
py::class_<
|
||||
SerializationStorageContext,
|
||||
std::shared_ptr<SerializationStorageContext>>(
|
||||
m, "SerializationStorageContext")
|
||||
.def(py::init<SerializationStorageContext&>())
|
||||
.def("has_storage", &SerializationStorageContext::hasStorage)
|
||||
.def("get_or_add_storage", &SerializationStorageContext::getOrAddStorage);
|
||||
|
||||
// torch.jit.ScriptModule is a subclass of this C++ object.
|
||||
// Methods here are prefixed with _ since they should not be
|
||||
|
|
@ -1674,7 +1686,8 @@ void initJitScriptBindings(PyObject* module) {
|
|||
"_import_ir_module_from_package",
|
||||
[](std::shared_ptr<CompilationUnit> cu,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> reader,
|
||||
std::shared_ptr<torch::jit::StorageContext> storage_context,
|
||||
std::shared_ptr<torch::jit::DeserializationStorageContext>
|
||||
storage_context,
|
||||
py::object map_location,
|
||||
std::string ts_id) {
|
||||
c10::optional<at::Device> optional_device;
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@
|
|||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/serialization/import.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <torch/csrc/jit/serialization/python_print.h>
|
||||
#include <torch/csrc/jit/serialization/storage_context.h>
|
||||
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
|
||||
#include <torch/csrc/onnx/onnx.h>
|
||||
|
||||
|
|
@ -70,6 +70,7 @@ class TORCH_API ScriptModuleSerializer {
|
|||
bool bytecode_format,
|
||||
bool save_mobile_debug_info);
|
||||
void serialize_unified_format(Module& module, uint64_t script_module_id);
|
||||
SerializationStorageContext& storage_context();
|
||||
|
||||
~ScriptModuleSerializer() = default;
|
||||
|
||||
|
|
@ -86,7 +87,7 @@ class TORCH_API ScriptModuleSerializer {
|
|||
const std::string& archive_name,
|
||||
const std::string& archive_dir,
|
||||
const std::string& tensor_dir,
|
||||
bool tensor_cdata_naming_scheme = false);
|
||||
bool use_storage_context = false);
|
||||
void updateSourceRangeTags(const SourceRangeRecords& ranges);
|
||||
|
||||
caffe2::serialize::PyTorchStreamWriter& writer_;
|
||||
|
|
@ -100,8 +101,9 @@ class TORCH_API ScriptModuleSerializer {
|
|||
OrderedDict<std::string, PythonPrint> file_streams_;
|
||||
// Used to keep references of storages around during serialization to solve
|
||||
// for ABA memory reuse problem hit when storages are created/destroyed
|
||||
// during serializaiton process.
|
||||
StorageContext storage_context_;
|
||||
// during serialization process. Also used to coordinate sharing of storages
|
||||
// between Script and eager modules in torch.package.
|
||||
SerializationStorageContext storage_context_;
|
||||
|
||||
// Uniquely identifies a SourceRange in a model.
|
||||
// SourceRanges are associated with Nodes of Graphs.
|
||||
|
|
|
|||
|
|
@ -441,7 +441,7 @@ void ScriptModuleSerializer::serialize(
|
|||
/*archive_name=*/"constants",
|
||||
/*archive_dir=*/"",
|
||||
/*tensor_dir=*/"constants/",
|
||||
/*tensor_cdata_naming_scheme=*/true);
|
||||
/*use_storage_context=*/true);
|
||||
|
||||
writeByteCode(module, save_mobile_debug_info);
|
||||
writeMobileMetadata(module, extra_files);
|
||||
|
|
@ -463,11 +463,13 @@ void ScriptModuleSerializer::writeArchive(
|
|||
const std::string& archive_name,
|
||||
const std::string& archive_dir,
|
||||
const std::string& tensor_dir,
|
||||
bool tensor_cdata_naming_scheme) {
|
||||
bool use_storage_context) {
|
||||
std::vector<char> data;
|
||||
// Vector to capture the run-time class types during pickling the IValues
|
||||
std::vector<c10::ClassTypePtr> memoizedClassTypes;
|
||||
std::vector<std::string> tensor_names;
|
||||
// tensors that are already serialized in use_storage_context
|
||||
std::unordered_set<std::string> serialized_tensors;
|
||||
Pickler data_pickle(
|
||||
[&](const char* buf, size_t size) {
|
||||
data.insert(data.end(), buf, buf + size);
|
||||
|
|
@ -479,12 +481,19 @@ void ScriptModuleSerializer::writeArchive(
|
|||
&memoizedClassTypes,
|
||||
[&](const at::Tensor& tensor) {
|
||||
// returns a string to use in picker.cpp as storage obj key
|
||||
if (tensor_cdata_naming_scheme) {
|
||||
std::string string_id =
|
||||
std::to_string(reinterpret_cast<std::intptr_t>(
|
||||
tensor.storage().unsafeGetStorageImpl()));
|
||||
tensor_names.push_back(string_id + ".storage");
|
||||
storage_context_.addStorage(string_id, tensor.storage());
|
||||
if (use_storage_context) {
|
||||
bool already_serialized =
|
||||
storage_context_.hasStorage(tensor.storage());
|
||||
std::string tensor_name =
|
||||
std::to_string(
|
||||
storage_context_.getOrAddStorage(tensor.storage())) +
|
||||
".storage";
|
||||
if (already_serialized) {
|
||||
// this case is hit when storage has been serialized already
|
||||
// from a torch.package context
|
||||
serialized_tensors.insert(tensor_name);
|
||||
}
|
||||
tensor_names.push_back(tensor_name);
|
||||
} else {
|
||||
tensor_names.push_back(std::to_string(tensor_names.size()));
|
||||
}
|
||||
|
|
@ -498,20 +507,18 @@ void ScriptModuleSerializer::writeArchive(
|
|||
std::string prefix = archive_name + "/";
|
||||
|
||||
TORCH_INTERNAL_ASSERT(tensor_names.size() == data_pickle.tensorData().size());
|
||||
const std::vector<std::string>& pre_serialized_files =
|
||||
writer_.getAllWrittenRecords();
|
||||
|
||||
for (const auto& td : data_pickle.tensorData()) {
|
||||
WriteableTensorData writable_td = getWriteableTensorData(td);
|
||||
std::string fname = tensor_dir + tensor_names[i++];
|
||||
if (tensor_cdata_naming_scheme &&
|
||||
std::find(
|
||||
pre_serialized_files.begin(), pre_serialized_files.end(), fname) !=
|
||||
pre_serialized_files.end()) {
|
||||
std::string tensor_name = tensor_names[i++];
|
||||
if (use_storage_context && serialized_tensors.count(tensor_name)) {
|
||||
// storage has been serialzed already, skip
|
||||
continue;
|
||||
}
|
||||
writer_.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes());
|
||||
writer_.writeRecord(
|
||||
tensor_dir + tensor_name,
|
||||
writable_td.data(),
|
||||
writable_td.sizeInBytes());
|
||||
}
|
||||
|
||||
std::string fname = archive_dir + archive_name + ".pkl";
|
||||
|
|
@ -650,7 +657,7 @@ void ScriptModuleSerializer::writeByteCode(
|
|||
/*archive_name=*/"bytecode",
|
||||
/*archive_dir=*/"",
|
||||
/*tensor_dir=*/"constants/",
|
||||
/*tensor_cdata_naming_scheme=*/true);
|
||||
/*use_storage_context=*/true);
|
||||
|
||||
auto debug_info_telements = Tup(std::move(debug_info_elements));
|
||||
|
||||
|
|
@ -760,7 +767,7 @@ void ScriptModuleSerializer::serialize_unified_format(
|
|||
"data",
|
||||
archive_dir,
|
||||
/*tensor_dir=*/".data/",
|
||||
/*tensor_cdata_naming_scheme=*/true);
|
||||
/*use_storage_context=*/true);
|
||||
// Then we serialize all code info.
|
||||
convertTypes(module.type());
|
||||
// The tensor constants from the code are written to a separate archive
|
||||
|
|
@ -772,12 +779,16 @@ void ScriptModuleSerializer::serialize_unified_format(
|
|||
"constants",
|
||||
archive_dir,
|
||||
/*tensor_dir=*/".data/",
|
||||
/*tensor_cdata_naming_scheme=*/true);
|
||||
/*use_storage_context=*/true);
|
||||
|
||||
// Note: writeFiles() call needs to be made in addition to calling this
|
||||
// function to have the code actually saved (tensors are saved)
|
||||
}
|
||||
|
||||
SerializationStorageContext& ScriptModuleSerializer::storage_context() {
|
||||
return storage_context_;
|
||||
}
|
||||
|
||||
void ExportModule(
|
||||
const Module& module,
|
||||
std::ostream& out,
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ class ScriptModuleDeserializer final {
|
|||
std::shared_ptr<PyTorchStreamReader> reader,
|
||||
std::string pickle_dir_prefix,
|
||||
std::string tensor_dir_prefix,
|
||||
std::shared_ptr<StorageContext> storage_context)
|
||||
std::shared_ptr<DeserializationStorageContext> storage_context)
|
||||
: compilation_unit_(std::move(cu)),
|
||||
reader_(std::move(reader)),
|
||||
storage_context_(std::move(storage_context)),
|
||||
|
|
@ -116,7 +116,7 @@ class ScriptModuleDeserializer final {
|
|||
|
||||
std::shared_ptr<CompilationUnit> compilation_unit_;
|
||||
std::shared_ptr<PyTorchStreamReader> reader_;
|
||||
std::shared_ptr<StorageContext> storage_context_;
|
||||
std::shared_ptr<DeserializationStorageContext> storage_context_;
|
||||
c10::optional<at::Device> device_;
|
||||
std::vector<at::IValue> constants_table_;
|
||||
std::string code_prefix_;
|
||||
|
|
@ -291,7 +291,7 @@ Module import_ir_module(
|
|||
Module import_ir_module(
|
||||
std::shared_ptr<CompilationUnit> cu,
|
||||
std::shared_ptr<PyTorchStreamReader> reader,
|
||||
std::shared_ptr<StorageContext> storage_context,
|
||||
std::shared_ptr<DeserializationStorageContext> storage_context,
|
||||
c10::optional<at::Device> device,
|
||||
std::string ts_id) {
|
||||
ScriptModuleDeserializer deserializer(
|
||||
|
|
|
|||
|
|
@ -17,29 +17,6 @@ class ReadAdapterInterface;
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// used in torch.package deserialization
|
||||
class TORCH_API StorageContext {
|
||||
public:
|
||||
explicit StorageContext() = default;
|
||||
|
||||
void addStorage(const std::string& name, c10::Storage storage) {
|
||||
storage_map_.insert({name, storage});
|
||||
}
|
||||
|
||||
bool hasStorage(const std::string& name) {
|
||||
return storage_map_.find(name) != storage_map_.end();
|
||||
}
|
||||
|
||||
c10::Storage getStorage(const std::string& name) {
|
||||
TORCH_INTERNAL_ASSERT(hasStorage(name));
|
||||
return storage_map_.find(name)->second;
|
||||
}
|
||||
~StorageContext() = default;
|
||||
|
||||
private:
|
||||
std::map<std::string, c10::Storage> storage_map_;
|
||||
};
|
||||
|
||||
TORCH_API Module import_ir_module(
|
||||
std::shared_ptr<CompilationUnit> cu,
|
||||
const std::string& filename,
|
||||
|
|
@ -65,7 +42,7 @@ TORCH_API Module import_ir_module(
|
|||
TORCH_API Module import_ir_module(
|
||||
std::shared_ptr<CompilationUnit> cu,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> reader,
|
||||
std::shared_ptr<torch::jit::StorageContext> storage_context,
|
||||
std::shared_ptr<torch::jit::DeserializationStorageContext> storage_context,
|
||||
c10::optional<at::Device> device,
|
||||
std::string ts_id /* torchscript identifier inside package */);
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ IValue readArchiveAndTensors(
|
|||
c10::optional<ObjLoader> obj_loader,
|
||||
c10::optional<at::Device> device,
|
||||
caffe2::serialize::PyTorchStreamReader& stream_reader,
|
||||
std::shared_ptr<StorageContext> storage_context) {
|
||||
std::shared_ptr<DeserializationStorageContext> storage_context) {
|
||||
std::string picklename = pickle_prefix + archive_name + ".pkl";
|
||||
at::DataPtr pickle_ptr;
|
||||
size_t pickle_size = 0;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/serialization/import.h>
|
||||
#include <torch/csrc/jit/serialization/unpickler.h>
|
||||
#include <memory>
|
||||
|
||||
|
|
@ -21,7 +20,7 @@ TORCH_API IValue readArchiveAndTensors(
|
|||
c10::optional<ObjLoader> obj_loader,
|
||||
c10::optional<at::Device> device,
|
||||
caffe2::serialize::PyTorchStreamReader& stream_reader,
|
||||
std::shared_ptr<StorageContext> storage_context = nullptr);
|
||||
std::shared_ptr<DeserializationStorageContext> storage_context = nullptr);
|
||||
|
||||
bool check_zip_file(
|
||||
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai);
|
||||
|
|
|
|||
79
torch/csrc/jit/serialization/storage_context.h
Normal file
79
torch/csrc/jit/serialization/storage_context.h
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// Used in torch.package and TorchScript serialization to coordinate
|
||||
// sharing of storages between models. Also used to create deterministic
|
||||
// naming for storages.
|
||||
class TORCH_API SerializationStorageContext {
|
||||
public:
|
||||
explicit SerializationStorageContext() = default;
|
||||
|
||||
uint64_t getOrAddStorage(c10::Storage storage) {
|
||||
if (!hasStorage(storage)) {
|
||||
uint64_t size = storage_id_map_.size();
|
||||
storage_id_map_[storage] = size;
|
||||
}
|
||||
return storage_id_map_[storage];
|
||||
}
|
||||
|
||||
bool hasStorage(c10::Storage storage) {
|
||||
return storage_id_map_.find(storage) != storage_id_map_.end();
|
||||
}
|
||||
|
||||
~SerializationStorageContext() = default;
|
||||
|
||||
private:
|
||||
class StorageSerializationHash {
|
||||
public:
|
||||
size_t operator()(const c10::Storage& storage) const {
|
||||
return std::hash<void*>()(
|
||||
reinterpret_cast<void*>(storage.unsafeGetStorageImpl()));
|
||||
}
|
||||
};
|
||||
|
||||
class StorageSerializationEqual {
|
||||
public:
|
||||
bool operator()(const c10::Storage& lhs, const c10::Storage& rhs) const {
|
||||
return lhs.unsafeGetStorageImpl() == rhs.unsafeGetStorageImpl();
|
||||
}
|
||||
};
|
||||
|
||||
std::unordered_map<
|
||||
c10::Storage,
|
||||
uint64_t,
|
||||
StorageSerializationHash,
|
||||
StorageSerializationEqual>
|
||||
storage_id_map_;
|
||||
};
|
||||
|
||||
// Used in torch.package and TorchScript deserialization to coordinate
|
||||
// sharing of storages between models.
|
||||
class TORCH_API DeserializationStorageContext {
|
||||
public:
|
||||
explicit DeserializationStorageContext() = default;
|
||||
|
||||
void addStorage(const std::string& name, c10::Storage storage) {
|
||||
TORCH_INTERNAL_ASSERT(!hasStorage(name));
|
||||
name_storage_map_.insert({name, storage});
|
||||
}
|
||||
|
||||
bool hasStorage(const std::string& name) {
|
||||
return name_storage_map_.find(name) != name_storage_map_.end();
|
||||
}
|
||||
|
||||
c10::Storage getStorage(const std::string& name) {
|
||||
TORCH_INTERNAL_ASSERT(hasStorage(name));
|
||||
return name_storage_map_.find(name)->second;
|
||||
}
|
||||
~DeserializationStorageContext() = default;
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, c10::Storage> name_storage_map_;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -5,8 +5,8 @@
|
|||
#endif
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
#include <torch/csrc/jit/serialization/import.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <torch/csrc/jit/serialization/storage_context.h>
|
||||
#include <torch/csrc/jit/serialization/unpickler.h>
|
||||
#include <string>
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ using TypeResolver =
|
|||
using ObjLoader = std::function<
|
||||
c10::intrusive_ptr<c10::ivalue::Object>(at::StrongTypePtr, IValue)>;
|
||||
|
||||
class StorageContext;
|
||||
class DeserializationStorageContext;
|
||||
|
||||
// [unpickler refactor] there is some cruft around PickleOpCode::BUILD,
|
||||
// PickleOpCode::NEWOBJ, and the last_opcode_ member below that should be
|
||||
|
|
@ -51,7 +51,7 @@ class TORCH_API Unpickler {
|
|||
std::function<at::DataPtr(const std::string&)> read_record,
|
||||
c10::optional<at::Device> device,
|
||||
bool use_storage_device = false,
|
||||
std::shared_ptr<StorageContext> storage_context = nullptr)
|
||||
std::shared_ptr<DeserializationStorageContext> storage_context = nullptr)
|
||||
: reader_(std::move(reader)),
|
||||
tensor_table_(),
|
||||
type_resolver_(std::move(type_resolver)),
|
||||
|
|
@ -157,7 +157,7 @@ class TORCH_API Unpickler {
|
|||
|
||||
// Used for torch.package to enable sharing of storages across
|
||||
// ScriptModules and eager modules
|
||||
std::shared_ptr<StorageContext> storage_context_;
|
||||
std::shared_ptr<DeserializationStorageContext> storage_context_;
|
||||
|
||||
// See [type tag serialization]
|
||||
uint64_t version_;
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from typing import (
|
|||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import quote
|
||||
|
|
@ -192,7 +191,6 @@ class PackageExporter:
|
|||
self.zip_file = torch._C.PyTorchFileWriter(f)
|
||||
self.zip_file.set_min_version(6)
|
||||
self.serialized_reduces: Dict[int, Any] = {}
|
||||
self.serialized_storages: Set[str] = set()
|
||||
|
||||
# A graph tracking all the modules and pickle objects added to this
|
||||
# package and the dependencies between them.
|
||||
|
|
@ -202,6 +200,7 @@ class PackageExporter:
|
|||
self.dependency_graph = DiGraph()
|
||||
self.verbose = verbose
|
||||
self.script_module_serializer = torch._C.ScriptModuleSerializer(self.zip_file)
|
||||
self.storage_context = self.script_module_serializer.storage_context()
|
||||
|
||||
# These are OrderedDicts for compatibility with RemovableHandle.
|
||||
# Generic OrderedDict type annotations are not present until 3.7.
|
||||
|
|
@ -776,20 +775,19 @@ node [shape=box];
|
|||
def _persistent_id(self, obj):
|
||||
if torch.is_storage(obj):
|
||||
storage_type = normalize_storage_type(type(obj))
|
||||
obj_key = str(obj._cdata)
|
||||
location = location_tag(obj)
|
||||
name = f".data/{obj_key}.storage"
|
||||
|
||||
if name not in self.serialized_storages:
|
||||
# check to see if storage was previously serialized
|
||||
serialized_files = self.zip_file.get_all_written_records()
|
||||
if name not in serialized_files:
|
||||
if obj.device.type != "cpu":
|
||||
obj = obj.cpu()
|
||||
num_bytes = obj.size() * obj.element_size()
|
||||
self.zip_file.write_record(name, obj.data_ptr(), num_bytes)
|
||||
self.serialized_storages.add(name)
|
||||
return ("storage", storage_type, obj_key, location, obj.size())
|
||||
# serialize storage if not already written
|
||||
storage_present = self.storage_context.has_storage(obj)
|
||||
storage_id = self.storage_context.get_or_add_storage(obj)
|
||||
if not storage_present:
|
||||
if obj.device.type != "cpu":
|
||||
obj = obj.cpu()
|
||||
num_bytes = obj.size() * obj.element_size()
|
||||
self.zip_file.write_record(
|
||||
f".data/{storage_id}.storage", obj.data_ptr(), num_bytes
|
||||
)
|
||||
return ("storage", storage_type, storage_id, location, obj.size())
|
||||
|
||||
if hasattr(obj, "__reduce_package__"):
|
||||
if _gate_torchscript_serialization and isinstance(
|
||||
|
|
|
|||
|
|
@ -173,10 +173,10 @@ class PackageImporter(Importer):
|
|||
restore_location = _get_restore_location(map_location)
|
||||
loaded_storages = {}
|
||||
loaded_reduces = {}
|
||||
storage_context = torch._C.StorageContext()
|
||||
storage_context = torch._C.DeserializationStorageContext()
|
||||
|
||||
def load_tensor(data_type, size, key, location, restore_location):
|
||||
name = f"{key}.storage"
|
||||
name = f"{int(key)}.storage"
|
||||
dtype = data_type(0).dtype
|
||||
|
||||
if storage_context.has_storage(name):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user