mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Update default opset to 18 (#156023)
Update default opset for the torchscript exporter to 18 to match the dynamo exporter, because support was actaully added and tested in https://github.com/pytorch/pytorch/pull/118828. In the next version we should plan to update to opset 21 or higher. This change also removes the hard limit on the torchscript exporter for more flexibility. Pull Request resolved: https://github.com/pytorch/pytorch/pull/156023 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
39c605e8b3
commit
f810e98143
|
|
@ -236,12 +236,6 @@ MIN_ONNX_OPSET_VERSION = 9
|
||||||
MAX_ONNX_OPSET_VERSION = _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET
|
MAX_ONNX_OPSET_VERSION = _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET
|
||||||
TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1)
|
TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1)
|
||||||
|
|
||||||
# The min onnx opset version to test for
|
|
||||||
FX_MIN_ONNX_OPSET_VERSION = 18
|
|
||||||
# The max onnx opset version to test for
|
|
||||||
FX_MAX_ONNX_OPSET_VERSION = 18
|
|
||||||
FX_TESTED_OPSETS = range(FX_MIN_ONNX_OPSET_VERSION, FX_MAX_ONNX_OPSET_VERSION + 1)
|
|
||||||
|
|
||||||
BOOL_TYPES = (torch.bool,)
|
BOOL_TYPES = (torch.bool,)
|
||||||
|
|
||||||
INT_TYPES = (
|
INT_TYPES = (
|
||||||
|
|
|
||||||
|
|
@ -87,8 +87,8 @@ namespace {
|
||||||
namespace onnx_torch = ::torch::onnx;
|
namespace onnx_torch = ::torch::onnx;
|
||||||
namespace onnx = ::ONNX_NAMESPACE;
|
namespace onnx = ::ONNX_NAMESPACE;
|
||||||
|
|
||||||
const static int kInvalidOpsetVersion = -1;
|
constexpr int kInvalidOpsetVersion = -1;
|
||||||
const static int kMainOpsetVersion = 20;
|
constexpr int kMainOpsetVersion = 23;
|
||||||
// Based on OP_SET_ID_VERSION_MAP in
|
// Based on OP_SET_ID_VERSION_MAP in
|
||||||
// https://github.com/onnx/onnx/blob/master/onnx/helper.py.
|
// https://github.com/onnx/onnx/blob/master/onnx/helper.py.
|
||||||
constexpr static std::array<int64_t, kMainOpsetVersion + 1>
|
constexpr static std::array<int64_t, kMainOpsetVersion + 1>
|
||||||
|
|
@ -114,6 +114,9 @@ constexpr static std::array<int64_t, kMainOpsetVersion + 1>
|
||||||
8, // opset 18
|
8, // opset 18
|
||||||
9, // opset 19
|
9, // opset 19
|
||||||
9, // opset 20
|
9, // opset 20
|
||||||
|
10, // opset 21
|
||||||
|
10, // opset 22
|
||||||
|
11, // opset 23
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string getNodeStackTraceString(const Node* n) {
|
std::string getNodeStackTraceString(const Node* n) {
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,9 @@ ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
|
||||||
|
|
||||||
ONNX_BASE_OPSET = 9
|
ONNX_BASE_OPSET = 9
|
||||||
ONNX_MIN_OPSET = 7
|
ONNX_MIN_OPSET = 7
|
||||||
ONNX_MAX_OPSET = 20
|
ONNX_MAX_OPSET = 23
|
||||||
ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 20
|
ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 20
|
||||||
# ONNX_DEFAULT_OPSET generated by tools/onnx/update_default_opset_version.py
|
ONNX_DEFAULT_OPSET = 18
|
||||||
ONNX_DEFAULT_OPSET = 17
|
|
||||||
ONNX_CONSTANT_FOLDING_MIN_OPSET = 9
|
ONNX_CONSTANT_FOLDING_MIN_OPSET = 9
|
||||||
|
|
||||||
PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues"
|
PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues"
|
||||||
|
|
|
||||||
|
|
@ -54,11 +54,6 @@ class _InternalGlobals:
|
||||||
|
|
||||||
@export_onnx_opset_version.setter
|
@export_onnx_opset_version.setter
|
||||||
def export_onnx_opset_version(self, value: int):
|
def export_onnx_opset_version(self, value: int):
|
||||||
supported_versions = range(
|
|
||||||
_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1
|
|
||||||
)
|
|
||||||
if value not in supported_versions:
|
|
||||||
raise ValueError(f"Unsupported ONNX opset version: {value}")
|
|
||||||
self._export_onnx_opset_version = value
|
self._export_onnx_opset_version = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
||||||
|
|
@ -159,7 +159,7 @@ def export_compat(
|
||||||
export_params=export_params,
|
export_params=export_params,
|
||||||
input_names=input_names,
|
input_names=input_names,
|
||||||
output_names=output_names,
|
output_names=output_names,
|
||||||
opset_version=17, # TODO(justinchuby): Hard coded to 17 for now
|
opset_version=opset_version,
|
||||||
dynamic_axes=dynamic_axes,
|
dynamic_axes=dynamic_axes,
|
||||||
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -353,9 +353,9 @@ def export(
|
||||||
|
|
||||||
Models exported this way are probably runnable only by Caffe2.
|
Models exported this way are probably runnable only by Caffe2.
|
||||||
|
|
||||||
opset_version (int, default 17): The version of the
|
opset_version (int, default 18): The version of the
|
||||||
`default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
|
`default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
|
||||||
to target. Must be >= 7 and <= 17.
|
to target. Must be >= 7.
|
||||||
do_constant_folding: Apply the constant-folding optimization.
|
do_constant_folding: Apply the constant-folding optimization.
|
||||||
Constant-folding will replace some of the ops that have all constant inputs
|
Constant-folding will replace some of the ops that have all constant inputs
|
||||||
with pre-computed constant nodes.
|
with pre-computed constant nodes.
|
||||||
|
|
@ -1393,10 +1393,7 @@ def _export(
|
||||||
if opset_version is None:
|
if opset_version is None:
|
||||||
opset_version = _constants.ONNX_DEFAULT_OPSET
|
opset_version = _constants.ONNX_DEFAULT_OPSET
|
||||||
|
|
||||||
# torch.onnx.export does not support opset versions >=18
|
|
||||||
if opset_version > _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET:
|
if opset_version > _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET:
|
||||||
# We do not want to fail because we should still allow users to create
|
|
||||||
# custom symbolic functions for opset>17
|
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Exporting to ONNX opset version {opset_version} is not supported. "
|
f"Exporting to ONNX opset version {opset_version} is not supported. "
|
||||||
f"by 'torch.onnx.export()'. "
|
f"by 'torch.onnx.export()'. "
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user