#include "torch/csrc/jit/export.h" #include "torch/csrc/jit/serialization.h" #include "torch/csrc/autograd/symbolic.h" #include "onnx/onnx_pb.h" #include "torch/csrc/onnx/onnx.h" #include "torch/csrc/utils/functional.h" #include #include "torch/csrc/jit/passes/dead_code_elimination.h" #include #include #include #include #include #include #include namespace torch { namespace jit { namespace { namespace onnx_torch = ::torch::onnx; namespace onnx = ::ONNX_NAMESPACE; std::string getExportableSchemaStringForMethod(const script::Method& method) { const auto& schema = method.getSchema(); for (const auto& argument : schema.arguments) { AT_CHECK( !argument.default_value, "Default arguments in script graphs may currently not be exported."); } std::ostringstream stream; stream << schema; return stream.str(); } std::string getNodeStackTraceString(const Node* n) { std::stringstream ss; if (n->getSourceLocation()) { n->getSourceLocation()->highlight(ss); } else { ss << ""; } return ss.str(); } void validateBlock(Block *b, onnx_torch::OperatorExportTypes operator_export_type) { for (auto node : b->nodes()) { for (Block *sub_block : node->blocks()) { validateBlock(sub_block, operator_export_type); } // Macro'ed so we get a marginally better line number on failed export #define FAIL_EXPORT(name) \ throw std::runtime_error(std::string("ONNX export failed: ") + name + "\n\nGraph we tried to export:\n" + b->owningGraph()->toString()); IR_IF(node, PythonOp) auto py_node = static_cast(value); FAIL_EXPORT( "Couldn't export Python operator " + py_node->name() + "\n\nDefined at:\n" + getNodeStackTraceString(node)) IR_ELSE() // Special error messages for certain types of operators if (node->kind() == aten::expand) { if (operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK) { WithInsertPoint guard(node); auto* new_node = b->owningGraph()->insertNode( b->owningGraph()->create(Symbol(::torch::jit::onnx::ATen), node->inputs(), node->outputs().size())); for (size_t i = 0; i < node->outputs().size(); ++i) { node->output(i)->replaceAllUsesWith(new_node->output(i)); } new_node->s_(Symbol::fromQualString("attr::operator"), "expand"); } else { FAIL_EXPORT( "Could not export a broadcasted operation; ONNX likely does not support this form of broadcasting.\n\nBroadcast occurred at:\n" + getNodeStackTraceString(node)); } } if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) { FAIL_EXPORT( "Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.\n\nUsage of this operation occurred at:\n" + getNodeStackTraceString(node)); } bool is_aten_fallback = operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK; if (!node->kind().is_onnx() && !is_aten_fallback && node->kind() != prim::Undefined) { FAIL_EXPORT( "Couldn't export operator " + node->kind().toDisplayString() + "\n\nDefined at:\n" + getNodeStackTraceString(node)); } IR_END() #undef FAIL_EXPORT } } void validateGraph(const std::shared_ptr& graph, onnx_torch::OperatorExportTypes operator_export_type) { validateBlock(graph->block(), operator_export_type); EliminateDeadCode(graph); } class EncoderBase { public: EncoderBase(onnx_torch::OperatorExportTypes operator_export_type, bool strip_doc); onnx::ModelProto get_model_proto() { return model_proto_; } protected: void EncodeGraph(onnx::GraphProto *graph_proto, const std::shared_ptr &graph, const std::vector &initializers = {}); void EncodeBlock(onnx::GraphProto *graph_proto, const Block *block, const std::vector &initializers = {}); virtual void EncodeTensor(onnx::TensorProto *tensor_proto, const at::Tensor &tensor, const at::optional external_ref = {}) = 0; virtual void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto, const Value* n) {}; virtual void EncodeValueInfo(onnx::GraphProto *graph_proto, onnx::ValueInfoProto* v, const Value* n); void AddAttribute(onnx::NodeProto *node_proto, const jit::Node *node, const jit::Symbol name); onnx::ModelProto model_proto_; size_t num_blocks_; onnx_torch::OperatorExportTypes operator_export_type_; bool strip_doc_; }; onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) { switch(at_type) { case at::kDouble: return onnx::TensorProto_DataType_DOUBLE; case at::kFloat: return onnx::TensorProto_DataType_FLOAT; case at::kHalf: return onnx::TensorProto_DataType_FLOAT16; case at::kByte: return onnx::TensorProto_DataType_UINT8; case at::kChar: return onnx::TensorProto_DataType_INT8; case at::kShort: return onnx::TensorProto_DataType_INT16; case at::kInt: return onnx::TensorProto_DataType_INT32; case at::kLong: return onnx::TensorProto_DataType_INT64; default: AT_ERROR("unexpected tensor scalar type"); } } EncoderBase::EncoderBase(onnx_torch::OperatorExportTypes operator_export_type, bool strip_doc) : num_blocks_(0), operator_export_type_(operator_export_type), strip_doc_(strip_doc) { model_proto_.set_producer_name("pytorch"); model_proto_.set_ir_version(onnx::IR_VERSION); model_proto_.set_producer_version("0.4"); } void EncoderBase::EncodeValueInfo( onnx::GraphProto *graph_proto, onnx::ValueInfoProto* v, const Value* n) { v->set_name(n->uniqueName()); onnx::TypeProto* t = v->mutable_type(); onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type(); onnx::TensorShapeProto* shape = tensor_type->mutable_shape(); if (CompleteTensorTypePtr node_type = n->type()->cast()) { const std::vector& sizes = node_type->sizes(); for (size_t i = 0; i < sizes.size(); i++) { shape->add_dim(); shape->mutable_dim(i)->set_dim_value(sizes[i]); } tensor_type->set_elem_type(ATenTypeToOnnxType(node_type->scalarType())); } else { tensor_type->set_elem_type(onnx::TensorProto_DataType_UNDEFINED); } } void EncoderBase::EncodeGraph( onnx::GraphProto *graph_proto, const std::shared_ptr &graph, const std::vector &initializers) { EncodeBlock(graph_proto, graph->block(), initializers); } void EncoderBase::EncodeBlock( onnx::GraphProto *graph_proto, const Block *block, const std::vector &initializers) { JIT_ASSERT(graph_proto != nullptr); std::string block_name = "torch-jit-export"; if (num_blocks_) { block_name += std::to_string(num_blocks_); } num_blocks_++; graph_proto->set_name(block_name); for (auto input : block->inputs()) { onnx::ValueInfoProto* v = graph_proto->add_input(); EncodeValueInfo(graph_proto, v, input); } for (auto output : block->outputs()) { onnx::ValueInfoProto* v = graph_proto->add_output(); EncodeValueInfo(graph_proto, v, output); } for (auto node : block->nodes()) { bool is_raw_export = operator_export_type_ == onnx_torch::OperatorExportTypes::RAW; if (node->kind() == prim::Undefined && !is_raw_export) { // Undefined nodes are used to implement optional inputs. One // way to "not provide" an optional input is to create an // Undefined node, and pass its output as that input. continue; } auto p_n = graph_proto->add_node(); if (node->getSourceLocation() && !strip_doc_) { std::stringstream ss; node->getSourceLocation()->highlight(ss); p_n->set_doc_string(ss.str()); } for(auto input : node->inputs()) { if (input->node()->kind() == prim::Undefined && !is_raw_export) { p_n->add_input(""); } else { p_n->add_input(input->uniqueName()); } } for(auto output : node->outputs()) { p_n->add_output(output->uniqueName()); EncodeIntermediateValueInfo(graph_proto, output); } if (is_raw_export) { JIT_ASSERT(!node->kind().is_onnx()); p_n->set_domain(node->kind().domainString()); } else if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) { JIT_ASSERT(node->kind().is_onnx()); } p_n->set_op_type(node->kind().toUnqualString()); for(auto attr_name : node->attributeNames()) { AddAttribute(p_n, node, attr_name); } if (is_raw_export && node->blocks().size() > 0) { auto blocks = p_n->add_attribute(); blocks->set_name("_blocks"); blocks->set_type(onnx::AttributeProto_AttributeType_GRAPHS); for (auto block : node->blocks()) { auto graph = blocks->add_graphs(); EncodeBlock(graph, block, initializers); } } if (node->kind() == torch::jit::onnx::Loop) { JIT_ASSERT(node->blocks().size() == 1); auto body = p_n->add_attribute(); body->set_name("body"); body->set_type(onnx::AttributeProto_AttributeType_GRAPH); auto g = body->mutable_g(); EncodeBlock(g, node->blocks()[0]); } if (node->kind() == torch::jit::onnx::If) { JIT_ASSERT(node->blocks().size() == 2); auto true_branch = p_n->add_attribute(); true_branch->set_name("then_branch"); true_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH); auto true_g = true_branch->mutable_g(); EncodeBlock(true_g, node->blocks()[0]); auto false_branch = p_n->add_attribute(); false_branch->set_name("else_branch"); false_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH); auto false_g = false_branch->mutable_g(); EncodeBlock(false_g, node->blocks()[1]); } } auto num_initializers = initializers.size(); JIT_ASSERT(block->inputs().size() >= num_initializers); size_t inputs_count = block->inputs().size() - num_initializers; for (auto & tensor : initializers) { // TODO: stop using positions to determine which initializers // match to which inputs std::string name = graph_proto->input(inputs_count++).name(); auto p = graph_proto->add_initializer(); p->set_name(name); EncodeTensor(p, tensor, name); } } void EncoderBase::AddAttribute(onnx::NodeProto *node_proto, const jit::Node *node, const jit::Symbol name) { auto attr = node_proto->add_attribute(); JIT_ASSERT(name.is_attr()); attr->set_name(name.toUnqualString()); switch(node->kindOf(name)) { case AttributeKind::f: attr->set_f(node->f(name)); attr->set_type(onnx::AttributeProto_AttributeType_FLOAT); break; case AttributeKind::fs: attr->set_type(onnx::AttributeProto_AttributeType_FLOATS); for(auto & v : node->fs(name)) attr->add_floats(v); break; case AttributeKind::i: attr->set_type(onnx::AttributeProto_AttributeType_INT); attr->set_i(node->i(name)); break; case AttributeKind::is: attr->set_type(onnx::AttributeProto_AttributeType_INTS); for(auto & v : node->is(name)) attr->add_ints(v); break; case AttributeKind::s: attr->set_type(onnx::AttributeProto_AttributeType_STRING); attr->set_s(node->s(name)); break; case AttributeKind::ss: attr->set_type(onnx::AttributeProto_AttributeType_STRINGS); for(auto & v : node->ss(name)) attr->add_strings(v); break; case AttributeKind::t: { attr->set_type(onnx::AttributeProto_AttributeType_TENSOR); auto t = attr->mutable_t(); EncodeTensor(t, node->t(name)); } break; case AttributeKind::ts: attr->set_type(onnx::AttributeProto_AttributeType_TENSORS); for(auto & v : node->ts(name)) { auto t = attr->add_tensors(); EncodeTensor(t, v); } break; case AttributeKind::g: { attr->set_type(onnx::AttributeProto_AttributeType_GRAPH); auto g = attr->mutable_g(); EncodeGraph(g, node->g(name)); } break; case AttributeKind::gs: attr->set_type(onnx::AttributeProto_AttributeType_GRAPHS); for(auto & v : node->gs(name)) { auto g = attr->add_graphs(); EncodeGraph(g, v); } break; default: throw std::runtime_error("unexpected attribute kind"); } } class GraphEncoder: public EncoderBase { public: GraphEncoder(const std::shared_ptr &graph, int64_t onnx_opset_version, onnx_torch::OperatorExportTypes operator_export_type, const std::vector &initializers, bool defer_weight_export, bool strip_doc); RawDataExportMap get_raw_data_export_map() { return raw_data_export_map_; } private: virtual void EncodeTensor(onnx::TensorProto *tensor_proto, const at::Tensor &tensor, const at::optional external_ref = {}) override; RawDataExportMap raw_data_export_map_; bool defer_weight_export_; }; GraphEncoder::GraphEncoder( const std::shared_ptr &graph, int64_t onnx_opset_version, onnx_torch::OperatorExportTypes operator_export_type, const std::vector &initializers, bool defer_weight_export, bool strip_doc) : EncoderBase(operator_export_type, strip_doc), defer_weight_export_(defer_weight_export) { if (operator_export_type != onnx_torch::OperatorExportTypes::RAW) { validateGraph(graph, operator_export_type); } auto* imp = model_proto_.add_opset_import(); // This is the version of ONNX operator set we are targeting imp->set_version(onnx_opset_version); EncodeGraph(model_proto_.mutable_graph(), graph, initializers); } void GraphEncoder::EncodeTensor( onnx::TensorProto *tensor_proto, const at::Tensor &tensor, const at::optional external_ref) { for(auto d : tensor.sizes()) { tensor_proto->add_dims(d); } tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.type().scalarType())); // CPU's HalfTensor doesn't have contiguous(), so first calling contiguous() auto t = tensor.contiguous().cpu(); // Add a buffer to the raw_data_export_map for the caller to dump into an // external data store. If external_ref is not specified, we instead dump // the contiguous data into the protobuf itself if (defer_weight_export_ && external_ref) { // For now, we use the name of the tensor as the external lookup name to // avoid ONNX protobuf changes. JIT_ASSERT(external_ref.value() == tensor_proto->name()); JIT_ASSERT(raw_data_export_map_.count(external_ref.value()) == 0); raw_data_export_map_[external_ref.value()] = t; tensor_proto->set_raw_data("__EXTERNAL"); } else { JIT_ASSERT(t.is_contiguous()); tensor_proto->set_raw_data(std::string(static_cast(t.data_ptr()), t.type().elementSizeInBytes() * t.numel())); } } class ModuleEncoder: public EncoderBase { public: ModuleEncoder(const script::Module &module, std::ostream& out); private: void EncodeModule(onnx::GraphProto *graph_proto, const script::Module &module); void EncodeParameters(onnx::GraphProto *graph_proto, const script::Module &module, const std::string prefix); void EncodeParameter(onnx::TensorProto *tensor_proto, const script::NamedParameter ¶meter, const std::string prefix); void EncodeMethods(onnx::GraphProto *graph_proto, const script::Module &module, const std::string prefix); void EncodeMethod(onnx::NodeProto *node_proto, script::Method &method, const std::string prefix); virtual void EncodeTensor(onnx::TensorProto *tensor_proto, const at::Tensor &tensor, const at::optional external_ref = {}) override; virtual void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto, const Value* n) override; virtual void EncodeValueInfo(onnx::GraphProto *graph_proto, onnx::ValueInfoProto* v, const Value* n) override; void EncodeTypeInfo(onnx::GraphProto *graph_proto, onnx::ValueInfoProto* v, const TypePtr& type, const std::string& name); PyTorchStreamWriter stream_writer_; // Used to deduplicate tensor storages std::unordered_map storage_dedup_map_; // Used to keep track of Parameter names so Methods can refer to them std::unordered_map parameter_map_; // Used to create sequential dummy names for node types size_t type_counter_ = 0; }; ModuleEncoder::ModuleEncoder( const script::Module &module, std::ostream& out) : EncoderBase(onnx_torch::OperatorExportTypes::RAW, false), stream_writer_(out) { model_proto_.set_doc_string("THIS PROTO IS NOT STANDARD ONNX"); EncodeModule(model_proto_.mutable_graph(), module); } void ModuleEncoder::EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto, const Value *n) { auto v = graph_proto->add_value_info(); EncodeTypeInfo(graph_proto, v, n->type(), n->uniqueName()); } void ModuleEncoder::EncodeTypeInfo( onnx::GraphProto *graph_proto, onnx::ValueInfoProto* v, const TypePtr& type, const std::string& name) { v->set_name(name); onnx::TypeProto* type_proto = v->mutable_type(); onnx::TypeProto_Tensor* tensortype_proto = type_proto->mutable_tensor_type(); onnx::TensorShapeProto* shape_proto = tensortype_proto->mutable_shape(); // Use TypeProto fields to encode types. // denotation stores the type as a string auto kind = type->kind(); if (kind == TypeKind::DynamicType) { type_proto->set_denotation("DynamicType"); tensortype_proto->set_elem_type(onnx::TensorProto_DataType_UNDEFINED); } else if (kind == TypeKind::TensorType) { type_proto->set_denotation("TensorType"); // encode the number of dimensions by pushing that number of ones into the shape proto auto tensor_type = type->expect(); for (int i = 0; i < tensor_type->dim(); i++) { shape_proto->add_dim(); shape_proto->mutable_dim(i)->set_dim_value(1); } tensortype_proto->set_elem_type(ATenTypeToOnnxType(tensor_type->scalarType())); } else if (kind == TypeKind::CompleteTensorType) { type_proto->set_denotation("CompleteTensorType"); CompleteTensorTypePtr node_type = type->cast(); // store the sizes and strides in the dims field of TensorShapeProto size_t i = 0; for (auto &size : node_type->sizes()) { shape_proto->add_dim(); shape_proto->mutable_dim(i)->set_dim_value(size); i++; } for (auto &stride : node_type->strides()) { shape_proto->add_dim(); shape_proto->mutable_dim(i)->set_dim_value(stride); i++; } tensortype_proto->set_elem_type(ATenTypeToOnnxType(node_type->scalarType())); } else if (kind == TypeKind::TupleType) { type_proto->set_denotation("TupleType"); TupleTypePtr node_type = type->cast(); auto elements = node_type->elements(); // Generate a name for and encode each subtype in the value_info field of the GraphProto. for (size_t i = 0; i < elements.size(); i++) { std::string name = "#" + std::to_string(type_counter_++); shape_proto->add_dim(); shape_proto->mutable_dim(i)->set_dim_param(name); onnx::ValueInfoProto* subtype_proto = graph_proto->add_value_info(); EncodeTypeInfo(graph_proto, subtype_proto, elements[i], name); } } else if (kind == TypeKind::ListType) { type_proto->set_denotation("ListType"); ListTypePtr node_type = type->cast(); // Generate a name for and encode the subtype in the value_info field of the GraphProto. std::string name = "#" + std::to_string(type_counter_++); shape_proto->add_dim(); shape_proto->mutable_dim(0)->set_dim_param(name); onnx::ValueInfoProto* subtype_proto = graph_proto->add_value_info(); EncodeTypeInfo(graph_proto, subtype_proto, node_type->getElementType(), name); } else if (kind == TypeKind::NumberType) { type_proto->set_denotation("NumberType"); } else if (kind == TypeKind::FloatType) { type_proto->set_denotation("FloatType"); } else if (kind == TypeKind::IntType) { type_proto->set_denotation("IntType"); } else if (kind == TypeKind::BoolType) { type_proto->set_denotation("BoolType"); } else if (kind == TypeKind::NoneType) { type_proto->set_denotation("NoneType"); } else if (kind == TypeKind::GeneratorType) { type_proto->set_denotation("GeneratorType"); } else if (kind == TypeKind::StringType) { type_proto->set_denotation("StringType"); } else if (kind == TypeKind::VarType) { type_proto->set_denotation("TypeVar:" + type->expect()->name()); } else if (kind == TypeKind::WorldType) { type_proto->set_denotation("WorldType"); } else { throw std::runtime_error("unexpected type kind"); } } void ModuleEncoder::EncodeValueInfo( onnx::GraphProto *graph_proto, onnx::ValueInfoProto* v, const Value* n) { EncodeTypeInfo(graph_proto, v, n->type(), n->uniqueName()); } void ModuleEncoder::EncodeModule( onnx::GraphProto *graph_proto, const script::Module &module) { EncodeParameters(graph_proto, module, ""); EncodeMethods(graph_proto, module, ""); auto str = model_proto_.SerializeAsString(); stream_writer_.writeRecord(str.data(), str.size()); } void ModuleEncoder::EncodeParameters( onnx::GraphProto *graph_proto, const script::Module &module, const std::string prefix) { // Encode each parameter as a initializer in the proto for (auto ¶meter : module.get_parameters()) { auto tensor_proto = graph_proto->add_initializer(); EncodeParameter(tensor_proto, parameter.value, prefix); } for (auto &submodule : module.get_modules()) { EncodeParameters(graph_proto, *submodule.value.module, prefix + submodule.key + "."); } } void ModuleEncoder::EncodeParameter( onnx::TensorProto *tensor_proto, const script::NamedParameter ¶meter, const std::string prefix) { auto tensor = parameter.slot(); // Name will be prefixed by submodule. e.g. submodule_foo.parameter_bar auto name = prefix + parameter.name; tensor_proto->set_name(name); parameter_map_[tensor] = name; // Parameters have these fields, but tensors do not tensor_proto->add_int64_data(parameter.is_buffer); tensor_proto->add_int64_data(tensor->requires_grad()); EncodeTensor(tensor_proto, *tensor, name); } void ModuleEncoder::EncodeMethods( onnx::GraphProto *graph_proto, const script::Module &module, const std::string prefix) { // Encode each parameter as a initializer in the proto for (auto &method : module.get_methods()) { auto node_proto = graph_proto->add_node(); EncodeMethod(node_proto, *method.value, prefix); } for (auto &submodule : module.get_modules()) { EncodeMethods(graph_proto, *submodule.value.module, prefix + submodule.key + "."); } } void ModuleEncoder::EncodeMethod( onnx::NodeProto *node_proto, script::Method &method, const std::string prefix) { node_proto->set_name(prefix + method.name()); if (method.is_optimized()) { // mark that this method was optimized node_proto->set_domain("optimized"); } // We store the schema string in the docstring. node_proto->set_doc_string(getExportableSchemaStringForMethod(method)); // Store member_inputs of Method in input for (auto &member_input : method.params()) { auto it = parameter_map_.find(member_input); JIT_ASSERT(it != parameter_map_.end()); node_proto->add_input(it->second); } auto attr_proto = node_proto->add_attribute(); attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH); for (auto node : method.graph()->nodes()) { if (node->kind() == prim::PythonOp) { auto py_node = static_cast(node); throw std::runtime_error( "Couldn't export Python operator " + py_node->name() + "\n\nDefined at:\n" + getNodeStackTraceString(node)); } } EncodeBlock(attr_proto->mutable_g(), method.graph()->block(), {}); } void ModuleEncoder::EncodeTensor( onnx::TensorProto *tensor_proto, const at::Tensor &tensor, const at::optional external_ref) { auto storage_ptr = tensor.storage().unsafeGetStorageImpl(); auto dedup_it = storage_dedup_map_.find(storage_ptr); if (dedup_it != storage_dedup_map_.end()) { tensor_proto->add_int64_data(dedup_it->second); } else { at::Tensor t = tensor; if (tensor.storage().device_type() == at::DeviceType::CUDA) { // NB: This new tensor is created to support cuda tensors. // Storages can be mutated when converting tensors from cuda to cpu, // and we need a cpu tensor to copy data from. t = at::getType(tensor).tensor( tensor.storage(), /* storageOffset = */ 0, /* size = */ { static_cast(tensor.type().elementSizeInBytes() * tensor.storage().size()) }, /* strides = */ { 1 }) .cpu(); } auto record_number = stream_writer_.writeRecord( static_cast(t.storage().data()), t.type().elementSizeInBytes() * t.storage().size()); tensor_proto->add_int64_data(record_number); storage_dedup_map_[storage_ptr] = record_number; } for (auto &d : tensor.sizes()) { tensor_proto->add_dims(d); } tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.type().scalarType())); tensor_proto->add_int64_data(tensor.storage_offset()); for (auto &d : tensor.strides()) { tensor_proto->add_int64_data(d); } } // Pretty printing namespace { constexpr char indent_char = ' '; constexpr size_t indent_multiplier = 2; std::string idt(size_t indent) { return std::string(indent * indent_multiplier, indent_char); } std::string nlidt(size_t indent) { return std::string("\n") + idt(indent); } void dump(const onnx::TensorProto& tensor, std::ostream& stream) { stream << "TensorProto shape: ["; for (int i = 0; i < tensor.dims_size(); ++i) { stream << tensor.dims(i) << (i == tensor.dims_size() - 1 ? "" : " "); } stream << "]"; } void dump(const onnx::TensorShapeProto& shape, std::ostream& stream) { for (int i = 0; i < shape.dim_size(); ++i) { auto &dim = shape.dim(i); if (dim.has_dim_value()) { stream << dim.dim_value(); } else { stream << "?"; } stream << (i == shape.dim_size() - 1 ? "" : " "); } } void dump(const onnx::TypeProto_Tensor& tensor_type, std::ostream& stream) { stream << "Tensor dims: "; dump(tensor_type.shape(), stream); } void dump(const onnx::TypeProto& type, std::ostream& stream) { dump(type.tensor_type(), stream); } void dump(const onnx::ValueInfoProto& value_info, std::ostream& stream) { stream << "{name: \"" << value_info.name() << "\", type:"; dump(value_info.type(), stream); stream << "}"; } void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent); void dump(const onnx::AttributeProto& attr, std::ostream& stream, size_t indent) { stream << "{ name: '" << attr.name() << "', type: "; if (attr.has_f()) { stream << "float, value: " << attr.f(); } else if (attr.has_i()) { stream << "int, value: " << attr.i(); } else if (attr.has_s()) { stream << "string, value: '" << attr.s() << "'"; } else if (attr.has_g()) { stream << "graph, value:\n"; dump(attr.g(), stream, indent+1); stream << nlidt(indent); } else if (attr.has_t()) { stream << "tensor, value:"; dump(attr.t(), stream); } else if (attr.floats_size()) { stream << "floats, values: ["; for (int i = 0; i < attr.floats_size(); ++i) stream << attr.floats(i) << (i == attr.floats_size() - 1 ? "" : " "); stream << "]"; } else if (attr.ints_size()) { stream << "ints, values: ["; for (int i = 0; i < attr.ints_size(); ++i) stream << attr.ints(i) << (i == attr.ints_size() - 1 ? "" : " "); stream << "]"; } else if (attr.strings_size()) { stream << "strings, values: ["; for (int i = 0; i < attr.strings_size(); ++i) stream << "'" << attr.strings(i) << "'" << (i == attr.strings_size() - 1 ? "" : " "); stream << "]"; } else if (attr.tensors_size()) { stream << "tensors, values: ["; for (auto& t : attr.tensors()) { dump(t, stream); } stream << "]"; } else if (attr.graphs_size()) { stream << "graphs, values: ["; for (auto& g : attr.graphs()) { dump(g, stream, indent+1); } stream << "]"; } else { stream << "UNKNOWN"; } stream << "}"; } void dump(const onnx::NodeProto& node, std::ostream& stream, size_t indent) { stream << "Node {type: \"" << node.op_type() << "\", inputs: ["; for (int i = 0; i < node.input_size(); ++i) { stream << node.input(i) << (i == node.input_size() - 1 ? "" : ","); } stream << "], outputs: ["; for (int i = 0; i < node.output_size(); ++i) { stream << node.output(i) << (i == node.output_size() - 1 ? "" : ","); } stream << "], attributes: ["; for (int i = 0; i < node.attribute_size(); ++i) { dump(node.attribute(i), stream, indent+1); stream << (i == node.attribute_size() - 1 ? "" : ","); } stream << "]}"; } void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent) { stream << idt(indent) << "GraphProto {" << nlidt(indent+1) << "name: \"" << graph.name() << "\"" << nlidt(indent+1) << "inputs: ["; for (int i = 0; i < graph.input_size(); ++i) { dump(graph.input(i), stream); stream << (i == graph.input_size() - 1 ? "" : ","); } stream << "]" << nlidt(indent+1) << "outputs: ["; for (int i = 0; i < graph.output_size(); ++i) { dump(graph.output(i), stream); stream << (i == graph.output_size() - 1 ? "" : ","); } stream << "]" << nlidt(indent+1) << "initializers: ["; for (int i = 0; i < graph.initializer_size(); ++i) { dump(graph.initializer(i), stream); stream << (i == graph.initializer_size() - 1 ? "" : ","); } stream << "]" << nlidt(indent+1) << "nodes: [" << nlidt(indent+2); for (int i = 0; i < graph.node_size(); ++i) { dump(graph.node(i), stream, indent+2); if (i != graph.node_size() - 1) stream << "," << nlidt(indent+2); } stream << nlidt(indent+1) << "]\n" << idt(indent) << "}\n"; } void dump(const onnx::OperatorSetIdProto& operator_set_id, std::ostream& stream) { stream << "OperatorSetIdProto { domain: " << operator_set_id.domain() << "}"; } void dump(const onnx::ModelProto& model, std::ostream& stream, size_t indent) { stream << idt(indent) << "ModelProto {" << nlidt(indent+1) << "producer_name: \"" << model.producer_name() << "\"" << nlidt(indent+1) << "domain: \"" << model.domain() << "\"" << nlidt(indent+1) << "doc_string: \"" << model.doc_string() << "\""; if (model.has_graph()) { stream << nlidt(indent+1) << "graph:\n"; dump(model.graph(), stream, indent+2); } if (model.opset_import_size()) { stream << idt(indent+1) << "opset_import: ["; for (auto &opset_imp : model.opset_import()) { dump(opset_imp, stream); } stream << "],\n"; } stream << idt(indent) << "}\n"; } } // namespace std::string prettyPrint(const onnx::ModelProto& model) { std::stringstream ss; dump(model, ss, 0); return ss.str(); } } std::string PrettyPrintExportedGraph( const std::shared_ptr &graph, const std::vector &initializers, int64_t onnx_opset_version, bool defer_weight_export, ::torch::onnx::OperatorExportTypes operator_export_type, bool google_printer) { auto graph_encoder = GraphEncoder( graph, onnx_opset_version, operator_export_type, initializers, defer_weight_export, true); if (google_printer) { return graph_encoder.get_model_proto().DebugString(); } return prettyPrint(graph_encoder.get_model_proto()); } // export_raw_ir will export IR ops without turning them into ONNX ops. // The output will use the ONNX protobuf format, but the ops will not // conform to the ONNX op specification. Thus, the output will not // be interpretable by a ONNX-compatible framework. However, PyTorch or // libtorch will be able to import the IR and play it back. std::tuple ExportGraph( const std::shared_ptr &graph, const std::vector &initializers, int64_t onnx_opset_version, bool defer_weight_export, ::torch::onnx::OperatorExportTypes operator_export_type) { auto graph_encoder = GraphEncoder( graph, onnx_opset_version, operator_export_type, initializers, defer_weight_export, false); return std::make_tuple(graph_encoder.get_model_proto().SerializeAsString(), graph_encoder.get_raw_data_export_map()); } void ExportModule(const script::Module& module, std::ostream& out) { ModuleEncoder(module, out); } void ExportModule(const script::Module& module, const std::string &filename) { std::ofstream out(filename, std::ios_base::binary); ExportModule(module, out); } }}