mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26758 This PR changes the order in which we import classes and functions so that is is no longer necessary for them to defined in order in a file, or for there to be proper import statements in the exported file. Actually importing a function/class now is driven by the need to resolve the entity during unpickling, type resolution, or value resolution. While this should allow significant simplification to the code that serializes classes, this work has not been done yet in order to avoid inevitable forward compat issues in the transition period. Notes: * Individual functions have been replaced with a SourceImporter object that exposes a resolveType method. This method loads the type if it has not been loaded yet, potentially parsing (but not loading) the file it exists in if that file hasn't been parsed yet. * Some legacy functionality needed to be added as a method to this object since the old format still used some of this logic for class resolution. Test Plan: Imported from OSS Differential Revision: D17558989 Pulled By: zdevito fbshipit-source-id: 7eae3470bcbd388c4de463e3462d527776ed46c6
1003 lines
35 KiB
C++
1003 lines
35 KiB
C++
#include <google/protobuf/util/json_util.h>
|
|
#include <google/protobuf/util/type_resolver_util.h>
|
|
|
|
#include <torch/csrc/autograd/symbolic.h>
|
|
#include <torch/csrc/jit/export.h>
|
|
#include <torch/csrc/onnx/onnx.h>
|
|
|
|
#include <ATen/core/functional.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <torch/csrc/jit/import_export_helpers.h>
|
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
|
#include <torch/csrc/jit/passes/python_print.h>
|
|
#include <torch/csrc/jit/pickle.h>
|
|
#include <torch/csrc/jit/source_range_serialization.h>
|
|
#include <torch/csrc/jit/instruction.h>
|
|
|
|
#include <caffe2/core/types.h>
|
|
#include <caffe2/proto/caffe2_pb.h>
|
|
#include <caffe2/proto/torch_pb.h>
|
|
#include <caffe2/serialize/inline_container.h>
|
|
#include <onnx/onnx_pb.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <c10/util/Optional.h>
|
|
|
|
#include <fstream>
|
|
#include <memory>
|
|
#include <set>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
char const * toString(OpCode op);
|
|
|
|
namespace {
|
|
namespace onnx_torch = ::torch::onnx;
|
|
namespace onnx = ::ONNX_NAMESPACE;
|
|
|
|
namespace {
|
|
ExportModuleExtraFilesHook& GetExtraFilesHook() {
|
|
static ExportModuleExtraFilesHook func = nullptr;
|
|
return func;
|
|
};
|
|
}
|
|
|
|
class ScriptModuleSerializer;
|
|
|
|
std::string getNodeStackTraceString(const Node* n) {
|
|
return n->sourceRange().str();
|
|
}
|
|
|
|
void validateBlock(
|
|
Block* b,
|
|
onnx_torch::OperatorExportTypes operator_export_type) {
|
|
for (auto node : b->nodes()) {
|
|
for (Block* sub_block : node->blocks()) {
|
|
validateBlock(sub_block, operator_export_type);
|
|
}
|
|
// Macro'ed so we get a marginally better line number on failed export
|
|
#define FAIL_EXPORT(name) \
|
|
throw std::runtime_error( \
|
|
std::string("ONNX export failed: ") + name + \
|
|
"\n\nGraph we tried to export:\n" + b->owningGraph()->toString());
|
|
if (node->kind() == prim::PythonOp) {
|
|
auto py_node = static_cast<PythonOp*>(node);
|
|
FAIL_EXPORT(
|
|
"Couldn't export Python operator " + py_node->name() +
|
|
"\n\nDefined at:\n" + getNodeStackTraceString(node))
|
|
} else {
|
|
// Special error messages for certain types of operators
|
|
if (node->kind() == aten::expand) {
|
|
if (operator_export_type ==
|
|
onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK) {
|
|
WithInsertPoint guard(node);
|
|
auto* new_node =
|
|
b->owningGraph()->insertNode(b->owningGraph()->create(
|
|
Symbol(::c10::onnx::ATen),
|
|
node->inputs(),
|
|
node->outputs().size()));
|
|
for (size_t i = 0; i < node->outputs().size(); ++i) {
|
|
node->output(i)->replaceAllUsesWith(new_node->output(i));
|
|
}
|
|
new_node->s_(Symbol::fromQualString("attr::operator"), "expand");
|
|
}
|
|
}
|
|
if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) {
|
|
FAIL_EXPORT(
|
|
"Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.\n\nUsage of this operation occurred at:\n" +
|
|
getNodeStackTraceString(node));
|
|
}
|
|
bool is_aten_enabled = operator_export_type ==
|
|
onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK ||
|
|
operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN;
|
|
if (!node->kind().is_onnx() && !node->kind().is_caffe2() &&
|
|
!is_aten_enabled && !node->mustBeNone()) {
|
|
FAIL_EXPORT(
|
|
"Couldn't export operator " + node->kind().toDisplayString() +
|
|
"\n\nDefined at:\n" + getNodeStackTraceString(node));
|
|
}
|
|
}
|
|
#undef FAIL_EXPORT
|
|
}
|
|
}
|
|
|
|
void validateGraph(
|
|
const std::shared_ptr<Graph>& graph,
|
|
onnx_torch::OperatorExportTypes operator_export_type) {
|
|
validateBlock(graph->block(), operator_export_type);
|
|
// this is run on an onnx graph which doesn't have side effects.
|
|
// ignore side effects in dead code elimination.
|
|
EliminateDeadCode(graph->block(), true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
|
|
}
|
|
|
|
class EncoderBase {
|
|
public:
|
|
EncoderBase(
|
|
onnx_torch::OperatorExportTypes operator_export_type,
|
|
bool strip_doc);
|
|
|
|
onnx::ModelProto get_model_proto() {
|
|
return model_proto_;
|
|
}
|
|
|
|
protected:
|
|
// Using std::map instead of std::unordered_map for initializers
|
|
// in EncodeGraph cosntructor so that the order in which initializers
|
|
// get written to the ONNX graph is always the deterministic and
|
|
// predictable. While this is not a ONNX requirement, it is needed
|
|
// for testing purposes in tests that use _export_to_pretty_string()
|
|
// for validating ONNX graphs.
|
|
void EncodeGraph(
|
|
onnx::GraphProto* graph_proto,
|
|
const std::shared_ptr<Graph>& graph,
|
|
const std::map<std::string, at::Tensor>& initializers =
|
|
std::map<std::string, at::Tensor>(),
|
|
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes =
|
|
std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>(),
|
|
bool keep_initializers_as_inputs = true);
|
|
|
|
void EncodeBlock(
|
|
onnx::GraphProto* graph_proto,
|
|
const Block* block,
|
|
const std::map<std::string, at::Tensor>& initializers =
|
|
std::map<std::string, at::Tensor>(),
|
|
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes =
|
|
std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>(),
|
|
bool keep_initializers_as_inputs = true);
|
|
|
|
virtual void EncodeTensor(
|
|
onnx::TensorProto* tensor_proto,
|
|
const at::Tensor& tensor,
|
|
const c10::optional<std::string> external_ref = {}) = 0;
|
|
|
|
virtual void EncodeIntermediateValueInfo(
|
|
onnx::GraphProto* graph_proto,
|
|
const Value* n){}
|
|
|
|
virtual void EncodeValueInfo(
|
|
onnx::GraphProto* graph_proto,
|
|
onnx::ValueInfoProto* v,
|
|
const Value* n,
|
|
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes =
|
|
std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>());
|
|
|
|
void AddAttribute(
|
|
onnx::NodeProto* node_proto,
|
|
const jit::Node* node,
|
|
const jit::Symbol name);
|
|
|
|
onnx::ModelProto model_proto_;
|
|
size_t num_blocks_;
|
|
onnx_torch::OperatorExportTypes operator_export_type_;
|
|
bool strip_doc_;
|
|
std::set<std::string> domains_;
|
|
};
|
|
|
|
onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
|
|
switch (at_type) {
|
|
case at::kDouble:
|
|
return onnx::TensorProto_DataType_DOUBLE;
|
|
case at::kFloat:
|
|
return onnx::TensorProto_DataType_FLOAT;
|
|
case at::kHalf:
|
|
return onnx::TensorProto_DataType_FLOAT16;
|
|
case at::kByte:
|
|
return onnx::TensorProto_DataType_UINT8;
|
|
case at::kChar:
|
|
return onnx::TensorProto_DataType_INT8;
|
|
case at::kShort:
|
|
return onnx::TensorProto_DataType_INT16;
|
|
case at::kInt:
|
|
return onnx::TensorProto_DataType_INT32;
|
|
case at::kLong:
|
|
return onnx::TensorProto_DataType_INT64;
|
|
case at::kBool:
|
|
return onnx::TensorProto_DataType_BOOL;
|
|
default:
|
|
AT_ERROR("unexpected tensor scalar type");
|
|
}
|
|
}
|
|
|
|
EncoderBase::EncoderBase(
|
|
onnx_torch::OperatorExportTypes operator_export_type,
|
|
bool strip_doc)
|
|
: num_blocks_(0),
|
|
operator_export_type_(operator_export_type),
|
|
strip_doc_(strip_doc) {
|
|
model_proto_.set_producer_name("pytorch");
|
|
// we pin IR version to version 4 (01/22/2019) instead of using
|
|
// onnx::IR_VERSION. with this change, the test_operators.py will be more
|
|
// stable. only bump it when it's necessary
|
|
model_proto_.set_ir_version(4);
|
|
// TODO: set the producer version using appropriate function call
|
|
model_proto_.set_producer_version("1.2");
|
|
}
|
|
|
|
void EncoderBase::EncodeValueInfo(
|
|
onnx::GraphProto* graph_proto,
|
|
onnx::ValueInfoProto* v,
|
|
const Value* n,
|
|
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes) {
|
|
std::string name = n->debugName();
|
|
v->set_name(name);
|
|
if (TensorTypePtr node_type = n->type()->cast<TensorType>()) {
|
|
if (!node_type->isComplete()) {
|
|
return;
|
|
}
|
|
onnx::TypeProto* t = v->mutable_type();
|
|
onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
|
|
onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
|
|
std::vector<std::int64_t> sizes =
|
|
node_type->sizes().concrete_sizes().value();
|
|
for (size_t i = 0; i < sizes.size(); i++) {
|
|
shape->add_dim();
|
|
if ((dynamic_axes.find(name) != dynamic_axes.end()) &&
|
|
(dynamic_axes.at(name).find(i) != dynamic_axes.at(name).end())){
|
|
shape->mutable_dim(i)->set_dim_param(dynamic_axes.at(name).at(i));
|
|
}
|
|
else{
|
|
shape->mutable_dim(i)->set_dim_value(sizes[i]);
|
|
}
|
|
}
|
|
tensor_type->set_elem_type(
|
|
ATenTypeToOnnxType(node_type->scalarType().value()));
|
|
} else if (BoolTypePtr node_type = n->type()->cast<BoolType>()) {
|
|
onnx::TypeProto* t = v->mutable_type();
|
|
onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
|
|
tensor_type->set_elem_type(ATenTypeToOnnxType(at::kBool));
|
|
}
|
|
}
|
|
|
|
void EncoderBase::EncodeGraph(
|
|
onnx::GraphProto* graph_proto,
|
|
const std::shared_ptr<Graph>& graph,
|
|
const std::map<std::string, at::Tensor>& initializers,
|
|
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes,
|
|
bool keep_initializers_as_inputs) {
|
|
EncodeBlock(graph_proto, graph->block(), initializers, dynamic_axes, keep_initializers_as_inputs);
|
|
}
|
|
|
|
void EncoderBase::EncodeBlock(
|
|
onnx::GraphProto* graph_proto,
|
|
const Block* block,
|
|
const std::map<std::string, at::Tensor>& initializers,
|
|
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes,
|
|
bool keep_initializers_as_inputs) {
|
|
AT_ASSERT(graph_proto != nullptr);
|
|
std::string block_name = "torch-jit-export";
|
|
if (num_blocks_) {
|
|
block_name += std::to_string(num_blocks_);
|
|
}
|
|
num_blocks_++;
|
|
graph_proto->set_name(block_name);
|
|
|
|
// Since ONNX IR VERSION 4, initializers do not have to
|
|
// be a subset of graph inputs. We use keep_initializers_as_inputs
|
|
// argument to determine whether to add initializers
|
|
// as inputs or not. If keep_initializers_as_inputs=false,
|
|
// we only add non-parameter inputs as inputs to ONNX graph, and.
|
|
// not the initializers (parameters). If keep_initializers_as_inputs
|
|
// =true, we add initializers as inputs too. Setting
|
|
// keep_initializers_as_inputs=false allows better
|
|
// optimizations, such as constant-folding, on ONNX graphs
|
|
// by backends/optimizers.
|
|
if (keep_initializers_as_inputs) {
|
|
for (auto input : block->inputs()) {
|
|
onnx::ValueInfoProto* v = graph_proto->add_input();
|
|
EncodeValueInfo(graph_proto, v, input, dynamic_axes);
|
|
}
|
|
}
|
|
else {
|
|
for (auto input : block->inputs()) {
|
|
auto it = initializers.find(input->debugName());
|
|
if (it == initializers.end()) {
|
|
onnx::ValueInfoProto* v = graph_proto->add_input();
|
|
EncodeValueInfo(graph_proto, v, input, dynamic_axes);
|
|
}
|
|
}
|
|
}
|
|
for (auto output : block->outputs()) {
|
|
onnx::ValueInfoProto* v = graph_proto->add_output();
|
|
EncodeValueInfo(graph_proto, v, output, dynamic_axes);
|
|
}
|
|
for (auto node : block->nodes()) {
|
|
bool is_raw_export =
|
|
operator_export_type_ == onnx_torch::OperatorExportTypes::RAW;
|
|
if (node->mustBeNone() && !is_raw_export) {
|
|
// None nodes are used to implement optional inputs. One
|
|
// way to "not provide" an optional input is to create an
|
|
// Undefined node, and pass its output as that input.
|
|
continue;
|
|
}
|
|
auto p_n = graph_proto->add_node();
|
|
if (!strip_doc_) {
|
|
p_n->set_doc_string(node->sourceRange().str());
|
|
}
|
|
for (auto input : node->inputs()) {
|
|
if (input->node()->mustBeNone() && !is_raw_export) {
|
|
p_n->add_input("");
|
|
} else {
|
|
p_n->add_input(input->debugName());
|
|
}
|
|
}
|
|
for (auto output : node->outputs()) {
|
|
p_n->add_output(output->debugName());
|
|
EncodeIntermediateValueInfo(graph_proto, output);
|
|
}
|
|
if (!node->kind().is_onnx()) {
|
|
p_n->set_domain(node->kind().domainString());
|
|
domains_.insert(node->kind().domainString());
|
|
}
|
|
if (is_raw_export) {
|
|
AT_ASSERT(!node->kind().is_onnx());
|
|
} else if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) {
|
|
AT_ASSERT(
|
|
!node->kind().is_aten() && !node->kind().is_prim() &&
|
|
!node->kind().is_attr());
|
|
}
|
|
p_n->set_op_type(node->kind().toUnqualString());
|
|
for (auto attr_name : node->attributeNames()) {
|
|
AddAttribute(p_n, node, attr_name);
|
|
}
|
|
if (is_raw_export && node->blocks().size() > 0) {
|
|
auto blocks = p_n->add_attribute();
|
|
blocks->set_name("_blocks");
|
|
blocks->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
|
|
for (auto block : node->blocks()) {
|
|
auto graph = blocks->add_graphs();
|
|
EncodeBlock(graph, block, initializers);
|
|
}
|
|
}
|
|
if (node->kind() == ::c10::onnx::Loop) {
|
|
AT_ASSERT(node->blocks().size() == 1);
|
|
|
|
auto body = p_n->add_attribute();
|
|
body->set_name("body");
|
|
body->set_type(onnx::AttributeProto_AttributeType_GRAPH);
|
|
auto g = body->mutable_g();
|
|
EncodeBlock(g, node->blocks()[0]);
|
|
}
|
|
if (node->kind() == ::c10::onnx::If) {
|
|
AT_ASSERT(node->blocks().size() == 2);
|
|
|
|
auto true_branch = p_n->add_attribute();
|
|
true_branch->set_name("then_branch");
|
|
true_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
|
|
auto true_g = true_branch->mutable_g();
|
|
EncodeBlock(true_g, node->blocks()[0]);
|
|
|
|
auto false_branch = p_n->add_attribute();
|
|
false_branch->set_name("else_branch");
|
|
false_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
|
|
auto false_g = false_branch->mutable_g();
|
|
EncodeBlock(false_g, node->blocks()[1]);
|
|
}
|
|
}
|
|
AT_ASSERT(block->inputs().size() >= initializers.size());
|
|
for (auto& name_tensor_pair : initializers) {
|
|
auto p = graph_proto->add_initializer();
|
|
p->set_name(name_tensor_pair.first);
|
|
EncodeTensor(p, name_tensor_pair.second, name_tensor_pair.first);
|
|
}
|
|
}
|
|
|
|
void EncoderBase::AddAttribute(
|
|
onnx::NodeProto* node_proto,
|
|
const jit::Node* node,
|
|
const jit::Symbol name) {
|
|
auto attr = node_proto->add_attribute();
|
|
AT_ASSERT(name.is_attr());
|
|
attr->set_name(name.toUnqualString());
|
|
switch (node->kindOf(name)) {
|
|
case AttributeKind::f:
|
|
attr->set_f(node->f(name));
|
|
attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
|
|
break;
|
|
case AttributeKind::fs:
|
|
attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
|
|
for (auto& v : node->fs(name))
|
|
attr->add_floats(v);
|
|
break;
|
|
case AttributeKind::i:
|
|
attr->set_type(onnx::AttributeProto_AttributeType_INT);
|
|
attr->set_i(node->i(name));
|
|
break;
|
|
case AttributeKind::is:
|
|
attr->set_type(onnx::AttributeProto_AttributeType_INTS);
|
|
for (auto& v : node->is(name))
|
|
attr->add_ints(v);
|
|
break;
|
|
case AttributeKind::s:
|
|
attr->set_type(onnx::AttributeProto_AttributeType_STRING);
|
|
attr->set_s(node->s(name));
|
|
break;
|
|
case AttributeKind::ss:
|
|
attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
|
|
for (auto& v : node->ss(name))
|
|
attr->add_strings(v);
|
|
break;
|
|
case AttributeKind::t: {
|
|
attr->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
|
auto t = attr->mutable_t();
|
|
EncodeTensor(t, node->t(name));
|
|
} break;
|
|
case AttributeKind::ts:
|
|
attr->set_type(onnx::AttributeProto_AttributeType_TENSORS);
|
|
for (auto& v : node->ts(name)) {
|
|
auto t = attr->add_tensors();
|
|
EncodeTensor(t, v);
|
|
}
|
|
break;
|
|
case AttributeKind::g: {
|
|
attr->set_type(onnx::AttributeProto_AttributeType_GRAPH);
|
|
auto g = attr->mutable_g();
|
|
EncodeGraph(g, node->g(name));
|
|
} break;
|
|
case AttributeKind::gs:
|
|
attr->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
|
|
for (auto& v : node->gs(name)) {
|
|
auto g = attr->add_graphs();
|
|
EncodeGraph(g, v);
|
|
}
|
|
break;
|
|
default:
|
|
throw std::runtime_error("unexpected attribute kind");
|
|
}
|
|
}
|
|
|
|
class GraphEncoder : public EncoderBase {
|
|
public:
|
|
GraphEncoder(
|
|
const std::shared_ptr<Graph>& graph,
|
|
int64_t onnx_opset_version,
|
|
onnx_torch::OperatorExportTypes operator_export_type,
|
|
const std::map<std::string, at::Tensor>& initializers,
|
|
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes,
|
|
bool defer_weight_export,
|
|
bool strip_doc,
|
|
bool keep_initializers_as_inputs);
|
|
|
|
RawDataExportMap get_raw_data_export_map() {
|
|
return raw_data_export_map_;
|
|
}
|
|
|
|
private:
|
|
void EncodeTensor(
|
|
onnx::TensorProto* tensor_proto,
|
|
const at::Tensor& tensor,
|
|
const c10::optional<std::string> external_ref = {}) override;
|
|
|
|
RawDataExportMap raw_data_export_map_;
|
|
bool defer_weight_export_;
|
|
};
|
|
|
|
GraphEncoder::GraphEncoder(
|
|
const std::shared_ptr<Graph>& graph,
|
|
int64_t onnx_opset_version,
|
|
onnx_torch::OperatorExportTypes operator_export_type,
|
|
const std::map<std::string, at::Tensor>& initializers,
|
|
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes,
|
|
bool defer_weight_export,
|
|
bool strip_doc,
|
|
bool keep_initializers_as_inputs)
|
|
: EncoderBase(operator_export_type, strip_doc),
|
|
defer_weight_export_(defer_weight_export) {
|
|
if (operator_export_type != onnx_torch::OperatorExportTypes::RAW) {
|
|
validateGraph(graph, operator_export_type);
|
|
}
|
|
|
|
auto* imp = model_proto_.add_opset_import();
|
|
// This is the version of ONNX operator set we are targeting
|
|
imp->set_version(onnx_opset_version);
|
|
|
|
EncodeGraph(model_proto_.mutable_graph(), graph, initializers, dynamic_axes, keep_initializers_as_inputs);
|
|
|
|
for (const std::string& domain : domains_) {
|
|
auto* opset = model_proto_.add_opset_import();
|
|
opset->set_domain(domain);
|
|
opset->set_version(0);
|
|
}
|
|
}
|
|
|
|
void GraphEncoder::EncodeTensor(
|
|
onnx::TensorProto* tensor_proto,
|
|
const at::Tensor& tensor,
|
|
const c10::optional<std::string> external_ref) {
|
|
for (auto d : tensor.sizes()) {
|
|
tensor_proto->add_dims(d);
|
|
}
|
|
tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.scalar_type()));
|
|
// CPU's HalfTensor doesn't have contiguous(), so first calling contiguous()
|
|
auto t = tensor.contiguous().cpu();
|
|
// Add a buffer to the raw_data_export_map for the caller to dump into an
|
|
// external data store. If external_ref is not specified, we instead dump
|
|
// the contiguous data into the protobuf itself
|
|
if (defer_weight_export_ && external_ref) {
|
|
// For now, we use the name of the tensor as the external lookup name to
|
|
// avoid ONNX protobuf changes.
|
|
AT_ASSERT(external_ref.value() == tensor_proto->name());
|
|
AT_ASSERT(raw_data_export_map_.count(external_ref.value()) == 0);
|
|
raw_data_export_map_[external_ref.value()] = t;
|
|
tensor_proto->set_raw_data("__EXTERNAL");
|
|
} else {
|
|
AT_ASSERT(t.is_contiguous());
|
|
tensor_proto->set_raw_data(std::string(
|
|
static_cast<char*>(t.data_ptr()), t.element_size() * t.numel()));
|
|
}
|
|
}
|
|
|
|
class ScriptModuleSerializer {
|
|
public:
|
|
ScriptModuleSerializer(const std::string& filename)
|
|
: writer_(filename.c_str()) {}
|
|
|
|
ScriptModuleSerializer(std::ostream* ofs)
|
|
: ofs_(), writer_(ofs) {}
|
|
|
|
void serialize(
|
|
const script::Module& module,
|
|
const script::ExtraFilesMap& extra_files,
|
|
bool bytecode_format) {
|
|
C10_LOG_API_USAGE_ONCE("torch.script.save");
|
|
writeExtraFiles(module, extra_files);
|
|
// Serialize all code info.
|
|
writeCode(module.type());
|
|
// The tensor constants from the code are written to a separate archive
|
|
// so loading the code does not depend on loading the data
|
|
std::vector<IValue> ivalue_constants(
|
|
constant_table_.begin(), constant_table_.end());
|
|
writeArchive("constants", c10::ivalue::Tuple::create(ivalue_constants));
|
|
if (bytecode_format) {
|
|
writeByteCode(module);
|
|
}
|
|
// finally we serialize the model
|
|
writeArchive("data", module.module_object());
|
|
}
|
|
|
|
private:
|
|
void writeArchive(const std::string& archive_name, const IValue& value) {
|
|
std::vector<char> data;
|
|
Pickler data_pickle(
|
|
[&](const char* buf, size_t size) {
|
|
data.insert(data.end(), buf, buf + size);
|
|
},
|
|
nullptr);
|
|
data_pickle.protocol();
|
|
data_pickle.pushIValue(value);
|
|
data_pickle.stop();
|
|
size_t i = 0;
|
|
for (const auto& td : data_pickle.tensorData()) {
|
|
std::stringstream fname;
|
|
fname << archive_name << "/" << i++;
|
|
writer_.writeRecord(fname.str(), td.data(), td.sizeInBytes());
|
|
}
|
|
std::stringstream fname;
|
|
fname << archive_name << ".pkl";
|
|
writer_.writeRecord(fname.str(), data.data(), data.size());
|
|
}
|
|
|
|
void writeExtraFiles(
|
|
const script::Module& module,
|
|
const script::ExtraFilesMap& extra_files) {
|
|
// Write out extra files.
|
|
for (const auto& kv : extra_files) {
|
|
const std::string key = "extra/" + kv.first;
|
|
writer_.writeRecord(key, kv.second.data(), kv.second.size());
|
|
}
|
|
auto hook = GetExtraFilesHook();
|
|
if (hook) {
|
|
script::ExtraFilesMap hook_files = hook(module);
|
|
for (const auto& kv : hook_files) {
|
|
const std::string key = "extra/" + kv.first;
|
|
writer_.writeRecord(key, kv.second.data(), kv.second.size());
|
|
}
|
|
}
|
|
}
|
|
|
|
void writeCode(const at::NamedTypePtr& root_type) {
|
|
convertNamedType(root_type);
|
|
static const std::string opset_string =
|
|
c10::str("op_version_set = ", CURRENT_OP_VERSION_SET, "\n");
|
|
|
|
// Mapping of filename => src. We need this because multiple clases may go
|
|
// in the same file (e.g. foo.bar.Baz and foo.bar.Qux)
|
|
|
|
// Aggregate classes into files by their qualified names
|
|
std::unordered_map<std::string, std::ostringstream> fileToSrc;
|
|
std::unordered_map<std::string, SourceRangeRecords> fileToDebug;
|
|
for (auto& item : converted_types_) {
|
|
const auto& converted_type = item.key();
|
|
auto& type_info = item.value();
|
|
|
|
// For the type, foo.bar.Baz
|
|
const std::string filename =
|
|
qualifierToArchivePath(converted_type->name()->prefix(), "code/");
|
|
// End state: filename is "foo/bar.py", in which we will define a class
|
|
// named Baz
|
|
auto& stream = fileToSrc[filename];
|
|
|
|
// Adjust the SourceRange offsets since we are concatenating multiple
|
|
// classes to a single file.
|
|
// Need to add opset_string size as an offset because we will be
|
|
// prepending it to the file. (We should remove this opset_version string
|
|
// at some point and stash it in the model.json)
|
|
const auto offset =
|
|
static_cast<size_t>(stream.tellp()) + opset_string.size();
|
|
for (auto& sourceRange : type_info.debug_info) {
|
|
sourceRange.bytes += offset;
|
|
}
|
|
|
|
auto& debugInfo = fileToDebug[filename];
|
|
debugInfo.insert(
|
|
debugInfo.end(),
|
|
type_info.debug_info.begin(),
|
|
type_info.debug_info.end());
|
|
fileToSrc[filename] << type_info.source;
|
|
}
|
|
|
|
for (const auto& item : fileToSrc) {
|
|
const auto& filename = item.first;
|
|
const auto src = item.second.str();
|
|
const auto& debugInfo = fileToDebug.at(filename);
|
|
|
|
// Prepend the opset_version string
|
|
const auto lib_str = c10::str(opset_string, src);
|
|
writer_.writeRecord(
|
|
filename, lib_str.c_str(), lib_str.size(), /*compress=*/true);
|
|
|
|
// Write out the debug information
|
|
std::stringstream debugFilename;
|
|
debugFilename << filename << ".debug_pkl";
|
|
SourceRangePickler source_range_pickler;
|
|
const auto& range_data = source_range_pickler.pickle(debugInfo);
|
|
|
|
writer_.writeRecord(
|
|
debugFilename.str(),
|
|
range_data.data(),
|
|
range_data.size(),
|
|
/*compress=*/true);
|
|
}
|
|
}
|
|
|
|
void writeByteCode(const script::Module& module) {
|
|
auto methods = module.get_methods();
|
|
std::vector<c10::IValue> elements;
|
|
for (const auto& method : methods) {
|
|
const auto& func = method.function();
|
|
torch::jit::Code code(func.graph());
|
|
|
|
// instructions
|
|
std::vector<IValue> inss;
|
|
for (const auto& ins : code.instructions()) {
|
|
TORCH_CHECK(isOpSupportedInMobile(ins.op), toString(ins.op),
|
|
" is not supported in mobile module.");
|
|
std::vector<IValue> insv{toString(ins.op), ins.X, ins.N};
|
|
inss.emplace_back(c10::ivalue::Tuple::create(std::move(insv)));
|
|
}
|
|
auto instructions = c10::ivalue::Tuple::create(std::move(inss));
|
|
auto named_ins = c10::ivalue::Tuple::create({"instructions", instructions});
|
|
|
|
// operators
|
|
std::vector<IValue> opss;
|
|
for (const auto& opname : code.opname_table()) {
|
|
opss.emplace_back(c10::ivalue::Tuple::create({opname.name, opname.overload_name}));
|
|
}
|
|
auto operators = c10::ivalue::Tuple::create(std::move(opss));
|
|
auto named_ops = c10::ivalue::Tuple::create({"operators", operators});
|
|
|
|
// constants
|
|
auto constants = c10::ivalue::Tuple::create(code.constant_table());
|
|
auto named_consts = c10::ivalue::Tuple::create({"constants", constants});
|
|
|
|
// since the register location is embedded into the bytecode, pass the register size
|
|
auto named_regsize = c10::ivalue::Tuple::create({"register_size",
|
|
static_cast<int>(code.register_size())});
|
|
|
|
auto element = c10::ivalue::Tuple::create({named_ins, named_ops, named_consts, named_regsize});
|
|
elements.push_back(c10::ivalue::Tuple::create({func.qualname().qualifiedName(), element}));
|
|
}
|
|
auto telements = c10::ivalue::Tuple::create(std::move(elements));
|
|
writeArchive("bytecode", telements);
|
|
}
|
|
|
|
void convertNamedType(const c10::NamedTypePtr& class_type) {
|
|
if (converted_types_.contains(class_type)) {
|
|
return;
|
|
}
|
|
|
|
std::vector<c10::NamedTypePtr> class_deps;
|
|
std::ostringstream source_stream;
|
|
SourceRangeRecords source_ranges;
|
|
PythonPrint(
|
|
source_stream,
|
|
source_ranges,
|
|
class_type,
|
|
constant_table_,
|
|
class_deps,
|
|
/*enforce_importable=*/true);
|
|
|
|
for (const auto& c : class_deps) {
|
|
if (c == class_type) {
|
|
// Don't re-process this class and enter an infinite loop. We need this
|
|
// because we insert to converted_classes_ post-traversal, so the
|
|
// current class isn't in there yet.
|
|
continue;
|
|
}
|
|
convertNamedType(c);
|
|
}
|
|
// Insert *after* we've traversed the dependencies. This ensures that any
|
|
// given class will appear after its dependencies in the order.
|
|
TypeInfo info{source_stream.str(), std::move(source_ranges)};
|
|
converted_types_.insert(class_type, std::move(info));
|
|
}
|
|
|
|
std::ofstream ofs_;
|
|
caffe2::serialize::PyTorchStreamWriter writer_;
|
|
std::vector<at::Tensor> constant_table_;
|
|
|
|
// all deps used by this module hierarchy
|
|
struct TypeInfo {
|
|
std::string source;
|
|
SourceRangeRecords debug_info;
|
|
};
|
|
OrderedDict<c10::NamedTypePtr, TypeInfo> converted_types_;
|
|
bool bytecode_format_;
|
|
};
|
|
|
|
// Pretty printing for ONNX
|
|
constexpr char indent_char = ' ';
|
|
constexpr size_t indent_multiplier = 2;
|
|
|
|
std::string idt(size_t indent) {
|
|
return std::string(indent * indent_multiplier, indent_char);
|
|
}
|
|
|
|
std::string nlidt(size_t indent) {
|
|
return std::string("\n") + idt(indent);
|
|
}
|
|
|
|
void dump(const onnx::TensorProto& tensor, std::ostream& stream) {
|
|
stream << "TensorProto shape: [";
|
|
for (int i = 0; i < tensor.dims_size(); ++i) {
|
|
stream << tensor.dims(i) << (i == tensor.dims_size() - 1 ? "" : " ");
|
|
}
|
|
stream << "]";
|
|
}
|
|
|
|
void dump(const onnx::TensorShapeProto& shape, std::ostream& stream) {
|
|
for (int i = 0; i < shape.dim_size(); ++i) {
|
|
auto& dim = shape.dim(i);
|
|
if (dim.has_dim_value()) {
|
|
stream << dim.dim_value();
|
|
} else {
|
|
stream << "?";
|
|
}
|
|
stream << (i == shape.dim_size() - 1 ? "" : " ");
|
|
}
|
|
}
|
|
|
|
void dump(const onnx::TypeProto_Tensor& tensor_type, std::ostream& stream) {
|
|
stream << "Tensor dims: ";
|
|
dump(tensor_type.shape(), stream);
|
|
}
|
|
|
|
void dump(const onnx::TypeProto& type, std::ostream& stream) {
|
|
dump(type.tensor_type(), stream);
|
|
}
|
|
|
|
void dump(const onnx::ValueInfoProto& value_info, std::ostream& stream) {
|
|
stream << "{name: \"" << value_info.name() << "\", type:";
|
|
dump(value_info.type(), stream);
|
|
stream << "}";
|
|
}
|
|
|
|
void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent);
|
|
|
|
void dump(
|
|
const onnx::AttributeProto& attr,
|
|
std::ostream& stream,
|
|
size_t indent) {
|
|
stream << "{ name: '" << attr.name() << "', type: ";
|
|
if (attr.has_f()) {
|
|
stream << "float, value: " << attr.f();
|
|
} else if (attr.has_i()) {
|
|
stream << "int, value: " << attr.i();
|
|
} else if (attr.has_s()) {
|
|
stream << "string, value: '" << attr.s() << "'";
|
|
} else if (attr.has_g()) {
|
|
stream << "graph, value:\n";
|
|
dump(attr.g(), stream, indent + 1);
|
|
stream << nlidt(indent);
|
|
} else if (attr.has_t()) {
|
|
stream << "tensor, value:";
|
|
dump(attr.t(), stream);
|
|
} else if (attr.floats_size()) {
|
|
stream << "floats, values: [";
|
|
for (int i = 0; i < attr.floats_size(); ++i)
|
|
stream << attr.floats(i) << (i == attr.floats_size() - 1 ? "" : " ");
|
|
stream << "]";
|
|
} else if (attr.ints_size()) {
|
|
stream << "ints, values: [";
|
|
for (int i = 0; i < attr.ints_size(); ++i)
|
|
stream << attr.ints(i) << (i == attr.ints_size() - 1 ? "" : " ");
|
|
stream << "]";
|
|
} else if (attr.strings_size()) {
|
|
stream << "strings, values: [";
|
|
for (int i = 0; i < attr.strings_size(); ++i)
|
|
stream << "'" << attr.strings(i) << "'"
|
|
<< (i == attr.strings_size() - 1 ? "" : " ");
|
|
stream << "]";
|
|
} else if (attr.tensors_size()) {
|
|
stream << "tensors, values: [";
|
|
for (auto& t : attr.tensors()) {
|
|
dump(t, stream);
|
|
}
|
|
stream << "]";
|
|
} else if (attr.graphs_size()) {
|
|
stream << "graphs, values: [";
|
|
for (auto& g : attr.graphs()) {
|
|
dump(g, stream, indent + 1);
|
|
}
|
|
stream << "]";
|
|
} else {
|
|
stream << "UNKNOWN";
|
|
}
|
|
stream << "}";
|
|
}
|
|
|
|
void dump(const onnx::NodeProto& node, std::ostream& stream, size_t indent) {
|
|
stream << "Node {type: \"" << node.op_type() << "\", inputs: [";
|
|
for (int i = 0; i < node.input_size(); ++i) {
|
|
stream << node.input(i) << (i == node.input_size() - 1 ? "" : ",");
|
|
}
|
|
stream << "], outputs: [";
|
|
for (int i = 0; i < node.output_size(); ++i) {
|
|
stream << node.output(i) << (i == node.output_size() - 1 ? "" : ",");
|
|
}
|
|
stream << "], attributes: [";
|
|
for (int i = 0; i < node.attribute_size(); ++i) {
|
|
dump(node.attribute(i), stream, indent + 1);
|
|
stream << (i == node.attribute_size() - 1 ? "" : ",");
|
|
}
|
|
stream << "]}";
|
|
}
|
|
|
|
void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent) {
|
|
stream << idt(indent) << "GraphProto {" << nlidt(indent + 1) << "name: \""
|
|
<< graph.name() << "\"" << nlidt(indent + 1) << "inputs: [";
|
|
for (int i = 0; i < graph.input_size(); ++i) {
|
|
dump(graph.input(i), stream);
|
|
stream << (i == graph.input_size() - 1 ? "" : ",");
|
|
}
|
|
stream << "]" << nlidt(indent + 1) << "outputs: [";
|
|
for (int i = 0; i < graph.output_size(); ++i) {
|
|
dump(graph.output(i), stream);
|
|
stream << (i == graph.output_size() - 1 ? "" : ",");
|
|
}
|
|
stream << "]" << nlidt(indent + 1) << "initializers: [";
|
|
for (int i = 0; i < graph.initializer_size(); ++i) {
|
|
dump(graph.initializer(i), stream);
|
|
stream << (i == graph.initializer_size() - 1 ? "" : ",");
|
|
}
|
|
stream << "]" << nlidt(indent + 1) << "nodes: [" << nlidt(indent + 2);
|
|
for (int i = 0; i < graph.node_size(); ++i) {
|
|
dump(graph.node(i), stream, indent + 2);
|
|
if (i != graph.node_size() - 1)
|
|
stream << "," << nlidt(indent + 2);
|
|
}
|
|
stream << nlidt(indent + 1) << "]\n" << idt(indent) << "}\n";
|
|
}
|
|
|
|
void dump(
|
|
const onnx::OperatorSetIdProto& operator_set_id,
|
|
std::ostream& stream) {
|
|
stream << "OperatorSetIdProto { domain: " << operator_set_id.domain() << "}";
|
|
}
|
|
|
|
void dump(const onnx::ModelProto& model, std::ostream& stream, size_t indent) {
|
|
stream << idt(indent) << "ModelProto {" << nlidt(indent + 1)
|
|
<< "producer_name: \"" << model.producer_name() << "\""
|
|
<< nlidt(indent + 1) << "domain: \"" << model.domain() << "\""
|
|
<< nlidt(indent + 1) << "doc_string: \"" << model.doc_string() << "\"";
|
|
if (model.has_graph()) {
|
|
stream << nlidt(indent + 1) << "graph:\n";
|
|
dump(model.graph(), stream, indent + 2);
|
|
}
|
|
if (model.opset_import_size()) {
|
|
stream << idt(indent + 1) << "opset_import: [";
|
|
for (auto& opset_imp : model.opset_import()) {
|
|
dump(opset_imp, stream);
|
|
}
|
|
stream << "],\n";
|
|
}
|
|
stream << idt(indent) << "}\n";
|
|
}
|
|
|
|
std::string prettyPrint(const onnx::ModelProto& model) {
|
|
std::stringstream ss;
|
|
dump(model, ss, 0);
|
|
return ss.str();
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) {
|
|
GetExtraFilesHook() = hook;
|
|
}
|
|
|
|
std::string pretty_print_onnx(
|
|
const std::shared_ptr<Graph>& graph,
|
|
const std::map<std::string, at::Tensor>& initializers,
|
|
int64_t onnx_opset_version,
|
|
bool defer_weight_export,
|
|
::torch::onnx::OperatorExportTypes operator_export_type,
|
|
bool google_printer,
|
|
bool keep_initializers_as_inputs) {
|
|
auto graph_encoder = GraphEncoder(
|
|
graph,
|
|
onnx_opset_version,
|
|
operator_export_type,
|
|
initializers,
|
|
std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>{},
|
|
defer_weight_export,
|
|
true,
|
|
keep_initializers_as_inputs);
|
|
if (google_printer) {
|
|
return graph_encoder.get_model_proto().DebugString();
|
|
}
|
|
return prettyPrint(graph_encoder.get_model_proto());
|
|
}
|
|
|
|
// export_raw_ir will export IR ops without turning them into ONNX ops.
|
|
// The output will use the ONNX protobuf format, but the ops will not
|
|
// conform to the ONNX op specification. Thus, the output will not
|
|
// be interpretable by a ONNX-compatible framework. However, PyTorch or
|
|
// libtorch will be able to import the IR and play it back.
|
|
std::tuple<std::string, RawDataExportMap> export_onnx(
|
|
const std::shared_ptr<Graph>& graph,
|
|
const std::map<std::string, at::Tensor>& initializers,
|
|
int64_t onnx_opset_version,
|
|
const std::unordered_map<std::string, std::unordered_map<std::int64_t, std::string>>& dynamic_axes,
|
|
bool defer_weight_export,
|
|
::torch::onnx::OperatorExportTypes operator_export_type,
|
|
bool strip_doc_string,
|
|
bool keep_initializers_as_inputs) {
|
|
auto graph_encoder = GraphEncoder(
|
|
graph,
|
|
onnx_opset_version,
|
|
operator_export_type,
|
|
initializers,
|
|
dynamic_axes,
|
|
defer_weight_export,
|
|
strip_doc_string,
|
|
keep_initializers_as_inputs);
|
|
return std::make_tuple(
|
|
graph_encoder.get_model_proto().SerializeAsString(),
|
|
graph_encoder.get_raw_data_export_map());
|
|
}
|
|
|
|
|
|
void ExportModule(
|
|
const script::Module& module,
|
|
std::ostream& out,
|
|
const script::ExtraFilesMap& extra_files,
|
|
bool bytecode_format) {
|
|
ScriptModuleSerializer serializer(&out);
|
|
serializer.serialize(module, extra_files, bytecode_format);
|
|
}
|
|
|
|
void ExportModule(
|
|
const script::Module& module,
|
|
const std::string& filename,
|
|
const script::ExtraFilesMap& extra_files,
|
|
bool bytecode_format) {
|
|
ScriptModuleSerializer serializer(filename);
|
|
serializer.serialize(module, extra_files, bytecode_format);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|