mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Use the torch.proto to store script module (#13736)
Summary: Directly operate protobuf in the serializer/deserializer. Pull Request resolved: https://github.com/pytorch/pytorch/pull/13736 Reviewed By: dzhulgakov Differential Revision: D13028487 Pulled By: houseroad fbshipit-source-id: e578474008874f00f2a22f0a2ffd85f52643881a
This commit is contained in:
parent
2871d3951f
commit
e2a7d43dfd
|
|
@ -146,6 +146,14 @@ inline int stoi(const string& str) {
|
|||
return n;
|
||||
}
|
||||
|
||||
inline uint64_t stoull(const string& str) {
|
||||
std::stringstream ss;
|
||||
uint64_t n = 0;
|
||||
ss << str;
|
||||
ss >> n;
|
||||
return n;
|
||||
}
|
||||
|
||||
inline double stod(const string& str, std::size_t* pos = 0) {
|
||||
std::stringstream ss;
|
||||
ss << str;
|
||||
|
|
@ -164,6 +172,7 @@ inline double stod(const string& str, std::size_t* pos = 0) {
|
|||
#define CAFFE2_TESTONLY_WE_ARE_USING_CUSTOM_STRING_FUNCTIONS 0
|
||||
using std::to_string;
|
||||
using std::stoi;
|
||||
using std::stoull;
|
||||
using std::stod;
|
||||
#endif // defined(__ANDROID__) || defined(CAFFE2_FORCE_STD_STRING_FALLBACK_TEST)
|
||||
|
||||
|
|
|
|||
|
|
@ -29,9 +29,11 @@ message ExternalDataProto {
|
|||
optional SourceType source_type = 1 [default = INLINE_CONTAINER];
|
||||
// used together with type
|
||||
optional string record_id = 2;
|
||||
// the size of the entire record (in bytes)
|
||||
optional uint64 record_size = 5;
|
||||
// the offset of the starting point, the content may be shared between
|
||||
// multiple tensors
|
||||
optional int64 offset = 3;
|
||||
optional int64 offset = 3 [default = 0];
|
||||
// the strides of the content
|
||||
repeated int64 strides = 4;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ message ParameterDef {
|
|||
// tensor type parameter
|
||||
optional caffe2.TensorProto tensor = 3;
|
||||
// objects other than tensors will be added here
|
||||
|
||||
optional string name = 4;
|
||||
}
|
||||
|
||||
message MethodDef {
|
||||
|
|
@ -26,6 +28,8 @@ message MethodDef {
|
|||
// if both exist, we reconstruct the graph from torch_script
|
||||
optional caffe2.NetDef graph = 2;
|
||||
optional string torch_script = 3;
|
||||
// temporary place to store the methods of jit script modules
|
||||
optional bytes onnx_proto = 101;
|
||||
|
||||
// inputs and outputs are inferred from graph or script
|
||||
}
|
||||
|
|
@ -50,6 +54,8 @@ message ModuleDef {
|
|||
|
||||
// the names of inputs and outputs of the module are inferred
|
||||
// from the main method.
|
||||
|
||||
optional string name = 6;
|
||||
}
|
||||
|
||||
enum ProtoVersion {
|
||||
|
|
|
|||
|
|
@ -88,10 +88,12 @@ constexpr uint64_t kFieldAlignment =
|
|||
64L; // 64 byte alignment supports up to AVX512 for mmap
|
||||
|
||||
// Reader-specific constants
|
||||
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x1L;
|
||||
// FileFOrmatVersion 1 was used in PyTorch 1.0 rc, which is a hacked ONNX proto.
|
||||
constexpr uint64_t kMinSupportedFileFormatVersion = 0x2L;
|
||||
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x2L;
|
||||
|
||||
// Writer-specific constants
|
||||
constexpr uint64_t kFileFormatVersion = 0x1L;
|
||||
constexpr uint64_t kFileFormatVersion = 0x2L;
|
||||
constexpr char kPadValue = -17; // 0xEF
|
||||
|
||||
} // namespace
|
||||
|
|
@ -109,7 +111,8 @@ class PyTorchStreamReader final {
|
|||
AT_ASSERTM(
|
||||
file_size_ % kFieldAlignment == 0,
|
||||
"File length is not a multiple of the alignment"
|
||||
" size. Is this a valid PyTorch model file?");
|
||||
" size. Is this a valid PyTorch model file? File size: ",
|
||||
caffe2::to_string(file_size_));
|
||||
readAndValidateFileHeader();
|
||||
}
|
||||
|
||||
|
|
@ -203,6 +206,13 @@ class PyTorchStreamReader final {
|
|||
" be corrupted or is not actually a PyTorch file.");
|
||||
// magic number mismatch in PyTorch file.
|
||||
uint64_t file_format_version = read64BitIntegerLittleEndian();
|
||||
AT_ASSERTM(
|
||||
file_format_version >= kMinSupportedFileFormatVersion,
|
||||
"Attempted to read a PyTorch file with version ",
|
||||
caffe2::to_string(file_format_version),
|
||||
", but the minimum supported version for reading is ",
|
||||
caffe2::to_string(kMinSupportedFileFormatVersion),
|
||||
". Your PyTorch script module file is too old. Please re-export it again.");
|
||||
AT_ASSERTM(
|
||||
file_format_version <= kMaxSupportedFileFormatVersion,
|
||||
"Attempted to read a PyTorch file with version ",
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
#include <google/protobuf/util/json_util.h>
|
||||
#include <google/protobuf/util/type_resolver_util.h>
|
||||
|
||||
#include "torch/csrc/jit/export.h"
|
||||
#include "torch/csrc/autograd/symbolic.h"
|
||||
#include "torch/csrc/onnx/onnx.h"
|
||||
|
|
@ -6,17 +9,21 @@
|
|||
#include <torch/csrc/jit/assertions.h>
|
||||
#include "torch/csrc/jit/passes/dead_code_elimination.h"
|
||||
|
||||
#include "caffe2/core/types.h"
|
||||
#include "caffe2/proto/caffe2_pb.h"
|
||||
#include "caffe2/proto/torch_pb.h"
|
||||
#include "caffe2/serialize/inline_container.h"
|
||||
#include "onnx/onnx_pb.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include "c10/util/Optional.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
|
|
@ -426,29 +433,20 @@ void GraphEncoder::EncodeTensor(
|
|||
}
|
||||
}
|
||||
|
||||
class ModuleEncoder: public EncoderBase {
|
||||
class MethodEncoder : public EncoderBase {
|
||||
public:
|
||||
ModuleEncoder(const script::Module &module,
|
||||
std::ostream& out);
|
||||
MethodEncoder(
|
||||
const script::Method& method,
|
||||
std::string* torch_script,
|
||||
std::unordered_map<const void*, uint64_t>* storage_map,
|
||||
std::unordered_map<at::Tensor*, std::string>* parameter_map,
|
||||
PyTorchStreamWriter* writer);
|
||||
|
||||
private:
|
||||
void EncodeModule(onnx::GraphProto *graph_proto, const script::Module &module);
|
||||
|
||||
void EncodeParameters(onnx::GraphProto *graph_proto,
|
||||
const script::Module &module,
|
||||
const std::string prefix);
|
||||
|
||||
void EncodeParameter(onnx::TensorProto *tensor_proto,
|
||||
const script::NamedParameter ¶meter,
|
||||
const std::string prefix);
|
||||
|
||||
void EncodeMethods(onnx::GraphProto *graph_proto,
|
||||
const script::Module &module,
|
||||
const std::string prefix);
|
||||
|
||||
void EncodeMethod(onnx::NodeProto *node_proto,
|
||||
script::Method &method,
|
||||
const std::string prefix);
|
||||
void EncodeMethod(
|
||||
std::string* torch_script,
|
||||
const script::Method& method,
|
||||
const std::string prefix);
|
||||
|
||||
void EncodeTensor(
|
||||
onnx::TensorProto* tensor_proto,
|
||||
|
|
@ -467,27 +465,35 @@ class ModuleEncoder: public EncoderBase {
|
|||
const TypePtr& type,
|
||||
const std::string& name);
|
||||
|
||||
PyTorchStreamWriter stream_writer_;
|
||||
PyTorchStreamWriter* stream_writer_;
|
||||
// Used to deduplicate tensor storages
|
||||
std::unordered_map<const void*, uint64_t> storage_dedup_map_;
|
||||
std::unordered_map<const void*, uint64_t>* storage_dedup_map_;
|
||||
|
||||
// Used to keep track of Parameter names so Methods can refer to them
|
||||
std::unordered_map<at::Tensor*, std::string> parameter_map_;
|
||||
std::unordered_map<at::Tensor*, std::string>* parameter_map_;
|
||||
|
||||
// Used to create sequential dummy names for node types
|
||||
size_t type_counter_ = 0;
|
||||
};
|
||||
|
||||
ModuleEncoder::ModuleEncoder(
|
||||
const script::Module &module,
|
||||
std::ostream& out)
|
||||
: EncoderBase(onnx_torch::OperatorExportTypes::RAW, false),
|
||||
stream_writer_(&out) {
|
||||
model_proto_.set_doc_string("THIS PROTO IS NOT STANDARD ONNX");
|
||||
EncodeModule(model_proto_.mutable_graph(), module);
|
||||
MethodEncoder::MethodEncoder(
|
||||
const script::Method& method,
|
||||
std::string* torch_script,
|
||||
std::unordered_map<const void*, uint64_t>* storage_map,
|
||||
std::unordered_map<at::Tensor*, std::string>* parameter_map,
|
||||
PyTorchStreamWriter* writer)
|
||||
: EncoderBase(onnx_torch::OperatorExportTypes::RAW, false) {
|
||||
storage_dedup_map_ = storage_map;
|
||||
parameter_map_ = parameter_map;
|
||||
stream_writer_ = writer;
|
||||
// we already keep the tree structure in the top level module,
|
||||
// so pass "" as prefix
|
||||
EncodeMethod(torch_script, method, "");
|
||||
}
|
||||
|
||||
void ModuleEncoder::EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto, const Value *n) {
|
||||
void MethodEncoder::EncodeIntermediateValueInfo(
|
||||
onnx::GraphProto* graph_proto,
|
||||
const Value* n) {
|
||||
auto v = graph_proto->add_value_info();
|
||||
EncodeTypeInfo(graph_proto, v, n->type(), n->uniqueName());
|
||||
}
|
||||
|
|
@ -511,8 +517,8 @@ c10::optional<std::string> getBaseTypeDenotation(TypeKind& kind) {
|
|||
return c10::nullopt;
|
||||
}
|
||||
|
||||
void ModuleEncoder::EncodeTypeInfo(
|
||||
onnx::GraphProto *graph_proto,
|
||||
void MethodEncoder::EncodeTypeInfo(
|
||||
onnx::GraphProto* graph_proto,
|
||||
onnx::ValueInfoProto* v,
|
||||
const TypePtr& type,
|
||||
const std::string& name) {
|
||||
|
|
@ -594,74 +600,20 @@ void ModuleEncoder::EncodeTypeInfo(
|
|||
}
|
||||
}
|
||||
|
||||
void ModuleEncoder::EncodeValueInfo(
|
||||
onnx::GraphProto *graph_proto,
|
||||
void MethodEncoder::EncodeValueInfo(
|
||||
onnx::GraphProto* graph_proto,
|
||||
onnx::ValueInfoProto* v,
|
||||
const Value* n) {
|
||||
EncodeTypeInfo(graph_proto, v, n->type(), n->uniqueName());
|
||||
}
|
||||
|
||||
void ModuleEncoder::EncodeModule(
|
||||
onnx::GraphProto *graph_proto,
|
||||
const script::Module &module) {
|
||||
EncodeParameters(graph_proto, module, "");
|
||||
EncodeMethods(graph_proto, module, "");
|
||||
auto str = model_proto_.SerializeAsString();
|
||||
stream_writer_.writeRecord(str.data(), str.size());
|
||||
}
|
||||
|
||||
void ModuleEncoder::EncodeParameters(
|
||||
onnx::GraphProto *graph_proto,
|
||||
const script::Module &module,
|
||||
const std::string prefix) {
|
||||
// Encode each parameter as a initializer in the proto
|
||||
for (auto ¶meter : module.get_parameters()) {
|
||||
auto tensor_proto = graph_proto->add_initializer();
|
||||
EncodeParameter(tensor_proto, parameter.value(), prefix);
|
||||
}
|
||||
|
||||
for (auto &submodule : module.get_modules()) {
|
||||
EncodeParameters(graph_proto, *submodule->module, prefix + submodule.key() + ".");
|
||||
}
|
||||
}
|
||||
|
||||
void ModuleEncoder::EncodeParameter(
|
||||
onnx::TensorProto *tensor_proto,
|
||||
const script::NamedParameter ¶meter,
|
||||
const std::string prefix) {
|
||||
auto tensor = parameter.slot();
|
||||
// Name will be prefixed by submodule. e.g. submodule_foo.parameter_bar
|
||||
auto name = prefix + parameter.name;
|
||||
|
||||
tensor_proto->set_name(name);
|
||||
parameter_map_[tensor] = name;
|
||||
|
||||
// Parameters have these fields, but tensors do not
|
||||
tensor_proto->add_int64_data(parameter.is_buffer);
|
||||
tensor_proto->add_int64_data(tensor->requires_grad());
|
||||
|
||||
EncodeTensor(tensor_proto, *tensor, name);
|
||||
}
|
||||
|
||||
void ModuleEncoder::EncodeMethods(
|
||||
onnx::GraphProto *graph_proto,
|
||||
const script::Module &module,
|
||||
const std::string prefix) {
|
||||
// Encode each parameter as a initializer in the proto
|
||||
for (auto &method : module.get_methods()) {
|
||||
auto node_proto = graph_proto->add_node();
|
||||
EncodeMethod(node_proto, *method.value(), prefix);
|
||||
}
|
||||
|
||||
for (auto &submodule : module.get_modules()) {
|
||||
EncodeMethods(graph_proto, *submodule->module, prefix + submodule.key() + ".");
|
||||
}
|
||||
}
|
||||
|
||||
void ModuleEncoder::EncodeMethod(
|
||||
onnx::NodeProto *node_proto,
|
||||
script::Method &method,
|
||||
void MethodEncoder::EncodeMethod(
|
||||
std::string* torch_script,
|
||||
const script::Method& method,
|
||||
const std::string prefix) {
|
||||
onnx::ModelProto model_proto;
|
||||
model_proto.set_doc_string("THIS PROTO IS NOT STANDARD ONNX");
|
||||
auto* node_proto = model_proto.mutable_graph()->add_node();
|
||||
node_proto->set_name(prefix + method.name());
|
||||
if (method.is_optimized()) {
|
||||
// mark that this method was optimized
|
||||
|
|
@ -673,8 +625,8 @@ void ModuleEncoder::EncodeMethod(
|
|||
|
||||
// Store member_inputs of Method in input
|
||||
for (auto &member_input : method.params()) {
|
||||
auto it = parameter_map_.find(member_input);
|
||||
JIT_ASSERT(it != parameter_map_.end());
|
||||
auto it = parameter_map_->find(member_input);
|
||||
JIT_ASSERT(it != parameter_map_->end());
|
||||
node_proto->add_input(it->second);
|
||||
}
|
||||
|
||||
|
|
@ -690,15 +642,16 @@ void ModuleEncoder::EncodeMethod(
|
|||
}
|
||||
}
|
||||
EncodeBlock(attr_proto->mutable_g(), method.graph()->block(), {});
|
||||
AT_ASSERT(model_proto.SerializeToString(torch_script));
|
||||
}
|
||||
|
||||
void ModuleEncoder::EncodeTensor(
|
||||
void MethodEncoder::EncodeTensor(
|
||||
onnx::TensorProto* tensor_proto,
|
||||
const at::Tensor& tensor,
|
||||
const c10::optional<std::string> external_ref) {
|
||||
auto storage_ptr = tensor.storage().unsafeGetStorageImpl();
|
||||
auto dedup_it = storage_dedup_map_.find(storage_ptr);
|
||||
if (dedup_it != storage_dedup_map_.end()) {
|
||||
auto dedup_it = storage_dedup_map_->find(storage_ptr);
|
||||
if (dedup_it != storage_dedup_map_->end()) {
|
||||
tensor_proto->add_int64_data(dedup_it->second);
|
||||
} else {
|
||||
at::Tensor t = tensor;
|
||||
|
|
@ -714,10 +667,10 @@ void ModuleEncoder::EncodeTensor(
|
|||
.cpu();
|
||||
}
|
||||
|
||||
auto record_number = stream_writer_.writeRecord(
|
||||
auto record_number = stream_writer_->writeRecord(
|
||||
static_cast<char*>(t.storage().data()), t.type().elementSizeInBytes() * t.storage().size());
|
||||
tensor_proto->add_int64_data(record_number);
|
||||
storage_dedup_map_[storage_ptr] = record_number;
|
||||
(*storage_dedup_map_)[storage_ptr] = record_number;
|
||||
}
|
||||
|
||||
for (auto &d : tensor.sizes()) {
|
||||
|
|
@ -899,6 +852,156 @@ void dump(const onnx::ModelProto& model, std::ostream& stream, size_t indent) {
|
|||
}
|
||||
stream << idt(indent) << "}\n";
|
||||
}
|
||||
|
||||
class ScriptModuleSerializer final {
|
||||
public:
|
||||
ScriptModuleSerializer(const std::string& filename)
|
||||
: ofs_(
|
||||
filename,
|
||||
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary),
|
||||
writer_(&ofs_) {
|
||||
// TODO appropriate support for mmap, right now we still use stream writer
|
||||
}
|
||||
ScriptModuleSerializer(std::ostream* ofs) : ofs_(), writer_(ofs) {}
|
||||
void serialize(const script::Module& module) {
|
||||
torch::ModelDef model_def;
|
||||
convertToModel(module, &model_def);
|
||||
std::string output;
|
||||
// NB: cannot use MessageToJsonString, since fbcode's protobuf is too old
|
||||
// be consistent with MessageToJsonString
|
||||
std::string url_prefix = "type.googleapis.com";
|
||||
std::unique_ptr<::google::protobuf::util::TypeResolver> resolver(
|
||||
::google::protobuf::util::NewTypeResolverForDescriptorPool(
|
||||
url_prefix, model_def.GetDescriptor()->file()->pool()));
|
||||
::google::protobuf::util::Status convert_result =
|
||||
::google::protobuf::util::BinaryToJsonString(
|
||||
resolver.get(),
|
||||
url_prefix + "/" + model_def.GetDescriptor()->full_name(),
|
||||
model_def.SerializeAsString(),
|
||||
&output);
|
||||
if (!convert_result.ok()) {
|
||||
std::stringstream ss;
|
||||
ss << convert_result;
|
||||
AT_ERROR(ss.str());
|
||||
}
|
||||
auto record_id = writer_.writeRecord(output.data(), output.size());
|
||||
writer_.writeEndOfFile();
|
||||
}
|
||||
|
||||
private:
|
||||
void convertToModel(
|
||||
const script::Module& module,
|
||||
torch::ModelDef* model_def) {
|
||||
model_def->set_name("script-model");
|
||||
model_def->set_producer_name("pytorch");
|
||||
model_def->set_producer_version("1.0"); // TODO: set the producer version
|
||||
// using appropriate function call
|
||||
model_def->set_proto_version(torch::ProtoVersion::PROTO_VERSION_NEWEST);
|
||||
std::string main_module_name = "";
|
||||
collectParamsInfo(module, main_module_name);
|
||||
convertModule(module, main_module_name, model_def->mutable_main_module());
|
||||
}
|
||||
void collectParamsInfo(
|
||||
const script::Module& module,
|
||||
const std::string& prefix) {
|
||||
for (const auto& elem : module.get_parameters()) {
|
||||
const script::NamedParameter& param = elem.value();
|
||||
parameterMap_[param.slot()] = prefix + param.name;
|
||||
}
|
||||
for (const auto& elem : module.get_modules()) {
|
||||
collectParamsInfo(*elem->module, prefix + elem.key() + ".");
|
||||
}
|
||||
}
|
||||
void convertModule(
|
||||
const script::Module& module,
|
||||
const std::string& name,
|
||||
torch::ModuleDef* module_def) {
|
||||
module_def->set_name(name);
|
||||
for (const auto& elem : module.get_parameters()) {
|
||||
torch::ParameterDef* param_def = module_def->add_parameters();
|
||||
convertParameter(elem.value(), param_def);
|
||||
}
|
||||
for (auto& elem : module.get_methods()) {
|
||||
torch::MethodDef* method_def = module_def->add_methods();
|
||||
convertMethod(*elem.value(), method_def);
|
||||
}
|
||||
for (const auto& elem : module.get_modules()) {
|
||||
torch::ModuleDef* sub_def = module_def->add_submodules();
|
||||
convertModule(*elem->module, elem.key(), sub_def);
|
||||
}
|
||||
}
|
||||
void convertParameter(
|
||||
const script::NamedParameter& param,
|
||||
torch::ParameterDef* param_def) {
|
||||
param_def->set_name(param.name);
|
||||
param_def->set_is_buffer(param.is_buffer);
|
||||
param_def->set_require_gradient(param.slot()->requires_grad());
|
||||
convertTensor(*(param.slot()), param_def->mutable_tensor());
|
||||
}
|
||||
void convertTensor(
|
||||
const at::Tensor& tensor,
|
||||
caffe2::TensorProto* tensor_proto) {
|
||||
for (auto d : tensor.sizes()) {
|
||||
tensor_proto->add_dims(d);
|
||||
}
|
||||
tensor_proto->set_data_type(caffe2::TypeMetaToDataType(
|
||||
at::scalarTypeToTypeMeta(tensor.type().scalarType())));
|
||||
tensor_proto->set_storage_type(caffe2::TensorProto_StorageType_EXTERNAL);
|
||||
caffe2::ExternalDataProto* external_data =
|
||||
tensor_proto->mutable_external_data();
|
||||
for (auto s : tensor.strides()) {
|
||||
external_data->add_strides(s);
|
||||
}
|
||||
external_data->set_offset(tensor.storage_offset());
|
||||
uint64_t record_size =
|
||||
tensor.type().elementSizeInBytes() * tensor.storage().size();
|
||||
external_data->set_record_size(record_size);
|
||||
auto* key = tensor.storage().unsafeGetStorageImpl();
|
||||
auto it = storageMap_.find(key);
|
||||
if (it == storageMap_.end()) {
|
||||
// TODO HIP support
|
||||
uint64_t record_id;
|
||||
if (tensor.storage().device_type() == at::DeviceType::CUDA) {
|
||||
// NB: This new tensor is created to support cuda tensors.
|
||||
// Storages can be mutated when converting tensors from cuda to cpu,
|
||||
// and we need a cpu tensor to copy data from.
|
||||
at::Tensor t = at::getType(tensor)
|
||||
._th_tensor(
|
||||
tensor.storage(),
|
||||
/* storageOffset = */ 0,
|
||||
/* size = */
|
||||
{static_cast<int64_t>(tensor.storage().size())},
|
||||
/* stride = */ {1})
|
||||
.cpu();
|
||||
AT_ASSERT(
|
||||
t.type().elementSizeInBytes() * t.storage().size() == record_size);
|
||||
record_id = writer_.writeRecord(
|
||||
t.storage().data(),
|
||||
t.type().elementSizeInBytes() * t.storage().size());
|
||||
} else {
|
||||
record_id = writer_.writeRecord(tensor.storage().data(), record_size);
|
||||
}
|
||||
external_data->set_record_id(caffe2::to_string(record_id));
|
||||
storageMap_[key] = record_id;
|
||||
} else {
|
||||
external_data->set_record_id(caffe2::to_string(it->second));
|
||||
}
|
||||
// TODO handle device case, set the device_detail and load to CUDA device
|
||||
}
|
||||
void convertMethod(script::Method& method, torch::MethodDef* method_def) {
|
||||
std::string torch_script;
|
||||
// TODO encode the real torch script instead of ModelProto
|
||||
MethodEncoder encoder(
|
||||
method, &torch_script, &storageMap_, ¶meterMap_, &writer_);
|
||||
method_def->set_onnx_proto(torch_script);
|
||||
}
|
||||
std::unordered_map<const void*, uint64_t>
|
||||
storageMap_; // storage_ptr => record_offset
|
||||
std::ofstream ofs_;
|
||||
PyTorchStreamWriter writer_;
|
||||
std::unordered_map<at::Tensor*, std::string> parameterMap_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::string prettyPrint(const onnx::ModelProto& model) {
|
||||
|
|
@ -941,13 +1044,13 @@ std::tuple<std::string, RawDataExportMap> ExportGraph(
|
|||
}
|
||||
|
||||
void ExportModule(const script::Module& module, std::ostream& out) {
|
||||
ModuleEncoder(module, out);
|
||||
ScriptModuleSerializer serializer(&out);
|
||||
serializer.serialize(module);
|
||||
}
|
||||
|
||||
void ExportModule(const script::Module& module, const std::string &filename) {
|
||||
std::ofstream out(filename, std::ios_base::binary);
|
||||
|
||||
ExportModule(module, out);
|
||||
ScriptModuleSerializer serializer(filename);
|
||||
serializer.serialize(module);
|
||||
}
|
||||
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,15 @@
|
|||
#include <google/protobuf/util/json_util.h>
|
||||
#include <google/protobuf/util/type_resolver_util.h>
|
||||
|
||||
#include "torch/csrc/jit/import.h"
|
||||
#include "torch/csrc/jit/ir.h"
|
||||
#include "torch/csrc/utils/functional.h"
|
||||
#include "torch/csrc/jit/assertions.h"
|
||||
#include "torch/csrc/jit/operator.h"
|
||||
|
||||
#include "caffe2/core/types.h"
|
||||
#include "caffe2/proto/caffe2_pb.h"
|
||||
#include "caffe2/proto/torch_pb.h"
|
||||
#include "caffe2/serialize/inline_container.h"
|
||||
#include "onnx/onnx_pb.h"
|
||||
|
||||
|
|
@ -22,10 +28,14 @@ namespace onnx = ::ONNX_NAMESPACE;
|
|||
|
||||
// IR graph construction
|
||||
|
||||
class ModuleDecoder {
|
||||
class MethodDecoder {
|
||||
public:
|
||||
ModuleDecoder(ModuleLookup module_lookup,
|
||||
std::istream& in);
|
||||
MethodDecoder(
|
||||
const onnx::ModelProto& model_proto,
|
||||
const std::unordered_map<std::string, at::Tensor*>& param_map,
|
||||
script::Module* parent_module,
|
||||
std::unordered_map<uint64_t, std::shared_ptr<at::Storage>>* storage_map,
|
||||
PyTorchStreamReader* stream_reader_);
|
||||
|
||||
private:
|
||||
std::shared_ptr<Graph> buildGraph(const onnx::GraphProto& graph_proto);
|
||||
|
|
@ -46,8 +56,6 @@ class ModuleDecoder {
|
|||
|
||||
TypePtr buildType(const onnx::TypeProto& type_proto);
|
||||
|
||||
at::Tensor buildParameter(const onnx::TensorProto& tensor_proto);
|
||||
|
||||
at::Tensor buildTensorCommon(const onnx::TensorProto& tensor_proto,
|
||||
const uint64_t record_number,
|
||||
const int64_t storage_offset,
|
||||
|
|
@ -57,12 +65,13 @@ class ModuleDecoder {
|
|||
ModuleLookup module_lookup,
|
||||
const std::string fullname);
|
||||
|
||||
PyTorchStreamReader stream_reader_;
|
||||
std::unordered_map<uint64_t, std::shared_ptr<at::Storage>> storage_map_;
|
||||
PyTorchStreamReader* stream_reader_;
|
||||
std::unordered_map<uint64_t, std::shared_ptr<at::Storage>>* storage_map_;
|
||||
std::unordered_map<std::string, const onnx::TypeProto*> value_type_map_;
|
||||
};
|
||||
|
||||
at::ScalarType ModuleDecoder::onnxTypeToATenType(onnx::TensorProto_DataType onnx_type) {
|
||||
at::ScalarType MethodDecoder::onnxTypeToATenType(
|
||||
onnx::TensorProto_DataType onnx_type) {
|
||||
switch(onnx_type) {
|
||||
case onnx::TensorProto_DataType_UINT8:
|
||||
return at::kByte;
|
||||
|
|
@ -85,8 +94,9 @@ at::ScalarType ModuleDecoder::onnxTypeToATenType(onnx::TensorProto_DataType onnx
|
|||
}
|
||||
}
|
||||
|
||||
void ModuleDecoder::buildBlocks(
|
||||
const std::vector<onnx::GraphProto>& graphs_, Node* node,
|
||||
void MethodDecoder::buildBlocks(
|
||||
const std::vector<onnx::GraphProto>& graphs_,
|
||||
Node* node,
|
||||
std::unordered_map<std::string, Value*>& value_map) {
|
||||
for (auto g_ : graphs_) {
|
||||
auto block = node->addBlock();
|
||||
|
|
@ -94,7 +104,8 @@ void ModuleDecoder::buildBlocks(
|
|||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Graph> ModuleDecoder::buildGraph(const onnx::GraphProto& graph_proto) {
|
||||
std::shared_ptr<Graph> MethodDecoder::buildGraph(
|
||||
const onnx::GraphProto& graph_proto) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> value_map;
|
||||
|
||||
|
|
@ -103,8 +114,10 @@ std::shared_ptr<Graph> ModuleDecoder::buildGraph(const onnx::GraphProto& graph_p
|
|||
return graph;
|
||||
}
|
||||
|
||||
void ModuleDecoder::buildBlock(const onnx::GraphProto& graph_proto, Block* block,
|
||||
std::unordered_map<std::string, Value*>& value_map) {
|
||||
void MethodDecoder::buildBlock(
|
||||
const onnx::GraphProto& graph_proto,
|
||||
Block* block,
|
||||
std::unordered_map<std::string, Value*>& value_map) {
|
||||
for (auto &subtype : graph_proto.value_info()) {
|
||||
value_type_map_[subtype.name()] = &subtype.type();
|
||||
}
|
||||
|
|
@ -190,7 +203,7 @@ void ModuleDecoder::buildBlock(const onnx::GraphProto& graph_proto, Block* block
|
|||
}
|
||||
}
|
||||
|
||||
TypePtr ModuleDecoder::buildType(const onnx::TypeProto& type_proto) {
|
||||
TypePtr MethodDecoder::buildType(const onnx::TypeProto& type_proto) {
|
||||
auto tensortype_proto = type_proto.tensor_type();
|
||||
auto shape_proto = tensortype_proto.shape();
|
||||
auto kind = type_proto.denotation();
|
||||
|
|
@ -248,29 +261,21 @@ TypePtr ModuleDecoder::buildType(const onnx::TypeProto& type_proto) {
|
|||
}
|
||||
}
|
||||
|
||||
void ModuleDecoder::buildValue(Value* value, const onnx::ValueInfoProto& valueinfo_proto) {
|
||||
void MethodDecoder::buildValue(
|
||||
Value* value,
|
||||
const onnx::ValueInfoProto& valueinfo_proto) {
|
||||
value->setType(buildType(valueinfo_proto.type()));
|
||||
}
|
||||
|
||||
void ModuleDecoder::buildIntermediateValue(Value* value, const std::string& name) {
|
||||
void MethodDecoder::buildIntermediateValue(
|
||||
Value* value,
|
||||
const std::string& name) {
|
||||
auto it = value_type_map_.find(name);
|
||||
JIT_ASSERT(it != value_type_map_.end());
|
||||
value->setType(buildType(*it->second));
|
||||
}
|
||||
|
||||
at::Tensor ModuleDecoder::buildParameter(const onnx::TensorProto& tensor_proto) {
|
||||
std::vector<int64_t> strides;
|
||||
// We've stored four other values (is_buffer, requires_grad, record no., storage_offset) before strides; ignore them
|
||||
std::move(tensor_proto.int64_data().begin() + 4, tensor_proto.int64_data().end(), std::back_inserter(strides));
|
||||
auto tensor = buildTensorCommon(tensor_proto,
|
||||
/* record_number = */ tensor_proto.int64_data(2),
|
||||
/* storage_offset = */ tensor_proto.int64_data(3),
|
||||
strides);
|
||||
autograd::Variable var = autograd::make_variable(tensor, /* requires_grad = */ tensor_proto.int64_data(1));
|
||||
return var;
|
||||
}
|
||||
|
||||
at::Tensor ModuleDecoder::buildTensor(const onnx::TensorProto& tensor_proto) {
|
||||
at::Tensor MethodDecoder::buildTensor(const onnx::TensorProto& tensor_proto) {
|
||||
std::vector<int64_t> strides;
|
||||
// We've stored two other values (record no., storage_offset) before strides; ignore it
|
||||
std::move(tensor_proto.int64_data().begin() + 2, tensor_proto.int64_data().end(), std::back_inserter(strides));
|
||||
|
|
@ -280,7 +285,7 @@ at::Tensor ModuleDecoder::buildTensor(const onnx::TensorProto& tensor_proto) {
|
|||
strides);
|
||||
}
|
||||
|
||||
at::Tensor ModuleDecoder::buildTensorCommon(
|
||||
at::Tensor MethodDecoder::buildTensorCommon(
|
||||
const onnx::TensorProto& tensor_proto,
|
||||
const uint64_t record_number,
|
||||
const int64_t storage_offset,
|
||||
|
|
@ -292,17 +297,18 @@ at::Tensor ModuleDecoder::buildTensorCommon(
|
|||
std::move(tensor_proto.dims().begin(), tensor_proto.dims().end(), std::back_inserter(dims));
|
||||
|
||||
// Find or create the storage
|
||||
auto storage_it = storage_map_.find(record_number);
|
||||
if (storage_it == storage_map_.end()) {
|
||||
auto storage_it = storage_map_->find(record_number);
|
||||
if (storage_it == storage_map_->end()) {
|
||||
at::DataPtr storage_ptr;
|
||||
int64_t size;
|
||||
std::tie(storage_ptr, size) = stream_reader_.getRecordWithKey(record_number);
|
||||
std::tie(storage_ptr, size) =
|
||||
stream_reader_->getRecordWithKey(record_number);
|
||||
auto storage = std::make_shared<at::Storage>(
|
||||
at::CPU(type).typeMeta(),
|
||||
std::move(storage_ptr),
|
||||
size / at::CPU(type).typeMeta().itemsize(),
|
||||
nullptr);
|
||||
storage_map_.insert(std::make_pair(record_number, storage));
|
||||
storage_map_->insert(std::make_pair(record_number, storage));
|
||||
return at::CPU(type)._th_tensor(*storage, storage_offset, dims, strides);
|
||||
}
|
||||
|
||||
|
|
@ -312,9 +318,8 @@ at::Tensor ModuleDecoder::buildTensorCommon(
|
|||
|
||||
// Given a full name of a parameter or method,
|
||||
// return the parent submodule and local name
|
||||
std::pair<std::shared_ptr<script::Module>, std::string> ModuleDecoder::parseFullName(
|
||||
ModuleLookup module_lookup,
|
||||
const std::string fullname) {
|
||||
std::pair<std::shared_ptr<script::Module>, std::string> MethodDecoder::
|
||||
parseFullName(ModuleLookup module_lookup, const std::string fullname) {
|
||||
AT_ASSERT(!fullname.empty());
|
||||
std::vector<std::string> vec;
|
||||
std::stringstream ss(fullname);
|
||||
|
|
@ -328,37 +333,23 @@ std::pair<std::shared_ptr<script::Module>, std::string> ModuleDecoder::parseFull
|
|||
return std::make_pair(module_lookup(vec), std::move(last));
|
||||
}
|
||||
|
||||
ModuleDecoder::ModuleDecoder(
|
||||
ModuleLookup module_lookup,
|
||||
std::istream& in) :
|
||||
stream_reader_(&in) {
|
||||
auto model_proto = onnx::ModelProto();
|
||||
auto record = stream_reader_.getLastRecord();
|
||||
model_proto.ParsePartialFromArray(std::get<0>(record).get(), std::get<1>(record));
|
||||
auto graph_proto = model_proto.graph();
|
||||
|
||||
std::unordered_map<std::string, at::Tensor*> param_map;
|
||||
|
||||
for (auto &tensor_proto : graph_proto.initializer()) {
|
||||
std::shared_ptr<script::Module> parent_module;
|
||||
std::string name;
|
||||
std::tie(parent_module, name) = parseFullName(module_lookup, tensor_proto.name());
|
||||
|
||||
auto param = buildParameter(tensor_proto);
|
||||
parent_module->register_parameter(name, param, /* is_buffer = */ tensor_proto.int64_data(0));
|
||||
param_map[tensor_proto.name()] = parent_module->parameter_slot(name);
|
||||
}
|
||||
|
||||
for (auto &node_proto : graph_proto.node()) {
|
||||
std::shared_ptr<script::Module> parent_module;
|
||||
std::string name;
|
||||
std::tie(parent_module, name) = parseFullName(module_lookup, node_proto.name());
|
||||
|
||||
MethodDecoder::MethodDecoder(
|
||||
const onnx::ModelProto& model_proto,
|
||||
const std::unordered_map<std::string, at::Tensor*>& param_map,
|
||||
script::Module* parent_module,
|
||||
std::unordered_map<uint64_t, std::shared_ptr<at::Storage>>* storage_map,
|
||||
PyTorchStreamReader* stream_reader) {
|
||||
storage_map_ = storage_map;
|
||||
stream_reader_ = stream_reader;
|
||||
const auto& graph_proto = model_proto.graph();
|
||||
for (const auto& node_proto : graph_proto.node()) {
|
||||
std::vector<at::Tensor*> member_inputs;
|
||||
for (auto ¶m_name : node_proto.input()) {
|
||||
member_inputs.push_back(param_map[param_name]);
|
||||
const std::string& name = node_proto.name();
|
||||
for (const auto& param_name : node_proto.input()) {
|
||||
auto it = param_map.find(param_name);
|
||||
AT_ASSERTM(it != param_map.end(), "cannot find parameter ", param_name);
|
||||
member_inputs.push_back(it->second);
|
||||
}
|
||||
|
||||
auto graph = buildGraph(node_proto.attribute(0).g());
|
||||
// has_domain field has a string iff the method was optimized
|
||||
parent_module->set_optimized(node_proto.has_domain());
|
||||
|
|
@ -370,22 +361,159 @@ ModuleDecoder::ModuleDecoder(
|
|||
}
|
||||
}
|
||||
|
||||
// this is a deserializer class which loads script modules from pt files. the
|
||||
// content of the file is written using PyTorchStreamWriter, for details please
|
||||
// check caffe2/serialize/inline_container.h. all the records except the last
|
||||
// one are tensor data, and the last record is a serialized ModelProto, defined
|
||||
// in caffe2/proto/torch.proto. ModelProto contains all the metadata of the
|
||||
// model, and it is serialized as json.
|
||||
class ScriptModuleDeserializer final {
|
||||
public:
|
||||
ScriptModuleDeserializer(const std::string& filename)
|
||||
: ifs_(filename, std::ifstream::in | std::ifstream::binary),
|
||||
reader_(&ifs_) {
|
||||
// TODO appropriate support for mmap, right now still use stream reader
|
||||
}
|
||||
ScriptModuleDeserializer(std::istream* is) : ifs_(), reader_(is) {}
|
||||
void deserialize(ModuleLookup module_lookup) {
|
||||
torch::ModelDef model_def;
|
||||
at::DataPtr data_ptr;
|
||||
size_t data_size;
|
||||
std::tie(data_ptr, data_size) = reader_.getLastRecord();
|
||||
// NB: cannot use JsonStringToMessage, since fbcode's protobuf is too old
|
||||
// be consistent with JsonStringToMessage
|
||||
std::string url_prefix = "type.googleapis.com";
|
||||
std::unique_ptr<::google::protobuf::util::TypeResolver> resolver(
|
||||
::google::protobuf::util::NewTypeResolverForDescriptorPool(
|
||||
url_prefix, model_def.GetDescriptor()->file()->pool()));
|
||||
std::string json_string = std::string(
|
||||
static_cast<char*>(data_ptr.get()),
|
||||
static_cast<char*>(data_ptr.get()) + data_size);
|
||||
std::string binary_string;
|
||||
auto convert_result = ::google::protobuf::util::JsonToBinaryString(
|
||||
resolver.get(),
|
||||
url_prefix + "/" + model_def.GetDescriptor()->full_name(),
|
||||
json_string,
|
||||
&binary_string);
|
||||
if (!convert_result.ok()) {
|
||||
std::stringstream ss;
|
||||
ss << convert_result;
|
||||
AT_ERROR(ss.str());
|
||||
}
|
||||
AT_ASSERTM(
|
||||
model_def.ParseFromString(binary_string),
|
||||
"JSON transcoder produced invalid protobuf output.");
|
||||
moduleLookup_ = module_lookup;
|
||||
const auto& module_def = model_def.main_module();
|
||||
collectParamsInfo(module_def, module_def.name());
|
||||
// TODO: this can be simplified when C++/Python interop lands,
|
||||
// and the submodules would be created as the same in either C++ or Python
|
||||
std::shared_ptr<script::Module> module = moduleLookup_(moduleStack_);
|
||||
convertModule(module_def, module.get());
|
||||
}
|
||||
|
||||
private:
|
||||
void collectParamsInfo(
|
||||
const torch::ModuleDef& module_def,
|
||||
const std::string& prefix) {
|
||||
std::shared_ptr<script::Module> module = moduleLookup_(moduleStack_);
|
||||
for (int i = 0; i < module_def.parameters_size(); ++i) {
|
||||
const torch::ParameterDef& param_def = module_def.parameters(i);
|
||||
at::Tensor tensor = createTensor(param_def.tensor());
|
||||
autograd::Variable variable =
|
||||
autograd::make_variable(tensor, param_def.require_gradient());
|
||||
module->register_parameter(
|
||||
param_def.name(), variable, param_def.is_buffer());
|
||||
parameterMap_[prefix + param_def.name()] =
|
||||
module->parameter_slot(param_def.name());
|
||||
}
|
||||
for (int i = 0; i < module_def.submodules_size(); ++i) {
|
||||
const torch::ModuleDef& sub_def = module_def.submodules(i);
|
||||
moduleStack_.push_back(sub_def.name());
|
||||
collectParamsInfo(sub_def, prefix + sub_def.name() + ".");
|
||||
moduleStack_.pop_back();
|
||||
}
|
||||
}
|
||||
void convertModule(
|
||||
const torch::ModuleDef& module_def,
|
||||
script::Module* module) {
|
||||
for (int i = 0; i < module_def.methods_size(); ++i) {
|
||||
const torch::MethodDef& method_def = module_def.methods(i);
|
||||
// TODO read unhacked torch script, right now it's serialized onnx proto
|
||||
::ONNX_NAMESPACE::ModelProto method_proto;
|
||||
AT_ASSERTM(
|
||||
method_proto.ParseFromString(method_def.onnx_proto()),
|
||||
"cannot parse method proto (i.e., hacked onnx proto)");
|
||||
MethodDecoder decoder(
|
||||
method_proto, parameterMap_, module, &storageMap_, &reader_);
|
||||
(void)decoder;
|
||||
}
|
||||
for (int i = 0; i < module_def.submodules_size(); ++i) {
|
||||
const torch::ModuleDef& sub_def = module_def.submodules(i);
|
||||
moduleStack_.push_back(sub_def.name());
|
||||
std::shared_ptr<script::Module> sub = moduleLookup_(moduleStack_);
|
||||
convertModule(sub_def, sub.get());
|
||||
moduleStack_.pop_back();
|
||||
}
|
||||
}
|
||||
at::Tensor createTensor(const caffe2::TensorProto& tensor_proto) {
|
||||
std::vector<int64_t> dims;
|
||||
for (int i = 0; i < tensor_proto.dims_size(); ++i) {
|
||||
dims.push_back(tensor_proto.dims(i));
|
||||
}
|
||||
AT_ASSERT(
|
||||
tensor_proto.storage_type() ==
|
||||
caffe2::TensorProto_StorageType_EXTERNAL);
|
||||
const caffe2::ExternalDataProto& external_data =
|
||||
tensor_proto.external_data();
|
||||
std::vector<int64_t> strides;
|
||||
for (int i = 0; i < external_data.strides_size(); ++i) {
|
||||
strides.push_back(external_data.strides(i));
|
||||
}
|
||||
auto type = at::typeMetaToScalarType(
|
||||
caffe2::DataTypeToTypeMeta(tensor_proto.data_type()));
|
||||
uint64_t record_id = caffe2::stoull(external_data.record_id());
|
||||
AT_ASSERT(record_id != 0);
|
||||
auto storage_it = storageMap_.find(record_id);
|
||||
if (storage_it == storageMap_.end()) {
|
||||
at::DataPtr storage_ptr;
|
||||
uint64_t record_size;
|
||||
std::tie(storage_ptr, record_size) = reader_.getRecordWithKey(record_id);
|
||||
AT_ASSERT(record_size == external_data.record_size());
|
||||
auto storage = std::make_shared<at::Storage>(
|
||||
at::CPU(type).typeMeta(),
|
||||
std::move(storage_ptr),
|
||||
record_size / at::CPU(type).typeMeta().itemsize(),
|
||||
nullptr); // NB: we didn't set any allocator for the tensor
|
||||
storageMap_.insert(std::make_pair(record_id, storage));
|
||||
return at::CPU(type)._th_tensor(
|
||||
*storage, external_data.offset(), dims, strides);
|
||||
}
|
||||
return at::CPU(type)._th_tensor(
|
||||
*(storage_it->second.get()), external_data.offset(), dims, strides);
|
||||
}
|
||||
std::ifstream ifs_;
|
||||
PyTorchStreamReader reader_;
|
||||
ModuleLookup moduleLookup_;
|
||||
std::vector<std::string> moduleStack_;
|
||||
std::unordered_map<uint64_t, std::shared_ptr<at::Storage>> storageMap_;
|
||||
std::unordered_map<std::string, at::Tensor*> parameterMap_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void import_ir_module(
|
||||
ModuleLookup module_lookup,
|
||||
std::istream& in) {
|
||||
ModuleDecoder decoder(module_lookup, in);
|
||||
(void)decoder;
|
||||
ScriptModuleDeserializer deserializer(&in);
|
||||
deserializer.deserialize(module_lookup);
|
||||
}
|
||||
|
||||
void import_ir_module(
|
||||
ModuleLookup module_lookup,
|
||||
const std::string& filename) {
|
||||
std::ifstream in(filename, std::ios_base::binary);
|
||||
|
||||
ModuleDecoder decoder(module_lookup, in);
|
||||
(void)decoder;
|
||||
ScriptModuleDeserializer deserializer(filename);
|
||||
deserializer.deserialize(module_lookup);
|
||||
}
|
||||
|
||||
std::shared_ptr<script::Module> load(std::istream& in) {
|
||||
|
|
@ -402,8 +530,8 @@ std::shared_ptr<script::Module> load(std::istream& in) {
|
|||
return curr;
|
||||
};
|
||||
|
||||
ModuleDecoder decoder(module_lookup, in);
|
||||
(void)decoder;
|
||||
ScriptModuleDeserializer deserializer(&in);
|
||||
deserializer.deserialize(module_lookup);
|
||||
|
||||
return module;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ struct Method {
|
|||
return retval;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor*> params() {
|
||||
std::vector<at::Tensor*> params() const {
|
||||
return member_inputs;
|
||||
}
|
||||
|
||||
|
|
@ -182,7 +182,7 @@ struct Method {
|
|||
return get_executor().debugDisableAutodiffSubgraphInlining();
|
||||
}
|
||||
|
||||
bool is_optimized() {
|
||||
bool is_optimized() const {
|
||||
return optimize;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user