pytorch/test/onnx/test_models_onnxruntime.py
Negin Raoof 95a97e51b5 [ONNX] Improve scripting inplace indexing ops (#44351)
Summary:
Fix a couple of issues with scripting inplace indexing in prepare_inplace_ops_for_onnx pass.
1- Tracing index copy (such as cases lik x[1:3] = data) already applies broadcasting on rhs if needed. The broadcasting node (aten::expand) is missing in scripting cases.

2- Inplace indexing with ellipsis (aten::copy_) is replaced with aten::index_put and then handled with slice+select in this pass.
Support for negative indices for this op added.

Shape inference is also enabled for scripting tests using new JIT API.
A few more tests are enabled for scripting.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44351

Reviewed By: ezyang

Differential Revision: D23880267

Pulled By: bzinodev

fbshipit-source-id: 78b33444633eb7ae0fbabc7415e3b16001f5207f
2020-09-28 00:32:36 -07:00

46 lines
1.6 KiB
Python

import unittest
import onnxruntime # noqa
from test_models import TestModels
from test_pytorch_onnx_onnxruntime import run_model_test
import torch
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
opset_versions = opset_versions if opset_versions else [7, 8, 9, 10, 11, 12]
for opset_version in opset_versions:
self.opset_version = opset_version
run_model_test(self, model, False,
input=inputs, rtol=rtol, atol=atol)
if self.is_script_test_enabled and opset_version > 11:
TestModels.use_new_jit_passes = True
TestModels.onnx_shape_inference = True
outputs = model(inputs)
script_model = torch.jit.script(model)
run_model_test(self, script_model, False, example_outputs=outputs,
input=inputs, rtol=rtol, atol=atol)
TestModels = type(str("TestModels"),
(unittest.TestCase,),
dict(TestModels.__dict__,
is_script_test_enabled=False,
exportTest=exportTest))
# model tests for scripting with new JIT APIs and shape inference
TestModels_new_jit_API = type(str("TestModels_new_jit_API"),
(unittest.TestCase,),
dict(TestModels.__dict__,
exportTest=exportTest,
is_script_test_enabled=True,
use_new_jit_passes=True,
onnx_shape_inference=True))
if __name__ == '__main__':
unittest.main()