mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Since caffe2 and torch have been consolidated, CAFFE2_API should be merged with TORCH_API. Addresses a TODO. Manually edited some references of the removed `CAFFE2_API`: * `CONTRIBUTING.md` * `caffe2/proto/CMakeLists.txt` * `cmake/ProtoBuf.cmake` * `c10/macros/Export.h` * `torch/csrc/WindowsTorchApiMacro.h` Pull Request resolved: https://github.com/pytorch/pytorch/pull/49496 Reviewed By: malfet, samestep Differential Revision: D25600726 Pulled By: janeyx99 fbshipit-source-id: 7e068d959e397ac183c097d7e9a9afeca5ddd782
146 lines
4.7 KiB
C++
146 lines
4.7 KiB
C++
#pragma once
|
|
|
|
#include "caffe2/core/common.h"
|
|
#include "caffe2/core/tensor.h"
|
|
#include "caffe2/onnx/helper.h"
|
|
#include "caffe2/proto/caffe2_pb.h"
|
|
#include "onnx/onnx_pb.h"
|
|
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
|
|
namespace caffe2 {
|
|
namespace onnx {
|
|
|
|
namespace {
|
|
using ::ONNX_NAMESPACE::AttributeProto;
|
|
using ::ONNX_NAMESPACE::GraphProto;
|
|
using ::ONNX_NAMESPACE::ModelProto;
|
|
using ::ONNX_NAMESPACE::NodeProto;
|
|
using ::ONNX_NAMESPACE::TensorProto;
|
|
} // namespace
|
|
|
|
using ConvertedResult =
|
|
std::pair<std::vector<NodeProto>, std::vector<TensorProto>>;
|
|
|
|
// Useful utility function
|
|
void rewriteSubnet(
|
|
Argument* arg,
|
|
std::map<std::string, std::string> oldname_to_newname);
|
|
|
|
// Rewrite Caffe2 nets into SSA forms. Notice that we will preserve the external
|
|
// output names for predict net.
|
|
TORCH_API std::unordered_map<std::string, std::string> SsaRewrite(
|
|
caffe2::NetDef* init_net,
|
|
caffe2::NetDef* pred_net,
|
|
bool PreserveInPlaceOps = true);
|
|
|
|
::ONNX_NAMESPACE::TensorProto::DataType Caffe2TypeToOnnxType(
|
|
caffe2::TensorProto::DataType t);
|
|
|
|
class TORCH_API OnnxExporter {
|
|
using SpecialOpConverter = ConvertedResult (OnnxExporter::*)(
|
|
const caffe2::OperatorDef&,
|
|
const std::unordered_map<std::string, caffe2::TensorShape>&);
|
|
|
|
public:
|
|
OnnxExporter(DummyName* dummy = nullptr) {
|
|
if (dummy) {
|
|
dummy_ = std::shared_ptr<DummyName>(dummy, [](DummyName*) {});
|
|
} else {
|
|
dummy_ = std::make_shared<DummyName>();
|
|
}
|
|
}
|
|
|
|
ConvertedResult Caffe2OpToOnnxNodes(
|
|
const caffe2::OperatorDef& def,
|
|
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
|
|
|
|
void InitOpToTensorProto(const caffe2::OperatorDef& def, TensorProto* tensor);
|
|
|
|
private:
|
|
ConvertedResult CommonCaffe2OpToOnnxNodes(const caffe2::OperatorDef& def);
|
|
|
|
ConvertedResult CreateArgMaxMinOpNodes(
|
|
const caffe2::OperatorDef& def,
|
|
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
|
|
|
|
ConvertedResult CreateBinaryElementwiseOpNodes(
|
|
const caffe2::OperatorDef& def,
|
|
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
|
|
|
|
ConvertedResult CreateCastNodes(
|
|
const caffe2::OperatorDef& def,
|
|
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
|
|
|
|
ConvertedResult CreateElementwiseLinearNodes(
|
|
const caffe2::OperatorDef& def,
|
|
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
|
|
|
|
ConvertedResult CreateConvPoolNodes(
|
|
const caffe2::OperatorDef& def,
|
|
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
|
|
|
|
ConvertedResult CreateGemmNodes(
|
|
const caffe2::OperatorDef& def,
|
|
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
|
|
|
|
ConvertedResult CreateReshapeNodes(
|
|
const caffe2::OperatorDef& def,
|
|
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
|
|
|
|
ConvertedResult CreateSliceNodes(
|
|
const caffe2::OperatorDef& def,
|
|
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
|
|
|
|
ConvertedResult CreateChannelShuffleNodes(
|
|
const caffe2::OperatorDef& def,
|
|
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
|
|
|
|
ConvertedResult CreateReduceMeanNodes(
|
|
const caffe2::OperatorDef& def,
|
|
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
|
|
|
|
ConvertedResult CreateConcatNodes(
|
|
const caffe2::OperatorDef& def,
|
|
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
|
|
|
|
ConvertedResult CreateMergeDimNodes(
|
|
const caffe2::OperatorDef& def,
|
|
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
|
|
|
|
ConvertedResult CreateLrnNodes(
|
|
const caffe2::OperatorDef& def,
|
|
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
|
|
|
|
ConvertedResult CreateUpsampleNodes(
|
|
const caffe2::OperatorDef& def,
|
|
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
|
|
|
|
// \brief Check block listed arguments where we won't pass down when
|
|
// converting to ONNX node
|
|
bool IsBlockListed(const caffe2::Argument& arg);
|
|
|
|
// \brief Convert Caffe2 argument to Onnx attribute
|
|
void CopyCaffe2ArgToOnnxAttr(
|
|
AttributeProto* attr,
|
|
const std::string& op_type,
|
|
const caffe2::Argument& arg);
|
|
|
|
// LUT getters
|
|
const std::unordered_map<std::string, std::string>& get_renamed_operators()
|
|
const;
|
|
const std::unordered_map<std::string, std::string>& get_renamed_attrs() const;
|
|
const std::
|
|
unordered_map<std::string, std::unordered_map<std::string, std::string>>&
|
|
get_per_op_renamed_attrs() const;
|
|
const std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>&
|
|
get_special_operators() const;
|
|
|
|
// Dummy name generator
|
|
std::shared_ptr<DummyName> dummy_;
|
|
};
|
|
} // namespace onnx
|
|
} // namespace caffe2
|