pytorch/torch/csrc/onnx/init.cpp
Bowen Bao 02e35ce17b [ONNX] Update onnx function export with comments and clean up (#66817) (#67803)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67803

* Addresses comments from #63589

[ONNX] remove torch::onnx::PRODUCER_VERSION (#67107)

Use constants from version.h instead.
This simplifies things since we no longer have to update
PRODUCER_VERSION for each release.

Also add TORCH_VERSION to version.h so that a string is available for
this purpose.

[ONNX] Set `ir_version` based on opset_version. (#67128)

This increases the odds that the exported ONNX model will be usable.
Before this change, we were setting the IR version to a value which may
be higher than what the model consumer supports.

Also some minor clean-up in the test code:
* Fix string replacement.
* Use a temporary file so as to not leave files around in the test
  current working directory.

Test Plan: Imported from OSS

Reviewed By: msaroufim

Differential Revision: D32181306

Pulled By: malfet

fbshipit-source-id: 02f136d34ef8f664ade0bc1985a584f0e8c2b663

Co-authored-by: BowenBao <bowbao@microsoft.com>
Co-authored-by: Gary Miguel <garymiguel@microsoft.com>
Co-authored-by: Nikita Shulga <nshulga@fb.com>
2021-11-05 10:35:35 -07:00

50 lines
2.2 KiB
C++

#include <onnx/onnx_pb.h>
#include <torch/csrc/onnx/init.h>
#include <torch/csrc/onnx/onnx.h>
#include <torch/version.h>
namespace torch {
namespace onnx {
void initONNXBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
auto onnx = m.def_submodule("_onnx");
py::enum_<::ONNX_NAMESPACE::TensorProto_DataType>(onnx, "TensorProtoDataType")
.value("UNDEFINED", ::ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED)
.value("FLOAT", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT)
.value("UINT8", ::ONNX_NAMESPACE::TensorProto_DataType_UINT8)
.value("INT8", ::ONNX_NAMESPACE::TensorProto_DataType_INT8)
.value("UINT16", ::ONNX_NAMESPACE::TensorProto_DataType_UINT16)
.value("INT16", ::ONNX_NAMESPACE::TensorProto_DataType_INT16)
.value("INT32", ::ONNX_NAMESPACE::TensorProto_DataType_INT32)
.value("INT64", ::ONNX_NAMESPACE::TensorProto_DataType_INT64)
.value("STRING", ::ONNX_NAMESPACE::TensorProto_DataType_STRING)
.value("BOOL", ::ONNX_NAMESPACE::TensorProto_DataType_BOOL)
.value("FLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)
.value("DOUBLE", ::ONNX_NAMESPACE::TensorProto_DataType_DOUBLE)
.value("UINT32", ::ONNX_NAMESPACE::TensorProto_DataType_UINT32)
.value("UINT64", ::ONNX_NAMESPACE::TensorProto_DataType_UINT64)
.value("COMPLEX64", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64)
.value("COMPLEX128", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128);
py::enum_<OperatorExportTypes>(onnx, "OperatorExportTypes")
.value("ONNX", OperatorExportTypes::ONNX)
.value("ONNX_ATEN", OperatorExportTypes::ONNX_ATEN)
.value("ONNX_ATEN_FALLBACK", OperatorExportTypes::ONNX_ATEN_FALLBACK)
.value("ONNX_FALLTHROUGH", OperatorExportTypes::ONNX_FALLTHROUGH);
py::enum_<TrainingMode>(onnx, "TrainingMode")
.value("EVAL", TrainingMode::EVAL)
.value("PRESERVE", TrainingMode::PRESERVE)
.value("TRAINING", TrainingMode::TRAINING);
onnx.attr("PRODUCER_VERSION") = py::str(TORCH_VERSION);
#ifdef PYTORCH_ONNX_CAFFE2_BUNDLE
onnx.attr("PYTORCH_ONNX_CAFFE2_BUNDLE") = true;
#else
onnx.attr("PYTORCH_ONNX_CAFFE2_BUNDLE") = false;
#endif
}
} // namespace onnx
} // namespace torch