pytorch/torch/csrc/onnx/init.cpp
Roy Li f908b2b919 Use google protobuf in pytorch onnx import/export
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/8469

Reviewed By: houseroad

Differential Revision: D9102041

Pulled By: li-roy

fbshipit-source-id: 805c473745d181b71c7deebf0b9afd0f0849ba4f
2018-08-01 12:54:41 -07:00

34 lines
1.7 KiB
C++

#include "torch/csrc/onnx/init.h"
#include "torch/csrc/onnx/onnx.h"
#include "onnx/onnx.pb.h"
namespace torch { namespace onnx {
void initONNXBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
auto onnx = m.def_submodule("_onnx");
py::enum_<::ONNX_NAMESPACE::TensorProto_DataType>(onnx, "TensorProtoDataType")
.value("UNDEFINED", ::ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED)
.value("FLOAT", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT)
.value("UINT8", ::ONNX_NAMESPACE::TensorProto_DataType_UINT8)
.value("INT8", ::ONNX_NAMESPACE::TensorProto_DataType_INT8)
.value("UINT16", ::ONNX_NAMESPACE::TensorProto_DataType_UINT16)
.value("INT16", ::ONNX_NAMESPACE::TensorProto_DataType_INT16)
.value("INT32", ::ONNX_NAMESPACE::TensorProto_DataType_INT32)
.value("INT64", ::ONNX_NAMESPACE::TensorProto_DataType_INT64)
.value("STRING", ::ONNX_NAMESPACE::TensorProto_DataType_STRING)
.value("BOOL", ::ONNX_NAMESPACE::TensorProto_DataType_BOOL)
.value("FLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)
.value("DOUBLE", ::ONNX_NAMESPACE::TensorProto_DataType_DOUBLE)
.value("UINT32", ::ONNX_NAMESPACE::TensorProto_DataType_UINT32)
.value("UINT64", ::ONNX_NAMESPACE::TensorProto_DataType_UINT64)
.value("COMPLEX64", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64)
.value("COMPLEX128", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128);
py::enum_<OperatorExportTypes>(onnx, "OperatorExportTypes")
.value("ONNX", OperatorExportTypes::ONNX)
.value("ONNX_ATEN", OperatorExportTypes::ONNX_ATEN)
.value("ONNX_ATEN_FALLBACK", OperatorExportTypes::ONNX_ATEN_FALLBACK)
.value("RAW", OperatorExportTypes::RAW);
}
}} // namespace torch::onnx