mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Enable ONNX constant folding for opset 11. (#29011)
Summary: Currently ONNX constant folding (`do_constant_folding=True` arg in `torch.onnx.export` API) supports only opset 9 and 10 of ONNX. Opset 11 support was recently introduced in the ONNX exporter. For opset 11, it is currently a no-op. This change enables ONNX constant folding for opset 11. Specifically there are three main changes: 1) Turn on constant folding ONNX pass for opset 11. 2) Enable constant folding tests in `test/onnx/test_utility_funs.py` and `test/onnx/test_pytorch_onnx_onnxruntime.py` for opset 11. Pull Request resolved: https://github.com/pytorch/pytorch/pull/29011 Reviewed By: hl475 Differential Revision: D18306998 Pulled By: houseroad fbshipit-source-id: afeed21ca29e01c278612e51dacd93397dd6e2d8
This commit is contained in:
parent
ee21142e40
commit
bc91e19861
|
|
@ -1760,7 +1760,7 @@ TestONNXRuntime_opset11 = type(str("TestONNXRuntime_opset11"),
|
||||||
dict(TestONNXRuntime.__dict__, opset_version=11))
|
dict(TestONNXRuntime.__dict__, opset_version=11))
|
||||||
|
|
||||||
|
|
||||||
# opset 10 tests, with keep_initializers_as_inputs=False for
|
# opset 9 tests, with keep_initializers_as_inputs=False for
|
||||||
# IR version 4 style export.
|
# IR version 4 style export.
|
||||||
TestONNXRuntime_opset9_IRv4 = type(str("TestONNXRuntime_opset9_IRv4"),
|
TestONNXRuntime_opset9_IRv4 = type(str("TestONNXRuntime_opset9_IRv4"),
|
||||||
(unittest.TestCase,),
|
(unittest.TestCase,),
|
||||||
|
|
@ -1776,5 +1776,13 @@ TestONNXRuntime_opset10_IRv4 = type(str("TestONNXRuntime_opset10_IRv4"),
|
||||||
keep_initializers_as_inputs=False))
|
keep_initializers_as_inputs=False))
|
||||||
|
|
||||||
|
|
||||||
|
# opset 11 tests, with keep_initializers_as_inputs=False for
|
||||||
|
# IR version 4 style export.
|
||||||
|
TestONNXRuntime_opset11_IRv4 = type(str("TestONNXRuntime_opset11_IRv4"),
|
||||||
|
(unittest.TestCase,),
|
||||||
|
dict(TestONNXRuntime.__dict__, opset_version=11,
|
||||||
|
keep_initializers_as_inputs=False))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
|
|
@ -246,5 +246,11 @@ TestUtilityFuns_opset10 = type(str("TestUtilityFuns_opset10"),
|
||||||
dict(TestUtilityFuns.__dict__, opset_version=10))
|
dict(TestUtilityFuns.__dict__, opset_version=10))
|
||||||
|
|
||||||
|
|
||||||
|
# opset 11 tests
|
||||||
|
TestUtilityFuns_opset11 = type(str("TestUtilityFuns_opset11"),
|
||||||
|
(TestCase,),
|
||||||
|
dict(TestUtilityFuns.__dict__, opset_version=11))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -117,13 +117,13 @@ c10::optional<at::Tensor> runTorchSlice_opset9(const Node* node,
|
||||||
c10::optional<at::Tensor> runTorchSlice_opset10(const Node* node,
|
c10::optional<at::Tensor> runTorchSlice_opset10(const Node* node,
|
||||||
std::vector<at::Tensor>& inputTensorValues) {
|
std::vector<at::Tensor>& inputTensorValues) {
|
||||||
if (inputTensorValues.size() < 3 || inputTensorValues.size() > 5) {
|
if (inputTensorValues.size() < 3 || inputTensorValues.size() > 5) {
|
||||||
std::cerr << "Warning: Constant folding - Invalid number of inputs found for opset 10 onnx::Slice op. "
|
std::cerr << "Warning: Constant folding - Invalid number of inputs found for opset 10 or 11 onnx::Slice op. "
|
||||||
<< "Constant folding not applied." << std::endl;
|
<< "Constant folding not applied." << std::endl;
|
||||||
return c10::nullopt;
|
return c10::nullopt;
|
||||||
}
|
}
|
||||||
// Checking validity of 'starts' and 'ends' input
|
// Checking validity of 'starts' and 'ends' input
|
||||||
if (inputTensorValues[1].sizes().size() != 1 || inputTensorValues[2].sizes().size() != 1) {
|
if (inputTensorValues[1].sizes().size() != 1 || inputTensorValues[2].sizes().size() != 1) {
|
||||||
std::cerr << "Warning: Constant folding - Invalid 'starts' or 'ends' inputs found for opset 10 onnx::Slice op. "
|
std::cerr << "Warning: Constant folding - Invalid 'starts' or 'ends' inputs found for opset 10 or 11 onnx::Slice op. "
|
||||||
<< "Constant folding not applied." << std::endl;
|
<< "Constant folding not applied." << std::endl;
|
||||||
return c10::nullopt;
|
return c10::nullopt;
|
||||||
}
|
}
|
||||||
|
|
@ -198,10 +198,10 @@ c10::optional<at::Tensor> runTorchBackendForOnnx(
|
||||||
int opset_version) {
|
int opset_version) {
|
||||||
at::Tensor updated_val;
|
at::Tensor updated_val;
|
||||||
if (node->kind() == onnx::Slice) {
|
if (node->kind() == onnx::Slice) {
|
||||||
if (opset_version == 9) {
|
if (opset_version == ONNX_OPSET_9) {
|
||||||
return runTorchSlice_opset9(node, inputTensorValues);
|
return runTorchSlice_opset9(node, inputTensorValues);
|
||||||
}
|
}
|
||||||
else if (opset_version == 10) {
|
else if (opset_version == ONNX_OPSET_10 || opset_version == ONNX_OPSET_11) {
|
||||||
return runTorchSlice_opset10(node, inputTensorValues);
|
return runTorchSlice_opset10(node, inputTensorValues);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
@ -329,9 +329,10 @@ std::vector<Node*> getOnnxConstParentsToRemove(Node* node) {
|
||||||
// nodes can be lifted so we run them earlier, before the usual parameters are
|
// nodes can be lifted so we run them earlier, before the usual parameters are
|
||||||
// known.
|
// known.
|
||||||
void ConstantFoldONNX(Block* b, ParamMap& paramsDict, int opset_version) {
|
void ConstantFoldONNX(Block* b, ParamMap& paramsDict, int opset_version) {
|
||||||
if (opset_version != 9 && opset_version != 10) {
|
if (opset_version != ONNX_OPSET_9 && opset_version != ONNX_OPSET_10 &&
|
||||||
|
opset_version != ONNX_OPSET_11) {
|
||||||
// Number of elements of 'axes' and 'ends' 1-D input tensors should be the same
|
// Number of elements of 'axes' and 'ends' 1-D input tensors should be the same
|
||||||
std::cerr << "Warning: Constant folding supported for only opsets 9 and 10. "
|
std::cerr << "Warning: Constant folding supported for only opsets 9, 10, and 11. "
|
||||||
<< "Constant folding not applied." << std::endl;
|
<< "Constant folding not applied." << std::endl;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,9 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
|
const int ONNX_OPSET_9 = 9;
|
||||||
|
const int ONNX_OPSET_10 = 10;
|
||||||
|
const int ONNX_OPSET_11 = 11;
|
||||||
void ConstantFoldONNX(Block* b, std::map<std::string, at::Tensor>& paramDict, int opset_version);
|
void ConstantFoldONNX(Block* b, std::map<std::string, at::Tensor>& paramDict, int opset_version);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -282,7 +282,7 @@ def _model_to_graph(model, args, verbose=False, training=False,
|
||||||
param_names = input_and_param_names[len(input_and_param_names) - len(params):]
|
param_names = input_and_param_names[len(input_and_param_names) - len(params):]
|
||||||
params_dict = dict(zip(param_names, params))
|
params_dict = dict(zip(param_names, params))
|
||||||
|
|
||||||
if do_constant_folding and _export_onnx_opset_version in [9, 10]:
|
if do_constant_folding and _export_onnx_opset_version in [9, 10, 11]:
|
||||||
params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict,
|
params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict,
|
||||||
_export_onnx_opset_version)
|
_export_onnx_opset_version)
|
||||||
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user