diff --git a/test/onnx/expect/TestOperators.test_dyn_arange.expect b/test/onnx/expect/TestOperators.test_dyn_arange.expect new file mode 100644 index 00000000000..c4f208507fe --- /dev/null +++ b/test/onnx/expect/TestOperators.test_dyn_arange.expect @@ -0,0 +1,129 @@ +ir_version: 4 +producer_name: "pytorch" +producer_version: "1.2" +graph { + node { + output: "1" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 7 + raw_data: "\000\000\000\000\000\000\000\000" + } + type: TENSOR + } + } + node { + input: "0" + output: "2" + op_type: "Shape" + } + node { + input: "2" + input: "1" + output: "3" + op_type: "Gather" + attribute { + name: "axis" + i: 0 + type: INT + } + } + node { + input: "3" + output: "4" + op_type: "Unsqueeze" + attribute { + name: "axes" + ints: 0 + type: INTS + } + } + node { + input: "4" + output: "5" + op_type: "ConstantOfShape" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\001\000\000\000\000\000\000\000" + } + type: TENSOR + } + } + node { + input: "5" + output: "6" + op_type: "NonZero" + } + node { + input: "6" + output: "7" + op_type: "Transpose" + attribute { + name: "perm" + ints: 1 + ints: 0 + type: INTS + } + } + node { + input: "7" + output: "8" + op_type: "Squeeze" + attribute { + name: "axes" + ints: 1 + type: INTS + } + } + node { + input: "8" + output: "9" + op_type: "Cast" + attribute { + name: "to" + i: 7 + type: INT + } + } + name: "torch-jit-export" + input { + name: "0" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 5 + } + dim { + dim_value: 3 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "9" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 5 + } + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index e05a1d4de2f..fa20fef568d 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -713,6 +713,14 @@ class TestOperators(TestCase): inputs = (scores, bbox_deltas, im_info, anchors) self.assertONNX(model, inputs) + def test_dyn_arange(self): + class TestModel(torch.nn.Module): + def forward(self, input): + return torch.arange(input.shape[0]) + + input = torch.randn(5, 3, 2) + self.assertONNX(TestModel(), input) + def test_layer_norm_aten(self): model = torch.nn.LayerNorm([10, 10]) x = torch.randn(20, 5, 10, 10) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index a32b5ff5143..9ea5f85709b 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -19,9 +19,33 @@ from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_S import model_defs.word_language_model as word_language_model +def ort_test_with_input(ort_sess, input, output, rtol, atol): + input, _ = torch.jit._flatten(input) + output, _ = torch.jit._flatten(output) + + def to_numpy(tensor): + if tensor.requires_grad: + return tensor.detach().cpu().numpy() + else: + return tensor.cpu().numpy() + + inputs = list(map(to_numpy, input)) + outputs = list(map(to_numpy, output)) + + ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs)) + ort_outs = ort_sess.run(None, ort_inputs) + + # compare onnxruntime and PyTorch results + assert len(outputs) == len(ort_outs), "number of outputs differ" + + # compare onnxruntime and PyTorch results + [np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)] + + def run_model_test(self, model, batch_size=2, state_dict=None, input=None, use_gpu=True, rtol=0.001, atol=1e-7, - example_outputs=None, do_constant_folding=True): + example_outputs=None, do_constant_folding=True, + dynamic_axes=None, test_with_inputs=None): model.eval() if input is None: @@ -40,31 +64,25 @@ def run_model_test(self, model, batch_size=2, state_dict=None, opset_version=self.opset_version, example_outputs=output, do_constant_folding=do_constant_folding, - keep_initializers_as_inputs=self.keep_initializers_as_inputs) - - input, _ = torch.jit._flatten(input) - output, _ = torch.jit._flatten(output) - - def to_numpy(tensor): - if tensor.requires_grad: - return tensor.detach().cpu().numpy() - else: - return tensor.cpu().numpy() - - inputs = list(map(to_numpy, input)) - outputs = list(map(to_numpy, output)) + keep_initializers_as_inputs=self.keep_initializers_as_inputs, + dynamic_axes=dynamic_axes) # compute onnxruntime output prediction ort_sess = onnxruntime.InferenceSession(f.getvalue()) - ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs)) - ort_outs = ort_sess.run(None, ort_inputs) + ort_test_with_input(ort_sess, input, output, rtol, atol) - # compare onnxruntime and PyTorch results - assert len(outputs) == len(ort_outs), "number of outputs differ" + # if addiional test inputs are provided run the onnx + # model with these inputs and check the outputs + if test_with_inputs is not None: + for test_input in test_with_inputs: + if isinstance(test_input, torch.Tensor): + test_input = (test_input,) + output = model(*test_input) + if isinstance(output, torch.Tensor): + output = (output,) - # compare onnxruntime and PyTorch results - [np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)] + ort_test_with_input(ort_sess, test_input, output, rtol, atol) class TestONNXRuntime(unittest.TestCase): @@ -78,10 +96,12 @@ class TestONNXRuntime(unittest.TestCase): torch.cuda.manual_seed_all(0) np.random.seed(seed=0) - def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=True, batch_size=2, use_gpu=True): + def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=True, + batch_size=2, use_gpu=True, dynamic_axes=None, test_with_inputs=None): run_model_test(self, model, batch_size=batch_size, input=input, use_gpu=use_gpu, rtol=rtol, atol=atol, - do_constant_folding=do_constant_folding) + do_constant_folding=do_constant_folding, + dynamic_axes=dynamic_axes, test_with_inputs=test_with_inputs) def run_word_language_model(self, model_name): ntokens = 50 @@ -280,6 +300,21 @@ class TestONNXRuntime(unittest.TestCase): x = torch.rand(5, 5, 5) self.run_test(DynamicSliceExportMod(), x) + @skipIfUnsupportedMinOpsetVersion(9) + def test_arange(self): + class ArangeModel(torch.nn.Module): + def forward(self, input): + return torch.arange(x.shape[0]), \ + torch.arange(12), \ + torch.arange(start=x.shape[0], end=x.shape[0] + 5) + + x = torch.randn(5, 3, 2) + y = torch.randn(8, 3, 2) + self.run_test(ArangeModel(), x, test_with_inputs=[y], + dynamic_axes={'input_1': [1], + 'output_1': [0], + 'output_2': [0]}) + def _test_index_generic(self, fn): class MyModel(torch.nn.Module): def __init__(self): diff --git a/test/test_jit.py b/test/test_jit.py index d3e5b7580fb..282f1496e5b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1553,6 +1553,40 @@ graph(%Ra, %Rb): def test_trace_size_with_grad(self): self.do_trace_size(True) + def do_trace_arange(self, requires_grad): + def arange(x): + return torch.arange(x.shape[0]) + + def arange_scalar(x): + return torch.arange(12) + + def arange_start_end(x): + return torch.arange(start=x.shape[0], end=x.shape[0] + 5) + + x = torch.randn(5, 3, 2, requires_grad=requires_grad) + y = torch.randn(8, 2, 4, requires_grad=requires_grad) + + # Check that it behaves as expected + traced_arange = torch.jit.trace(arange, x) + self.assertEqual(traced_arange(y), arange(y)) + self.assertEqual(traced_arange(x), arange(x)) + + traced_arange_scalar = torch.jit.trace(arange_scalar, x) + self.assertEqual(traced_arange_scalar(y), arange_scalar(y)) + self.assertEqual(traced_arange_scalar(x), arange_scalar(x)) + + traced_arange_start_end = torch.jit.trace(arange_start_end, x) + self.assertEqual(traced_arange_start_end(y), arange_start_end(y)) + self.assertEqual(traced_arange_start_end(x), arange_start_end(x)) + + def test_trace_arange(self): + self.do_trace_arange(False) + + # test the different graph_executor path that happens when + # gradients are required and sizes are involved + def test_trace_arange_with_grad(self): + self.do_trace_arange(True) + def test_trace_casts(self): casts = [ lambda x: x.byte(), @@ -12018,13 +12052,6 @@ a") FooMod(), (torch.rand(3, 4),), f, operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK) - def test_trace_checker_arange_as_constant(self): - with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'): - @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(4, 5),)]) - def foo(x): - y = torch.arange(0, x.shape[0]).double() - return x + y.unsqueeze(1) - @suppress_warnings def test_trace_checker_dot_data(self): with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Tensor-valued Constant nodes differed in value ' diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp index 8ce11ecc25e..200b20a98be 100644 --- a/tools/autograd/templates/python_torch_functions.cpp +++ b/tools/autograd/templates/python_torch_functions.cpp @@ -112,7 +112,7 @@ static PyObject * THPVariable_arange(PyObject* self, PyObject* args, PyObject* k static PythonArgParser parser({ "arange(Scalar end, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", "arange(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", - }); + }, /*traceable=*/true); ParsedArgs<9> parsed_args; auto r = parser.parse(args, kwargs, parsed_args);