mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable ONNX constant folding for opset 11. (#29011)
Summary: Currently ONNX constant folding (`do_constant_folding=True` arg in `torch.onnx.export` API) supports only opset 9 and 10 of ONNX. Opset 11 support was recently introduced in the ONNX exporter. For opset 11, it is currently a no-op. This change enables ONNX constant folding for opset 11. Specifically there are three main changes: 1) Turn on constant folding ONNX pass for opset 11. 2) Enable constant folding tests in `test/onnx/test_utility_funs.py` and `test/onnx/test_pytorch_onnx_onnxruntime.py` for opset 11. Pull Request resolved: https://github.com/pytorch/pytorch/pull/29011 Reviewed By: hl475 Differential Revision: D18306998 Pulled By: houseroad fbshipit-source-id: afeed21ca29e01c278612e51dacd93397dd6e2d8
This commit is contained in:
parent
ee21142e40
commit
bc91e19861
|
|
@ -1760,7 +1760,7 @@ TestONNXRuntime_opset11 = type(str("TestONNXRuntime_opset11"),
|
|||
dict(TestONNXRuntime.__dict__, opset_version=11))
|
||||
|
||||
|
||||
# opset 10 tests, with keep_initializers_as_inputs=False for
|
||||
# opset 9 tests, with keep_initializers_as_inputs=False for
|
||||
# IR version 4 style export.
|
||||
TestONNXRuntime_opset9_IRv4 = type(str("TestONNXRuntime_opset9_IRv4"),
|
||||
(unittest.TestCase,),
|
||||
|
|
@ -1776,5 +1776,13 @@ TestONNXRuntime_opset10_IRv4 = type(str("TestONNXRuntime_opset10_IRv4"),
|
|||
keep_initializers_as_inputs=False))
|
||||
|
||||
|
||||
# opset 11 tests, with keep_initializers_as_inputs=False for
|
||||
# IR version 4 style export.
|
||||
TestONNXRuntime_opset11_IRv4 = type(str("TestONNXRuntime_opset11_IRv4"),
|
||||
(unittest.TestCase,),
|
||||
dict(TestONNXRuntime.__dict__, opset_version=11,
|
||||
keep_initializers_as_inputs=False))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -246,5 +246,11 @@ TestUtilityFuns_opset10 = type(str("TestUtilityFuns_opset10"),
|
|||
dict(TestUtilityFuns.__dict__, opset_version=10))
|
||||
|
||||
|
||||
# opset 11 tests
|
||||
TestUtilityFuns_opset11 = type(str("TestUtilityFuns_opset11"),
|
||||
(TestCase,),
|
||||
dict(TestUtilityFuns.__dict__, opset_version=11))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -117,13 +117,13 @@ c10::optional<at::Tensor> runTorchSlice_opset9(const Node* node,
|
|||
c10::optional<at::Tensor> runTorchSlice_opset10(const Node* node,
|
||||
std::vector<at::Tensor>& inputTensorValues) {
|
||||
if (inputTensorValues.size() < 3 || inputTensorValues.size() > 5) {
|
||||
std::cerr << "Warning: Constant folding - Invalid number of inputs found for opset 10 onnx::Slice op. "
|
||||
std::cerr << "Warning: Constant folding - Invalid number of inputs found for opset 10 or 11 onnx::Slice op. "
|
||||
<< "Constant folding not applied." << std::endl;
|
||||
return c10::nullopt;
|
||||
}
|
||||
// Checking validity of 'starts' and 'ends' input
|
||||
if (inputTensorValues[1].sizes().size() != 1 || inputTensorValues[2].sizes().size() != 1) {
|
||||
std::cerr << "Warning: Constant folding - Invalid 'starts' or 'ends' inputs found for opset 10 onnx::Slice op. "
|
||||
std::cerr << "Warning: Constant folding - Invalid 'starts' or 'ends' inputs found for opset 10 or 11 onnx::Slice op. "
|
||||
<< "Constant folding not applied." << std::endl;
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
|
@ -198,10 +198,10 @@ c10::optional<at::Tensor> runTorchBackendForOnnx(
|
|||
int opset_version) {
|
||||
at::Tensor updated_val;
|
||||
if (node->kind() == onnx::Slice) {
|
||||
if (opset_version == 9) {
|
||||
if (opset_version == ONNX_OPSET_9) {
|
||||
return runTorchSlice_opset9(node, inputTensorValues);
|
||||
}
|
||||
else if (opset_version == 10) {
|
||||
else if (opset_version == ONNX_OPSET_10 || opset_version == ONNX_OPSET_11) {
|
||||
return runTorchSlice_opset10(node, inputTensorValues);
|
||||
}
|
||||
else {
|
||||
|
|
@ -329,9 +329,10 @@ std::vector<Node*> getOnnxConstParentsToRemove(Node* node) {
|
|||
// nodes can be lifted so we run them earlier, before the usual parameters are
|
||||
// known.
|
||||
void ConstantFoldONNX(Block* b, ParamMap& paramsDict, int opset_version) {
|
||||
if (opset_version != 9 && opset_version != 10) {
|
||||
if (opset_version != ONNX_OPSET_9 && opset_version != ONNX_OPSET_10 &&
|
||||
opset_version != ONNX_OPSET_11) {
|
||||
// Number of elements of 'axes' and 'ends' 1-D input tensors should be the same
|
||||
std::cerr << "Warning: Constant folding supported for only opsets 9 and 10. "
|
||||
std::cerr << "Warning: Constant folding supported for only opsets 9, 10, and 11. "
|
||||
<< "Constant folding not applied." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,9 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
const int ONNX_OPSET_9 = 9;
|
||||
const int ONNX_OPSET_10 = 10;
|
||||
const int ONNX_OPSET_11 = 11;
|
||||
void ConstantFoldONNX(Block* b, std::map<std::string, at::Tensor>& paramDict, int opset_version);
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -282,7 +282,7 @@ def _model_to_graph(model, args, verbose=False, training=False,
|
|||
param_names = input_and_param_names[len(input_and_param_names) - len(params):]
|
||||
params_dict = dict(zip(param_names, params))
|
||||
|
||||
if do_constant_folding and _export_onnx_opset_version in [9, 10]:
|
||||
if do_constant_folding and _export_onnx_opset_version in [9, 10, 11]:
|
||||
params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict,
|
||||
_export_onnx_opset_version)
|
||||
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user