pytorch/caffe2/onnx/onnx_exporter.h
Yinghai Lu a87d475c2f Do not rename net boundary inputs/outputs during ssaRewrite. (#17545)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17545

This diff avoids renaming boundary inputs of net during onnxifi transform.
It also removes adding mappings for the initializer during onnxifi op creation.
Thus gets read of the mapped ws creation during onnxifi op creation.

Reviewed By: zrphercule

Differential Revision: D14243161

fbshipit-source-id: 6eafa920c45f6a6bfacbbb443e8e84cf9778644c
2019-03-06 14:26:58 -08:00

139 lines
4.6 KiB
C++

#pragma once
#include "caffe2/core/common.h"
#include "caffe2/core/tensor.h"
#include "caffe2/onnx/helper.h"
#include "caffe2/proto/caffe2_pb.h"
#include "onnx/onnx_pb.h"
#include <string>
#include <unordered_map>
#include <vector>
namespace caffe2 {
namespace onnx {
namespace {
using ::ONNX_NAMESPACE::AttributeProto;
using ::ONNX_NAMESPACE::GraphProto;
using ::ONNX_NAMESPACE::ModelProto;
using ::ONNX_NAMESPACE::NodeProto;
using ::ONNX_NAMESPACE::TensorProto;
} // namespace
using ConvertedResult =
std::pair<std::vector<NodeProto>, std::vector<TensorProto>>;
// Rewrite Caffe2 nets into SSA forms. Notice that we will preserve the external
// output names for predict net.
CAFFE2_API std::unordered_map<std::string, std::string> SsaRewrite(
caffe2::NetDef* init_net,
caffe2::NetDef* pred_net);
::ONNX_NAMESPACE::TensorProto::DataType Caffe2TypeToOnnxType(
caffe2::TensorProto::DataType t);
class CAFFE2_API OnnxExporter {
using SpecialOpConverter = ConvertedResult (OnnxExporter::*)(
const caffe2::OperatorDef&,
const std::unordered_map<std::string, caffe2::TensorShape>&);
public:
OnnxExporter(DummyName* dummy = nullptr) {
if (dummy) {
dummy_ = std::shared_ptr<DummyName>(dummy, [](DummyName*) {});
} else {
dummy_ = std::make_shared<DummyName>();
}
}
ConvertedResult Caffe2OpToOnnxNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
void InitOpToTensorProto(const caffe2::OperatorDef& def, TensorProto* tensor);
private:
ConvertedResult CommonCaffe2OpToOnnxNodes(const caffe2::OperatorDef& def);
ConvertedResult CreateArgMaxMinOpNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateBinaryElementwiseOpNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateCastNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateElementwiseLinearNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateConvPoolNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateGemmNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateReshapeNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateSliceNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateChannelShuffleNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateReduceMeanNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateConcatNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateMergeDimNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateLrnNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateUpsampleNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
// \brief Check black listed arguemnts where we won't pass down when
// converting to ONNX node
bool IsBlackListed(const caffe2::Argument& arg);
// \brief Convert Caffe2 argument to Onnx attribute
void CopyCaffe2ArgToOnnxAttr(
AttributeProto* attr,
const std::string& op_type,
const caffe2::Argument& arg);
// LUT getters
const std::unordered_map<std::string, std::string>& get_renamed_operators()
const;
const std::unordered_map<std::string, std::string>& get_renamed_attrs() const;
const std::
unordered_map<std::string, std::unordered_map<std::string, std::string>>&
get_per_op_renamed_attrs() const;
const std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>&
get_special_operators() const;
// Dummy name generator
std::shared_ptr<DummyName> dummy_;
};
} // namespace onnx
} // namespace caffe2