Enable ONNX constant folding for opset 11. (#29011)

Summary:
Currently ONNX constant folding (`do_constant_folding=True` arg in `torch.onnx.export` API) supports only opset 9 and 10 of ONNX. Opset 11 support was recently introduced in the ONNX exporter. For opset 11, it is currently a no-op. This change enables ONNX constant folding for opset 11. Specifically there are three main changes:
1) Turn on constant folding ONNX pass for opset 11.
2) Enable constant folding tests in `test/onnx/test_utility_funs.py` and `test/onnx/test_pytorch_onnx_onnxruntime.py` for opset 11.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29011

Reviewed By: hl475

Differential Revision: D18306998

Pulled By: houseroad

fbshipit-source-id: afeed21ca29e01c278612e51dacd93397dd6e2d8
This commit is contained in:
Spandan Tiwari 2019-11-05 23:20:58 -08:00 committed by Facebook Github Bot
parent ee21142e40
commit bc91e19861
5 changed files with 26 additions and 8 deletions

View File

@ -1760,7 +1760,7 @@ TestONNXRuntime_opset11 = type(str("TestONNXRuntime_opset11"),
dict(TestONNXRuntime.__dict__, opset_version=11))
# opset 10 tests, with keep_initializers_as_inputs=False for
# opset 9 tests, with keep_initializers_as_inputs=False for
# IR version 4 style export.
TestONNXRuntime_opset9_IRv4 = type(str("TestONNXRuntime_opset9_IRv4"),
(unittest.TestCase,),
@ -1776,5 +1776,13 @@ TestONNXRuntime_opset10_IRv4 = type(str("TestONNXRuntime_opset10_IRv4"),
keep_initializers_as_inputs=False))
# opset 11 tests, with keep_initializers_as_inputs=False for
# IR version 4 style export.
TestONNXRuntime_opset11_IRv4 = type(str("TestONNXRuntime_opset11_IRv4"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=11,
keep_initializers_as_inputs=False))
if __name__ == '__main__':
unittest.main()

View File

@ -246,5 +246,11 @@ TestUtilityFuns_opset10 = type(str("TestUtilityFuns_opset10"),
dict(TestUtilityFuns.__dict__, opset_version=10))
# opset 11 tests
TestUtilityFuns_opset11 = type(str("TestUtilityFuns_opset11"),
(TestCase,),
dict(TestUtilityFuns.__dict__, opset_version=11))
if __name__ == '__main__':
run_tests()

View File

@ -117,13 +117,13 @@ c10::optional<at::Tensor> runTorchSlice_opset9(const Node* node,
c10::optional<at::Tensor> runTorchSlice_opset10(const Node* node,
std::vector<at::Tensor>& inputTensorValues) {
if (inputTensorValues.size() < 3 || inputTensorValues.size() > 5) {
std::cerr << "Warning: Constant folding - Invalid number of inputs found for opset 10 onnx::Slice op. "
std::cerr << "Warning: Constant folding - Invalid number of inputs found for opset 10 or 11 onnx::Slice op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
// Checking validity of 'starts' and 'ends' input
if (inputTensorValues[1].sizes().size() != 1 || inputTensorValues[2].sizes().size() != 1) {
std::cerr << "Warning: Constant folding - Invalid 'starts' or 'ends' inputs found for opset 10 onnx::Slice op. "
std::cerr << "Warning: Constant folding - Invalid 'starts' or 'ends' inputs found for opset 10 or 11 onnx::Slice op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
@ -198,10 +198,10 @@ c10::optional<at::Tensor> runTorchBackendForOnnx(
int opset_version) {
at::Tensor updated_val;
if (node->kind() == onnx::Slice) {
if (opset_version == 9) {
if (opset_version == ONNX_OPSET_9) {
return runTorchSlice_opset9(node, inputTensorValues);
}
else if (opset_version == 10) {
else if (opset_version == ONNX_OPSET_10 || opset_version == ONNX_OPSET_11) {
return runTorchSlice_opset10(node, inputTensorValues);
}
else {
@ -329,9 +329,10 @@ std::vector<Node*> getOnnxConstParentsToRemove(Node* node) {
// nodes can be lifted so we run them earlier, before the usual parameters are
// known.
void ConstantFoldONNX(Block* b, ParamMap& paramsDict, int opset_version) {
if (opset_version != 9 && opset_version != 10) {
if (opset_version != ONNX_OPSET_9 && opset_version != ONNX_OPSET_10 &&
opset_version != ONNX_OPSET_11) {
// Number of elements of 'axes' and 'ends' 1-D input tensors should be the same
std::cerr << "Warning: Constant folding supported for only opsets 9 and 10. "
std::cerr << "Warning: Constant folding supported for only opsets 9, 10, and 11. "
<< "Constant folding not applied." << std::endl;
return;
}

View File

@ -5,6 +5,9 @@
namespace torch {
namespace jit {
const int ONNX_OPSET_9 = 9;
const int ONNX_OPSET_10 = 10;
const int ONNX_OPSET_11 = 11;
void ConstantFoldONNX(Block* b, std::map<std::string, at::Tensor>& paramDict, int opset_version);
}

View File

@ -282,7 +282,7 @@ def _model_to_graph(model, args, verbose=False, training=False,
param_names = input_and_param_names[len(input_and_param_names) - len(params):]
params_dict = dict(zip(param_names, params))
if do_constant_folding and _export_onnx_opset_version in [9, 10]:
if do_constant_folding and _export_onnx_opset_version in [9, 10, 11]:
params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict,
_export_onnx_opset_version)
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)