mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
* ScriptModule ONNX export * ScriptModule ONNX export * Export for control flow nodes * Add pretty-print capability for ONNX export testing * Update tests and handling of mutliple GraphProto names * Maybe bugfix? * factor out code from export and pretty print
31 lines
1.4 KiB
C++
31 lines
1.4 KiB
C++
#include "torch/csrc/onnx/init.h"
|
|
#include "torch/csrc/onnx/onnx.pb.h"
|
|
#include "torch/csrc/onnx/onnx.h"
|
|
|
|
namespace torch { namespace onnx {
|
|
void initONNXBindings(PyObject* module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
auto onnx = m.def_submodule("_onnx");
|
|
py::enum_<onnx_TensorProto_DataType>(onnx, "TensorProtoDataType")
|
|
.value("UNDEFINED", onnx_TensorProto_DataType_UNDEFINED)
|
|
.value("FLOAT", onnx_TensorProto_DataType_FLOAT)
|
|
.value("UINT8", onnx_TensorProto_DataType_UINT8)
|
|
.value("INT8", onnx_TensorProto_DataType_INT8)
|
|
.value("UINT16", onnx_TensorProto_DataType_UINT16)
|
|
.value("INT16", onnx_TensorProto_DataType_INT16)
|
|
.value("INT32", onnx_TensorProto_DataType_INT32)
|
|
.value("INT64", onnx_TensorProto_DataType_INT64)
|
|
.value("STRING", onnx_TensorProto_DataType_STRING)
|
|
.value("BOOL", onnx_TensorProto_DataType_BOOL)
|
|
.value("FLOAT16", onnx_TensorProto_DataType_FLOAT16)
|
|
.value("DOUBLE", onnx_TensorProto_DataType_DOUBLE)
|
|
.value("UINT32", onnx_TensorProto_DataType_UINT32)
|
|
.value("UINT64", onnx_TensorProto_DataType_UINT64)
|
|
.value("COMPLEX64", onnx_TensorProto_DataType_COMPLEX64)
|
|
.value("COMPLEX128", onnx_TensorProto_DataType_COMPLEX128);
|
|
|
|
py::class_<ModelProto>(onnx, "ModelProto")
|
|
.def("prettyPrint", &ModelProto::prettyPrint);
|
|
}
|
|
}} // namespace torch::onnx
|