mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23630 This is temporary, won't be needed with the new serialization format. But for now, since the main module gets its name from the archive name, we need this for safety, other wise something like `torch.jit.save("torch.pt") will break things. Test Plan: Imported from OSS Reviewed By: jamesr66a Differential Revision: D16592404 Pulled By: suo fbshipit-source-id: b538dc3438a80ea7bca14d84591ecd63f4b1289f
468 lines
16 KiB
C++
468 lines
16 KiB
C++
#include <google/protobuf/util/json_util.h>
|
|
#include <google/protobuf/util/type_resolver_util.h>
|
|
|
|
#include <ATen/core/functional.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <torch/csrc/jit/import.h>
|
|
#include <torch/csrc/jit/import_export_helpers.h>
|
|
#include <torch/csrc/jit/import_source.h>
|
|
#include <torch/csrc/jit/ir.h>
|
|
#include <torch/csrc/jit/pickler.h>
|
|
#include <torch/csrc/jit/script/script_type_parser.h>
|
|
#include <torch/csrc/jit/source_range_serialization.h>
|
|
#include <torch/csrc/jit/source_range_serialization_impl.h>
|
|
|
|
#include "caffe2/core/common.h"
|
|
#include "caffe2/core/types.h"
|
|
#include "caffe2/proto/caffe2_pb.h"
|
|
#include "caffe2/proto/torch_pb.h"
|
|
#include "caffe2/serialize/file_adapter.h"
|
|
#include "caffe2/serialize/inline_container.h"
|
|
#include "caffe2/serialize/istream_adapter.h"
|
|
|
|
#include <ATen/ATen.h>
|
|
|
|
#include <fstream>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
using caffe2::serialize::FileAdapter;
|
|
using caffe2::serialize::IStreamAdapter;
|
|
using caffe2::serialize::PyTorchStreamReader;
|
|
using caffe2::serialize::ReadAdapterInterface;
|
|
|
|
namespace {
|
|
|
|
struct ClassResolver : public script::Resolver {
|
|
explicit ClassResolver(std::shared_ptr<script::CompilationUnit> cu)
|
|
: cu_(std::move(cu)) {}
|
|
TypePtr resolveType(const std::string& name, const SourceRange& loc)
|
|
const override {
|
|
return cu_->get_type(c10::QualifiedName(name));
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<script::CompilationUnit> cu_;
|
|
};
|
|
|
|
// this is a deserializer class which loads script modules from pt files. the
|
|
// content of the file is written using PyTorchStreamWriter, for details please
|
|
// check caffe2/serialize/inline_container.h. all the records except the last
|
|
// one are tensor data, and the last record is a serialized ModelProto, defined
|
|
// in caffe2/proto/torch.proto. ModelProto contains all the metadata of the
|
|
// model, and it is serialized as json.
|
|
class ScriptModuleDeserializer final {
|
|
public:
|
|
ScriptModuleDeserializer(
|
|
std::shared_ptr<script::CompilationUnit> cu,
|
|
std::unique_ptr<PyTorchStreamReader> reader)
|
|
: compilation_unit_(cu),
|
|
reader_(std::move(reader)) {}
|
|
|
|
script::Module deserialize(
|
|
c10::optional<at::Device> device,
|
|
script::ExtraFilesMap& extra_files);
|
|
|
|
private:
|
|
at::Tensor loadTensor(
|
|
const torch::TensorDef& tensor_proto,
|
|
std::unordered_map<std::string, at::Storage>& storageMap);
|
|
|
|
script::Module convertModule(const torch::ModuleDef& module_def);
|
|
|
|
void loadTensorTable(torch::ModelDef* model_def);
|
|
std::vector<IValue> loadPickleArchive(const std::string& name);
|
|
void importCallback(const std::string& qualifier);
|
|
void moduleSetState(const script::Module& module, IValue state);
|
|
|
|
std::shared_ptr<script::CompilationUnit> compilation_unit_;
|
|
|
|
std::unique_ptr<PyTorchStreamReader> reader_;
|
|
c10::optional<at::Device> device_;
|
|
std::vector<std::string> moduleStack_;
|
|
|
|
std::vector<at::Tensor> tensor_table_;
|
|
std::vector<IValue> pickled_ivalues_;
|
|
|
|
std::unordered_set<std::string> imported_libs_;
|
|
};
|
|
|
|
script::Module ScriptModuleDeserializer::deserialize(
|
|
c10::optional<at::Device> device,
|
|
script::ExtraFilesMap& extra_files) {
|
|
C10_LOG_API_USAGE_ONCE("torch.script.load");
|
|
torch::ModelDef model_def;
|
|
at::DataPtr data_ptr;
|
|
size_t data_size;
|
|
std::tie(data_ptr, data_size) = reader_->getRecord("model.json");
|
|
// NB: cannot use JsonStringToMessage, since fbcode's protobuf is too old
|
|
// be consistent with JsonStringToMessage
|
|
std::string url_prefix = "type.googleapis.com";
|
|
std::unique_ptr<::google::protobuf::util::TypeResolver> resolver(
|
|
::google::protobuf::util::NewTypeResolverForDescriptorPool(
|
|
url_prefix, model_def.GetDescriptor()->file()->pool()));
|
|
std::string json_string = std::string(
|
|
static_cast<char*>(data_ptr.get()),
|
|
static_cast<char*>(data_ptr.get()) + data_size);
|
|
std::string binary_string;
|
|
::google::protobuf::util::JsonParseOptions opts;
|
|
opts.ignore_unknown_fields = true;
|
|
auto convert_result = ::google::protobuf::util::JsonToBinaryString(
|
|
resolver.get(),
|
|
url_prefix + "/" + model_def.GetDescriptor()->full_name(),
|
|
json_string,
|
|
&binary_string,
|
|
opts);
|
|
if (!convert_result.ok()) {
|
|
std::stringstream ss;
|
|
ss << convert_result;
|
|
AT_ERROR(ss.str());
|
|
}
|
|
AT_ASSERTM(
|
|
model_def.ParseFromString(binary_string),
|
|
"JSON transcoder produced invalid protobuf output.");
|
|
device_ = device;
|
|
|
|
const auto& module_def = model_def.main_module();
|
|
|
|
// Load extra files.
|
|
for (const auto& kv : extra_files) {
|
|
const std::string& key = "extra/" + kv.first;
|
|
if (reader_->hasFile(key)) {
|
|
at::DataPtr meta_ptr;
|
|
size_t meta_size;
|
|
std::tie(meta_ptr, meta_size) = reader_->getRecord(key);
|
|
extra_files[kv.first] =
|
|
std::string(static_cast<char*>(meta_ptr.get()), meta_size);
|
|
}
|
|
}
|
|
|
|
loadTensorTable(&model_def);
|
|
if (model_def.proto_version() >= 2) {
|
|
pickled_ivalues_ = loadPickleArchive("attributes.pkl");
|
|
}
|
|
|
|
return convertModule(module_def);
|
|
}
|
|
|
|
void ScriptModuleDeserializer::loadTensorTable(torch::ModelDef* model_def) {
|
|
std::unordered_map<std::string, at::Storage> storageMap;
|
|
for (const torch::TensorDef& tensor : model_def->tensors()) {
|
|
tensor_table_.emplace_back(loadTensor(tensor, storageMap));
|
|
}
|
|
}
|
|
|
|
std::vector<IValue> ScriptModuleDeserializer::loadPickleArchive(const std::string& name) {
|
|
at::DataPtr attributes_ptr;
|
|
size_t attributes_size;
|
|
std::tie(attributes_ptr, attributes_size) = reader_->getRecord(name);
|
|
Unpickler unpickler(
|
|
attributes_ptr.get(),
|
|
attributes_size,
|
|
&tensor_table_,
|
|
[&](const c10::QualifiedName& qn) {
|
|
importCallback(qn.prefix());
|
|
return c10::StrongTypePtr(
|
|
compilation_unit_, compilation_unit_->get_class(qn));
|
|
});
|
|
return unpickler.parse_ivalue_list();
|
|
}
|
|
|
|
at::Tensor ScriptModuleDeserializer::loadTensor(
|
|
const torch::TensorDef& tensor_proto,
|
|
std::unordered_map<std::string, at::Storage>& storageMap) {
|
|
std::vector<int64_t> dims(
|
|
tensor_proto.dims().begin(), tensor_proto.dims().end());
|
|
std::vector<int64_t> strides(
|
|
tensor_proto.strides().begin(), tensor_proto.strides().end());
|
|
auto type = at::typeMetaToScalarType(
|
|
caffe2::DataTypeToTypeMeta(tensor_proto.data_type()));
|
|
if (tensor_proto.is_quantized()) {
|
|
type = toQIntType(type);
|
|
}
|
|
const std::string& record_key = tensor_proto.data().key();
|
|
AT_ASSERT(tensor_proto.has_device() && !tensor_proto.device().empty());
|
|
at::Device device(tensor_proto.device());
|
|
if (device_.has_value()) {
|
|
// override the device, if user provides map_location
|
|
device = device_.value();
|
|
}
|
|
|
|
auto storage_it = storageMap.find(record_key);
|
|
if (storage_it == storageMap.end()) {
|
|
at::DataPtr storage_ptr;
|
|
uint64_t record_size;
|
|
std::tie(storage_ptr, record_size) = reader_->getRecord(record_key);
|
|
auto cpu_storage = at::Storage(
|
|
at::CPU(type).typeMeta(),
|
|
record_size / at::CPU(type).typeMeta().itemsize(),
|
|
std::move(storage_ptr),
|
|
/*allocator=*/nullptr,
|
|
/*resizable=*/false); // NB: we didn't set any allocator for the tensor
|
|
if (device.type() == at::DeviceType::CPU) {
|
|
storage_it =
|
|
storageMap.insert(std::make_pair(record_key, cpu_storage)).first;
|
|
} else if (device.type() == at::DeviceType::CUDA) {
|
|
at::Tensor cpu_tensor =
|
|
at::empty({0}, at::CPU(type).options()).set_(cpu_storage);
|
|
at::Storage cuda_storage =
|
|
cpu_tensor.to(device, cpu_tensor.scalar_type()).storage();
|
|
storage_it =
|
|
storageMap.insert(std::make_pair(record_key, cuda_storage)).first;
|
|
} else {
|
|
AT_ERROR(
|
|
"supported devices include CPU and CUDA, however got ",
|
|
at::DeviceTypeName(device.type(), false));
|
|
}
|
|
}
|
|
if (storage_it->second.device().type() != device.type() ||
|
|
(device.has_index() &&
|
|
storage_it->second.device().index() != device.index())) {
|
|
std::stringstream oss;
|
|
oss << "storage previously was specified with device "
|
|
<< storage_it->second.device() << "but now is specified with device "
|
|
<< device << std::endl;
|
|
AT_ERROR(oss.str());
|
|
}
|
|
|
|
at::Tensor result;
|
|
|
|
if (device.type() == at::DeviceType::CPU) {
|
|
if (tensor_proto.is_quantized()) {
|
|
result = at::_empty_affine_quantized(
|
|
{0},
|
|
type,
|
|
tensor_proto.scale(),
|
|
tensor_proto.zero_point())
|
|
.set_(storage_it->second, tensor_proto.offset(), dims, strides);
|
|
}
|
|
else {
|
|
result =
|
|
at::empty({0}, at::CPU(type).options())
|
|
.set_(storage_it->second, tensor_proto.offset(), dims, strides);
|
|
}
|
|
} else if (device.type() == at::DeviceType::CUDA) {
|
|
result =
|
|
at::empty(
|
|
{0}, c10::TensorOptions(type).device(storage_it->second.device()))
|
|
.set_(storage_it->second, tensor_proto.offset(), dims, strides);
|
|
}
|
|
AT_ASSERT(result.defined());
|
|
|
|
result = autograd::make_variable(result, tensor_proto.requires_grad());
|
|
|
|
return result;
|
|
}
|
|
|
|
void ScriptModuleDeserializer::importCallback(const std::string& qualifier) {
|
|
if (imported_libs_.count(qualifier)) {
|
|
return;
|
|
}
|
|
imported_libs_.insert(qualifier);
|
|
std::function<void(const std::string&)> import_callback =
|
|
[this](const std::string& qualifier) { importCallback(qualifier); };
|
|
const std::string path = ImportExportHelpers::qualifierToPath(qualifier);
|
|
at::DataPtr data;
|
|
size_t size;
|
|
std::tie(data, size) = reader_->getRecord(path);
|
|
auto src = std::make_shared<Source>(
|
|
std::string(static_cast<const char*>(data.get()), size), path, 0);
|
|
script::import_libs(
|
|
compilation_unit_, qualifier, src, tensor_table_, import_callback);
|
|
}
|
|
|
|
void ScriptModuleDeserializer::moduleSetState(
|
|
const script::Module& module,
|
|
IValue state) {
|
|
auto setstate = module.find_method("__setstate__");
|
|
|
|
TORCH_CHECK(
|
|
setstate,
|
|
"Cannot call '__setstate__' method because"
|
|
" it does not exist");
|
|
|
|
// TODO: once modules are first class in the interpreter and methods are not
|
|
// lowered, change this to `module->run_method("__setstate__", {state});`
|
|
if (setstate->num_inputs() == 1) {
|
|
setstate->run({module.module_object()});
|
|
} else if (setstate->num_inputs() == 2) {
|
|
setstate->run({module.module_object(), state});
|
|
} else {
|
|
AT_ERROR("Unexpected schema on '__setstate__'");
|
|
}
|
|
}
|
|
|
|
script::Module ScriptModuleDeserializer::convertModule(
|
|
const torch::ModuleDef& module_def) {
|
|
// HACK: The current model exporter can create module_defs with invalid Python
|
|
// identifiers as names (they contain `.`)
|
|
const auto atoms = c10::QualifiedName(module_def.name()).atoms();
|
|
const size_t numPushed = atoms.size();
|
|
for (const auto& atom : atoms) {
|
|
moduleStack_.emplace_back(atom);
|
|
}
|
|
auto module =
|
|
script::Module(c10::QualifiedName(moduleStack_), compilation_unit_);
|
|
for (int i = 0; i < module_def.submodules_size(); ++i) {
|
|
const torch::ModuleDef& sub_def = module_def.submodules(i);
|
|
auto submodule = convertModule(sub_def);
|
|
module.register_module(sub_def.name(), submodule);
|
|
}
|
|
for (int i = 0; i < module_def.parameters_size(); ++i) {
|
|
const torch::ParameterDef& param_def = module_def.parameters(i);
|
|
at::Tensor tensor = tensor_table_.at(param_def.tensor_id());
|
|
if (param_def.is_buffer()) {
|
|
module.register_buffer(param_def.name(), tensor);
|
|
} else {
|
|
module.register_parameter(param_def.name(), tensor, /*is_buffer=*/false);
|
|
}
|
|
}
|
|
script::ScriptTypeParser typeParser(
|
|
std::make_shared<ClassResolver>(compilation_unit_));
|
|
for (int i = 0; i < module_def.attributes_size(); ++i) {
|
|
const torch::AttributeDef& attr_def = module_def.attributes(i);
|
|
if (module.find_buffer(attr_def.name())) {
|
|
// TODO: handle this above so this can be removed
|
|
continue;
|
|
}
|
|
|
|
IValue ivalue;
|
|
if (attr_def.id() >= 0) {
|
|
// attribute has no value in the table, set it to None for now. After
|
|
// __getstate__, check that all the attributes that are not Optional
|
|
// can't be None
|
|
ivalue = pickled_ivalues_.at(attr_def.id());
|
|
}
|
|
|
|
module.register_attribute(
|
|
attr_def.name(), typeParser.parseType(attr_def.type()), ivalue);
|
|
}
|
|
|
|
// If present, load in the table of source ranges from the original
|
|
// generating code.
|
|
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr;
|
|
if (module_def.has_torchscript_debug_arena()) {
|
|
at::DataPtr data;
|
|
size_t size;
|
|
std::tie(data, size) =
|
|
reader_->getRecord(module_def.torchscript_debug_arena().key());
|
|
|
|
gen_ranges =
|
|
std::make_shared<ConcreteSourceRangeUnpickler>(std::move(data), size);
|
|
}
|
|
|
|
if (module_def.has_torchscript_arena()) {
|
|
at::DataPtr data;
|
|
size_t size;
|
|
std::tie(data, size) =
|
|
reader_->getRecord(module_def.torchscript_arena().key());
|
|
std::string data_str(static_cast<const char*>(data.get()), size);
|
|
auto src = std::make_shared<Source>(
|
|
std::string(static_cast<const char*>(data.get()), size),
|
|
module_def.torchscript_arena().key(),
|
|
1,
|
|
std::move(gen_ranges));
|
|
|
|
std::function<void(const std::string&)> import_callback =
|
|
[&, this](const std::string& qualifier) { importCallback(qualifier); };
|
|
script::import_methods(module, src, tensor_table_, import_callback);
|
|
}
|
|
|
|
if (module_def.has_get_state_attribute_id()) {
|
|
moduleSetState(
|
|
module, pickled_ivalues_.at(module_def.get_state_attribute_id()));
|
|
}
|
|
|
|
for (const auto& slot : module.get_attributes()) {
|
|
// Verify that all the non-optional attributes have been initialized
|
|
// TODO: Issue #20497
|
|
if (slot.type()->kind() != TypeKind::OptionalType) {
|
|
TORCH_CHECK(
|
|
!slot.value().isNone(),
|
|
"The field '",
|
|
slot.name(),
|
|
"' was left unitialized after __setstate__, but expected a ",
|
|
"value of type '",
|
|
slot.type()->python_str(),
|
|
"'");
|
|
}
|
|
}
|
|
|
|
for (size_t i = 0; i < numPushed; i++) {
|
|
moduleStack_.pop_back();
|
|
}
|
|
return module;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
script::Module import_ir_module(
|
|
std::shared_ptr<script::CompilationUnit> cu,
|
|
std::istream& in,
|
|
c10::optional<at::Device> device,
|
|
script::ExtraFilesMap& extra_files) {
|
|
auto reader = torch::make_unique<PyTorchStreamReader>(&in);
|
|
ScriptModuleDeserializer deserializer(
|
|
std::move(cu), std::move(reader));
|
|
return deserializer.deserialize(device, extra_files);
|
|
}
|
|
|
|
script::Module import_ir_module(
|
|
std::shared_ptr<script::CompilationUnit> cu,
|
|
const std::string& filename,
|
|
c10::optional<at::Device> device,
|
|
script::ExtraFilesMap& extra_files) {
|
|
auto reader = torch::make_unique<PyTorchStreamReader>(filename);
|
|
ScriptModuleDeserializer deserializer(
|
|
std::move(cu), std::move(reader));
|
|
return deserializer.deserialize(device, extra_files);
|
|
}
|
|
|
|
script::Module import_ir_module(
|
|
std::shared_ptr<script::CompilationUnit> cu,
|
|
std::unique_ptr<ReadAdapterInterface> rai,
|
|
c10::optional<at::Device> device,
|
|
script::ExtraFilesMap& extra_files) {
|
|
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
|
|
ScriptModuleDeserializer deserializer(
|
|
std::move(cu), std::move(reader));
|
|
return deserializer.deserialize(device, extra_files);
|
|
}
|
|
|
|
script::Module load(
|
|
std::istream& in,
|
|
c10::optional<at::Device> device,
|
|
script::ExtraFilesMap& extra_files) {
|
|
std::unique_ptr<IStreamAdapter> rai =
|
|
caffe2::make_unique<IStreamAdapter>(&in);
|
|
auto module = load(std::move(rai), device, extra_files);
|
|
return module;
|
|
}
|
|
|
|
script::Module load(
|
|
const std::string& filename,
|
|
c10::optional<at::Device> device,
|
|
script::ExtraFilesMap& extra_files) {
|
|
std::unique_ptr<FileAdapter> rai = caffe2::make_unique<FileAdapter>(filename);
|
|
auto module = load(std::move(rai), device, extra_files);
|
|
return module;
|
|
}
|
|
|
|
script::Module load(
|
|
std::unique_ptr<ReadAdapterInterface> rai,
|
|
c10::optional<c10::Device> device,
|
|
script::ExtraFilesMap& extra_files) {
|
|
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
|
|
auto cu = std::make_shared<script::CompilationUnit>();
|
|
ScriptModuleDeserializer deserializer(
|
|
std::move(cu), std::move(reader));
|
|
return deserializer.deserialize(device, extra_files);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|