Use the torch.proto to store script module (#13736)

Summary:
Directly operate protobuf in the serializer/deserializer.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13736

Reviewed By: dzhulgakov

Differential Revision: D13028487

Pulled By: houseroad

fbshipit-source-id: e578474008874f00f2a22f0a2ffd85f52643881a
This commit is contained in:
Lu Fang 2018-11-14 00:19:08 -08:00 committed by Facebook Github Bot
parent 2871d3951f
commit e2a7d43dfd
7 changed files with 448 additions and 190 deletions

View File

@ -146,6 +146,14 @@ inline int stoi(const string& str) {
return n;
}
inline uint64_t stoull(const string& str) {
std::stringstream ss;
uint64_t n = 0;
ss << str;
ss >> n;
return n;
}
inline double stod(const string& str, std::size_t* pos = 0) {
std::stringstream ss;
ss << str;
@ -164,6 +172,7 @@ inline double stod(const string& str, std::size_t* pos = 0) {
#define CAFFE2_TESTONLY_WE_ARE_USING_CUSTOM_STRING_FUNCTIONS 0
using std::to_string;
using std::stoi;
using std::stoull;
using std::stod;
#endif // defined(__ANDROID__) || defined(CAFFE2_FORCE_STD_STRING_FALLBACK_TEST)

View File

@ -29,9 +29,11 @@ message ExternalDataProto {
optional SourceType source_type = 1 [default = INLINE_CONTAINER];
// used together with type
optional string record_id = 2;
// the size of the entire record (in bytes)
optional uint64 record_size = 5;
// the offset of the starting point, the content may be shared between
// multiple tensors
optional int64 offset = 3;
optional int64 offset = 3 [default = 0];
// the strides of the content
repeated int64 strides = 4;
}

View File

@ -13,6 +13,8 @@ message ParameterDef {
// tensor type parameter
optional caffe2.TensorProto tensor = 3;
// objects other than tensors will be added here
optional string name = 4;
}
message MethodDef {
@ -26,6 +28,8 @@ message MethodDef {
// if both exist, we reconstruct the graph from torch_script
optional caffe2.NetDef graph = 2;
optional string torch_script = 3;
// temporary place to store the methods of jit script modules
optional bytes onnx_proto = 101;
// inputs and outputs are inferred from graph or script
}
@ -50,6 +54,8 @@ message ModuleDef {
// the names of inputs and outputs of the module are inferred
// from the main method.
optional string name = 6;
}
enum ProtoVersion {

View File

@ -88,10 +88,12 @@ constexpr uint64_t kFieldAlignment =
64L; // 64 byte alignment supports up to AVX512 for mmap
// Reader-specific constants
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x1L;
// FileFOrmatVersion 1 was used in PyTorch 1.0 rc, which is a hacked ONNX proto.
constexpr uint64_t kMinSupportedFileFormatVersion = 0x2L;
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x2L;
// Writer-specific constants
constexpr uint64_t kFileFormatVersion = 0x1L;
constexpr uint64_t kFileFormatVersion = 0x2L;
constexpr char kPadValue = -17; // 0xEF
} // namespace
@ -109,7 +111,8 @@ class PyTorchStreamReader final {
AT_ASSERTM(
file_size_ % kFieldAlignment == 0,
"File length is not a multiple of the alignment"
" size. Is this a valid PyTorch model file?");
" size. Is this a valid PyTorch model file? File size: ",
caffe2::to_string(file_size_));
readAndValidateFileHeader();
}
@ -203,6 +206,13 @@ class PyTorchStreamReader final {
" be corrupted or is not actually a PyTorch file.");
// magic number mismatch in PyTorch file.
uint64_t file_format_version = read64BitIntegerLittleEndian();
AT_ASSERTM(
file_format_version >= kMinSupportedFileFormatVersion,
"Attempted to read a PyTorch file with version ",
caffe2::to_string(file_format_version),
", but the minimum supported version for reading is ",
caffe2::to_string(kMinSupportedFileFormatVersion),
". Your PyTorch script module file is too old. Please re-export it again.");
AT_ASSERTM(
file_format_version <= kMaxSupportedFileFormatVersion,
"Attempted to read a PyTorch file with version ",

View File

@ -1,3 +1,6 @@
#include <google/protobuf/util/json_util.h>
#include <google/protobuf/util/type_resolver_util.h>
#include "torch/csrc/jit/export.h"
#include "torch/csrc/autograd/symbolic.h"
#include "torch/csrc/onnx/onnx.h"
@ -6,17 +9,21 @@
#include <torch/csrc/jit/assertions.h>
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "caffe2/core/types.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/proto/torch_pb.h"
#include "caffe2/serialize/inline_container.h"
#include "onnx/onnx_pb.h"
#include <ATen/ATen.h>
#include "c10/util/Optional.h"
#include <memory>
#include <vector>
#include <string>
#include <sstream>
#include <fstream>
#include <memory>
#include <sstream>
#include <stack>
#include <string>
#include <vector>
namespace torch { namespace jit {
@ -426,29 +433,20 @@ void GraphEncoder::EncodeTensor(
}
}
class ModuleEncoder: public EncoderBase {
class MethodEncoder : public EncoderBase {
public:
ModuleEncoder(const script::Module &module,
std::ostream& out);
MethodEncoder(
const script::Method& method,
std::string* torch_script,
std::unordered_map<const void*, uint64_t>* storage_map,
std::unordered_map<at::Tensor*, std::string>* parameter_map,
PyTorchStreamWriter* writer);
private:
void EncodeModule(onnx::GraphProto *graph_proto, const script::Module &module);
void EncodeParameters(onnx::GraphProto *graph_proto,
const script::Module &module,
const std::string prefix);
void EncodeParameter(onnx::TensorProto *tensor_proto,
const script::NamedParameter &parameter,
const std::string prefix);
void EncodeMethods(onnx::GraphProto *graph_proto,
const script::Module &module,
const std::string prefix);
void EncodeMethod(onnx::NodeProto *node_proto,
script::Method &method,
const std::string prefix);
void EncodeMethod(
std::string* torch_script,
const script::Method& method,
const std::string prefix);
void EncodeTensor(
onnx::TensorProto* tensor_proto,
@ -467,27 +465,35 @@ class ModuleEncoder: public EncoderBase {
const TypePtr& type,
const std::string& name);
PyTorchStreamWriter stream_writer_;
PyTorchStreamWriter* stream_writer_;
// Used to deduplicate tensor storages
std::unordered_map<const void*, uint64_t> storage_dedup_map_;
std::unordered_map<const void*, uint64_t>* storage_dedup_map_;
// Used to keep track of Parameter names so Methods can refer to them
std::unordered_map<at::Tensor*, std::string> parameter_map_;
std::unordered_map<at::Tensor*, std::string>* parameter_map_;
// Used to create sequential dummy names for node types
size_t type_counter_ = 0;
};
ModuleEncoder::ModuleEncoder(
const script::Module &module,
std::ostream& out)
: EncoderBase(onnx_torch::OperatorExportTypes::RAW, false),
stream_writer_(&out) {
model_proto_.set_doc_string("THIS PROTO IS NOT STANDARD ONNX");
EncodeModule(model_proto_.mutable_graph(), module);
MethodEncoder::MethodEncoder(
const script::Method& method,
std::string* torch_script,
std::unordered_map<const void*, uint64_t>* storage_map,
std::unordered_map<at::Tensor*, std::string>* parameter_map,
PyTorchStreamWriter* writer)
: EncoderBase(onnx_torch::OperatorExportTypes::RAW, false) {
storage_dedup_map_ = storage_map;
parameter_map_ = parameter_map;
stream_writer_ = writer;
// we already keep the tree structure in the top level module,
// so pass "" as prefix
EncodeMethod(torch_script, method, "");
}
void ModuleEncoder::EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto, const Value *n) {
void MethodEncoder::EncodeIntermediateValueInfo(
onnx::GraphProto* graph_proto,
const Value* n) {
auto v = graph_proto->add_value_info();
EncodeTypeInfo(graph_proto, v, n->type(), n->uniqueName());
}
@ -511,8 +517,8 @@ c10::optional<std::string> getBaseTypeDenotation(TypeKind& kind) {
return c10::nullopt;
}
void ModuleEncoder::EncodeTypeInfo(
onnx::GraphProto *graph_proto,
void MethodEncoder::EncodeTypeInfo(
onnx::GraphProto* graph_proto,
onnx::ValueInfoProto* v,
const TypePtr& type,
const std::string& name) {
@ -594,74 +600,20 @@ void ModuleEncoder::EncodeTypeInfo(
}
}
void ModuleEncoder::EncodeValueInfo(
onnx::GraphProto *graph_proto,
void MethodEncoder::EncodeValueInfo(
onnx::GraphProto* graph_proto,
onnx::ValueInfoProto* v,
const Value* n) {
EncodeTypeInfo(graph_proto, v, n->type(), n->uniqueName());
}
void ModuleEncoder::EncodeModule(
onnx::GraphProto *graph_proto,
const script::Module &module) {
EncodeParameters(graph_proto, module, "");
EncodeMethods(graph_proto, module, "");
auto str = model_proto_.SerializeAsString();
stream_writer_.writeRecord(str.data(), str.size());
}
void ModuleEncoder::EncodeParameters(
onnx::GraphProto *graph_proto,
const script::Module &module,
const std::string prefix) {
// Encode each parameter as a initializer in the proto
for (auto &parameter : module.get_parameters()) {
auto tensor_proto = graph_proto->add_initializer();
EncodeParameter(tensor_proto, parameter.value(), prefix);
}
for (auto &submodule : module.get_modules()) {
EncodeParameters(graph_proto, *submodule->module, prefix + submodule.key() + ".");
}
}
void ModuleEncoder::EncodeParameter(
onnx::TensorProto *tensor_proto,
const script::NamedParameter &parameter,
const std::string prefix) {
auto tensor = parameter.slot();
// Name will be prefixed by submodule. e.g. submodule_foo.parameter_bar
auto name = prefix + parameter.name;
tensor_proto->set_name(name);
parameter_map_[tensor] = name;
// Parameters have these fields, but tensors do not
tensor_proto->add_int64_data(parameter.is_buffer);
tensor_proto->add_int64_data(tensor->requires_grad());
EncodeTensor(tensor_proto, *tensor, name);
}
void ModuleEncoder::EncodeMethods(
onnx::GraphProto *graph_proto,
const script::Module &module,
const std::string prefix) {
// Encode each parameter as a initializer in the proto
for (auto &method : module.get_methods()) {
auto node_proto = graph_proto->add_node();
EncodeMethod(node_proto, *method.value(), prefix);
}
for (auto &submodule : module.get_modules()) {
EncodeMethods(graph_proto, *submodule->module, prefix + submodule.key() + ".");
}
}
void ModuleEncoder::EncodeMethod(
onnx::NodeProto *node_proto,
script::Method &method,
void MethodEncoder::EncodeMethod(
std::string* torch_script,
const script::Method& method,
const std::string prefix) {
onnx::ModelProto model_proto;
model_proto.set_doc_string("THIS PROTO IS NOT STANDARD ONNX");
auto* node_proto = model_proto.mutable_graph()->add_node();
node_proto->set_name(prefix + method.name());
if (method.is_optimized()) {
// mark that this method was optimized
@ -673,8 +625,8 @@ void ModuleEncoder::EncodeMethod(
// Store member_inputs of Method in input
for (auto &member_input : method.params()) {
auto it = parameter_map_.find(member_input);
JIT_ASSERT(it != parameter_map_.end());
auto it = parameter_map_->find(member_input);
JIT_ASSERT(it != parameter_map_->end());
node_proto->add_input(it->second);
}
@ -690,15 +642,16 @@ void ModuleEncoder::EncodeMethod(
}
}
EncodeBlock(attr_proto->mutable_g(), method.graph()->block(), {});
AT_ASSERT(model_proto.SerializeToString(torch_script));
}
void ModuleEncoder::EncodeTensor(
void MethodEncoder::EncodeTensor(
onnx::TensorProto* tensor_proto,
const at::Tensor& tensor,
const c10::optional<std::string> external_ref) {
auto storage_ptr = tensor.storage().unsafeGetStorageImpl();
auto dedup_it = storage_dedup_map_.find(storage_ptr);
if (dedup_it != storage_dedup_map_.end()) {
auto dedup_it = storage_dedup_map_->find(storage_ptr);
if (dedup_it != storage_dedup_map_->end()) {
tensor_proto->add_int64_data(dedup_it->second);
} else {
at::Tensor t = tensor;
@ -714,10 +667,10 @@ void ModuleEncoder::EncodeTensor(
.cpu();
}
auto record_number = stream_writer_.writeRecord(
auto record_number = stream_writer_->writeRecord(
static_cast<char*>(t.storage().data()), t.type().elementSizeInBytes() * t.storage().size());
tensor_proto->add_int64_data(record_number);
storage_dedup_map_[storage_ptr] = record_number;
(*storage_dedup_map_)[storage_ptr] = record_number;
}
for (auto &d : tensor.sizes()) {
@ -899,6 +852,156 @@ void dump(const onnx::ModelProto& model, std::ostream& stream, size_t indent) {
}
stream << idt(indent) << "}\n";
}
class ScriptModuleSerializer final {
public:
ScriptModuleSerializer(const std::string& filename)
: ofs_(
filename,
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary),
writer_(&ofs_) {
// TODO appropriate support for mmap, right now we still use stream writer
}
ScriptModuleSerializer(std::ostream* ofs) : ofs_(), writer_(ofs) {}
void serialize(const script::Module& module) {
torch::ModelDef model_def;
convertToModel(module, &model_def);
std::string output;
// NB: cannot use MessageToJsonString, since fbcode's protobuf is too old
// be consistent with MessageToJsonString
std::string url_prefix = "type.googleapis.com";
std::unique_ptr<::google::protobuf::util::TypeResolver> resolver(
::google::protobuf::util::NewTypeResolverForDescriptorPool(
url_prefix, model_def.GetDescriptor()->file()->pool()));
::google::protobuf::util::Status convert_result =
::google::protobuf::util::BinaryToJsonString(
resolver.get(),
url_prefix + "/" + model_def.GetDescriptor()->full_name(),
model_def.SerializeAsString(),
&output);
if (!convert_result.ok()) {
std::stringstream ss;
ss << convert_result;
AT_ERROR(ss.str());
}
auto record_id = writer_.writeRecord(output.data(), output.size());
writer_.writeEndOfFile();
}
private:
void convertToModel(
const script::Module& module,
torch::ModelDef* model_def) {
model_def->set_name("script-model");
model_def->set_producer_name("pytorch");
model_def->set_producer_version("1.0"); // TODO: set the producer version
// using appropriate function call
model_def->set_proto_version(torch::ProtoVersion::PROTO_VERSION_NEWEST);
std::string main_module_name = "";
collectParamsInfo(module, main_module_name);
convertModule(module, main_module_name, model_def->mutable_main_module());
}
void collectParamsInfo(
const script::Module& module,
const std::string& prefix) {
for (const auto& elem : module.get_parameters()) {
const script::NamedParameter& param = elem.value();
parameterMap_[param.slot()] = prefix + param.name;
}
for (const auto& elem : module.get_modules()) {
collectParamsInfo(*elem->module, prefix + elem.key() + ".");
}
}
void convertModule(
const script::Module& module,
const std::string& name,
torch::ModuleDef* module_def) {
module_def->set_name(name);
for (const auto& elem : module.get_parameters()) {
torch::ParameterDef* param_def = module_def->add_parameters();
convertParameter(elem.value(), param_def);
}
for (auto& elem : module.get_methods()) {
torch::MethodDef* method_def = module_def->add_methods();
convertMethod(*elem.value(), method_def);
}
for (const auto& elem : module.get_modules()) {
torch::ModuleDef* sub_def = module_def->add_submodules();
convertModule(*elem->module, elem.key(), sub_def);
}
}
void convertParameter(
const script::NamedParameter& param,
torch::ParameterDef* param_def) {
param_def->set_name(param.name);
param_def->set_is_buffer(param.is_buffer);
param_def->set_require_gradient(param.slot()->requires_grad());
convertTensor(*(param.slot()), param_def->mutable_tensor());
}
void convertTensor(
const at::Tensor& tensor,
caffe2::TensorProto* tensor_proto) {
for (auto d : tensor.sizes()) {
tensor_proto->add_dims(d);
}
tensor_proto->set_data_type(caffe2::TypeMetaToDataType(
at::scalarTypeToTypeMeta(tensor.type().scalarType())));
tensor_proto->set_storage_type(caffe2::TensorProto_StorageType_EXTERNAL);
caffe2::ExternalDataProto* external_data =
tensor_proto->mutable_external_data();
for (auto s : tensor.strides()) {
external_data->add_strides(s);
}
external_data->set_offset(tensor.storage_offset());
uint64_t record_size =
tensor.type().elementSizeInBytes() * tensor.storage().size();
external_data->set_record_size(record_size);
auto* key = tensor.storage().unsafeGetStorageImpl();
auto it = storageMap_.find(key);
if (it == storageMap_.end()) {
// TODO HIP support
uint64_t record_id;
if (tensor.storage().device_type() == at::DeviceType::CUDA) {
// NB: This new tensor is created to support cuda tensors.
// Storages can be mutated when converting tensors from cuda to cpu,
// and we need a cpu tensor to copy data from.
at::Tensor t = at::getType(tensor)
._th_tensor(
tensor.storage(),
/* storageOffset = */ 0,
/* size = */
{static_cast<int64_t>(tensor.storage().size())},
/* stride = */ {1})
.cpu();
AT_ASSERT(
t.type().elementSizeInBytes() * t.storage().size() == record_size);
record_id = writer_.writeRecord(
t.storage().data(),
t.type().elementSizeInBytes() * t.storage().size());
} else {
record_id = writer_.writeRecord(tensor.storage().data(), record_size);
}
external_data->set_record_id(caffe2::to_string(record_id));
storageMap_[key] = record_id;
} else {
external_data->set_record_id(caffe2::to_string(it->second));
}
// TODO handle device case, set the device_detail and load to CUDA device
}
void convertMethod(script::Method& method, torch::MethodDef* method_def) {
std::string torch_script;
// TODO encode the real torch script instead of ModelProto
MethodEncoder encoder(
method, &torch_script, &storageMap_, &parameterMap_, &writer_);
method_def->set_onnx_proto(torch_script);
}
std::unordered_map<const void*, uint64_t>
storageMap_; // storage_ptr => record_offset
std::ofstream ofs_;
PyTorchStreamWriter writer_;
std::unordered_map<at::Tensor*, std::string> parameterMap_;
};
} // namespace
std::string prettyPrint(const onnx::ModelProto& model) {
@ -941,13 +1044,13 @@ std::tuple<std::string, RawDataExportMap> ExportGraph(
}
void ExportModule(const script::Module& module, std::ostream& out) {
ModuleEncoder(module, out);
ScriptModuleSerializer serializer(&out);
serializer.serialize(module);
}
void ExportModule(const script::Module& module, const std::string &filename) {
std::ofstream out(filename, std::ios_base::binary);
ExportModule(module, out);
ScriptModuleSerializer serializer(filename);
serializer.serialize(module);
}
}}

View File

@ -1,9 +1,15 @@
#include <google/protobuf/util/json_util.h>
#include <google/protobuf/util/type_resolver_util.h>
#include "torch/csrc/jit/import.h"
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/utils/functional.h"
#include "torch/csrc/jit/assertions.h"
#include "torch/csrc/jit/operator.h"
#include "caffe2/core/types.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/proto/torch_pb.h"
#include "caffe2/serialize/inline_container.h"
#include "onnx/onnx_pb.h"
@ -22,10 +28,14 @@ namespace onnx = ::ONNX_NAMESPACE;
// IR graph construction
class ModuleDecoder {
class MethodDecoder {
public:
ModuleDecoder(ModuleLookup module_lookup,
std::istream& in);
MethodDecoder(
const onnx::ModelProto& model_proto,
const std::unordered_map<std::string, at::Tensor*>& param_map,
script::Module* parent_module,
std::unordered_map<uint64_t, std::shared_ptr<at::Storage>>* storage_map,
PyTorchStreamReader* stream_reader_);
private:
std::shared_ptr<Graph> buildGraph(const onnx::GraphProto& graph_proto);
@ -46,8 +56,6 @@ class ModuleDecoder {
TypePtr buildType(const onnx::TypeProto& type_proto);
at::Tensor buildParameter(const onnx::TensorProto& tensor_proto);
at::Tensor buildTensorCommon(const onnx::TensorProto& tensor_proto,
const uint64_t record_number,
const int64_t storage_offset,
@ -57,12 +65,13 @@ class ModuleDecoder {
ModuleLookup module_lookup,
const std::string fullname);
PyTorchStreamReader stream_reader_;
std::unordered_map<uint64_t, std::shared_ptr<at::Storage>> storage_map_;
PyTorchStreamReader* stream_reader_;
std::unordered_map<uint64_t, std::shared_ptr<at::Storage>>* storage_map_;
std::unordered_map<std::string, const onnx::TypeProto*> value_type_map_;
};
at::ScalarType ModuleDecoder::onnxTypeToATenType(onnx::TensorProto_DataType onnx_type) {
at::ScalarType MethodDecoder::onnxTypeToATenType(
onnx::TensorProto_DataType onnx_type) {
switch(onnx_type) {
case onnx::TensorProto_DataType_UINT8:
return at::kByte;
@ -85,8 +94,9 @@ at::ScalarType ModuleDecoder::onnxTypeToATenType(onnx::TensorProto_DataType onnx
}
}
void ModuleDecoder::buildBlocks(
const std::vector<onnx::GraphProto>& graphs_, Node* node,
void MethodDecoder::buildBlocks(
const std::vector<onnx::GraphProto>& graphs_,
Node* node,
std::unordered_map<std::string, Value*>& value_map) {
for (auto g_ : graphs_) {
auto block = node->addBlock();
@ -94,7 +104,8 @@ void ModuleDecoder::buildBlocks(
}
}
std::shared_ptr<Graph> ModuleDecoder::buildGraph(const onnx::GraphProto& graph_proto) {
std::shared_ptr<Graph> MethodDecoder::buildGraph(
const onnx::GraphProto& graph_proto) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> value_map;
@ -103,8 +114,10 @@ std::shared_ptr<Graph> ModuleDecoder::buildGraph(const onnx::GraphProto& graph_p
return graph;
}
void ModuleDecoder::buildBlock(const onnx::GraphProto& graph_proto, Block* block,
std::unordered_map<std::string, Value*>& value_map) {
void MethodDecoder::buildBlock(
const onnx::GraphProto& graph_proto,
Block* block,
std::unordered_map<std::string, Value*>& value_map) {
for (auto &subtype : graph_proto.value_info()) {
value_type_map_[subtype.name()] = &subtype.type();
}
@ -190,7 +203,7 @@ void ModuleDecoder::buildBlock(const onnx::GraphProto& graph_proto, Block* block
}
}
TypePtr ModuleDecoder::buildType(const onnx::TypeProto& type_proto) {
TypePtr MethodDecoder::buildType(const onnx::TypeProto& type_proto) {
auto tensortype_proto = type_proto.tensor_type();
auto shape_proto = tensortype_proto.shape();
auto kind = type_proto.denotation();
@ -248,29 +261,21 @@ TypePtr ModuleDecoder::buildType(const onnx::TypeProto& type_proto) {
}
}
void ModuleDecoder::buildValue(Value* value, const onnx::ValueInfoProto& valueinfo_proto) {
void MethodDecoder::buildValue(
Value* value,
const onnx::ValueInfoProto& valueinfo_proto) {
value->setType(buildType(valueinfo_proto.type()));
}
void ModuleDecoder::buildIntermediateValue(Value* value, const std::string& name) {
void MethodDecoder::buildIntermediateValue(
Value* value,
const std::string& name) {
auto it = value_type_map_.find(name);
JIT_ASSERT(it != value_type_map_.end());
value->setType(buildType(*it->second));
}
at::Tensor ModuleDecoder::buildParameter(const onnx::TensorProto& tensor_proto) {
std::vector<int64_t> strides;
// We've stored four other values (is_buffer, requires_grad, record no., storage_offset) before strides; ignore them
std::move(tensor_proto.int64_data().begin() + 4, tensor_proto.int64_data().end(), std::back_inserter(strides));
auto tensor = buildTensorCommon(tensor_proto,
/* record_number = */ tensor_proto.int64_data(2),
/* storage_offset = */ tensor_proto.int64_data(3),
strides);
autograd::Variable var = autograd::make_variable(tensor, /* requires_grad = */ tensor_proto.int64_data(1));
return var;
}
at::Tensor ModuleDecoder::buildTensor(const onnx::TensorProto& tensor_proto) {
at::Tensor MethodDecoder::buildTensor(const onnx::TensorProto& tensor_proto) {
std::vector<int64_t> strides;
// We've stored two other values (record no., storage_offset) before strides; ignore it
std::move(tensor_proto.int64_data().begin() + 2, tensor_proto.int64_data().end(), std::back_inserter(strides));
@ -280,7 +285,7 @@ at::Tensor ModuleDecoder::buildTensor(const onnx::TensorProto& tensor_proto) {
strides);
}
at::Tensor ModuleDecoder::buildTensorCommon(
at::Tensor MethodDecoder::buildTensorCommon(
const onnx::TensorProto& tensor_proto,
const uint64_t record_number,
const int64_t storage_offset,
@ -292,17 +297,18 @@ at::Tensor ModuleDecoder::buildTensorCommon(
std::move(tensor_proto.dims().begin(), tensor_proto.dims().end(), std::back_inserter(dims));
// Find or create the storage
auto storage_it = storage_map_.find(record_number);
if (storage_it == storage_map_.end()) {
auto storage_it = storage_map_->find(record_number);
if (storage_it == storage_map_->end()) {
at::DataPtr storage_ptr;
int64_t size;
std::tie(storage_ptr, size) = stream_reader_.getRecordWithKey(record_number);
std::tie(storage_ptr, size) =
stream_reader_->getRecordWithKey(record_number);
auto storage = std::make_shared<at::Storage>(
at::CPU(type).typeMeta(),
std::move(storage_ptr),
size / at::CPU(type).typeMeta().itemsize(),
nullptr);
storage_map_.insert(std::make_pair(record_number, storage));
storage_map_->insert(std::make_pair(record_number, storage));
return at::CPU(type)._th_tensor(*storage, storage_offset, dims, strides);
}
@ -312,9 +318,8 @@ at::Tensor ModuleDecoder::buildTensorCommon(
// Given a full name of a parameter or method,
// return the parent submodule and local name
std::pair<std::shared_ptr<script::Module>, std::string> ModuleDecoder::parseFullName(
ModuleLookup module_lookup,
const std::string fullname) {
std::pair<std::shared_ptr<script::Module>, std::string> MethodDecoder::
parseFullName(ModuleLookup module_lookup, const std::string fullname) {
AT_ASSERT(!fullname.empty());
std::vector<std::string> vec;
std::stringstream ss(fullname);
@ -328,37 +333,23 @@ std::pair<std::shared_ptr<script::Module>, std::string> ModuleDecoder::parseFull
return std::make_pair(module_lookup(vec), std::move(last));
}
ModuleDecoder::ModuleDecoder(
ModuleLookup module_lookup,
std::istream& in) :
stream_reader_(&in) {
auto model_proto = onnx::ModelProto();
auto record = stream_reader_.getLastRecord();
model_proto.ParsePartialFromArray(std::get<0>(record).get(), std::get<1>(record));
auto graph_proto = model_proto.graph();
std::unordered_map<std::string, at::Tensor*> param_map;
for (auto &tensor_proto : graph_proto.initializer()) {
std::shared_ptr<script::Module> parent_module;
std::string name;
std::tie(parent_module, name) = parseFullName(module_lookup, tensor_proto.name());
auto param = buildParameter(tensor_proto);
parent_module->register_parameter(name, param, /* is_buffer = */ tensor_proto.int64_data(0));
param_map[tensor_proto.name()] = parent_module->parameter_slot(name);
}
for (auto &node_proto : graph_proto.node()) {
std::shared_ptr<script::Module> parent_module;
std::string name;
std::tie(parent_module, name) = parseFullName(module_lookup, node_proto.name());
MethodDecoder::MethodDecoder(
const onnx::ModelProto& model_proto,
const std::unordered_map<std::string, at::Tensor*>& param_map,
script::Module* parent_module,
std::unordered_map<uint64_t, std::shared_ptr<at::Storage>>* storage_map,
PyTorchStreamReader* stream_reader) {
storage_map_ = storage_map;
stream_reader_ = stream_reader;
const auto& graph_proto = model_proto.graph();
for (const auto& node_proto : graph_proto.node()) {
std::vector<at::Tensor*> member_inputs;
for (auto &param_name : node_proto.input()) {
member_inputs.push_back(param_map[param_name]);
const std::string& name = node_proto.name();
for (const auto& param_name : node_proto.input()) {
auto it = param_map.find(param_name);
AT_ASSERTM(it != param_map.end(), "cannot find parameter ", param_name);
member_inputs.push_back(it->second);
}
auto graph = buildGraph(node_proto.attribute(0).g());
// has_domain field has a string iff the method was optimized
parent_module->set_optimized(node_proto.has_domain());
@ -370,22 +361,159 @@ ModuleDecoder::ModuleDecoder(
}
}
// this is a deserializer class which loads script modules from pt files. the
// content of the file is written using PyTorchStreamWriter, for details please
// check caffe2/serialize/inline_container.h. all the records except the last
// one are tensor data, and the last record is a serialized ModelProto, defined
// in caffe2/proto/torch.proto. ModelProto contains all the metadata of the
// model, and it is serialized as json.
class ScriptModuleDeserializer final {
public:
ScriptModuleDeserializer(const std::string& filename)
: ifs_(filename, std::ifstream::in | std::ifstream::binary),
reader_(&ifs_) {
// TODO appropriate support for mmap, right now still use stream reader
}
ScriptModuleDeserializer(std::istream* is) : ifs_(), reader_(is) {}
void deserialize(ModuleLookup module_lookup) {
torch::ModelDef model_def;
at::DataPtr data_ptr;
size_t data_size;
std::tie(data_ptr, data_size) = reader_.getLastRecord();
// NB: cannot use JsonStringToMessage, since fbcode's protobuf is too old
// be consistent with JsonStringToMessage
std::string url_prefix = "type.googleapis.com";
std::unique_ptr<::google::protobuf::util::TypeResolver> resolver(
::google::protobuf::util::NewTypeResolverForDescriptorPool(
url_prefix, model_def.GetDescriptor()->file()->pool()));
std::string json_string = std::string(
static_cast<char*>(data_ptr.get()),
static_cast<char*>(data_ptr.get()) + data_size);
std::string binary_string;
auto convert_result = ::google::protobuf::util::JsonToBinaryString(
resolver.get(),
url_prefix + "/" + model_def.GetDescriptor()->full_name(),
json_string,
&binary_string);
if (!convert_result.ok()) {
std::stringstream ss;
ss << convert_result;
AT_ERROR(ss.str());
}
AT_ASSERTM(
model_def.ParseFromString(binary_string),
"JSON transcoder produced invalid protobuf output.");
moduleLookup_ = module_lookup;
const auto& module_def = model_def.main_module();
collectParamsInfo(module_def, module_def.name());
// TODO: this can be simplified when C++/Python interop lands,
// and the submodules would be created as the same in either C++ or Python
std::shared_ptr<script::Module> module = moduleLookup_(moduleStack_);
convertModule(module_def, module.get());
}
private:
void collectParamsInfo(
const torch::ModuleDef& module_def,
const std::string& prefix) {
std::shared_ptr<script::Module> module = moduleLookup_(moduleStack_);
for (int i = 0; i < module_def.parameters_size(); ++i) {
const torch::ParameterDef& param_def = module_def.parameters(i);
at::Tensor tensor = createTensor(param_def.tensor());
autograd::Variable variable =
autograd::make_variable(tensor, param_def.require_gradient());
module->register_parameter(
param_def.name(), variable, param_def.is_buffer());
parameterMap_[prefix + param_def.name()] =
module->parameter_slot(param_def.name());
}
for (int i = 0; i < module_def.submodules_size(); ++i) {
const torch::ModuleDef& sub_def = module_def.submodules(i);
moduleStack_.push_back(sub_def.name());
collectParamsInfo(sub_def, prefix + sub_def.name() + ".");
moduleStack_.pop_back();
}
}
void convertModule(
const torch::ModuleDef& module_def,
script::Module* module) {
for (int i = 0; i < module_def.methods_size(); ++i) {
const torch::MethodDef& method_def = module_def.methods(i);
// TODO read unhacked torch script, right now it's serialized onnx proto
::ONNX_NAMESPACE::ModelProto method_proto;
AT_ASSERTM(
method_proto.ParseFromString(method_def.onnx_proto()),
"cannot parse method proto (i.e., hacked onnx proto)");
MethodDecoder decoder(
method_proto, parameterMap_, module, &storageMap_, &reader_);
(void)decoder;
}
for (int i = 0; i < module_def.submodules_size(); ++i) {
const torch::ModuleDef& sub_def = module_def.submodules(i);
moduleStack_.push_back(sub_def.name());
std::shared_ptr<script::Module> sub = moduleLookup_(moduleStack_);
convertModule(sub_def, sub.get());
moduleStack_.pop_back();
}
}
at::Tensor createTensor(const caffe2::TensorProto& tensor_proto) {
std::vector<int64_t> dims;
for (int i = 0; i < tensor_proto.dims_size(); ++i) {
dims.push_back(tensor_proto.dims(i));
}
AT_ASSERT(
tensor_proto.storage_type() ==
caffe2::TensorProto_StorageType_EXTERNAL);
const caffe2::ExternalDataProto& external_data =
tensor_proto.external_data();
std::vector<int64_t> strides;
for (int i = 0; i < external_data.strides_size(); ++i) {
strides.push_back(external_data.strides(i));
}
auto type = at::typeMetaToScalarType(
caffe2::DataTypeToTypeMeta(tensor_proto.data_type()));
uint64_t record_id = caffe2::stoull(external_data.record_id());
AT_ASSERT(record_id != 0);
auto storage_it = storageMap_.find(record_id);
if (storage_it == storageMap_.end()) {
at::DataPtr storage_ptr;
uint64_t record_size;
std::tie(storage_ptr, record_size) = reader_.getRecordWithKey(record_id);
AT_ASSERT(record_size == external_data.record_size());
auto storage = std::make_shared<at::Storage>(
at::CPU(type).typeMeta(),
std::move(storage_ptr),
record_size / at::CPU(type).typeMeta().itemsize(),
nullptr); // NB: we didn't set any allocator for the tensor
storageMap_.insert(std::make_pair(record_id, storage));
return at::CPU(type)._th_tensor(
*storage, external_data.offset(), dims, strides);
}
return at::CPU(type)._th_tensor(
*(storage_it->second.get()), external_data.offset(), dims, strides);
}
std::ifstream ifs_;
PyTorchStreamReader reader_;
ModuleLookup moduleLookup_;
std::vector<std::string> moduleStack_;
std::unordered_map<uint64_t, std::shared_ptr<at::Storage>> storageMap_;
std::unordered_map<std::string, at::Tensor*> parameterMap_;
};
} // namespace
void import_ir_module(
ModuleLookup module_lookup,
std::istream& in) {
ModuleDecoder decoder(module_lookup, in);
(void)decoder;
ScriptModuleDeserializer deserializer(&in);
deserializer.deserialize(module_lookup);
}
void import_ir_module(
ModuleLookup module_lookup,
const std::string& filename) {
std::ifstream in(filename, std::ios_base::binary);
ModuleDecoder decoder(module_lookup, in);
(void)decoder;
ScriptModuleDeserializer deserializer(filename);
deserializer.deserialize(module_lookup);
}
std::shared_ptr<script::Module> load(std::istream& in) {
@ -402,8 +530,8 @@ std::shared_ptr<script::Module> load(std::istream& in) {
return curr;
};
ModuleDecoder decoder(module_lookup, in);
(void)decoder;
ScriptModuleDeserializer deserializer(&in);
deserializer.deserialize(module_lookup);
return module;
}

View File

@ -151,7 +151,7 @@ struct Method {
return retval;
}
std::vector<at::Tensor*> params() {
std::vector<at::Tensor*> params() const {
return member_inputs;
}
@ -182,7 +182,7 @@ struct Method {
return get_executor().debugDisableAutodiffSubgraphInlining();
}
bool is_optimized() {
bool is_optimized() const {
return optimize;
}