mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: It appears to be a bug with test_arange, which wasn't revealed with older version of onnxruntime. TLDR. The test tries to update exported onnx model to accept dynamic sized input, however it is written incorrectly such that the exported model input is still fixed sized. Meanwhile, the version of ort in CI doesn't validate if model input size matches with input data, so this error was not found. Affecting ci in https://github.com/pytorch/pytorch/pull/25797 Pull Request resolved: https://github.com/pytorch/pytorch/pull/26320 Reviewed By: hl475 Differential Revision: D17406442 Pulled By: houseroad fbshipit-source-id: a09ad4b925ccbed0b71342f5aaa7878e1c4a5a2d
1132 lines
42 KiB
Python
1132 lines
42 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import unittest
|
|
import onnxruntime # noqa
|
|
import torch
|
|
|
|
import numpy as np
|
|
import io
|
|
import itertools
|
|
|
|
from torch.nn.utils import rnn as rnn_utils
|
|
from model_defs.lstm_flattening_result import LstmFlatteningResult
|
|
from model_defs.rnn_model_with_packed_sequence import RnnModelWithPackedSequence
|
|
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion
|
|
from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
|
|
import model_defs.word_language_model as word_language_model
|
|
|
|
|
|
def ort_test_with_input(ort_sess, input, output, rtol, atol):
|
|
input, _ = torch.jit._flatten(input)
|
|
output, _ = torch.jit._flatten(output)
|
|
|
|
def to_numpy(tensor):
|
|
if tensor.requires_grad:
|
|
return tensor.detach().cpu().numpy()
|
|
else:
|
|
return tensor.cpu().numpy()
|
|
|
|
inputs = list(map(to_numpy, input))
|
|
outputs = list(map(to_numpy, output))
|
|
|
|
ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs))
|
|
ort_outs = ort_sess.run(None, ort_inputs)
|
|
|
|
# compare onnxruntime and PyTorch results
|
|
assert len(outputs) == len(ort_outs), "number of outputs differ"
|
|
|
|
# compare onnxruntime and PyTorch results
|
|
[np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)]
|
|
|
|
|
|
def run_model_test(self, model, batch_size=2, state_dict=None,
|
|
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
|
|
example_outputs=None, do_constant_folding=True,
|
|
dynamic_axes=None, test_with_inputs=None,
|
|
input_names=None, output_names=None):
|
|
model.eval()
|
|
|
|
if input is None:
|
|
input = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
|
|
|
|
with torch.no_grad():
|
|
if isinstance(input, torch.Tensor):
|
|
input = (input,)
|
|
output = model(*input)
|
|
if isinstance(output, torch.Tensor):
|
|
output = (output,)
|
|
|
|
# export the model to ONNX
|
|
f = io.BytesIO()
|
|
torch.onnx.export(model, input, f,
|
|
opset_version=self.opset_version,
|
|
example_outputs=output,
|
|
do_constant_folding=do_constant_folding,
|
|
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
|
|
dynamic_axes=dynamic_axes,
|
|
input_names=input_names, output_names=output_names)
|
|
|
|
|
|
# compute onnxruntime output prediction
|
|
ort_sess = onnxruntime.InferenceSession(f.getvalue())
|
|
ort_test_with_input(ort_sess, input, output, rtol, atol)
|
|
|
|
# if addiional test inputs are provided run the onnx
|
|
# model with these inputs and check the outputs
|
|
if test_with_inputs is not None:
|
|
for test_input in test_with_inputs:
|
|
if isinstance(test_input, torch.Tensor):
|
|
test_input = (test_input,)
|
|
output = model(*test_input)
|
|
if isinstance(output, torch.Tensor):
|
|
output = (output,)
|
|
|
|
ort_test_with_input(ort_sess, test_input, output, rtol, atol)
|
|
|
|
|
|
class TestONNXRuntime(unittest.TestCase):
|
|
from torch.onnx.symbolic_helper import _export_onnx_opset_version
|
|
opset_version = _export_onnx_opset_version
|
|
keep_initializers_as_inputs = True # For IR version 3 type export.
|
|
|
|
def setUp(self):
|
|
torch.manual_seed(0)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(0)
|
|
np.random.seed(seed=0)
|
|
|
|
def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=True,
|
|
batch_size=2, use_gpu=True, dynamic_axes=None, test_with_inputs=None,
|
|
input_names=None, output_names=None):
|
|
run_model_test(self, model, batch_size=batch_size,
|
|
input=input, use_gpu=use_gpu, rtol=rtol, atol=atol,
|
|
do_constant_folding=do_constant_folding,
|
|
dynamic_axes=dynamic_axes, test_with_inputs=test_with_inputs,
|
|
input_names=input_names, output_names=output_names)
|
|
|
|
def run_word_language_model(self, model_name):
|
|
ntokens = 50
|
|
emsize = 5
|
|
nhid = 5
|
|
nlayers = 5
|
|
dropout = 0.2
|
|
tied = False
|
|
batchsize = 5
|
|
model = word_language_model.RNNModel(model_name, ntokens, emsize,
|
|
nhid, nlayers, dropout, tied,
|
|
batchsize)
|
|
x = torch.arange(0, ntokens).long().view(-1, batchsize)
|
|
# Only support CPU version, since tracer is not working in GPU RNN.
|
|
self.run_test(model, (x, model.hidden))
|
|
|
|
def test_word_language_model_RNN_TANH(self):
|
|
self.run_word_language_model("RNN_TANH")
|
|
|
|
def test_word_language_model_RNN_RELU(self):
|
|
self.run_word_language_model("RNN_RELU")
|
|
|
|
def test_word_language_model_LSTM(self):
|
|
self.run_word_language_model("LSTM")
|
|
|
|
def test_word_language_model_GRU(self):
|
|
self.run_word_language_model("GRU")
|
|
|
|
def test_index_1d(self):
|
|
self._test_index_generic(lambda input: input[0])
|
|
|
|
def test_index_2d_1dimslice(self):
|
|
self._test_index_generic(lambda input: input[0:1, :])
|
|
|
|
def test_index_2d_sliceint(self):
|
|
self._test_index_generic(lambda input: input[1, :])
|
|
|
|
def test_index_2d_neg_slice(self):
|
|
self._test_index_generic(lambda input: input[0:-1, :])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_full_trace(self):
|
|
class FullModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.full((3, 4), x, dtype=torch.long)
|
|
|
|
x = torch.tensor(12)
|
|
self.run_test(FullModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_full_script(self):
|
|
class FullModelScripting(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.full((3, 4), x, dtype=torch.long)
|
|
|
|
x = torch.tensor(12)
|
|
self.run_test(FullModelScripting(), x)
|
|
|
|
def test_maxpool(self):
|
|
model = torch.nn.MaxPool1d(2, stride=1)
|
|
x = torch.randn(20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(8)
|
|
def test_maxpool_adaptive(self):
|
|
model = torch.nn.AdaptiveMaxPool1d((5), return_indices=False)
|
|
x = torch.randn(20, 16, 50, requires_grad=True)
|
|
self.run_test(model, x)
|
|
|
|
def test_maxpool_2d(self):
|
|
model = torch.nn.MaxPool2d(5, padding=(1, 2))
|
|
x = torch.randn(1, 20, 16, 50, requires_grad=True)
|
|
self.run_test(model, x)
|
|
|
|
def test_maxpool_1d_ceil(self):
|
|
model = torch.nn.MaxPool1d(3, 2, ceil_mode=True)
|
|
x = torch.randn(20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
def test_maxpool_2d_ceil(self):
|
|
model = torch.nn.MaxPool2d(3, 2, ceil_mode=True)
|
|
x = torch.randn(20, 16, 50, 32)
|
|
self.run_test(model, x)
|
|
|
|
def test_maxpool_3d_ceil(self):
|
|
model = torch.nn.MaxPool3d(3, 2, ceil_mode=True)
|
|
x = torch.randn(20, 16, 50, 44, 31)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(8)
|
|
def test_maxpool_with_indices(self):
|
|
model = torch.nn.MaxPool1d(2, stride=1, return_indices=True)
|
|
x = torch.randn(20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_maxpool_dilation(self):
|
|
model = torch.nn.MaxPool1d(2, stride=1, dilation=2)
|
|
x = torch.randn(20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
def test_avgpool(self):
|
|
model = torch.nn.AvgPool1d(2, stride=1)
|
|
x = torch.randn(20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
def test_avgpool_1d_ceil(self):
|
|
model = torch.nn.AvgPool1d(3, 2, ceil_mode=True)
|
|
x = torch.randn(1, 1, 7)
|
|
self.run_test(model, x)
|
|
|
|
def test_avgpool_2d_ceil(self):
|
|
model = torch.nn.AvgPool2d(3, 2, ceil_mode=True)
|
|
x = torch.randn(20, 16, 50, 32)
|
|
self.run_test(model, x)
|
|
|
|
def test_avgpool_3d_ceil(self):
|
|
model = torch.nn.AvgPool3d(3, 2, ceil_mode=True)
|
|
x = torch.randn(20, 16, 50, 44, 31)
|
|
self.run_test(model, x)
|
|
|
|
def test_arithmetic(self):
|
|
class ArithmeticModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = x + 2
|
|
x = x - 4
|
|
x = x * 6
|
|
x = x / 8
|
|
return x
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(ArithmeticModule(), x)
|
|
|
|
def test_slice_trace(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x[0:1]
|
|
|
|
x = torch.randn(3)
|
|
self.run_test(MyModule(), x)
|
|
|
|
def test_slice_neg(self):
|
|
class NegSlice(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x[-1:]
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
self.run_test(NegSlice(), x)
|
|
|
|
def test_slice_neg_large(self):
|
|
class NegSlice(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x[:, :, :, :, -3]
|
|
|
|
x = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(NegSlice(), x)
|
|
|
|
@unittest.skip('https://github.com/pytorch/pytorch/issues/10984')
|
|
def test_slice_neg_large_negone(self):
|
|
class NegSlice(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x[:, :, :, :, -1]
|
|
|
|
x = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(NegSlice(), x)
|
|
|
|
def test_slice_dynamic(self):
|
|
class DynamicSliceExportMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
results = []
|
|
for i in range(4):
|
|
results.append(x[:x.size(0) - i, i:x.size(2), i:3])
|
|
return tuple(results)
|
|
|
|
x = torch.rand(5, 5, 5)
|
|
self.run_test(DynamicSliceExportMod(), x)
|
|
|
|
def test_slice_dynamic_script(self):
|
|
class DynamicSliceModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x[1:x.size(0)]
|
|
|
|
x = torch.rand(1, 2)
|
|
self.run_test(DynamicSliceModel(), x)
|
|
|
|
def test_slice_dynamic_to_end(self):
|
|
class DynamicSliceExportMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
results = []
|
|
for i in range(4):
|
|
results.append(x[:, i:, x.size(2) - 5])
|
|
return tuple(results)
|
|
|
|
x = torch.rand(5, 5, 5)
|
|
self.run_test(DynamicSliceExportMod(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_arange(self):
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.arange(input.shape[0]), \
|
|
torch.arange(12), \
|
|
torch.arange(start=input.shape[0], end=input.shape[0] + 5)
|
|
|
|
x = torch.randn(5, 3, 2)
|
|
y = torch.randn(8, 3, 2)
|
|
self.run_test(ArangeModel(), x, test_with_inputs=[y],
|
|
input_names=['input_1'],
|
|
output_names=['output_1', 'output_2', 'output_3'],
|
|
dynamic_axes={'input_1': [0],
|
|
'output_1': [0]})
|
|
|
|
def _test_index_generic(self, fn):
|
|
class MyModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModel, self).__init__()
|
|
|
|
def forward(self, input):
|
|
return fn(input)
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), m1)
|
|
|
|
def test_tensor_index_advanced_indexing(self):
|
|
self._test_index_generic(
|
|
lambda input: input[:, torch.tensor([[0, 2], [1, 1]]), :, torch.tensor([2, 1]), torch.tensor([0, 3])])
|
|
self._test_index_generic(lambda input: input[..., torch.tensor([2, 1]), torch.tensor([0, 3])])
|
|
self._test_index_generic(lambda input: input[:, torch.tensor([0, 2]), None, 2:4, torch.tensor([[1, 3], [4, 0]])])
|
|
self._test_index_generic(lambda input: input[:, torch.tensor([0, 2]), torch.tensor([1]), 2:4, torch.tensor([[1], [4]])])
|
|
|
|
def test_tensor_index_advanced_indexing_consecutive(self):
|
|
self._test_index_generic(lambda input: input[:, torch.tensor([0, 2]), torch.tensor([[1, 3], [4, 0]]), None])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_flip(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.flip(x, dims=[0])
|
|
|
|
x = torch.tensor(np.arange(6.0).reshape(2, 3))
|
|
self.run_test(MyModule(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_interpolate_scale(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.interpolate(x, mode="nearest", scale_factor=2)
|
|
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
|
self.run_test(MyModel(), x)
|
|
|
|
# NOTE: Supported in onnxruntime master, enable this after 0.5 release.
|
|
@skipIfUnsupportedOpsetVersion([10, 11])
|
|
def test_interpolate_output_size(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.interpolate(x, mode="nearest", size=(6, 8))
|
|
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
|
self.run_test(MyModel(), x)
|
|
|
|
# NOTE: Supported in onnxruntime master, enable this after 0.5 release.
|
|
@skipIfUnsupportedOpsetVersion([10, 11])
|
|
def test_interpolate_upsample(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
size = [v * 2 for v in x.size()[2:]]
|
|
# work around for now: turn the dynamic sizes into constant
|
|
size = [int(i) for i in size]
|
|
return torch.nn.functional.interpolate(x, mode="nearest", size=size)
|
|
|
|
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
|
self.run_test(MyModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_interpolate_downsample(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.interpolate(x, mode="nearest", scale_factor=[1, 1, 0.5, 0.5])
|
|
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
|
self.run_test(MyModel(), x)
|
|
|
|
def test_std(self):
|
|
class StandardDeviation(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std(input, unbiased=False)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = StandardDeviation()
|
|
self.run_test(model, x)
|
|
|
|
def test_std_along_dims(self):
|
|
class StandardDeviation(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std(input, dim=(0, 1), unbiased=False)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = StandardDeviation()
|
|
self.run_test(model, x)
|
|
|
|
def test_std_keepdim(self):
|
|
class StandardDeviation(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std(input, dim=(0, 1), unbiased=False, keepdim=True)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = StandardDeviation()
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedOpsetVersion([7, 8])
|
|
def test_interpolate_upsample_dynamic_sizes(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
size = [v * 2 for v in x.size()[2:]]
|
|
return torch.nn.functional.interpolate(x, mode="nearest", size=size)
|
|
|
|
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
|
self.run_test(MyModel(), x)
|
|
|
|
def test_narrow(self):
|
|
class NarrowModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.narrow(input, 0, 0, 2)
|
|
|
|
x = torch.randn(3, 3, requires_grad=True)
|
|
self.run_test(NarrowModel(), x)
|
|
|
|
# TODO: enable for opset 10 when ONNXRuntime version will be updated
|
|
def test_index_select_constant_scaler_index(self):
|
|
class IndexSelectScalerIndexModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
index = 2
|
|
return torch.index_select(x, 1, torch.tensor(index))
|
|
x = torch.randn(3, 4)
|
|
self.run_test(IndexSelectScalerIndexModel(), x)
|
|
|
|
def test_index_select_scaler_index(self):
|
|
class IndexSelectScalerIndexModel(torch.nn.Module):
|
|
def __init__(self, index_base):
|
|
super(IndexSelectScalerIndexModel, self).__init__()
|
|
self.index_base = torch.tensor(index_base)
|
|
|
|
def forward(self, x, index_offset):
|
|
index = self.index_base + index_offset
|
|
return torch.index_select(x, 1, index)
|
|
x = torch.randn(3, 4)
|
|
offset = 2
|
|
index_offset = torch.tensor(offset)
|
|
base = 1
|
|
self.run_test(IndexSelectScalerIndexModel(base), (x, index_offset))
|
|
|
|
# TODO: enable for opset 10 when ONNXRuntime version will be updated
|
|
@skipIfUnsupportedOpsetVersion([10, 11])
|
|
def test_topk(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.topk(x, 3)
|
|
|
|
x = torch.arange(1., 6., requires_grad=True)
|
|
self.run_test(MyModule(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
@skipIfUnsupportedOpsetVersion([11])
|
|
def test_topk_script(self):
|
|
class MyModuleDynamic(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x, k):
|
|
return torch.topk(x, k)
|
|
|
|
x = torch.arange(1., 6., requires_grad=True)
|
|
k = torch.tensor(3)
|
|
self.run_test(MyModuleDynamic(), [x, k])
|
|
|
|
def test_layer_norm(self):
|
|
model = torch.nn.LayerNorm([10, 10])
|
|
x = torch.randn(20, 5, 10, 10)
|
|
self.run_test(model, x)
|
|
|
|
def test_multinomial(self):
|
|
class Multinomial(torch.nn.Module):
|
|
def forward(self, weight):
|
|
return torch.multinomial(weight, 3, replacement=True)
|
|
|
|
class MultinomialNoReplacement(torch.nn.Module):
|
|
def forward(self, weight):
|
|
return torch.multinomial(weight, 1)
|
|
|
|
weight = torch.tensor([[0, 10, 0, 0], [0, 0, 100, 0]], dtype=torch.float)
|
|
self.run_test(Multinomial(), (weight,))
|
|
self.run_test(MultinomialNoReplacement(), (weight,))
|
|
|
|
def test_reduce_log_sum_exp(self):
|
|
class ReduceLogSumExpModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
a = torch.logsumexp(input, dim=0)
|
|
b = torch.logsumexp(input, dim=(0, 1))
|
|
return a + b
|
|
|
|
x = torch.randn(4, 4, requires_grad=True)
|
|
self.run_test(ReduceLogSumExpModel(), x)
|
|
|
|
def test_lstm_constant_folding(self):
|
|
class LstmNet(torch.nn.Module):
|
|
def __init__(self, input_size, hidden_size, num_layers, bidirectional):
|
|
super(LstmNet, self).__init__()
|
|
self.lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional)
|
|
|
|
def forward(self, input, initial_state):
|
|
return self.lstm(input, initial_state)
|
|
|
|
def get_LstmNet_model_and_inputs(input_size, hidden_size, num_layers, batch_size,
|
|
seq_len, bidirectional):
|
|
num_directions = 2 if bidirectional else 1
|
|
model = LstmNet(input_size, hidden_size, num_layers, bidirectional)
|
|
input = torch.randn(seq_len, batch_size, input_size)
|
|
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
|
|
c0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
|
|
return model, (input, (h0, c0))
|
|
|
|
batch_size1 = 3
|
|
model1, input1 = get_LstmNet_model_and_inputs(7, 3, 2, batch_size1, 5, True)
|
|
self.run_test(model1, input1, do_constant_folding=True)
|
|
|
|
batch_size2 = 4
|
|
model2, input2 = get_LstmNet_model_and_inputs(5, 4, 3, batch_size2, 7, False)
|
|
self.run_test(model2, input2, do_constant_folding=True)
|
|
|
|
def test_gru_constant_folding(self):
|
|
class GruNet(torch.nn.Module):
|
|
def __init__(self, input_size, hidden_size, num_layers, bidirectional):
|
|
super(GruNet, self).__init__()
|
|
self.mygru = torch.nn.GRU(input_size, hidden_size, num_layers, bidirectional=bidirectional)
|
|
|
|
def forward(self, input, initial_state):
|
|
out = self.mygru(input, initial_state)
|
|
return out
|
|
|
|
def get_GruNet_model_and_inputs(input_size, hidden_size, num_layers, batch_size,
|
|
seq_len, bidirectional):
|
|
num_directions = 2 if bidirectional else 1
|
|
model = GruNet(input_size, hidden_size, num_layers, bidirectional)
|
|
input = torch.randn(seq_len, batch_size, input_size)
|
|
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
|
|
return model, (input, h0)
|
|
|
|
batch_size1 = 3
|
|
model1, input1 = get_GruNet_model_and_inputs(7, 3, 2, batch_size1, 5, True)
|
|
self.run_test(model1, input1, do_constant_folding=True)
|
|
|
|
batch_size2 = 4
|
|
model2, input2 = get_GruNet_model_and_inputs(5, 4, 3, batch_size2, 7, False)
|
|
self.run_test(model2, input2, do_constant_folding=True)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(8)
|
|
def test_max_tensors(self):
|
|
class MaxModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.max(input, other)
|
|
|
|
model = MaxModel()
|
|
x = torch.randn(4, 4, requires_grad=True)
|
|
y = torch.randn(4, 1, requires_grad=True)
|
|
self.run_test(model, (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_arange_end(self):
|
|
class ArangeScript(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
|
|
|
|
x = torch.randn(3, 4, requires_grad=True)
|
|
outputs = ArangeScript()(x)
|
|
self.run_test(ArangeScript(), x)
|
|
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, a):
|
|
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
|
|
|
|
self.run_test(ArangeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_arange_start_end(self):
|
|
class ArangeScript(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
|
|
|
|
x = torch.randn(3, 4, requires_grad=True)
|
|
outputs = ArangeScript()(x)
|
|
self.run_test(ArangeScript(), x)
|
|
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, a):
|
|
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
|
|
|
|
self.run_test(ArangeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_arange_start_end_step(self):
|
|
class ArangeScript(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
|
|
|
|
x = torch.randn(3, 4, requires_grad=True)
|
|
outputs = ArangeScript()(x)
|
|
self.run_test(ArangeScript(), x)
|
|
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, a):
|
|
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
|
|
|
|
self.run_test(ArangeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test__dim_arange(self):
|
|
class DimArange(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch._dim_arange(input, 1)
|
|
|
|
x = torch.ones(5, 6)
|
|
self.run_test(DimArange(), x)
|
|
|
|
def test_gt(self):
|
|
class GreaterModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return input > other
|
|
|
|
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
|
y = torch.randn(1, 2, 3, 4, requires_grad=True)
|
|
self.run_test(GreaterModel(), (x, y))
|
|
|
|
x = torch.randint(10, (3, 4), dtype=torch.int32)
|
|
y = torch.randint(10, (3, 4), dtype=torch.int32)
|
|
self.run_test(GreaterModel(), (x, y))
|
|
|
|
def test_gt_scalar(self):
|
|
class GreaterModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input > 1
|
|
|
|
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
|
self.run_test(GreaterModel(), x)
|
|
|
|
x = torch.randint(10, (3, 4), dtype=torch.int32)
|
|
self.run_test(GreaterModel(), x)
|
|
|
|
def test_lt(self):
|
|
class LessModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return input > other
|
|
|
|
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
|
y = torch.randn(1, 2, 3, 4, requires_grad=True)
|
|
self.run_test(LessModel(), (x, y))
|
|
|
|
x = torch.randint(10, (3, 4), dtype=torch.int32)
|
|
y = torch.randint(10, (3, 4), dtype=torch.int32)
|
|
self.run_test(LessModel(), (x, y))
|
|
|
|
def test_matmul(self):
|
|
class MatmulModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.matmul(input, other)
|
|
|
|
x = torch.randn(3, 4, requires_grad=True)
|
|
y = torch.randn(4, 5, requires_grad=True)
|
|
self.run_test(MatmulModel(), (x, y))
|
|
|
|
x = torch.randint(10, (3, 4))
|
|
y = torch.randint(10, (4, 5))
|
|
self.run_test(MatmulModel(), (x, y))
|
|
|
|
def test_matmul_batch(self):
|
|
class MatmulModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.matmul(input, other)
|
|
|
|
x = torch.randn(2, 3, 4, requires_grad=True)
|
|
y = torch.randn(2, 4, 5, requires_grad=True)
|
|
self.run_test(MatmulModel(), (x, y))
|
|
|
|
x = torch.randint(10, (2, 3, 4))
|
|
y = torch.randint(10, (2, 4, 5))
|
|
self.run_test(MatmulModel(), (x, y))
|
|
|
|
def test_view(self):
|
|
class ViewModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input.view(4, 24)
|
|
|
|
x = torch.randint(10, (4, 2, 3, 4), dtype=torch.int32)
|
|
self.run_test(ViewModel(), x)
|
|
|
|
def test_flatten(self):
|
|
class FlattenModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.flatten(input)
|
|
|
|
x = torch.randint(10, (1, 2, 3, 4))
|
|
self.run_test(FlattenModel(), x)
|
|
|
|
def test_flatten2d(self):
|
|
class FlattenModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.flatten(input, 1)
|
|
|
|
x = torch.randint(10, (1, 2, 3, 4))
|
|
self.run_test(FlattenModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_tensor_factories(self):
|
|
class TensorFactory(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.zeros(x.size()) + torch.ones(x.size())
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(TensorFactory(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_tensor_factories_script(self):
|
|
class TensorFactory(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.zeros(x.shape, dtype=torch.float) + torch.ones(x.shape, dtype=torch.float)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(TensorFactory(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_tensor_like_factories_script(self):
|
|
class TensorFactory(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
zeros = torch.zeros_like(x, dtype=torch.float, layout=torch.strided, device=torch.device('cpu'))
|
|
ones = torch.ones_like(x, dtype=torch.float, layout=torch.strided, device=torch.device('cpu'))
|
|
return zeros + ones
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(TensorFactory(), x)
|
|
|
|
@skipIfUnsupportedOpsetVersion([11])
|
|
def test_sort(self):
|
|
class SortModel(torch.nn.Module):
|
|
def __init__(self, dim):
|
|
super(SortModel, self).__init__()
|
|
self.dim = dim
|
|
|
|
def forward(self, x):
|
|
return torch.sort(x, dim=self.dim, descending=True)
|
|
|
|
dim = 1
|
|
x = torch.randn(3, 4)
|
|
self.run_test(SortModel(dim), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_masked_fill(self):
|
|
class MaskedFillModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.uint8)
|
|
return x.masked_fill(mask, 2)
|
|
|
|
x = torch.zeros(4, 2, 3, requires_grad=True)
|
|
self.run_test(MaskedFillModel(), x)
|
|
|
|
class MaskedFillModel2(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.masked_fill(x > 3, -1)
|
|
|
|
x = torch.arange(16).view(2, 2, 4).to(torch.float32)
|
|
self.run_test(MaskedFillModel2(), x)
|
|
|
|
@unittest.skip("Enable this once depthToSpace attr 'mode' is supported in ORT")
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_pixel_shuffle(self):
|
|
class PixelShuffle(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.pixel_shuffle(x, upscale_factor=2)
|
|
|
|
x = torch.randn(2, 16, 4, 3, requires_grad=True)
|
|
self.run_test(PixelShuffle(), x)
|
|
|
|
def test_frobenius_norm(self):
|
|
class NormModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.norm(x, p="fro", dim=0, keepdim=False)
|
|
|
|
x = torch.randn(4, 2, 3, requires_grad=True)
|
|
self.run_test(NormModel(), x)
|
|
|
|
def test_frobenius_norm_keepdim(self):
|
|
class NormModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.norm(x, p="fro", dim=(0, 1), keepdim=True)
|
|
|
|
x = torch.randn(4, 2, 3, requires_grad=True)
|
|
self.run_test(NormModel(), x)
|
|
|
|
def test_rsqrt(self):
|
|
class RsqrtModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.rsqrt()
|
|
|
|
x = torch.randn(4, 2, 3, requires_grad=True).to(dtype=torch.float64)
|
|
self.run_test(RsqrtModel(), x)
|
|
|
|
def test_rsqrt_zeros(self):
|
|
class RsqrtModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.rsqrt()
|
|
x = torch.zeros(4, 2, 3, requires_grad=True).to(dtype=torch.float64)
|
|
self.run_test(RsqrtModel(), x)
|
|
|
|
# TODO: enable opset 11 test once ORT support for unique is in
|
|
@skipIfUnsupportedOpsetVersion([11])
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_unique(self):
|
|
class UniqueModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.unique(x, sorted=True, return_inverse=False, return_counts=True)
|
|
|
|
x = torch.tensor([1, 3, 2, 3], dtype=torch.long)
|
|
self.run_test(UniqueModel(), x)
|
|
|
|
# TODO: enable opset 11 test once ORT support for unique is in
|
|
@skipIfUnsupportedOpsetVersion([11])
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_unique_along_dim(self):
|
|
class UniqueModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.unique(x, dim=0, sorted=True, return_inverse=True, return_counts=False)
|
|
|
|
x = torch.tensor([1, 3, 2, 3], dtype=torch.long)
|
|
self.run_test(UniqueModel(), x)
|
|
|
|
# TODO: enable opset 11 test once ORT support for cumsum is in
|
|
@skipIfUnsupportedOpsetVersion([11])
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_cumsum(self):
|
|
class CumSum(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.cumsum(input, dim=0)
|
|
x = torch.randn(2, 3, 4)
|
|
model = CumSum()
|
|
self.run_test(model, x)
|
|
|
|
def test_log(self):
|
|
class Log(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.log(input)
|
|
x = torch.rand(2, 3, 4)
|
|
model = Log()
|
|
self.run_test(model, x)
|
|
|
|
def test_log1p(self):
|
|
class Log1p(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.log1p(input)
|
|
x = torch.rand(2, 3, 4)
|
|
model = Log1p()
|
|
self.run_test(model, x)
|
|
|
|
# TODO: remove the skip tag once ORT implementation is in place
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@skipIfUnsupportedOpsetVersion([11])
|
|
def test_round(self):
|
|
class Round(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.round(x)
|
|
|
|
x = torch.tensor([0.9920, -1.0362, -1.5000, 3.5000], requires_grad=True)
|
|
self.run_test(Round(), x)
|
|
|
|
def _dispatch_rnn_test(self, name, *args, **kwargs):
|
|
if name == 'elman':
|
|
self._elman_rnn_test(*args, **kwargs)
|
|
if name == 'lstm':
|
|
self._lstm_test(*args, **kwargs)
|
|
if name == 'gru':
|
|
self._gru_test(*args, **kwargs)
|
|
|
|
def _elman_rnn_test(self, layers, nonlinearity, bidirectional,
|
|
initial_state, packed_sequence, dropout):
|
|
batch_first = True if packed_sequence == 2 else False
|
|
model = torch.nn.RNN(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers, nonlinearity=nonlinearity,
|
|
bidirectional=bidirectional, dropout=dropout, batch_first=batch_first)
|
|
|
|
if packed_sequence == 1:
|
|
model = RnnModelWithPackedSequence(model, False)
|
|
if packed_sequence == 2:
|
|
model = RnnModelWithPackedSequence(model, True)
|
|
|
|
def make_input(batch_size):
|
|
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
|
|
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
|
|
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
|
|
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
|
|
inputs = [inputs]
|
|
|
|
directions = 2 if bidirectional else 1
|
|
|
|
if initial_state:
|
|
h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
|
|
inputs.append(h0)
|
|
if packed_sequence != 0:
|
|
inputs.append(torch.IntTensor(seq_lengths))
|
|
if len(inputs) == 1:
|
|
input = inputs[0]
|
|
else:
|
|
input = tuple(inputs)
|
|
return input
|
|
|
|
input = make_input(RNN_BATCH_SIZE)
|
|
self.run_test(model, input, batch_size=RNN_BATCH_SIZE)
|
|
|
|
# test that the model still runs with a different batch size
|
|
other_input = make_input(RNN_BATCH_SIZE + 1)
|
|
self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1)
|
|
|
|
def _lstm_test(self, layers, bidirectional, initial_state,
|
|
packed_sequence, dropout):
|
|
batch_first = True if packed_sequence == 2 else False
|
|
model = LstmFlatteningResult(
|
|
RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers,
|
|
bidirectional=bidirectional, dropout=dropout, batch_first=batch_first)
|
|
if packed_sequence == 1:
|
|
model = RnnModelWithPackedSequence(model, False)
|
|
if packed_sequence == 2:
|
|
model = RnnModelWithPackedSequence(model, True)
|
|
|
|
def make_input(batch_size):
|
|
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
|
|
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
|
|
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
|
|
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
|
|
inputs = [inputs]
|
|
|
|
directions = 2 if bidirectional else 1
|
|
|
|
if initial_state:
|
|
h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
|
|
c0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
|
|
inputs.append((h0, c0))
|
|
if packed_sequence != 0:
|
|
inputs.append(torch.IntTensor(seq_lengths))
|
|
if len(inputs) == 1:
|
|
input = inputs[0]
|
|
else:
|
|
input = tuple(inputs)
|
|
return input
|
|
|
|
input = make_input(RNN_BATCH_SIZE)
|
|
self.run_test(model, input, batch_size=RNN_BATCH_SIZE)
|
|
|
|
# test that the model still runs with a different batch size
|
|
other_input = make_input(RNN_BATCH_SIZE + 1)
|
|
self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1)
|
|
|
|
def _gru_test(self, layers, bidirectional, initial_state,
|
|
packed_sequence, dropout):
|
|
batch_first = True if packed_sequence == 2 else False
|
|
model = torch.nn.GRU(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers, bidirectional=bidirectional, dropout=dropout,
|
|
batch_first=batch_first)
|
|
if packed_sequence == 1:
|
|
model = RnnModelWithPackedSequence(model, False)
|
|
if packed_sequence == 2:
|
|
model = RnnModelWithPackedSequence(model, True)
|
|
|
|
def make_input(batch_size):
|
|
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
|
|
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
|
|
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
|
|
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
|
|
inputs = [inputs]
|
|
|
|
directions = 2 if bidirectional else 1
|
|
|
|
if initial_state:
|
|
h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
|
|
inputs.append(h0)
|
|
if packed_sequence != 0:
|
|
inputs.append(torch.IntTensor(seq_lengths))
|
|
if len(inputs) == 1:
|
|
input = inputs[0]
|
|
else:
|
|
input = tuple(inputs)
|
|
return input
|
|
|
|
input = make_input(RNN_BATCH_SIZE)
|
|
self.run_test(model, input, batch_size=RNN_BATCH_SIZE)
|
|
|
|
# test that the model still runs with a different batch size
|
|
other_input = make_input(RNN_BATCH_SIZE + 1)
|
|
self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1)
|
|
|
|
|
|
def make_test(name, base, layer, bidirectional, initial_state,
|
|
variable_length, dropout,
|
|
**extra_kwargs):
|
|
test_name = str('_'.join([
|
|
'test', name, layer[1],
|
|
bidirectional[1], initial_state[1],
|
|
variable_length[1], dropout[1]
|
|
]))
|
|
|
|
# Cannot export with older opsets because of 'ConstantFill' op
|
|
# ConstantFill was a temp op removed at opset 8. This is no longer supported by onnxruntime
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def f(self):
|
|
self._dispatch_rnn_test(
|
|
base,
|
|
layers=layer[0],
|
|
bidirectional=bidirectional[0],
|
|
initial_state=initial_state[0],
|
|
packed_sequence=variable_length[0],
|
|
dropout=dropout[0],
|
|
**extra_kwargs)
|
|
|
|
f.__name__ = test_name
|
|
setattr(TestONNXRuntime, f.__name__, f)
|
|
|
|
|
|
def setup_rnn_tests():
|
|
layers_opts = [
|
|
(1, 'unilayer'),
|
|
(3, 'trilayer')
|
|
]
|
|
bidirectional_opts = [
|
|
(False, 'forward'),
|
|
(True, 'bidirectional')
|
|
]
|
|
initial_state_opts = [
|
|
(True, 'with_initial_state'),
|
|
(False, 'no_initial_state')
|
|
]
|
|
variable_length_opts = [
|
|
(0, 'without_sequence_lengths'),
|
|
(1, 'with_variable_length_sequences'),
|
|
(2, 'with_batch_first_sequence_lengths')
|
|
]
|
|
dropout_opts = [
|
|
(0.2, 'with_dropout'),
|
|
(0.0, 'without_dropout')
|
|
]
|
|
test_count = 0
|
|
for (layer, bidirectional, initial_state, variable_length, dropout) in \
|
|
itertools.product(
|
|
layers_opts,
|
|
bidirectional_opts,
|
|
initial_state_opts,
|
|
variable_length_opts,
|
|
dropout_opts,
|
|
):
|
|
|
|
for base, name, extra_kwargs in (
|
|
('elman', 'elman_relu', {'nonlinearity': u'relu'}),
|
|
('elman', 'elman_tanh', {'nonlinearity': u'tanh'}),
|
|
('lstm', 'lstm', {}),
|
|
('gru', 'gru', {})
|
|
):
|
|
# This is a hack to skip elman_rnn bidirectional tests for now
|
|
# TODO: Revert this once elman_rnn bidirectional issue is fixed
|
|
if base == 'elman' and bidirectional[1] == 'bidirectional':
|
|
continue
|
|
make_test(name, base, layer, bidirectional, initial_state,
|
|
variable_length, dropout,
|
|
**extra_kwargs)
|
|
test_count += 1
|
|
|
|
# sanity check that a representative example does exist
|
|
TestONNXRuntime.test_gru_trilayer_forward_with_initial_state_without_sequence_lengths_with_dropout
|
|
|
|
# make sure no one accidentally disables all the tests without
|
|
# noticing
|
|
# assert test_count == 192, test_count
|
|
# TODO: Revert this once elman_rnn bidirectional issue is fixed
|
|
if test_count != 144:
|
|
raise ValueError('Expected 144 tests but found {}'.format(test_count))
|
|
|
|
|
|
setup_rnn_tests()
|
|
|
|
# opset 7 tests
|
|
TestONNXRuntime_opset7 = type(str("TestONNXRuntime_opset7"),
|
|
(unittest.TestCase,),
|
|
dict(TestONNXRuntime.__dict__, opset_version=7))
|
|
|
|
# opset 8 tests
|
|
TestONNXRuntime_opset8 = type(str("TestONNXRuntime_opset8"),
|
|
(unittest.TestCase,),
|
|
dict(TestONNXRuntime.__dict__, opset_version=8))
|
|
|
|
|
|
# opset 10 tests
|
|
TestONNXRuntime_opset10 = type(str("TestONNXRuntime_opset10"),
|
|
(unittest.TestCase,),
|
|
dict(TestONNXRuntime.__dict__, opset_version=10))
|
|
|
|
# opset 11 tests
|
|
TestONNXRuntime_opset11 = type(str("TestONNXRuntime_opset11"),
|
|
(unittest.TestCase,),
|
|
dict(TestONNXRuntime.__dict__, opset_version=11))
|
|
|
|
|
|
# opset 10 tests, with keep_initializers_as_inputs=False for
|
|
# IR version 4 style export.
|
|
TestONNXRuntime_opset9_IRv4 = type(str("TestONNXRuntime_opset9_IRv4"),
|
|
(unittest.TestCase,),
|
|
dict(TestONNXRuntime.__dict__,
|
|
keep_initializers_as_inputs=False))
|
|
|
|
|
|
# opset 10 tests, with keep_initializers_as_inputs=False for
|
|
# IR version 4 style export.
|
|
TestONNXRuntime_opset10_IRv4 = type(str("TestONNXRuntime_opset10_IRv4"),
|
|
(unittest.TestCase,),
|
|
dict(TestONNXRuntime.__dict__, opset_version=10,
|
|
keep_initializers_as_inputs=False))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|