mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Fix a couple of issues with scripting inplace indexing in prepare_inplace_ops_for_onnx pass. 1- Tracing index copy (such as cases lik x[1:3] = data) already applies broadcasting on rhs if needed. The broadcasting node (aten::expand) is missing in scripting cases. 2- Inplace indexing with ellipsis (aten::copy_) is replaced with aten::index_put and then handled with slice+select in this pass. Support for negative indices for this op added. Shape inference is also enabled for scripting tests using new JIT API. A few more tests are enabled for scripting. Pull Request resolved: https://github.com/pytorch/pytorch/pull/44351 Reviewed By: ezyang Differential Revision: D23880267 Pulled By: bzinodev fbshipit-source-id: 78b33444633eb7ae0fbabc7415e3b16001f5207f
46 lines
1.6 KiB
Python
46 lines
1.6 KiB
Python
import unittest
|
|
import onnxruntime # noqa
|
|
|
|
from test_models import TestModels
|
|
from test_pytorch_onnx_onnxruntime import run_model_test
|
|
import torch
|
|
|
|
|
|
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
|
|
opset_versions = opset_versions if opset_versions else [7, 8, 9, 10, 11, 12]
|
|
|
|
for opset_version in opset_versions:
|
|
self.opset_version = opset_version
|
|
run_model_test(self, model, False,
|
|
input=inputs, rtol=rtol, atol=atol)
|
|
|
|
if self.is_script_test_enabled and opset_version > 11:
|
|
TestModels.use_new_jit_passes = True
|
|
TestModels.onnx_shape_inference = True
|
|
|
|
outputs = model(inputs)
|
|
script_model = torch.jit.script(model)
|
|
run_model_test(self, script_model, False, example_outputs=outputs,
|
|
input=inputs, rtol=rtol, atol=atol)
|
|
|
|
|
|
TestModels = type(str("TestModels"),
|
|
(unittest.TestCase,),
|
|
dict(TestModels.__dict__,
|
|
is_script_test_enabled=False,
|
|
exportTest=exportTest))
|
|
|
|
|
|
# model tests for scripting with new JIT APIs and shape inference
|
|
TestModels_new_jit_API = type(str("TestModels_new_jit_API"),
|
|
(unittest.TestCase,),
|
|
dict(TestModels.__dict__,
|
|
exportTest=exportTest,
|
|
is_script_test_enabled=True,
|
|
use_new_jit_passes=True,
|
|
onnx_shape_inference=True))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|