mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: * Support propagating `dim_param` in ONNX by encoding as `ShapeSymbol` in `SymbolicShape` of outputs. If export is called with `dynamic_axes` provided, shape inference will start with these axes set as dynamic. * Add new test file `test_pytorch_onnx_shape_inference.py`, reusing all test cases from `test_pytorch_onnx_onnxruntime.py`, but focus on validating shape for all nodes in graph. Currently this is not enabled in the CI, since there are still quite some existing issues and corner cases to fix. The test is default to run only at opset 12. * Bug fixes, such as div, _len, and peephole.cpp passes for PackPadded, and LogSoftmaxCrossEntropy. * This PR depends on existing PR such as 44332. Pull Request resolved: https://github.com/pytorch/pytorch/pull/44920 Reviewed By: eellison Differential Revision: D23958398 Pulled By: bzinodev fbshipit-source-id: 00479d9bd19c867d526769a15ba97ec16d56e51d
4818 lines
180 KiB
Python
4818 lines
180 KiB
Python
import unittest
|
|
import onnxruntime # noqa
|
|
import torch
|
|
|
|
import numpy as np
|
|
import io
|
|
import itertools
|
|
import copy
|
|
|
|
from torch.nn.utils import rnn as rnn_utils
|
|
from model_defs.lstm_flattening_result import LstmFlatteningResult
|
|
from model_defs.rnn_model_with_packed_sequence import RnnModelWithPackedSequence
|
|
from test_pytorch_common import (skipIfUnsupportedMinOpsetVersion, disableScriptTest,
|
|
skipIfUnsupportedOpsetVersion, skipIfNoLapack,
|
|
skipIfUnsupportedMaxOpsetVersion, skipIfONNXShapeInference)
|
|
from test_pytorch_common import BATCH_SIZE
|
|
from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
|
|
from typing import List
|
|
import model_defs.word_language_model as word_language_model
|
|
import torchvision
|
|
import onnx
|
|
|
|
def to_numpy(tensor):
|
|
if tensor.requires_grad:
|
|
return tensor.detach().cpu().numpy()
|
|
else:
|
|
return tensor.cpu().numpy()
|
|
|
|
def convert_to_onnx(model, input=None, opset_version=9, example_outputs=None,
|
|
do_constant_folding=True, keep_initializers_as_inputs=True,
|
|
dynamic_axes=None, input_names=None, output_names=None,
|
|
fixed_batch_size=False, training=None,
|
|
onnx_shape_inference=False,
|
|
use_new_jit_passes=False):
|
|
# export the model to ONNX
|
|
f = io.BytesIO()
|
|
input_copy = copy.deepcopy(input)
|
|
torch.onnx._export(model, input_copy, f,
|
|
opset_version=opset_version,
|
|
example_outputs=example_outputs,
|
|
do_constant_folding=do_constant_folding,
|
|
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
|
dynamic_axes=dynamic_axes,
|
|
input_names=input_names, output_names=output_names,
|
|
fixed_batch_size=fixed_batch_size, training=training,
|
|
onnx_shape_inference=onnx_shape_inference,
|
|
use_new_jit_passes=use_new_jit_passes)
|
|
|
|
# compute onnxruntime output prediction
|
|
ort_sess = onnxruntime.InferenceSession(f.getvalue())
|
|
return ort_sess
|
|
|
|
def run_ort(ort_sess, input):
|
|
input_copy = copy.deepcopy(input)
|
|
input, _ = torch.jit._flatten(input_copy)
|
|
inputs = list(map(to_numpy, input))
|
|
|
|
ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs))
|
|
ort_outs = ort_sess.run(None, ort_inputs)
|
|
|
|
return ort_outs
|
|
|
|
def ort_compare_with_pytorch(ort_outs, output, rtol, atol):
|
|
output, _ = torch.jit._flatten(output)
|
|
outputs = list(map(to_numpy, output))
|
|
|
|
# compare onnxruntime and PyTorch results
|
|
assert len(outputs) == len(ort_outs), "number of outputs differ"
|
|
|
|
# compare onnxruntime and PyTorch results
|
|
[np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)]
|
|
|
|
|
|
def run_model_test(self, model, batch_size=2, state_dict=None,
|
|
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
|
|
example_outputs=None, do_constant_folding=True,
|
|
dynamic_axes=None, test_with_inputs=None,
|
|
input_names=None, output_names=None,
|
|
fixed_batch_size=False):
|
|
model.eval()
|
|
|
|
if input is None:
|
|
input = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
|
|
|
|
with torch.no_grad():
|
|
if isinstance(input, torch.Tensor):
|
|
input = (input,)
|
|
# In-place operators will update input tensor data as well.
|
|
# Thus inputs are replicated before every forward call.
|
|
input_copy = copy.deepcopy(input)
|
|
output = model(*input_copy)
|
|
if isinstance(output, torch.Tensor):
|
|
output = (output,)
|
|
|
|
ort_sess = convert_to_onnx(model, input=input, opset_version=self.opset_version,
|
|
example_outputs=output, do_constant_folding=do_constant_folding,
|
|
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
|
|
dynamic_axes=dynamic_axes, input_names=input_names,
|
|
output_names=output_names, fixed_batch_size=fixed_batch_size, training=None,
|
|
onnx_shape_inference=self.onnx_shape_inference,
|
|
use_new_jit_passes=self.use_new_jit_passes)
|
|
|
|
ort_outs = run_ort(ort_sess, input)
|
|
ort_compare_with_pytorch(ort_outs, output, rtol, atol)
|
|
|
|
|
|
# if additional test inputs are provided run the onnx
|
|
# model with these inputs and check the outputs
|
|
if test_with_inputs is not None:
|
|
for test_input in test_with_inputs:
|
|
if isinstance(test_input, torch.Tensor):
|
|
test_input = (test_input,)
|
|
test_input_copy = copy.deepcopy(test_input)
|
|
output = model(*test_input_copy)
|
|
if isinstance(output, torch.Tensor):
|
|
output = (output,)
|
|
ort_outs = run_ort(ort_sess, test_input)
|
|
ort_compare_with_pytorch(ort_outs, output, rtol, atol)
|
|
|
|
|
|
class TestONNXRuntime(unittest.TestCase):
|
|
from torch.onnx.symbolic_helper import _export_onnx_opset_version
|
|
opset_version = _export_onnx_opset_version
|
|
keep_initializers_as_inputs = True # For IR version 3 type export.
|
|
use_new_jit_passes = False # For testing main code-path
|
|
onnx_shape_inference = False
|
|
|
|
def setUp(self):
|
|
torch.manual_seed(0)
|
|
onnxruntime.set_seed(0)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(0)
|
|
np.random.seed(seed=0)
|
|
self.is_script_test_enabled = True
|
|
|
|
def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=True,
|
|
batch_size=2, use_gpu=True, dynamic_axes=None, test_with_inputs=None,
|
|
input_names=None, output_names=None, fixed_batch_size=False):
|
|
def _run_test(m):
|
|
return run_model_test(self, m, batch_size=batch_size,
|
|
input=input, use_gpu=use_gpu, rtol=rtol, atol=atol,
|
|
do_constant_folding=do_constant_folding,
|
|
dynamic_axes=dynamic_axes, test_with_inputs=test_with_inputs,
|
|
input_names=input_names, output_names=output_names,
|
|
fixed_batch_size=fixed_batch_size)
|
|
if self.is_script_test_enabled and self.use_new_jit_passes:
|
|
script_model = torch.jit.script(model)
|
|
_run_test(script_model)
|
|
|
|
_run_test(model)
|
|
|
|
def run_model_test_with_external_data(self, model, input, rtol=0.001, atol=1e-7,
|
|
example_outputs=None, do_constant_folding=True,
|
|
dynamic_axes=None, input_names=None, output_names=None,
|
|
ort_optim_on=True):
|
|
import os
|
|
import tempfile
|
|
|
|
model.eval()
|
|
with torch.no_grad():
|
|
if isinstance(input, torch.Tensor):
|
|
input = (input,)
|
|
# In-place operators will update input tensor data as well.
|
|
# Thus inputs are replicated before every forward call.
|
|
input_copy = copy.deepcopy(input)
|
|
output = model(*input_copy)
|
|
if isinstance(output, torch.Tensor):
|
|
output = (output,)
|
|
|
|
# export the model to ONNX
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
model_file_name = os.path.join(tmpdirname, 'model.onnx')
|
|
input_copy = copy.deepcopy(input)
|
|
torch.onnx.export(model, input_copy, model_file_name,
|
|
opset_version=self.opset_version,
|
|
example_outputs=output,
|
|
verbose=False,
|
|
do_constant_folding=do_constant_folding,
|
|
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
|
|
dynamic_axes=dynamic_axes,
|
|
input_names=input_names, output_names=output_names,
|
|
use_external_data_format=True)
|
|
# compute onnxruntime output prediction
|
|
ort_sess_opt = onnxruntime.SessionOptions()
|
|
ort_sess_opt.graph_optimization_level = \
|
|
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED if ort_optim_on else \
|
|
onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
|
|
ort_sess = onnxruntime.InferenceSession(model_file_name, sess_options=ort_sess_opt)
|
|
input_copy = copy.deepcopy(input)
|
|
ort_outs = run_ort(ort_sess, input_copy)
|
|
ort_compare_with_pytorch(ort_outs, output, rtol, atol)
|
|
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
|
|
def test_embedding_model_with_external_data(self):
|
|
class LargeModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(LargeModel, self).__init__()
|
|
dim = 15
|
|
n = 4 * 100
|
|
self.emb = torch.nn.Embedding(n, dim)
|
|
self.lin1 = torch.nn.Linear(dim, 1)
|
|
self.seq = torch.nn.Sequential(
|
|
self.emb,
|
|
self.lin1,
|
|
)
|
|
|
|
def forward(self, input):
|
|
return self.seq(input)
|
|
|
|
model = LargeModel()
|
|
x = torch.tensor([2], dtype=torch.long)
|
|
self.run_model_test_with_external_data(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
|
|
def test_mobilenet_v2_with_external_data(self):
|
|
model = torchvision.models.mobilenet_v2(pretrained=True)
|
|
x = torch.randn(2, 3, 224, 224, requires_grad=True)
|
|
# We are turning off Onnx Runtime optimization off in this test,
|
|
# because external data format is not supported to in ORT optimizer.
|
|
# Once that support is added, we can set ort_optim_on=True (default).
|
|
self.run_model_test_with_external_data(model, x, rtol=1e-3, atol=1e-5,
|
|
ort_optim_on=False)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
|
|
def test_attribute_with_external_data(self):
|
|
class LargeModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + torch.ones(2, 1024)
|
|
|
|
x = torch.randn(2, 1)
|
|
self.run_model_test_with_external_data(LargeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
|
|
@unittest.skip("Enable this once large model with subgraph is supported in ORT")
|
|
def test_subgraph_with_external_data(self):
|
|
class LargeModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
for i in range(x.size(0)):
|
|
x = x + torch.ones(2, 1024)
|
|
return x
|
|
|
|
x = torch.randn(2, 1)
|
|
self.run_model_test_with_external_data(torch.jit.script(LargeModel()), x)
|
|
|
|
def test_fuse_conv_bn1d(self):
|
|
class Fuse(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Fuse, self).__init__()
|
|
self.conv = torch.nn.Conv1d(16, 33, 3, stride=2)
|
|
self.bn = torch.nn.BatchNorm1d(33)
|
|
|
|
def forward(self, x):
|
|
out = self.conv(x)
|
|
return self.bn(out)
|
|
|
|
model = Fuse()
|
|
x = torch.randn(20, 16, 50, requires_grad=True)
|
|
self.run_test(model, (x,))
|
|
|
|
def test_fuse_conv_bn2d(self):
|
|
class Fuse(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Fuse, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 2, kernel_size=1, stride=2, padding=3, bias=False)
|
|
self.bn = torch.nn.BatchNorm2d(2)
|
|
|
|
def forward(self, x):
|
|
out = self.conv(x)
|
|
return self.bn(out)
|
|
|
|
model = Fuse()
|
|
x = torch.randn(2, 3, 2, 2, requires_grad=True)
|
|
self.run_test(model, (x,))
|
|
|
|
def test_fuse_conv_bn3d(self):
|
|
class Fuse(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Fuse, self).__init__()
|
|
self.conv = torch.nn.Conv3d(3, 2, (3, 5, 2), stride=(2, 1, 1), padding=(3, 2, 0), bias=False)
|
|
self.bn = torch.nn.BatchNorm3d(2)
|
|
|
|
def forward(self, x):
|
|
out = self.conv(x)
|
|
return self.bn(out)
|
|
|
|
model = Fuse()
|
|
x = torch.randn(2, 3, 10, 50, 100, requires_grad=True)
|
|
self.run_test(model, (x,), rtol=1e-3, atol=1e-6)
|
|
|
|
def test_reshape_constant_fold(self):
|
|
class Reshape(torch.nn.Module):
|
|
def __init__(self, ):
|
|
super(Reshape, self).__init__()
|
|
self.register_buffer("weight", torch.ones(5))
|
|
|
|
def forward(self, x):
|
|
scale_1 = self.weight.reshape(1, -1, 1, 1)
|
|
return x * scale_1
|
|
|
|
x = torch.randn(4, 5)
|
|
self.run_test(Reshape(), (x,), rtol=1e-3, atol=1e-5)
|
|
|
|
def run_word_language_model(self, model_name):
|
|
ntokens = 50
|
|
emsize = 5
|
|
nhid = 5
|
|
nlayers = 5
|
|
dropout = 0.2
|
|
tied = False
|
|
batchsize = 5
|
|
model = word_language_model.RNNModel(model_name, ntokens, emsize,
|
|
nhid, nlayers, dropout, tied,
|
|
batchsize)
|
|
x = torch.arange(0, ntokens).long().view(-1, batchsize)
|
|
# Only support CPU version, since tracer is not working in GPU RNN.
|
|
self.run_test(model, (x, model.hidden))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest() # Faster RCNN model is not scriptable
|
|
def test_faster_rcnn(self):
|
|
model = torchvision.models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200,
|
|
max_size=300)
|
|
model.eval()
|
|
x = torch.randn(2, 3, 200, 300, requires_grad=True)
|
|
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
|
|
|
|
def get_image_from_url(self, url):
|
|
import os
|
|
from urllib.parse import urlsplit
|
|
from urllib import request
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
from torch._utils_internal import get_writable_path
|
|
|
|
filename = os.path.basename(urlsplit(url)[2])
|
|
data_dir = get_writable_path(os.path.join(os.path.dirname(__file__)))
|
|
path = os.path.join(data_dir, filename)
|
|
data = request.urlopen(url, timeout=15).read()
|
|
with open(path, 'wb') as f:
|
|
f.write(data)
|
|
image = Image.open(path).convert("RGB")
|
|
image = image.resize((300, 200), Image.BILINEAR)
|
|
to_tensor = transforms.ToTensor()
|
|
return to_tensor(image)
|
|
|
|
def get_test_images(self):
|
|
image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
|
|
image = self.get_image_from_url(url=image_url)
|
|
images = [image]
|
|
return images
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest()
|
|
def test_mask_rcnn(self):
|
|
model = torchvision.models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200,
|
|
max_size=300)
|
|
images = self.get_test_images()
|
|
self.run_test(model, (images,), rtol=1e-3, atol=1e-5)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest()
|
|
def test_keypoint_rcnn(self):
|
|
model = torchvision.models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200,
|
|
max_size=300)
|
|
images = self.get_test_images()
|
|
self.run_test(model, (images,), rtol=1e-3, atol=1e-5)
|
|
|
|
@disableScriptTest()
|
|
def test_word_language_model_RNN_TANH(self):
|
|
self.run_word_language_model("RNN_TANH")
|
|
|
|
@disableScriptTest()
|
|
def test_word_language_model_RNN_RELU(self):
|
|
self.run_word_language_model("RNN_RELU")
|
|
|
|
@disableScriptTest()
|
|
def test_word_language_model_LSTM(self):
|
|
self.run_word_language_model("LSTM")
|
|
|
|
@disableScriptTest()
|
|
def test_word_language_model_GRU(self):
|
|
self.run_word_language_model("GRU")
|
|
|
|
def test_index_1d(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[0]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), m1)
|
|
|
|
def test_index_2d_1dimslice(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[0:1, :]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), m1)
|
|
|
|
def test_index_2d_sliceint(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[1, :]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), m1)
|
|
|
|
def test_index_2d_neg_slice(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[0:-1, :]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), m1)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_index_mask(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[torch.tensor([0, 1, 0], dtype=torch.uint8)]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), m1)
|
|
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[torch.tensor([0, 1, 0], dtype=torch.bool)]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), m1)
|
|
|
|
@disableScriptTest()
|
|
def test_dict(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, x_in):
|
|
x_out = {}
|
|
x_out["test_key_out"] = torch.add(x_in[list(x_in.keys())[0]], list(x_in.keys())[0])
|
|
return x_out
|
|
|
|
x = {torch.tensor(1.): torch.randn(1, 2, 3)}
|
|
self.run_test(MyModel(), (x,))
|
|
|
|
@disableScriptTest()
|
|
def test_dict_str(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, x_in):
|
|
x_out = {}
|
|
x_out["test_key_out"] = torch.add(x_in["test_key_in"], 2.)
|
|
return x_out
|
|
|
|
x = {"test_key_in": torch.randn(1, 2, 3)}
|
|
self.run_test(MyModel(), (x,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_cste_script(self):
|
|
class MyModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.zeros(x.size(0)), torch.ones((x.size(1), x.size(0)), dtype=torch.int64)
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(MyModel(), x)
|
|
|
|
def test_scalar_tensor(self):
|
|
class test(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.scalar_tensor(input.size(0)), \
|
|
torch.scalar_tensor(input.size(1), dtype=torch.int64)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.randn(7, 8, 9)
|
|
model = test()
|
|
self.run_test(model, x, test_with_inputs=[y],
|
|
input_names=['input_1'],
|
|
dynamic_axes={'input_1': [0, 1, 2]})
|
|
|
|
def test_tensor(self):
|
|
class ScalarInputModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return torch.tensor(input.shape[1])
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(ScalarInputModel(), x)
|
|
|
|
class TensorInputModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return torch.tensor([input.shape[0], input.shape[1]])
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(TensorInputModel(), x)
|
|
|
|
class FloatInputModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return torch.tensor([float(input)])
|
|
|
|
x = torch.randn(1)
|
|
self.run_test(FloatInputModel(), x)
|
|
|
|
class InputWithDtypeModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return torch.tensor(input.shape[1], dtype=torch.long)
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(InputWithDtypeModel(), x)
|
|
|
|
class MixedInputModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return torch.tensor([input.shape[0], int(input)])
|
|
|
|
x = torch.randn(1)
|
|
self.run_test(MixedInputModel(), x)
|
|
|
|
def test_hardtanh(self):
|
|
model = torch.nn.Hardtanh(-1.5, 2.5)
|
|
x = torch.arange(-5, 5).to(dtype=torch.float32)
|
|
self.run_test(model, x)
|
|
|
|
def test_hardtanh_script_with_default_values(self):
|
|
class MyModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.nn.functional.hardtanh(x)
|
|
|
|
x = torch.arange(-5, 5).to(dtype=torch.float32)
|
|
self.run_test(MyModel(), x)
|
|
|
|
def test_clamp(self):
|
|
class ClampModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.clamp(-0.5, 0.5)
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(ClampModel(), x)
|
|
|
|
class ClampMinModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.clamp(min=-0.5)
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(ClampMinModel(), x)
|
|
|
|
class ClampMaxModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.clamp(max=0.5)
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(ClampMaxModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_clamp_dyn(self):
|
|
class ClampMaxModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x.clamp(None, x.size(0))
|
|
|
|
x = torch.arange(16).view(4, 4).float()
|
|
self.run_test(ClampMaxModel(), x)
|
|
|
|
|
|
class ClampMinModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x.clamp(x.size(0), None)
|
|
|
|
x = torch.arange(16).view(4, 4).float()
|
|
self.run_test(ClampMinModel(), x)
|
|
|
|
class ClampMinMaxModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x.clamp(x.size(0), x.size(1))
|
|
|
|
x = torch.arange(16).view(2, 8).float()
|
|
self.run_test(ClampMinMaxModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_full_trace(self):
|
|
class FullModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.full((3, 4), x, dtype=torch.long)
|
|
|
|
x = torch.tensor(12)
|
|
self.run_test(FullModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_full_script(self):
|
|
class FullModelScripting(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.full((3, 4), x, dtype=torch.long)
|
|
|
|
x = torch.tensor(12)
|
|
self.run_test(FullModelScripting(), x)
|
|
|
|
def test_fuse_addmm(self):
|
|
class AddmmModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mm(x, x) + x
|
|
|
|
x = torch.ones(3, 3)
|
|
self.run_test(AddmmModel(), x)
|
|
|
|
def test_maxpool(self):
|
|
model = torch.nn.MaxPool1d(2, stride=1)
|
|
x = torch.randn(20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
def test_conv(self):
|
|
class TraceModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TraceModel, self).__init__()
|
|
self.conv1 = torch.nn.Conv1d(16, 33, 3, stride=2)
|
|
self.conv2 = torch.nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
|
|
self.conv3 = torch.nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
|
|
|
|
def forward(self, input1, input2, input3):
|
|
return self.conv1(input1), self.conv2(input2), self.conv3(input3)
|
|
|
|
class ScriptModel(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ScriptModel, self).__init__()
|
|
self.conv1 = torch.nn.Conv1d(16, 33, 3, stride=2)
|
|
self.conv2 = torch.nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
|
|
self.conv3 = torch.nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input1, input2, input3):
|
|
return self.conv1(input1), self.conv2(input2), self.conv3(input3)
|
|
|
|
x1 = torch.randn(20, 16, 50)
|
|
x2 = torch.randn(20, 16, 50, 100)
|
|
x3 = torch.randn(20, 16, 10, 50, 100)
|
|
|
|
self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
|
|
self.run_test(ScriptModel(), (x1, x2, x3), atol=10e-5)
|
|
|
|
def test_conv_shape_inference(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Model, self).__init__()
|
|
self.conv2 = torch.nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
|
|
|
|
def forward(self, input):
|
|
return self.conv2(input) + 2
|
|
|
|
x = torch.randn(20, 16, 50, 100)
|
|
self.run_test(Model(), x, atol=10e-5,
|
|
input_names=['x'],
|
|
dynamic_axes={'x': [0]})
|
|
|
|
def test_conv_transpose(self):
|
|
class TraceModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TraceModel, self).__init__()
|
|
self.conv1 = torch.nn.ConvTranspose1d(16, 33, 3, stride=2)
|
|
self.conv2 = torch.nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
|
|
self.conv3 = torch.nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
|
|
|
|
def forward(self, input1, input2, input3):
|
|
return self.conv1(input1), self.conv2(input2), self.conv3(input3)
|
|
|
|
class ScriptModel(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ScriptModel, self).__init__()
|
|
self.conv1 = torch.nn.ConvTranspose1d(16, 33, 3, stride=2)
|
|
self.conv2 = torch.nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
|
|
self.conv3 = torch.nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input1, input2, input3):
|
|
return self.conv1(input1), self.conv2(input2), self.conv3(input3)
|
|
|
|
x1 = torch.randn(20, 16, 50)
|
|
x2 = torch.randn(20, 16, 50, 100)
|
|
x3 = torch.randn(20, 16, 10, 50, 100)
|
|
|
|
self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
|
|
self.run_test(ScriptModel(), (x1, x2, x3), atol=10e-5)
|
|
|
|
# Conversion of Transpose depends on input shape to be known.
|
|
# The following test only works when onnx shape inference is enabled.
|
|
@skipIfONNXShapeInference(False)
|
|
def test_transpose_infer_shape(self):
|
|
class TransposeModule(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(TransposeModule, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x.transpose(0, 1)
|
|
|
|
x = torch.randn(32, 3, 64, 64)
|
|
self.run_test(TransposeModule(), x)
|
|
|
|
def squeeze_model_tests(self, d, x1, x2):
|
|
class Squeeze(torch.nn.Module):
|
|
def __init__(self, d):
|
|
super(Squeeze, self).__init__()
|
|
self.d = d
|
|
|
|
def forward(self, x):
|
|
if self.d is not None:
|
|
return torch.squeeze(x, dim=self.d)
|
|
else:
|
|
return torch.squeeze(x)
|
|
|
|
x2 = [] if x2 is None else [x2]
|
|
self.run_test(Squeeze(d), x1, input_names=['input'], dynamic_axes={'input': {0: '0', 1: '1', 2: '2'}}, test_with_inputs=x2)
|
|
|
|
def test_squeeze_without_no_op(self):
|
|
x = torch.randn(2, 1, 4)
|
|
self.squeeze_model_tests(1, x, None)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_squeeze(self):
|
|
x_squeeze = torch.randn(2, 1, 4)
|
|
x_noop = torch.randn(2, 2, 3)
|
|
self.squeeze_model_tests(1, x_squeeze, x_noop)
|
|
|
|
def test_squeeze_neg_without_no_op(self):
|
|
x = torch.randn(2, 1, 4)
|
|
self.squeeze_model_tests(-2, x, None)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_squeeze_neg(self):
|
|
x_squeeze = torch.randn(2, 1, 4)
|
|
x_noop = torch.randn(2, 2, 3)
|
|
self.squeeze_model_tests(-2, x_squeeze, x_noop)
|
|
|
|
def test_squeeze_all_dims(self):
|
|
x_squeeze = torch.randn(2, 1, 4)
|
|
x_noop = torch.randn(2, 2, 3)
|
|
self.squeeze_model_tests(None, x_squeeze, x_noop)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_squeeze_no_op(self):
|
|
x_noop = torch.randn(2, 1, 4)
|
|
x_squeeze = torch.randn(2, 2, 1)
|
|
self.squeeze_model_tests(2, x_noop, x_squeeze)
|
|
|
|
def test_squeeze_no_op_without_additional_inputs(self):
|
|
x_noop = torch.randn(2, 1, 4)
|
|
self.squeeze_model_tests(2, x_noop, None)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_squeeze_runtime_dim(self):
|
|
class Squeeze(torch.nn.Module):
|
|
def forward(self, d1, d2):
|
|
t = torch.zeros(d1[0], d2[0])
|
|
return t.squeeze(0)
|
|
|
|
d1 = torch.tensor([1])
|
|
d3 = torch.tensor([3])
|
|
d4 = torch.tensor([4])
|
|
self.run_test(Squeeze(), (d1, d4), test_with_inputs=[(d3, d4)])
|
|
self.run_test(Squeeze(), (d3, d4), test_with_inputs=[(d1, d3)])
|
|
|
|
def test_unsqueeze(self):
|
|
class Unsqueeze(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.unsqueeze(x, dim=-2)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(Unsqueeze(), x)
|
|
|
|
def test_maxpool_default_stride(self):
|
|
class MaxPoolModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.max_pool2d(x, 2)
|
|
|
|
model = MaxPoolModel()
|
|
x = torch.randn(10, 20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(8)
|
|
def test_maxpool_adaptive(self):
|
|
model = torch.nn.AdaptiveMaxPool1d((5), return_indices=False)
|
|
x = torch.randn(20, 16, 50, requires_grad=True)
|
|
self.run_test(model, x)
|
|
|
|
def test_maxpool_2d(self):
|
|
model = torch.nn.MaxPool2d(5, padding=(1, 2))
|
|
x = torch.randn(1, 20, 16, 50, requires_grad=True)
|
|
self.run_test(model, x)
|
|
|
|
def test_maxpool_1d_ceil(self):
|
|
model = torch.nn.MaxPool1d(3, 2, ceil_mode=True)
|
|
x = torch.randn(20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
def test_maxpool_2d_ceil(self):
|
|
model = torch.nn.MaxPool2d(3, 2, ceil_mode=True)
|
|
x = torch.randn(20, 16, 50, 32)
|
|
self.run_test(model, x)
|
|
|
|
def test_maxpool_3d_ceil(self):
|
|
model = torch.nn.MaxPool3d(3, 2, ceil_mode=True)
|
|
x = torch.randn(20, 16, 50, 44, 31)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(8)
|
|
@disableScriptTest() # Functional module not scriptable
|
|
def test_maxpool_with_indices(self):
|
|
model = torch.nn.MaxPool1d(2, stride=1, return_indices=True)
|
|
x = torch.randn(20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_maxpool_dilation(self):
|
|
model = torch.nn.MaxPool1d(2, stride=1, dilation=2)
|
|
x = torch.randn(20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
def test_avgpool_default_stride(self):
|
|
class AvgPoolModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.avg_pool2d(x, 2)
|
|
|
|
model = AvgPoolModel()
|
|
x = torch.randn(10, 20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
def test_avgpool(self):
|
|
model = torch.nn.AvgPool1d(2, stride=1)
|
|
x = torch.randn(20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
def test_avgpool_1d_ceil(self):
|
|
model = torch.nn.AvgPool1d(3, 2, ceil_mode=True)
|
|
x = torch.randn(1, 1, 7)
|
|
self.run_test(model, x)
|
|
|
|
def test_avgpool_2d_ceil(self):
|
|
model = torch.nn.AvgPool2d(3, 2, ceil_mode=True)
|
|
x = torch.randn(20, 16, 50, 32)
|
|
self.run_test(model, x)
|
|
|
|
def test_avgpool_3d_ceil(self):
|
|
model = torch.nn.AvgPool3d(3, 2, ceil_mode=True)
|
|
x = torch.randn(20, 16, 50, 44, 31)
|
|
self.run_test(model, x)
|
|
|
|
def test_arithmetic(self):
|
|
class ArithmeticModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = x + 2
|
|
x = x - 4
|
|
x = x * 6
|
|
x = x / 8
|
|
return x
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(ArithmeticModule(), x)
|
|
|
|
# In scripting the first transpose node do not carry shape and dtype info.
|
|
# The following test only works when onnx shape inference is enabled.
|
|
@skipIfONNXShapeInference(False)
|
|
def test_arithmetic_infer_dtype(self):
|
|
class ArithmeticModule(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
x = x.t()
|
|
x = x + 2
|
|
x = x - 4
|
|
x = x * 6
|
|
x = x / 8
|
|
return x
|
|
|
|
x = torch.randn(2, 3)
|
|
self.run_test(ArithmeticModule(), x)
|
|
|
|
def test_floor_div(self):
|
|
class FloorDivModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x // 3, x // 2., \
|
|
x.to(dtype=torch.float64) // 3, x.to(dtype=torch.float64) // 2., \
|
|
x.to(dtype=torch.int64) // 3, x.to(dtype=torch.int64) // 2., \
|
|
x // (y + 1.).to(dtype=torch.int64), x // y, \
|
|
x.to(dtype=torch.float64) // y.to(dtype=torch.int64), x.to(dtype=torch.float64) // y.to(dtype=torch.float64), \
|
|
x.to(dtype=torch.int64) // y.to(dtype=torch.int64), x.to(dtype=torch.int64) // y
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4)
|
|
self.run_test(FloorDivModule(), (x, y))
|
|
|
|
def test_floor_div_script(self):
|
|
class FloorDivModule(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x, y):
|
|
return x // 3, x // 2., x // y
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.randn(2, 3, 4)
|
|
self.run_test(FloorDivModule(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_floordiv(self):
|
|
class FloordivModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.new_zeros(x.size(2) // x.size(1))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(FloordivModule(), (x,))
|
|
|
|
def test_div(self):
|
|
class DivModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x / y, torch.true_divide(x, y)
|
|
|
|
x = torch.randn(2, 3, 4).to(torch.int)
|
|
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
|
|
self.run_test(DivModule(), (x, y))
|
|
self.run_test(DivModule(), (x.float(), y.float()))
|
|
|
|
# Note: div cannot (generally) be exported via scripting
|
|
# since its type promotion logic is dependent on knowing the scalar types
|
|
# of the input tensors. That is, the ONNX graph is dependent on the
|
|
# data type of the inputs. This makes it appropriate for tracing only.
|
|
def test_div_promotion_trace(self):
|
|
class DivModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x / y, torch.true_divide(x, y)
|
|
|
|
x = torch.randn(2, 3, 4).to(torch.int)
|
|
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
|
|
|
|
prev_default = torch.get_default_dtype()
|
|
|
|
torch.set_default_dtype(torch.float)
|
|
self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
|
|
|
|
torch.set_default_dtype(torch.double)
|
|
self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
|
|
|
|
torch.set_default_dtype(prev_default)
|
|
|
|
# In scripting x, y do not carry shape and dtype info.
|
|
# The following test only works when onnx shape inference is enabled.
|
|
@skipIfONNXShapeInference(False)
|
|
def test_div_promotion_script(self):
|
|
class DivModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
# Add transpose to hide shape/type information
|
|
# Otherwise shape and type are still avaiable from input.
|
|
x = x.transpose(1, 2)
|
|
y = y.transpose(1, 2)
|
|
return x / y, torch.true_divide(x, y)
|
|
|
|
x = torch.randn(2, 3, 4).to(torch.int)
|
|
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
|
|
|
|
prev_default = torch.get_default_dtype()
|
|
|
|
# 1. x,y are int, and output is float.
|
|
# This can be handled by the default case, where both are cast to float.
|
|
# It works even if type of x, y are unknown.
|
|
torch.set_default_dtype(torch.float)
|
|
self.run_test(torch.jit.script(DivModule()), (x, y))
|
|
|
|
# 2. x,y are int, and output is double.
|
|
# This can be handled by the default case, where both are cast to double.
|
|
# It works even if type of x, y are unknown.
|
|
torch.set_default_dtype(torch.double)
|
|
self.run_test(torch.jit.script(DivModule()), (x, y))
|
|
|
|
# 3. x is int, y is double, and output is double.
|
|
# This can only be handled when both type of x and y are known.
|
|
torch.set_default_dtype(prev_default)
|
|
x = torch.randn(2, 3, 4).to(torch.int)
|
|
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.double)
|
|
self.run_test(torch.jit.script(DivModule()), (x, y))
|
|
|
|
def test_slice_trace(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x[0:1]
|
|
|
|
x = torch.randn(3)
|
|
self.run_test(MyModule(), x)
|
|
|
|
def test_slice_neg(self):
|
|
class NegSlice(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x[-1:]
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
self.run_test(NegSlice(), x)
|
|
|
|
def test_slice_neg_large(self):
|
|
class NegSlice(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x[:, :, -3:-1, :, -1]
|
|
|
|
x = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(NegSlice(), x)
|
|
|
|
def test_slice_neg_large_negone(self):
|
|
class NegSlice(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x[:, :, :, :, -1]
|
|
|
|
x = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(NegSlice(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_slice_with_input_index(self):
|
|
class InputIndexSlice(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
x[:y.size(0), 0, :] = y
|
|
return x
|
|
|
|
x = torch.zeros((56, 6, 256))
|
|
y = torch.rand((22, 256))
|
|
self.run_test(InputIndexSlice(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
@disableScriptTest() # scripting tuple/list append
|
|
def test_slice_dynamic(self):
|
|
class DynamicSliceExportMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
results = []
|
|
for i in range(4):
|
|
results.append(x[:x.size(0) - i, i:x.size(2), i:3])
|
|
return tuple(results)
|
|
|
|
x = torch.rand(5, 5, 5)
|
|
y = torch.randn(6, 7, 8)
|
|
self.run_test(DynamicSliceExportMod(), x, test_with_inputs=[y],
|
|
input_names=['input_1'],
|
|
output_names=['output_1'],
|
|
dynamic_axes={'input_1': [0, 1, 2],
|
|
'output_1': [0, 1, 2]})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_slice_dynamic_script(self):
|
|
class DynamicSliceModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x[1:x.size(1)]
|
|
|
|
x = torch.rand(1, 2)
|
|
self.run_test(DynamicSliceModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_slice_dynamic_shape_script(self):
|
|
class DynamicSliceModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.new_zeros(x.shape[1:x.size(2)])
|
|
|
|
x = torch.rand(1, 2, 3, 4)
|
|
self.run_test(DynamicSliceModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
@disableScriptTest() # scripting tuple/list append
|
|
def test_slice_dynamic_to_end(self):
|
|
class DynamicSliceExportMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
results = []
|
|
for i in range(4):
|
|
results.append(x[:, i:, x.size(2) - 5])
|
|
return tuple(results)
|
|
|
|
x = torch.rand(5, 5, 5)
|
|
self.run_test(DynamicSliceExportMod(), x,
|
|
dynamic_axes={'input_1': [0, 1, 2],
|
|
'output_1': [0, 1, 2]})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_arange_dynamic(self):
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.arange(input.shape[0]), \
|
|
torch.arange(12), \
|
|
torch.arange(start=input.shape[0], end=input.shape[0] + 5)
|
|
|
|
x = torch.randn(5, 3, 2)
|
|
y = torch.randn(8, 3, 2)
|
|
self.run_test(ArangeModel(), x, test_with_inputs=[y],
|
|
input_names=['input_1'],
|
|
output_names=['output_1', 'output_2', 'output_3'],
|
|
dynamic_axes={'input_1': [0],
|
|
'output_1': [0]})
|
|
self.run_test(torch.jit.script(ArangeModel()), x,
|
|
test_with_inputs=[y], input_names=['input_1'],
|
|
output_names=['output_1', 'output_2', 'output_3'],
|
|
dynamic_axes={'input_1': [0],
|
|
'output_1': [0]})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_dynamic_arange_out(self):
|
|
class ArangeOutModel(torch.nn.Module):
|
|
def forward(self, end):
|
|
out_t = torch.tensor([1], dtype=torch.int64)
|
|
return torch.arange(end, out=out_t)
|
|
|
|
x = torch.tensor(8)
|
|
self.run_test(ArangeOutModel(), (x))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_dynamic_arange_start_out(self):
|
|
class ArangeStartOutModel(torch.nn.Module):
|
|
def forward(self, start, end):
|
|
out_t = torch.tensor([1], dtype=torch.int64)
|
|
return torch.arange(start.size(0), end, out=out_t)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.tensor(8)
|
|
self.run_test(ArangeStartOutModel(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_arange(self):
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, start, end):
|
|
return torch.arange(start.size(0), end, 1.5, dtype=torch.int64)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.tensor(8.5, dtype=torch.float)
|
|
self.run_test(ArangeModel(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_arange_out(self):
|
|
class ArangeOutModel(torch.nn.Module):
|
|
def forward(self, end):
|
|
out_t = torch.tensor([1], dtype=torch.float)
|
|
return torch.arange(end, out=out_t)
|
|
|
|
x = torch.tensor(8.5, dtype=torch.float)
|
|
self.run_test(ArangeOutModel(), (x))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_arange_start_out(self):
|
|
class ArangeStartOutModel(torch.nn.Module):
|
|
def forward(self, start, end):
|
|
out_t = torch.tensor([1], dtype=torch.float)
|
|
return torch.arange(start.size(0), end, out=out_t)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.tensor(8.5, dtype=torch.float)
|
|
self.run_test(ArangeStartOutModel(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_arange_no_type(self):
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, end):
|
|
return torch.arange(end), \
|
|
torch.arange(0, end)
|
|
|
|
x = torch.tensor(6.2, dtype=torch.float)
|
|
self.run_test(ArangeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_size(self):
|
|
class SizeModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.arange(input.size(0)), torch.arange(input.size(-1)), torch.ones(input.shape)
|
|
|
|
x = torch.randn(5, 3, 2)
|
|
self.run_test(SizeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest() # x.stride() not scriptable
|
|
def test_as_strided(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
chunk_size = list(x.size())
|
|
chunk_size[1] = chunk_size[1] * 2 - 1
|
|
chunk_stride = list(x.stride())
|
|
chunk_stride[1] = chunk_stride[1] // 2
|
|
return x.as_strided((3, 3, 3), (1, 4, 2), storage_offset=2), x.as_strided(chunk_size, chunk_stride)
|
|
|
|
x = torch.randn(5, 8, 7)
|
|
self.run_test(Model(), x)
|
|
|
|
@disableScriptTest() # Ellipses followed by tensor indexing not scriptable
|
|
def test_tensor_index_advanced_indexing_ellipsis(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[..., torch.tensor([2, 1]), torch.tensor([0, 3])]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), (m1,))
|
|
|
|
def test_tensor_index_advanced_indexing(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[:, torch.tensor([[0, 2], [1, 1]]), :, torch.tensor([2, 1]), torch.tensor([0, 3])]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), (m1,))
|
|
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[:, torch.tensor([0, 2]), None, 2:4, torch.tensor([[1, 3], [4, 0]])]
|
|
|
|
self.run_test(MyModel(), (m1,))
|
|
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[:, torch.tensor([0, 2]), torch.tensor([1]), 2:4, torch.tensor([[1], [4]])]
|
|
|
|
self.run_test(MyModel(), (m1,))
|
|
|
|
def test_tensor_index_advanced_indexing_consecutive(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[:, torch.tensor([0, 2]), torch.tensor([[1, 3], [4, 0]]), None]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), (m1,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_index_put(self):
|
|
class IndexPutModel(torch.nn.Module):
|
|
def forward(self, x, ind, update):
|
|
x[ind] = update
|
|
return x
|
|
|
|
x = torch.randn(3, 4)
|
|
ind = torch.tensor([1], dtype=torch.long)
|
|
update = torch.ones(4)
|
|
self.run_test(IndexPutModel(), (x, ind, update))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_index_put_accumulate(self):
|
|
class IndexPutModel(torch.nn.Module):
|
|
def forward(self, x, ind, update):
|
|
return x.index_put((ind, ), update, accumulate=True)
|
|
|
|
x = torch.randn(3, 4)
|
|
ind = torch.tensor([2], dtype=torch.long)
|
|
update = torch.ones(4)
|
|
self.run_test(IndexPutModel(), (x, ind, update))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_index_put_slice_index(self):
|
|
class IndexPutModel(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[1:2, 1:3, torch.tensor([1])] += update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
update = torch.tensor([10, 15]).view(1, 2, 1)
|
|
self.run_test(IndexPutModel(), (x, update))
|
|
|
|
class IndexPutModel2(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[torch.tensor([0, 2]), torch.tensor([1, 2])] += update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
update = torch.randn(2, 5)
|
|
self.run_test(IndexPutModel2(), (x, update))
|
|
|
|
class IndexPutModel3(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[torch.tensor([0, 2]), 1:2] += update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
update = torch.tensor([10, 15]).view(2, 1, 1)
|
|
self.run_test(IndexPutModel3(), (x, update))
|
|
|
|
class IndexPutModel4(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[torch.tensor([0, 2]), 2] += update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
update = torch.tensor([10, 15]).view(2, 1)
|
|
self.run_test(IndexPutModel4(), (x, update))
|
|
|
|
class IndexPutModel5(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[1:3, torch.tensor([0, 2]), 2] += update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
update = torch.tensor([10, 15]).view(2, 1)
|
|
self.run_test(IndexPutModel5(), (x, update))
|
|
|
|
class IndexPutModel6(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[1:3, 0] = update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
update = torch.arange(2 * 5).to(torch.float).view(2, 5)
|
|
self.run_test(IndexPutModel6(), (x, update))
|
|
|
|
class IndexPutModel7(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[1:, 0] = update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
update = torch.arange(2 * 5).to(torch.float).view(2, 5)
|
|
self.run_test(IndexPutModel7(), (x, update))
|
|
|
|
class IndexPutModel8(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[:3, 0] = update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
update = torch.arange(3 * 5).to(torch.float).view(3, 5)
|
|
self.run_test(IndexPutModel8(), (x, update))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest() # Ellipses followed by tensor indexing not scriptable
|
|
def test_index_put_ellipsis(self):
|
|
class IndexPutModel(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[..., torch.tensor([2, 1, 3]), 2:4] += update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5, 6, 7)
|
|
update = torch.randn(3, 1, 1, 3, 2)
|
|
self.run_test(IndexPutModel(), (x, update))
|
|
|
|
class IndexPutModel2(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[2, ..., torch.tensor([2, 1, 3]), 2:4] += update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5, 6, 7)
|
|
update = torch.randn(4, 1, 3, 2)
|
|
self.run_test(IndexPutModel2(), (x, update))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_copy_(self):
|
|
class CopyModel(torch.nn.Module):
|
|
def forward(self, x, data):
|
|
x[1:3] = data
|
|
return x
|
|
|
|
x = torch.randn(3, 4)
|
|
update = torch.randn(2, 4)
|
|
self.run_test(CopyModel(), (x, update))
|
|
|
|
# mixed slice and select
|
|
class CopyModel2(torch.nn.Module):
|
|
def forward(self, x, data):
|
|
x[1:3, 0] = data
|
|
return x
|
|
|
|
x = torch.randn(3, 4)
|
|
update = torch.tensor([0], dtype=torch.float32)
|
|
self.run_test(CopyModel2(), (x, update))
|
|
|
|
update = torch.tensor([2, 3], dtype=torch.float32)
|
|
self.run_test(CopyModel2(), (x, update))
|
|
|
|
update = torch.randn(2)
|
|
self.run_test(CopyModel2(), (x, update))
|
|
|
|
class CopyModel3(torch.nn.Module):
|
|
def forward(self, x, data):
|
|
x[1, 1:3] = data
|
|
return x
|
|
|
|
x = torch.randn(3, 4)
|
|
update = torch.tensor([0], dtype=torch.float32)
|
|
self.run_test(CopyModel3(), (x, update))
|
|
|
|
update = torch.tensor([2, 3], dtype=torch.float32)
|
|
self.run_test(CopyModel3(), (x, update))
|
|
|
|
update = torch.randn(2)
|
|
self.run_test(CopyModel3(), (x, update))
|
|
|
|
class CopyModel4(torch.nn.Module):
|
|
def forward(self, x, ind, data):
|
|
x[ind] = data
|
|
return x
|
|
|
|
x = torch.randn(3, 4)
|
|
ind = torch.tensor(2)
|
|
data = torch.randn(4)
|
|
self.run_test(CopyModel4(), (x, ind, data))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest() # Model not scriptable (output with shape doesn't match the broadcast shape)
|
|
def test_copy_tracing(self):
|
|
class CopyModel(torch.nn.Module):
|
|
def forward(self, x, data):
|
|
x[1, 1:3] = data
|
|
return x
|
|
|
|
x = torch.randn(3, 4)
|
|
update = torch.randn(1, 2)
|
|
self.run_test(CopyModel(), (x, update))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_copy_ellipsis(self):
|
|
class CopyModel(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[..., 1] = update
|
|
return x
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
update = torch.ones(1)
|
|
self.run_test(CopyModel(), (x, update))
|
|
|
|
x = torch.randn(2, 3, 4, 5, 6)
|
|
update = torch.ones(1)
|
|
self.run_test(CopyModel(), (x, update))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest() # Missing input size (with ellipsis indexing)
|
|
def test_copy_ellipsis_tracing(self):
|
|
class CopyModel(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[2, ..., 1:3] = update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5, 6)
|
|
|
|
update = torch.ones(1)
|
|
self.run_test(CopyModel(), (x, update))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_flip(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.flip(x, dims=[0])
|
|
|
|
x = torch.tensor(np.arange(6.0).reshape(2, 3))
|
|
self.run_test(MyModule(), x)
|
|
|
|
def test_random(self):
|
|
class RandN(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mul(x, (torch.randn(2, 3, 4) + x).size(0))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(RandN(), x)
|
|
|
|
class Rand(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mul(x, (torch.rand(2, 3, 4) + x).size(0))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(Rand(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest() # symbolic update for randn
|
|
def test_random_dynamic_size(self):
|
|
class RandN(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mul(x, torch.randn(x.size()).size(1))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(RandN(), x)
|
|
|
|
class Rand(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mul(x, torch.rand(x.size()).size(1))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(Rand(), x)
|
|
|
|
def test_random_like(self):
|
|
class RandNLike(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mul(x, torch.randn_like(x).size(0))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(RandNLike(), x)
|
|
self.run_test(torch.jit.script(RandNLike()), x)
|
|
|
|
class RandLike(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mul(x, torch.rand_like(x).size(0))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(RandLike(), x)
|
|
self.run_test(torch.jit.script(RandLike()), x)
|
|
|
|
def test_random_like_dtype(self):
|
|
class RandNLike(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mul(x.to(torch.double), torch.randn_like(x, dtype=torch.double).size(0))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(RandNLike(), x)
|
|
|
|
class RandLike(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mul(x.to(torch.double), torch.rand_like(x, dtype=torch.double).size(0))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(RandLike(), x)
|
|
|
|
def _interpolate(self, x, mode, use_size, is_upsample, align_corners=False):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
scale = 2.3 if is_upsample else 0.5
|
|
if len(x.size()) == 3:
|
|
scale_array = 2.3
|
|
if len(x.size()) == 4:
|
|
scale_array = [2.3, 5.1]
|
|
if len(x.size()) == 5:
|
|
scale_array = [3.3, 2.3, 5.1]
|
|
if use_size:
|
|
size_array = [int(float(v) * scale) for v in x.size()[2:]]
|
|
if align_corners:
|
|
return torch.nn.functional.interpolate(x, mode=mode, size=size_array[0], align_corners=True), \
|
|
torch.nn.functional.interpolate(x, mode=mode, size=size_array, align_corners=True)
|
|
return torch.nn.functional.interpolate(x, mode=mode, size=size_array[0]), \
|
|
torch.nn.functional.interpolate(x, mode=mode, size=size_array)
|
|
if align_corners:
|
|
return torch.nn.functional.interpolate(x, mode=mode, scale_factor=scale,
|
|
align_corners=True, recompute_scale_factor=False), \
|
|
torch.nn.functional.interpolate(x, mode=mode, scale_factor=scale_array,
|
|
align_corners=True, recompute_scale_factor=False)
|
|
return torch.nn.functional.interpolate(x, mode=mode,
|
|
scale_factor=scale, recompute_scale_factor=False), \
|
|
torch.nn.functional.interpolate(x, mode=mode,
|
|
scale_factor=scale_array, recompute_scale_factor=False)
|
|
|
|
self.run_test(MyModel(), x)
|
|
|
|
def _interpolate_script(self, x, mode, use_size, is_upsample, align_corners=False):
|
|
class MyModel(torch.jit.ScriptModule):
|
|
__constants__ = ['mode', 'use_size', 'is_upsample', 'size', 'scale', 'size_array', 'scale_array', 'align_corners']
|
|
|
|
def __init__(self, mode, use_size, is_upsample, align_corners):
|
|
super(MyModel, self).__init__()
|
|
self.mode = mode
|
|
self.use_size = use_size
|
|
self.is_upsample = is_upsample
|
|
self.align_corners = align_corners
|
|
self.scale = 2.0 if self.is_upsample else 0.5
|
|
self.size = 24 if self.is_upsample else 2
|
|
if x.dim() == 3:
|
|
self.scale_array = [2.3]
|
|
self.size_array = [16]
|
|
elif x.dim() == 4:
|
|
self.scale_array = [2.3, 3.1]
|
|
self.size_array = [16, 32]
|
|
else:
|
|
self.scale_array = [2.3, 3.1, 4.6]
|
|
self.size_array = [16, 32, 64]
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
if self.use_size:
|
|
if self.align_corners:
|
|
return torch.nn.functional.interpolate(x, mode=self.mode, size=self.size, align_corners=True), \
|
|
torch.nn.functional.interpolate(x, mode=self.mode, size=self.size_array, align_corners=True)
|
|
return torch.nn.functional.interpolate(x, mode=self.mode, size=self.size), \
|
|
torch.nn.functional.interpolate(x, mode=self.mode, size=self.size_array)
|
|
if self.align_corners:
|
|
return torch.nn.functional.interpolate(x, mode=self.mode,
|
|
scale_factor=self.scale, recompute_scale_factor=False), \
|
|
torch.nn.functional.interpolate(x, mode=self.mode,
|
|
scale_factor=self.scale_array, recompute_scale_factor=False)
|
|
return torch.nn.functional.interpolate(x, mode=self.mode,
|
|
scale_factor=self.scale, recompute_scale_factor=False), \
|
|
torch.nn.functional.interpolate(x, mode=self.mode,
|
|
scale_factor=self.scale_array, recompute_scale_factor=False)
|
|
|
|
model = MyModel(mode, use_size, is_upsample, align_corners)
|
|
self.run_test(model, x, atol=1e-6)
|
|
|
|
def _interpolate_tests(self, is_upsample):
|
|
# - cubic mode is not supported for opsets below 11;
|
|
# - linear mode does not match for opsets below 11;
|
|
modes = ["nearest", "linear", "bicubic"]
|
|
if self.opset_version < 11:
|
|
modes = ["nearest"]
|
|
x = [torch.randn(1, 2, 6, requires_grad=True),
|
|
torch.randn(1, 2, 4, 6, requires_grad=True),
|
|
torch.randn(1, 2, 4, 4, 6, requires_grad=True)]
|
|
|
|
for mode in modes:
|
|
for xi in x:
|
|
mode_i = mode
|
|
# TODO: enable bicubic downsample when ORT precision loss fixed
|
|
if mode == "bicubic" and xi.dim() != 4:
|
|
continue
|
|
elif mode == "linear":
|
|
if xi.dim() == 3:
|
|
# TODO : enable when linear mode is implemented for 1d inputs in ORT
|
|
continue
|
|
elif xi.dim() == 4:
|
|
mode_i = "bilinear"
|
|
elif xi.dim() == 5:
|
|
# TODO : enable when linear mode is implemented for 3d inputs in ORT
|
|
mode_i = "trilinear"
|
|
continue
|
|
self._interpolate(xi, mode_i, True, is_upsample)
|
|
# test with align_corners if supported
|
|
if mode != 'nearest':
|
|
self._interpolate(xi, mode_i, True, is_upsample, True)
|
|
self._interpolate_script(xi, mode_i, True, is_upsample, True)
|
|
# the following cases, require dynamic sizes/scales,
|
|
# which which is not supported for opset_version < 9
|
|
if self.opset_version >= 9:
|
|
self._interpolate_script(xi, mode_i, True, is_upsample)
|
|
self._interpolate(xi, mode_i, False, is_upsample)
|
|
# test with align_corners if supported
|
|
if mode != 'nearest':
|
|
self._interpolate(xi, mode_i, False, is_upsample, True)
|
|
self._interpolate_script(xi, mode_i, False, is_upsample, True)
|
|
self._interpolate_script(xi, mode_i, False, is_upsample)
|
|
|
|
@disableScriptTest()
|
|
def test_interpolate_upsample(self):
|
|
self._interpolate_tests(True)
|
|
|
|
@disableScriptTest()
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_interpolate_function_substitution(self):
|
|
class ScriptModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.nn.functional.interpolate(x, mode="nearest", scale_factor=2.)
|
|
|
|
class ScriptModule(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ScriptModule, self).__init__()
|
|
self.submodule = ScriptModel()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return self.submodule(input)
|
|
|
|
x = torch.randn(1, 2, 4, 4, 6)
|
|
self.run_test(ScriptModule(), (x,))
|
|
|
|
@torch.jit.script
|
|
def script_method(x):
|
|
return torch.nn.functional.interpolate(x, mode="nearest", scale_factor=2.)
|
|
|
|
class TracingModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return script_method(x)
|
|
|
|
self.run_test(TracingModule(), (x,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
@disableScriptTest()
|
|
def test_interpolate_downsample(self):
|
|
self._interpolate_tests(False)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest()
|
|
def test_interpolate_no_shape(self):
|
|
class MyModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x, y):
|
|
x = torch.add(x, x)
|
|
out1 = torch.nn.functional.interpolate(x, mode="bilinear", size=(16, 16), align_corners=False)
|
|
out2 = torch.nn.functional.interpolate(x, mode="nearest", size=(int(y.size(0)), int(y.size(1))))
|
|
return out1, out2
|
|
|
|
x = torch.randn(1, 2, 4, 4, requires_grad=True)
|
|
y = torch.randn(16, 16, requires_grad=True)
|
|
self.run_test(MyModel(), (x, y))
|
|
|
|
@disableScriptTest()
|
|
def test_groupnorm(self):
|
|
model = torch.nn.GroupNorm(3, 6, 0.002)
|
|
x = torch.randn(4, 6, 180, 180, 180)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.GroupNorm(1, 6, 0.002)
|
|
x = torch.randn(4, 6, 180, 180)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.GroupNorm(6, 6, 0.002)
|
|
x = torch.randn(4, 6, 180, 180)
|
|
self.run_test(model, x)
|
|
|
|
@disableScriptTest()
|
|
def test_groupnorm_noaffine(self):
|
|
model = torch.nn.GroupNorm(4, 8, 0.002, affine=False)
|
|
x = torch.randn(3, 8, 224, 224)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.GroupNorm(1, 6, 0.002, affine=False)
|
|
x = torch.randn(4, 6, 180, 180)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.GroupNorm(6, 6, 0.002, affine=False)
|
|
x = torch.randn(4, 6, 180, 180)
|
|
self.run_test(model, x)
|
|
|
|
def test_std(self):
|
|
class StandardDeviation(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std(input, unbiased=False)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = StandardDeviation()
|
|
self.run_test(model, x)
|
|
|
|
def test_pow(self):
|
|
class PowModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x.pow(y)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.randn(2, 3, 4)
|
|
self.run_test(PowModule(), (x, y))
|
|
|
|
x = torch.randint(10, (2, 3, 4))
|
|
y = torch.randint(10, (2, 3, 4)).to(dtype=torch.int32)
|
|
self.run_test(PowModule(), (x, y))
|
|
|
|
x = torch.randint(10, (2, 3, 4))
|
|
y = torch.randint(10, (2, 3, 4))
|
|
self.run_test(PowModule(), (x, y))
|
|
|
|
x = torch.randn(2, 3, 4).to(dtype=torch.float64)
|
|
y = torch.randint(10, (2, 3, 4))
|
|
self.run_test(PowModule(), (x, y))
|
|
|
|
def test_std_along_dims(self):
|
|
class StandardDeviation(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std(input, dim=(0, 1), unbiased=False)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = StandardDeviation()
|
|
self.run_test(model, x)
|
|
|
|
def test_std_keepdim(self):
|
|
class StandardDeviation(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std(input, dim=(0, 1), unbiased=False, keepdim=True)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = StandardDeviation()
|
|
self.run_test(model, x)
|
|
|
|
def test_bitshift(self):
|
|
class BitshiftModel(torch.nn.Module):
|
|
def forward(self, input, input2):
|
|
return input >> 1, input << 3.1, \
|
|
input2 >> torch.tensor([1, 2]), input2 << 4.2
|
|
input = torch.arange(24, dtype=torch.float32).reshape(3, 4, 2)
|
|
input2 = torch.arange(24, dtype=torch.int64).reshape(3, 4, 2)
|
|
self.run_test(BitshiftModel(), (input, input2))
|
|
|
|
def test_bitshift_other_fp(self):
|
|
class BitshiftModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input << 2.4
|
|
input = torch.arange(24, dtype=torch.int64).reshape(3, 4, 2)
|
|
self.run_test(BitshiftModel(), input)
|
|
|
|
# uint8 not implemented in ORT for Mul used in
|
|
# exporting bitshift for opset_version < 10
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_bitshift_uint8(self):
|
|
class BitshiftModel(torch.nn.Module):
|
|
def forward(self, input, input2):
|
|
return input >> 1, input << 3., \
|
|
input2 >> torch.tensor([1, 2], dtype=torch.uint8), input2 << 4.
|
|
input = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
|
|
input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
|
|
self.run_test(BitshiftModel(), (input, input2))
|
|
|
|
def test_narrow(self):
|
|
class NarrowModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.narrow(input, 0, 0, 2)
|
|
|
|
x = torch.randn(3, 3, requires_grad=True)
|
|
self.run_test(NarrowModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_narrow_dynamic(self):
|
|
class NarrowModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.narrow(input, 0, 0, input.shape[0] - 1)
|
|
|
|
x = torch.randn(3, 3, requires_grad=True)
|
|
self.run_test(NarrowModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_index_fill(self):
|
|
class IndexFillModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
index = torch.tensor([2, 0])
|
|
return input.index_fill(2, index, -1)
|
|
|
|
x = torch.randn(3, 4, 5, requires_grad=True)
|
|
self.run_test(IndexFillModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_index_copy(self):
|
|
class IndexCopyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
index = torch.tensor([2, 0])
|
|
source = torch.ones(3, 2, 5)
|
|
return input.index_copy(1, index, source)
|
|
|
|
x = torch.randn(3, 4, 5, requires_grad=True)
|
|
self.run_test(IndexCopyModel(), x)
|
|
|
|
def test_select(self):
|
|
class Select(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x[:, 1]
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(Select(), x)
|
|
|
|
def test_select_negative_index(self):
|
|
class Select(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x[:, -1]
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(Select(), x)
|
|
|
|
# TODO: enable for opset 10 when ONNXRuntime version will be updated
|
|
|
|
def test_index_select_constant_scaler_index(self):
|
|
class IndexSelectScalerIndexModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
index = 2
|
|
return torch.index_select(x, 1, torch.tensor(index))
|
|
x = torch.randn(3, 4)
|
|
self.run_test(IndexSelectScalerIndexModel(), x)
|
|
|
|
def test_index_select_scaler_index(self):
|
|
class IndexSelectScalerIndexModel(torch.nn.Module):
|
|
def __init__(self, index_base):
|
|
super(IndexSelectScalerIndexModel, self).__init__()
|
|
self.index_base = torch.tensor(index_base)
|
|
|
|
def forward(self, x, index_offset):
|
|
index = self.index_base + index_offset
|
|
return torch.index_select(x, 1, index)
|
|
x = torch.randn(3, 4)
|
|
offset = 2
|
|
index_offset = torch.tensor(offset)
|
|
base = 1
|
|
self.run_test(IndexSelectScalerIndexModel(base), (x, index_offset))
|
|
|
|
def test_take(self):
|
|
class TakeModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.take(x, y)
|
|
|
|
x = torch.randn(6, 4, 3, 3)
|
|
y = torch.tensor([4, 1, 7, 15, 63])
|
|
self.run_test(TakeModel(), (x, y))
|
|
|
|
def test_topk(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.topk(x, 3)
|
|
|
|
x = torch.arange(1., 6., requires_grad=True)
|
|
self.run_test(MyModule(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_topk_smallest_unsorted(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x, k):
|
|
# When sorted=False, order of elements in the outout tensors
|
|
# are not expected to match between PyTorch and ORT
|
|
topk_unsorted = torch.topk(x, k, largest=False, sorted=False)
|
|
topk_sorted = torch.topk(x, k, largest=False, sorted=True)
|
|
return topk_sorted, torch.sort(topk_unsorted.values).values
|
|
|
|
x = torch.arange(1., 6., requires_grad=True)
|
|
k = torch.tensor(3)
|
|
self.run_test(MyModule(), (x, k))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_topk_script(self):
|
|
class MyModuleDynamic(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x, k):
|
|
return torch.topk(x, k)
|
|
|
|
x = torch.arange(1., 6., requires_grad=True)
|
|
k = torch.tensor(3)
|
|
self.run_test(MyModuleDynamic(), [x, k])
|
|
|
|
@skipIfUnsupportedOpsetVersion([7])
|
|
def test_normalize(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.normalize(x)
|
|
|
|
x = torch.randn(3, 3)
|
|
self.run_test(Model(), x)
|
|
|
|
def test_layer_norm(self):
|
|
model = torch.nn.LayerNorm([10, 10])
|
|
x = torch.randn(20, 5, 10, 10)
|
|
self.run_test(model, x)
|
|
|
|
def test_batchnorm1d(self):
|
|
x = torch.randn(10, 10)
|
|
model = torch.nn.BatchNorm1d(10, affine=True)
|
|
self.run_test(model, x)
|
|
|
|
x = torch.randn(10, 10, 128)
|
|
self.run_test(model, x)
|
|
|
|
def test_batchnorm1d_noaffine(self):
|
|
x = torch.randn(10, 10)
|
|
model = torch.nn.BatchNorm1d(10, affine=False)
|
|
self.run_test(model, x)
|
|
|
|
x = torch.randn(10, 10, 128)
|
|
self.run_test(model, x)
|
|
|
|
def test_batchnorm2d(self):
|
|
x = torch.randn(10, 3, 128, 128)
|
|
model = torch.nn.BatchNorm2d(3, affine=True)
|
|
self.run_test(model, x)
|
|
|
|
def test_batchnorm2d_noaffine(self):
|
|
x = torch.randn(10, 3, 128, 128)
|
|
model = torch.nn.BatchNorm2d(3, affine=False)
|
|
self.run_test(model, x)
|
|
|
|
def test_batchnorm3d(self):
|
|
x = torch.randn(10, 3, 128, 128, 128)
|
|
model = torch.nn.BatchNorm3d(3, affine=True)
|
|
self.run_test(model, x)
|
|
|
|
def test_batchnorm3d_noaffine(self):
|
|
x = torch.randn(10, 3, 128, 128, 128)
|
|
model = torch.nn.BatchNorm3d(3, affine=False)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_scatter_with_scalar(self):
|
|
class ScatterModel(torch.nn.Module):
|
|
def forward(self, input, indices):
|
|
values = 1.0
|
|
return input.scatter(1, indices, values)
|
|
|
|
input = torch.tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=torch.float64)
|
|
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
|
|
self.run_test(ScatterModel(), input=(input, indices))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_scatter_with_scalar_different_types(self):
|
|
# Tests the case when scalar src (updates values) type is different
|
|
# from self type. Happens only with scalar src - PyTorch does not
|
|
# allow this when src is a tensor.
|
|
class ScatterModel(torch.nn.Module):
|
|
def forward(self, input, indices):
|
|
values = 1.0
|
|
return input.scatter(1, indices, values)
|
|
|
|
input = torch.tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=torch.float32)
|
|
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
|
|
self.run_test(ScatterModel(), input=(input, indices))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_scatter(self):
|
|
class ScatterModel(torch.nn.Module):
|
|
def forward(self, input, indices, values):
|
|
return input.scatter(1, indices, values)
|
|
|
|
input = torch.tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]])
|
|
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
|
|
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
|
|
self.run_test(ScatterModel(), input=(input, indices, values))
|
|
|
|
input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
|
indices = torch.tensor([[1, 0], [0, 2], [0, 1]], dtype=torch.int64)
|
|
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
|
|
self.run_test(ScatterModel(), (input, indices, values))
|
|
|
|
input = torch.zeros(3, 4, 5, 6)
|
|
indices = torch.tensor([[1, 0], [0, 2], [0, 1]], dtype=torch.int64)
|
|
indices = indices.view(3, 2, 1, 1).expand(3, 2, 5, 6)
|
|
values = torch.arange(3 * 2 * 5 * 6, dtype=torch.float32).view(3, 2, 5, 6)
|
|
self.run_test(ScatterModel(), (input, indices, values))
|
|
|
|
input = torch.zeros(3, 4, 2)
|
|
indices = torch.tensor([[[1, 0], [0, 2]], [[1, 1], [0, 1]], [[2, 1], [2, 2]]])
|
|
values = torch.arange(3 * 2 * 2, dtype=torch.float32).view(3, 2, 2)
|
|
self.run_test(ScatterModel(), (input, indices, values))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_scatter_add(self):
|
|
class ScatterModel(torch.nn.Module):
|
|
def forward(self, input, indices, values):
|
|
return input.scatter_add(1, indices, values)
|
|
|
|
input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
|
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
|
|
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
|
|
self.run_test(ScatterModel(), input=(input, indices, values))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_one_hot(self):
|
|
class OneHot(torch.nn.Module):
|
|
def __init__(self, num_classes):
|
|
super().__init__()
|
|
self.num_classes = num_classes
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.one_hot(x, self.num_classes)
|
|
|
|
x = torch.arange(10)
|
|
self.run_test(OneHot(15), (x))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_gather(self):
|
|
class GatherModel(torch.nn.Module):
|
|
def forward(self, input, indices):
|
|
return input.gather(1, indices)
|
|
|
|
input = torch.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
|
|
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
|
|
self.run_test(GatherModel(), input=(input, indices))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_expand(self):
|
|
class ExpandModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input.expand(2, 3, -1)
|
|
|
|
input = torch.randn(2, 1, 4)
|
|
self.run_test(ExpandModel(), input=(input))
|
|
|
|
class ExpandInferDimModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input.expand(-1, input.size(0))
|
|
|
|
input = torch.randn(3, 1)
|
|
self.run_test(ExpandInferDimModel(), input=(input))
|
|
|
|
class ExpandTensorSizeModel(torch.nn.Module):
|
|
def forward(self, input, size):
|
|
return input.expand(size)
|
|
|
|
input = torch.randn(3,)
|
|
size = torch.tensor(-1)
|
|
self.run_test(ExpandTensorSizeModel(), input=(input, size))
|
|
|
|
def test_multinomial(self):
|
|
class Multinomial(torch.nn.Module):
|
|
def forward(self, weight):
|
|
return torch.multinomial(weight, 3, replacement=True)
|
|
|
|
class MultinomialNoReplacement(torch.nn.Module):
|
|
def forward(self, weight):
|
|
return torch.multinomial(weight, 1)
|
|
|
|
weight = torch.tensor([[0, 10, 0, 0], [0, 0, 100, 0]], dtype=torch.float)
|
|
self.run_test(Multinomial(), (weight,))
|
|
self.run_test(MultinomialNoReplacement(), (weight,))
|
|
|
|
def _test_reduced_ops(self, op):
|
|
class ReducedOpModule(torch.nn.Module):
|
|
def forward(self, input):
|
|
return op(input, dim=-1)
|
|
|
|
if op != torch.mean: # torch.mean only supports float types
|
|
x = torch.randint(10, (4, 4), dtype=torch.uint8)
|
|
self.run_test(ReducedOpModule(), x)
|
|
|
|
x = torch.randint(10, (4, 4), dtype=torch.int8)
|
|
self.run_test(ReducedOpModule(), x)
|
|
|
|
x = torch.randint(10, (4, 4), dtype=torch.int16)
|
|
self.run_test(ReducedOpModule(), x)
|
|
|
|
x = torch.randint(10, (4, 4), dtype=torch.int32)
|
|
self.run_test(ReducedOpModule(), x)
|
|
|
|
x = torch.randint(10, (4, 4), dtype=torch.int64)
|
|
self.run_test(ReducedOpModule(), x)
|
|
|
|
# torch.mean only supports float types
|
|
# ORT does not support double ReduceProd for double
|
|
if op != torch.prod and op != torch.mean:
|
|
x = torch.randn(4, 5, dtype=torch.double)
|
|
self.run_test(ReducedOpModule(), x)
|
|
|
|
if op != torch.prod: # torch.prod not implemented for Half
|
|
x = torch.randn(4, 4, dtype=torch.half)
|
|
self.run_test(ReducedOpModule(), x)
|
|
|
|
x = torch.randn(4, 5, dtype=torch.float)
|
|
self.run_test(ReducedOpModule(), x)
|
|
|
|
def test_reduced_sum(self):
|
|
return self._test_reduced_ops(op=torch.sum)
|
|
|
|
def test_reduced_mean(self):
|
|
return self._test_reduced_ops(op=torch.mean)
|
|
|
|
def test_reduced_prod(self):
|
|
return self._test_reduced_ops(op=torch.prod)
|
|
|
|
def test_reduced_min_max(self):
|
|
class ReducedMinMaxModule(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.min(input, dim=-1)[0], torch.max(input, dim=0)[0]
|
|
x = torch.randint(10, (4, 4), dtype=torch.int32)
|
|
self.run_test(ReducedMinMaxModule(), x)
|
|
|
|
x = torch.randint(10, (4, 4), dtype=torch.int64)
|
|
self.run_test(ReducedMinMaxModule(), x)
|
|
|
|
x = torch.randn(4, 5, dtype=torch.float)
|
|
self.run_test(ReducedMinMaxModule(), x)
|
|
|
|
def test_reduce_log_sum_exp(self):
|
|
class ReduceLogSumExpModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
a = torch.logsumexp(input, dim=0)
|
|
b = torch.logsumexp(input, dim=(0, 1))
|
|
return a + b
|
|
|
|
x = torch.randn(4, 4, requires_grad=True)
|
|
self.run_test(ReduceLogSumExpModel(), x)
|
|
|
|
def test_softmax(self):
|
|
for i in range(-4, 3):
|
|
model = torch.nn.Softmax(dim=i)
|
|
input = torch.randn(3, 4, 5, 6)
|
|
self.run_test(model, input)
|
|
|
|
class SoftmaxUnknownRank(torch.nn.Module):
|
|
def __init__(self, i):
|
|
super().__init__()
|
|
self.softmax = torch.nn.Softmax(dim=i)
|
|
|
|
def forward(self, x):
|
|
return self.softmax(x.reshape(3, 4, 5, 6))
|
|
|
|
model = torch.jit.script(SoftmaxUnknownRank(i))
|
|
self.run_test(model, input)
|
|
|
|
def test_softmax_large_values(self):
|
|
input = torch.tensor([[-1e12, -1e12, -1e12], [1e12, 0.0, -5.0], [3.0, 4.0, 5.0]])
|
|
for i in range(-2, 1):
|
|
model = torch.nn.Softmax(dim=i)
|
|
self.run_test(model, input)
|
|
|
|
class SoftmaxUnknownRank(torch.nn.Module):
|
|
def __init__(self, i):
|
|
super().__init__()
|
|
self.softmax = torch.nn.Softmax(dim=i)
|
|
|
|
def forward(self, x):
|
|
return self.softmax(x.reshape(3, 3))
|
|
|
|
model = torch.jit.script(SoftmaxUnknownRank(i))
|
|
self.run_test(model, input)
|
|
|
|
def test_logsoftmax(self):
|
|
for i in range(7)[2:]:
|
|
model = torch.nn.LogSoftmax(dim=i - 1)
|
|
dims = [2] * (i - 2) + [3, 4]
|
|
input = torch.ones(*dims, requires_grad=True)
|
|
self.run_test(model, input)
|
|
|
|
def test_logsoftmax_dim(self):
|
|
for i in range(-4, 3):
|
|
model = torch.nn.LogSoftmax(dim=i)
|
|
input = torch.randn(3, 4, 5, 6)
|
|
self.run_test(model, input)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest() # scripting prim_dtype
|
|
def test_lstm_no_hidden(self):
|
|
class LSTMModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.rnn = torch.nn.LSTM(input_size=16, hidden_size=16)
|
|
|
|
def forward(self, x):
|
|
return self.rnn(x)
|
|
|
|
input = torch.randn((10, 16, 16))
|
|
self.run_test(LSTMModel(), (input,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest()
|
|
def test_lstm(self):
|
|
model = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)
|
|
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
|
|
h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
|
|
c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
|
|
self.run_test(model, (input, (h0, c0)))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest()
|
|
def test_lstm_default_init_state(self):
|
|
model = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)
|
|
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
|
|
self.run_test(model, input)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest() # LSTMModel model not scriptable
|
|
def test_lstm_fixed_batch_size(self):
|
|
class LSTMModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(LSTMModel, self).__init__()
|
|
self.lstm = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)
|
|
|
|
def forward(self, input):
|
|
batch_size = input.size()[1]
|
|
h0_np = np.ones([1, batch_size, RNN_HIDDEN_SIZE]).astype(np.float32)
|
|
c0_np = np.ones([1, batch_size, RNN_HIDDEN_SIZE]).astype(np.float32)
|
|
h0 = torch.from_numpy(h0_np)
|
|
c0 = torch.from_numpy(c0_np)
|
|
return self.lstm(input, (h0, c0))
|
|
|
|
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
|
|
# verify with different input of same batch size
|
|
input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
|
|
self.run_test(LSTMModel(), input, fixed_batch_size=True, test_with_inputs=[input2])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest()
|
|
def test_lstm_post_fix_init_state(self):
|
|
class LSTMModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(LSTMModel, self).__init__()
|
|
self.lstm = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE,
|
|
1, bidirectional=False)
|
|
|
|
def forward(self, input):
|
|
batch_size = input.size()[1]
|
|
h0_np = np.ones([1, batch_size, RNN_HIDDEN_SIZE]).astype(np.float32)
|
|
c0_np = np.ones([1, batch_size, RNN_HIDDEN_SIZE]).astype(np.float32)
|
|
h0 = torch.from_numpy(h0_np)
|
|
c0 = torch.from_numpy(c0_np)
|
|
return self.lstm(input, (h0, c0))
|
|
|
|
model = LSTMModel()
|
|
input = torch.randn(RNN_SEQUENCE_LENGTH, 1, RNN_INPUT_SIZE)
|
|
# verify with different input of different batch size
|
|
input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
|
|
self.run_test(model, input, dynamic_axes={'input' : {0 : 'seq', 1 : 'batch'}},
|
|
test_with_inputs=[input2])
|
|
|
|
@disableScriptTest()
|
|
def test_lstm_constant_folding(self):
|
|
class LstmNet(torch.nn.Module):
|
|
def __init__(self, input_size, hidden_size, num_layers, bidirectional):
|
|
super(LstmNet, self).__init__()
|
|
self.lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional)
|
|
|
|
def forward(self, input, initial_state):
|
|
return self.lstm(input, initial_state)
|
|
|
|
def get_LstmNet_model_and_inputs(input_size, hidden_size, num_layers, batch_size,
|
|
seq_len, bidirectional):
|
|
num_directions = 2 if bidirectional else 1
|
|
model = LstmNet(input_size, hidden_size, num_layers, bidirectional)
|
|
input = torch.randn(seq_len, batch_size, input_size)
|
|
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
|
|
c0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
|
|
return model, (input, (h0, c0))
|
|
|
|
batch_size1 = 3
|
|
model1, input1 = get_LstmNet_model_and_inputs(7, 3, 2, batch_size1, 5, True)
|
|
self.run_test(model1, input1, do_constant_folding=True)
|
|
|
|
batch_size2 = 4
|
|
model2, input2 = get_LstmNet_model_and_inputs(5, 4, 3, batch_size2, 7, False)
|
|
self.run_test(model2, input2, do_constant_folding=True)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest()
|
|
def test_lstm_no_bias(self):
|
|
class LstmNet(torch.nn.Module):
|
|
def __init__(self, num_layers, bidirectional):
|
|
super(LstmNet, self).__init__()
|
|
self.lstm = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, num_layers, bias=False, bidirectional=bidirectional)
|
|
|
|
def forward(self, input, initial_state):
|
|
return self.lstm(input, initial_state)
|
|
|
|
def get_LstmNet_model_and_inputs(num_layers, bidirectional):
|
|
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
|
|
num_directions = 2 if bidirectional else 1
|
|
model = LstmNet(num_layers, bidirectional)
|
|
h0 = torch.randn(num_layers * num_directions, BATCH_SIZE, RNN_HIDDEN_SIZE)
|
|
c0 = torch.randn(num_layers * num_directions, BATCH_SIZE, RNN_HIDDEN_SIZE)
|
|
return model, (input, (h0, c0))
|
|
|
|
num_layers = [1, 1, 2, 3]
|
|
bidirectional = [True, False, True, False]
|
|
models_and_inputs = [get_LstmNet_model_and_inputs(n, b) for n, b in zip(num_layers, bidirectional)]
|
|
for model, input in models_and_inputs:
|
|
self.run_test(model, input)
|
|
|
|
@disableScriptTest()
|
|
def test_rnn_no_bias(self):
|
|
def make_model(layers, packed_sequence):
|
|
batch_first = True if packed_sequence == 2 else False
|
|
model = torch.nn.RNN(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers, bidirectional=False,
|
|
batch_first=batch_first, bias=False)
|
|
|
|
if packed_sequence == 1:
|
|
model = RnnModelWithPackedSequence(model, False)
|
|
if packed_sequence == 2:
|
|
model = RnnModelWithPackedSequence(model, True)
|
|
return model
|
|
|
|
def make_input(batch_size, layers, packed_sequence):
|
|
batch_first = True if packed_sequence == 2 else False
|
|
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
|
|
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
|
|
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
|
|
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
|
|
inputs = [inputs]
|
|
|
|
h0 = torch.randn(layers, batch_size, RNN_HIDDEN_SIZE)
|
|
inputs.append(h0)
|
|
if packed_sequence != 0:
|
|
inputs.append(torch.IntTensor(seq_lengths))
|
|
if len(inputs) == 1:
|
|
input = inputs[0]
|
|
else:
|
|
input = tuple(inputs)
|
|
return input
|
|
|
|
layers = [1, 3, 1, 3, 1, 3]
|
|
packed_sequence = [0, 0, 1, 1, 2, 2]
|
|
models = [make_model(l, p) for l, p in zip(layers, packed_sequence)]
|
|
inputs = [make_input(RNN_BATCH_SIZE, l, p) for l, p in zip(layers, packed_sequence)]
|
|
|
|
for model, input in zip(models, inputs):
|
|
self.run_test(model, input, batch_size=RNN_BATCH_SIZE)
|
|
|
|
def test_gru_no_bias(self):
|
|
class GruNet(torch.nn.Module):
|
|
def __init__(self, input_size, hidden_size, num_layers, bidirectional):
|
|
super(GruNet, self).__init__()
|
|
self.mygru = torch.nn.GRU(input_size, hidden_size, num_layers, bidirectional=bidirectional, bias=False)
|
|
|
|
def forward(self, input, initial_state):
|
|
out = self.mygru(input, initial_state)
|
|
return out
|
|
|
|
def get_GruNet_model_and_inputs(input_size, hidden_size, num_layers, batch_size,
|
|
seq_len, bidirectional):
|
|
num_directions = 2 if bidirectional else 1
|
|
model = GruNet(input_size, hidden_size, num_layers, bidirectional)
|
|
input = torch.randn(seq_len, batch_size, input_size)
|
|
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
|
|
return model, (input, h0)
|
|
|
|
input_size = [7, 5]
|
|
hidden_size = [3, 4]
|
|
num_layers = [2, 3]
|
|
batch_size = [3, 4]
|
|
seq_len = [5, 7]
|
|
bidirectional = [True, False]
|
|
models_and_inputs = [get_GruNet_model_and_inputs(i, h, n, b, s, bi)
|
|
for i, h, n, b, s, bi in zip(input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional)]
|
|
for model, input in models_and_inputs:
|
|
self.run_test(model, input, do_constant_folding=True)
|
|
|
|
def test_gru_constant_folding(self):
|
|
class GruNet(torch.nn.Module):
|
|
def __init__(self, input_size, hidden_size, num_layers, bidirectional):
|
|
super(GruNet, self).__init__()
|
|
self.mygru = torch.nn.GRU(input_size, hidden_size, num_layers, bidirectional=bidirectional)
|
|
|
|
def forward(self, input, initial_state):
|
|
out = self.mygru(input, initial_state)
|
|
return out
|
|
|
|
def get_GruNet_model_and_inputs(input_size, hidden_size, num_layers, batch_size,
|
|
seq_len, bidirectional):
|
|
num_directions = 2 if bidirectional else 1
|
|
model = GruNet(input_size, hidden_size, num_layers, bidirectional)
|
|
input = torch.randn(seq_len, batch_size, input_size)
|
|
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
|
|
return model, (input, h0)
|
|
|
|
batch_size1 = 3
|
|
model1, input1 = get_GruNet_model_and_inputs(7, 3, 2, batch_size1, 5, True)
|
|
self.run_test(model1, input1, do_constant_folding=True)
|
|
|
|
batch_size2 = 4
|
|
model2, input2 = get_GruNet_model_and_inputs(5, 4, 3, batch_size2, 7, False)
|
|
self.run_test(model2, input2, do_constant_folding=True)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(8)
|
|
def test_max_tensors(self):
|
|
class MaxModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.max(input, other)
|
|
|
|
model = MaxModel()
|
|
x = torch.randn(4, 4, requires_grad=True)
|
|
y = torch.randn(4, 1, requires_grad=True)
|
|
self.run_test(model, (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_arange_end(self):
|
|
class ArangeScript(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
|
|
|
|
x = torch.randn(3, 4, requires_grad=True)
|
|
outputs = ArangeScript()(x)
|
|
self.run_test(ArangeScript(), x)
|
|
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, a):
|
|
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
|
|
|
|
self.run_test(ArangeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_arange_end_notype(self):
|
|
class ArangeScript(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
return torch.arange(a.size(0))
|
|
|
|
x = torch.randn(3, 4, requires_grad=True)
|
|
outputs = ArangeScript()(x)
|
|
self.run_test(ArangeScript(), x)
|
|
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, a):
|
|
return torch.arange(a.size(0))
|
|
|
|
self.run_test(ArangeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_arange_start_end(self):
|
|
class ArangeScript(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
|
|
|
|
x = torch.randn(3, 4, requires_grad=True)
|
|
self.run_test(ArangeScript(), x)
|
|
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, a):
|
|
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
|
|
|
|
self.run_test(ArangeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_arange_start_end_notype(self):
|
|
class ArangeScript(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
return torch.arange(2.7, a.size(0) + 2).view(-1, 1) + a
|
|
|
|
x = torch.randn(3, 4, requires_grad=True)
|
|
self.run_test(ArangeScript(), x)
|
|
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, a):
|
|
return torch.arange(2.7, a.size(0) + 2).view(-1, 1) + a
|
|
|
|
self.run_test(ArangeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_arange_start_end_step(self):
|
|
class ArangeScript(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
|
|
|
|
x = torch.randn(3, 4, requires_grad=True)
|
|
self.run_test(ArangeScript(), x)
|
|
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, a):
|
|
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
|
|
|
|
self.run_test(ArangeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_arange_start_end_step_notype(self):
|
|
class ArangeScript(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
return torch.arange(2.7, a.size(0) * a.size(1) + 2, a.size(1)).view(-1, 1) + a
|
|
|
|
x = torch.randn(3, 4, requires_grad=True)
|
|
self.run_test(ArangeScript(), x)
|
|
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, a):
|
|
return torch.arange(2.7, a.size(0) * a.size(1) + 2, a.size(1)).view(-1, 1) + a
|
|
|
|
self.run_test(ArangeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test__dim_arange(self):
|
|
class DimArange(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch._dim_arange(input, 1)
|
|
|
|
x = torch.ones(5, 6)
|
|
self.run_test(DimArange(), x)
|
|
|
|
def _test_compare_ops(self, model, num_inputs):
|
|
x_float = torch.randn(1, 2, 3, 4, requires_grad=True)
|
|
x_int = torch.randint(10, (3, 4), dtype=torch.int32)
|
|
if num_inputs > 1:
|
|
y_float = torch.randn(1, 2, 3, 4, requires_grad=True)
|
|
y_int = torch.randint(10, (3, 4), dtype=torch.int32)
|
|
self.run_test(model, (x_float, y_float))
|
|
self.run_test(model, (x_float, y_int))
|
|
self.run_test(model, (x_int, y_float))
|
|
self.run_test(model, (x_int, y_int))
|
|
else:
|
|
self.run_test(model, x_float)
|
|
self.run_test(model, x_int)
|
|
|
|
def test_gt(self):
|
|
class GreaterModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return input > other
|
|
self._test_compare_ops(GreaterModel(), 2)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_ge(self):
|
|
class GreaterOrEqualModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return input >= other
|
|
self._test_compare_ops(GreaterOrEqualModel(), 2)
|
|
|
|
def test_gt_scalar(self):
|
|
class GreaterModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input > 1
|
|
self._test_compare_ops(GreaterModel(), 1)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_ge_scalar(self):
|
|
class GreaterOrEqualModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input >= 1
|
|
self._test_compare_ops(GreaterOrEqualModel(), 1)
|
|
|
|
def test_lt(self):
|
|
class LessModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return input > other
|
|
self._test_compare_ops(LessModel(), 2)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_le(self):
|
|
class LessOrEqualModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return input <= other
|
|
self._test_compare_ops(LessOrEqualModel(), 2)
|
|
|
|
def test_lt_scalar(self):
|
|
class LessModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input < 1
|
|
self._test_compare_ops(LessModel(), 1)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_le_scalar(self):
|
|
class LessOrEqualModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input <= 1
|
|
self._test_compare_ops(LessOrEqualModel(), 1)
|
|
|
|
def test_matmul(self):
|
|
class MatmulModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.matmul(input, other)
|
|
|
|
x = torch.randn(3, 4, requires_grad=True)
|
|
y = torch.randn(4, 5, requires_grad=True)
|
|
self.run_test(MatmulModel(), (x, y))
|
|
|
|
x = torch.randint(10, (3, 4))
|
|
y = torch.randint(10, (4, 5))
|
|
self.run_test(MatmulModel(), (x, y))
|
|
|
|
def test_matmul_batch(self):
|
|
class MatmulModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.matmul(input, other)
|
|
|
|
x = torch.randn(2, 3, 4, requires_grad=True)
|
|
y = torch.randn(2, 4, 5, requires_grad=True)
|
|
self.run_test(MatmulModel(), (x, y))
|
|
|
|
x = torch.randint(10, (2, 3, 4))
|
|
y = torch.randint(10, (2, 4, 5))
|
|
self.run_test(MatmulModel(), (x, y))
|
|
|
|
def _argmin_argmax_model(self, input):
|
|
class ArgminArgmaxModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.argmin(input), \
|
|
torch.argmax(input), \
|
|
torch.argmin(input, keepdim=True), \
|
|
torch.argmax(input, keepdim=True)
|
|
|
|
self.run_test(ArgminArgmaxModel(), input)
|
|
|
|
def test_argmin_argmax(self):
|
|
input = torch.randn(7, 3, 5)
|
|
self._argmin_argmax_model(input)
|
|
|
|
# Argmin and Argmax with "select_last_index" is not supprted before opset 12
|
|
# "select_last_index" was added in opset 12 to deal with corner case where the
|
|
# same value appears multiple times in the tensor
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_argmin_argmax_select_last_index(self):
|
|
input = torch.tensor([[1., 2., 3.],
|
|
[1., 1., 2.]])
|
|
self._argmin_argmax_model(input)
|
|
|
|
input = torch.ones(7, 3, 5)
|
|
self._argmin_argmax_model(input)
|
|
|
|
def test_repeat(self):
|
|
class RepeatModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
x2 = x.repeat(y.shape[0], 1)
|
|
y1 = y.view(-1, 1)
|
|
return x2 + y1
|
|
|
|
x = torch.tensor([1, 2, 3])
|
|
y = torch.tensor([4, 5, 8, 9])
|
|
self.run_test(RepeatModel(), (x, y))
|
|
|
|
def test_view(self):
|
|
class ViewModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input.view(4, 24)
|
|
|
|
x = torch.randint(10, (4, 2, 3, 4), dtype=torch.int32)
|
|
self.run_test(ViewModel(), x)
|
|
|
|
def test_view_dynamic(self):
|
|
class ViewModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return input.view(other.shape)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
shape = torch.randn(6, 4)
|
|
self.run_test(ViewModel(), (x, shape))
|
|
|
|
def test_view_dynamic_zero_dim(self):
|
|
class ViewModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
input = input.view(-1, 2)
|
|
return input.view(1, -1)
|
|
|
|
x = torch.ones(2)
|
|
another_x = torch.empty((0,))
|
|
self.run_test(ViewModel(), x, test_with_inputs=[another_x],
|
|
input_names=['input_1'], dynamic_axes={'input_1': [0, ]})
|
|
|
|
def test_view_as(self):
|
|
class ViewModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return input.view_as(other)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.randn(6, 4)
|
|
self.run_test(ViewModel(), (x, y))
|
|
|
|
@disableScriptTest() # ONNX Shape inference failure in if/else block for Gemm
|
|
def test_weight_norm(self):
|
|
model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=1)
|
|
x = torch.randn(3, 4, 5, requires_grad=True)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.utils.weight_norm(torch.nn.Conv1d(1, 1, 3))
|
|
x = torch.randn(1, 1, 5, requires_grad=True)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.utils.weight_norm(torch.nn.Conv1d(1, 1, 3), dim=-2)
|
|
x = torch.randn(1, 1, 5, requires_grad=True)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.utils.weight_norm(torch.nn.Conv1d(3, 6, 3), name='weight')
|
|
x = torch.randn(3, 3, 5, requires_grad=True)
|
|
self.run_test(model, x)
|
|
|
|
@disableScriptTest() # ONNX Shape inference failure in if/else block for Gemm
|
|
def test_weight_norm_nodim(self):
|
|
model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=None)
|
|
x = torch.randn(3, 4, 5, requires_grad=True)
|
|
self.run_test(model, x)
|
|
|
|
def test_flatten(self):
|
|
class FlattenModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.flatten(input)
|
|
|
|
x = torch.randint(10, (1, 2, 3, 4))
|
|
self.run_test(FlattenModel(), x)
|
|
|
|
def test_flatten2d(self):
|
|
class FlattenModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.flatten(input, 1)
|
|
|
|
x = torch.randint(10, (1, 2, 3, 4))
|
|
self.run_test(FlattenModel(), x)
|
|
|
|
def test_flatten2d_neg(self):
|
|
class FlattenModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.flatten(x, 1, -1), torch.flatten(x, 0, -2), torch.flatten(x, 1, -2)
|
|
|
|
x = torch.randint(10, (1, 2, 3, 4))
|
|
self.run_test(FlattenModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_flatten_dynamic_axes(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.flatten(x, start_dim=2, end_dim=3)
|
|
|
|
batch_size = 3
|
|
x = torch.randn(batch_size, 5, 4, 5)
|
|
y = torch.randn(5, 5, 4, 5)
|
|
model = MyModule()
|
|
self.run_test(model, x, test_with_inputs=[y],
|
|
input_names=['input'],
|
|
output_names=['output'],
|
|
dynamic_axes={'input' : {0 : 'batch_size'},
|
|
'output' : {0 : 'batch_size'}})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_getitem(self):
|
|
class GetItemModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x, y, z, ind):
|
|
# this will create prim::ListConstruct(x, y, z) + aten::__getitem__
|
|
arr = [x, y, z]
|
|
return arr[ind]
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
y = torch.randn(1, 4, 5)
|
|
z = torch.randn(2, 4, 5)
|
|
ind = torch.tensor(1, dtype=torch.long)
|
|
self.run_test(GetItemModel(), (x, y, z, ind))
|
|
|
|
ind = torch.tensor(-2, dtype=torch.long)
|
|
self.run_test(GetItemModel(), (x, y, z, ind))
|
|
|
|
def test_unbind(self):
|
|
class UnbindModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
_, out, _ = input.unbind()
|
|
return out
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
self.run_test(UnbindModel(), x)
|
|
|
|
class UnbindModel2(torch.nn.Module):
|
|
def forward(self, input):
|
|
_, out, _, _ = input.unbind(1)
|
|
return out
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
self.run_test(UnbindModel2(), x)
|
|
|
|
class UnbindModel3(torch.nn.Module):
|
|
def forward(self, input):
|
|
_, out, _, _ = input.unbind(-2)
|
|
return out
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
self.run_test(UnbindModel3(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_len(self):
|
|
class LenModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return len(input.unbind()) + input
|
|
|
|
x = torch.randn(4, 5)
|
|
self.run_test(LenModel(), x, input_names=['input'], dynamic_axes={'input': {0: 'seq'}},
|
|
test_with_inputs=(torch.randn(5, 5),))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_len_list(self):
|
|
class LenListModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return torch.ones(len(input.shape))
|
|
|
|
x = torch.randn(4, 5)
|
|
self.run_test(LenListModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_unbind_dynamic(self):
|
|
class UnbindModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return input.unbind()[1]
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
self.run_test(UnbindModel(), x)
|
|
|
|
class UnbindModel2(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return input.unbind(-1)[1]
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
self.run_test(UnbindModel2(), x)
|
|
|
|
def test_split(self):
|
|
class SplitModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
out1, out2, out3 = input.split([2, 1, 2])
|
|
return out1, out2, out3
|
|
|
|
x = torch.randn(5, 4, 3)
|
|
self.run_test(SplitModel(), x)
|
|
|
|
class SplitModel2(torch.nn.Module):
|
|
def forward(self, input):
|
|
out1, out2, out3 = input.split([2, 1, 1], -2)
|
|
return out1, out2, out3
|
|
|
|
x = torch.randn(5, 4, 3)
|
|
self.run_test(SplitModel2(), x)
|
|
|
|
class SplitModel3(torch.nn.Module):
|
|
def forward(self, input):
|
|
out1, out2, out3 = input.split([2, 1, 2])
|
|
return out3, out1
|
|
|
|
x = torch.randn(5, 4, 3)
|
|
self.run_test(torch.jit.script(SplitModel3()), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest()
|
|
def test_split_size_as_list(self):
|
|
class SplitModel(torch.nn.Module):
|
|
def forward(self, input, split_sizes: List[int]):
|
|
out = []
|
|
split_list: List[torch.Tensor] = input.split(split_sizes)
|
|
|
|
for ob in split_list:
|
|
out.append(ob)
|
|
return torch.cat(out, dim=0)
|
|
|
|
x = torch.randn(6, 4, 3)
|
|
split_sizes = [torch.tensor(2), torch.tensor(4)]
|
|
self.run_test(SplitModel(), (x, split_sizes))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_split_size_with_slice(self):
|
|
class SplitModule(torch.nn.Module):
|
|
def forward(self, x, y, t):
|
|
splits = (x.size(1), y.size(1))
|
|
out, out2 = torch.split(t, splits, dim=1)
|
|
return out, out2
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 4)
|
|
t = torch.randn(2, 7)
|
|
self.run_test(SplitModule(), (x, y, t))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_split_dynamic(self):
|
|
class SplitModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return input.split(2)[1]
|
|
|
|
x = torch.randn(5, 4, 3)
|
|
self.run_test(SplitModel(), x)
|
|
|
|
class SplitModel2(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return input.split(2, -3)[1]
|
|
|
|
x = torch.randn(5, 4, 3)
|
|
self.run_test(SplitModel2(), x)
|
|
|
|
def test_concat(self):
|
|
class ConcatModel(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
return torch.cat((x, y, z))
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
y = torch.randn(1, 4, 5)
|
|
z = torch.randn(2, 4, 5)
|
|
self.run_test(ConcatModel(), (x, y, z))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_concat_dynamic(self):
|
|
class ConcatDynamicModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.cat(x.unbind())
|
|
|
|
x = torch.randn(4, 5, 6)
|
|
self.run_test(ConcatDynamicModel(), x)
|
|
|
|
def test_stack(self):
|
|
class StackModel(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
return torch.stack((x, y, z), 1)
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
y = torch.randn(3, 4, 5)
|
|
z = torch.randn(3, 4, 5)
|
|
self.run_test(StackModel(), (x, y, z))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_stack_dynamic(self):
|
|
class StackDynamicModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.stack(x.unbind(), 1)
|
|
|
|
x = torch.randn(4, 5, 6)
|
|
self.run_test(StackDynamicModel(), x)
|
|
|
|
def test_loop_dynamic(self):
|
|
class LoopModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
for i in range(x.size(2)):
|
|
x = x + i
|
|
return x
|
|
|
|
model = LoopModel()
|
|
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
|
|
self.run_test(model, inputs)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_loop_nested(self):
|
|
class NestedLoopsModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
for i in range(5):
|
|
a = 0
|
|
while a < 4:
|
|
a += 1
|
|
x = x + a
|
|
return x
|
|
|
|
model = NestedLoopsModel()
|
|
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
|
|
self.run_test(model, inputs)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_loop_with_list(self):
|
|
class ListLoopModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
res = []
|
|
res1 = []
|
|
arr = x.split([3, 4, 1, 1, 2, 3, 2], 0)
|
|
res2 = torch.zeros(3, 4, dtype=torch.long)
|
|
res3 = []
|
|
res4 = []
|
|
for i in range(len(arr)):
|
|
res = res.append(arr[i].sum(0, False))
|
|
res1 = res1.append(arr[-1 - i].sum(0, False))
|
|
res2 += 1
|
|
res3 = res3 + [arr[i].sum(0, False)]
|
|
res4 += [arr[-1 - i].sum(0, False)]
|
|
return torch.stack(res), torch.stack(res1), res2, torch.stack(res3), torch.stack(res4)
|
|
|
|
model = ListLoopModel()
|
|
inputs = torch.randn(16)
|
|
self.run_test(model, inputs)
|
|
|
|
@skipIfONNXShapeInference(False)
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_loop_transpose(self):
|
|
class LoopModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
res = torch.zeros_like(x[0])
|
|
for i in range(x.size(0)):
|
|
res += x[0].transpose(0, 1)
|
|
return res
|
|
|
|
model = torch.jit.script(LoopModel())
|
|
x = torch.randn(5, 3, 3)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_list(self):
|
|
class ListModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
tensors = x.unbind()
|
|
res = []
|
|
res.append(tensors[0])
|
|
res.append(tensors[1])
|
|
res.pop(1)
|
|
|
|
res.insert(0, tensors[1])
|
|
res.append(tensors[2])
|
|
res += [tensors[3], tensors[4]]
|
|
res = res + [tensors[5]]
|
|
return torch.ones(len(res))
|
|
|
|
model = ListModel()
|
|
inputs = torch.randn(16, 1)
|
|
self.run_test(model, inputs)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_tensor_factories(self):
|
|
class TensorFactory(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.zeros(x.size()) + torch.ones(x.size())
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(TensorFactory(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_tensor_factories_script(self):
|
|
class TensorFactory(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.zeros(x.shape, dtype=torch.float) + torch.ones(x.shape, dtype=torch.float)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(TensorFactory(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_tensor_like_factories_script(self):
|
|
class TensorFactory(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
zeros = torch.zeros_like(x, dtype=torch.float, layout=torch.strided, device=torch.device('cpu'))
|
|
ones = torch.ones_like(x, dtype=torch.float, layout=torch.strided, device=torch.device('cpu'))
|
|
return zeros + ones
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(TensorFactory(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_eye(self):
|
|
class TensorFactory(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.eye(x.size()[1], 3), torch.eye(4, 4, dtype=torch.long), torch.eye(x.size()[1], 2, dtype=torch.long)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
another_x = torch.randn(5, 6, 7)
|
|
self.run_test(TensorFactory(), x, test_with_inputs=[another_x],
|
|
input_names=['input_1'], dynamic_axes={'input_1': [0, 1, 2]})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_inplace_zero(self):
|
|
class Zero_(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.zero_(), x
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(Zero_(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_new_zeros(self):
|
|
class Zero_(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.new_zeros(x.shape[1:2]), x.new_zeros(x.shape[2:], dtype=torch.long)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(Zero_(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_list_pass(self):
|
|
class Slice(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x.new_zeros(x.shape[2:] + y.shape[1:])
|
|
|
|
x = torch.randn(2, 3, 4, 5)
|
|
y = torch.randn(1, 2, 3, 4)
|
|
self.run_test(Slice(), (x, y))
|
|
|
|
class Size(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x.new_zeros(x.shape + y.shape)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.randn(1, 2, 3)
|
|
self.run_test(Size(), (x, y))
|
|
|
|
class Array(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
arr1 = [x.shape[0], x.shape[1], 2]
|
|
arr2 = [y.shape[0], y.shape[1]]
|
|
return x.new_zeros(arr1 + arr2)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.randn(1, 2, 3)
|
|
self.run_test(Array(), (x, y))
|
|
|
|
class List(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
l1 = list(x.shape)
|
|
l2 = list(y.shape)
|
|
return x.new_zeros(l1 + l2)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.randn(1, 2, 3)
|
|
self.run_test(List(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_new_empty(self):
|
|
class Emtpy(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.new_empty(x.shape[0]).fill_(0), x.new_empty(x.shape[0], dtype=torch.long) * 0
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(Emtpy(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_new_full(self):
|
|
class Full(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.new_full(x.shape[1:2], 5), x.new_full(x.shape[0:1], 1.3, dtype=torch.long)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(Full(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_inplace_list(self):
|
|
class Arithmetic(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x, y):
|
|
return torch.cat([x.add_(3), y.fill_(0)])
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 3)
|
|
self.run_test(Arithmetic(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_inplace_fill(self):
|
|
class Fill_(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.fill_(3), x
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(Fill_(), x)
|
|
|
|
def test_inplace_arithmetic(self):
|
|
class Arithmetic(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x, y):
|
|
x.add_(3)
|
|
y.mul_(x)
|
|
return x, y
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.randn(2, 3, 4)
|
|
self.run_test(Arithmetic(), (x, y))
|
|
|
|
@disableScriptTest()
|
|
def test_sort(self):
|
|
class SortModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
out = []
|
|
for i in range(-2, 2):
|
|
out.append(torch.sort(x, dim=i, descending=True))
|
|
return out
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(SortModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest()
|
|
def test_sort_ascending(self):
|
|
class SortModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
out = []
|
|
for i in range(-2, 2):
|
|
out.append(torch.sort(x, dim=i, descending=False))
|
|
return out
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(SortModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_masked_fill(self):
|
|
class MaskedFillModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.uint8)
|
|
return x.masked_fill(mask, 2)
|
|
|
|
x = torch.zeros(4, 2, 3, requires_grad=True)
|
|
self.run_test(MaskedFillModel(), x)
|
|
|
|
class MaskedFillModel2(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.masked_fill(x > 3, -1)
|
|
|
|
x = torch.arange(16).view(2, 2, 4).to(torch.float32)
|
|
self.run_test(MaskedFillModel2(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_masked_fill_inplace(self):
|
|
|
|
class MaskedFillModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.uint8)
|
|
x.masked_fill_(mask, 2)
|
|
return x
|
|
|
|
x = torch.zeros(4, 2, 3, requires_grad=True)
|
|
self.run_test(MaskedFillModel(), x)
|
|
|
|
class MaskedFillModel2(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
x.masked_fill_(x > 3, -1)
|
|
return x
|
|
|
|
x = torch.arange(16).view(2, 2, 4).to(torch.float32)
|
|
self.run_test(MaskedFillModel2(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_masked_scatter(self):
|
|
class MaskedScatterModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.masked_scatter(x, x.ge(0.5), torch.ones(100, 100) * 5)
|
|
|
|
x = torch.randn(3, 4, 5, requires_grad=True)
|
|
self.run_test(MaskedScatterModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_masked_select(self):
|
|
class MaskedSelectModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.masked_select(x, x.ge(0.5))
|
|
|
|
x = torch.randn(3, 4, 5, requires_grad=True)
|
|
self.run_test(MaskedSelectModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_pixel_shuffle(self):
|
|
class PixelShuffle(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.pixel_shuffle(x, upscale_factor=2)
|
|
|
|
x = torch.randn(2, 16, 4, 3, requires_grad=True)
|
|
self.run_test(PixelShuffle(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_scalar_type(self):
|
|
class ArithmeticModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.size(0) * 2 * x
|
|
|
|
x = torch.ones(2, 3, dtype=torch.float32)
|
|
self.run_test(ArithmeticModel(), x)
|
|
|
|
class ReciprocalModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.reciprocal(x)
|
|
|
|
x = torch.tensor([2.0, 4.0], dtype=torch.double)
|
|
self.run_test(ReciprocalModel(), x)
|
|
|
|
class ComparisonModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
a = torch.tensor([12.0])
|
|
return x.lt(1.5) & y.le(2) & x.le(1), x.gt(y), x.lt(y), a.ge(x.size(0))
|
|
|
|
x = torch.ones(2, 3, dtype=torch.int32)
|
|
y = torch.ones(2, 3, dtype=torch.float32)
|
|
self.run_test(ComparisonModel(), (x, y))
|
|
|
|
class MatMulModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return (torch.mm(x, x) + x + torch.mm(x, x) + x)
|
|
|
|
x = torch.ones(3, 3)
|
|
self.run_test(MatMulModel(), x)
|
|
|
|
class AddMMModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mm(x, x) + x
|
|
|
|
x = torch.ones(3, 3)
|
|
self.run_test(AddMMModel(), x)
|
|
|
|
class FullModel(torch.nn.Module):
|
|
# add is used for exporting full
|
|
def forward(self, x):
|
|
return torch.full((3, 4), x)
|
|
x = torch.tensor(12.)
|
|
self.run_test(FullModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest() # dtype mismatch
|
|
def test_full_like(self):
|
|
class FullLikeModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.full_like(x, 4)
|
|
|
|
x = torch.tensor(12)
|
|
self.run_test(FullLikeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest() # dtype mismatch
|
|
def test_full_like_value(self):
|
|
class FullLikeModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
out = y + 2
|
|
return torch.full_like(x, out)
|
|
|
|
x = torch.tensor(12)
|
|
y = torch.tensor(2)
|
|
self.run_test(FullLikeModel(), (x, y))
|
|
|
|
def test_l1_norm(self):
|
|
class NormModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.norm(x, p=1, dim=-1, keepdim=False)
|
|
|
|
x = torch.randn(4, 2, 3, requires_grad=True)
|
|
self.run_test(NormModel(), x)
|
|
|
|
def test_l2_norm(self):
|
|
class NormModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.norm(x, p=2, dim=-2, keepdim=False)
|
|
|
|
x = torch.randn(4, 2, 3, requires_grad=True)
|
|
self.run_test(NormModel(), x)
|
|
|
|
def test_frobenius_norm(self):
|
|
class NormModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.norm(x, p="fro", dim=0, keepdim=False)
|
|
|
|
x = torch.randn(4, 2, 3, requires_grad=True)
|
|
self.run_test(NormModel(), x)
|
|
|
|
def test_frobenius_norm_keepdim(self):
|
|
class NormModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.norm(x, p="fro", dim=(0, 1), keepdim=True)
|
|
|
|
x = torch.randn(4, 2, 3, requires_grad=True)
|
|
self.run_test(NormModel(), x)
|
|
|
|
def test_unfold(self):
|
|
class UnfoldModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.unfold(dimension=2, size=2, step=2)
|
|
|
|
x = torch.randn(4, 2, 3, requires_grad=True)
|
|
self.run_test(UnfoldModel(), x)
|
|
|
|
@skipIfONNXShapeInference(False)
|
|
def test_unfold_infer_shape(self):
|
|
class UnfoldModule(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(UnfoldModule, self).__init__()
|
|
self.conv = torch.nn.Conv1d(3, 1, 3, stride=2)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x.unfold(dimension=2, size=2, step=2)
|
|
|
|
x = torch.randn(32, 3, 64)
|
|
self.run_test(UnfoldModule(), x)
|
|
|
|
def test_remainder(self):
|
|
class RemainderModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.remainder(input, other)
|
|
|
|
x = torch.randn(4, 2, 3)
|
|
y = torch.randn(1, 2, 1)
|
|
self.run_test(RemainderModel(), (x, y))
|
|
|
|
def test_remainder_scalar(self):
|
|
class RemainderModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.remainder(input, 2.55)
|
|
|
|
x = torch.randint(10, (2, 3))
|
|
self.run_test(RemainderModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_fmod(self):
|
|
class FModModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.fmod(input, other)
|
|
|
|
x = torch.randn(4, 2, 3)
|
|
y = torch.randn(1, 2, 1)
|
|
self.run_test(FModModel(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_fmod_scalar(self):
|
|
class FModModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.fmod(input, 2.55)
|
|
|
|
x = torch.randint(10, (2, 3))
|
|
self.run_test(FModModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_gelu(self):
|
|
class GeluModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.gelu(x)
|
|
|
|
x = torch.randn(2, 4, 5, 6, requires_grad=True)
|
|
self.run_test(GeluModel(), x)
|
|
|
|
def test_add_inplace(self):
|
|
class InplaceAddModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
x += 12
|
|
return x
|
|
|
|
x = torch.randn(4, 2, 3, requires_grad=True)
|
|
self.run_test(InplaceAddModel(), x)
|
|
|
|
def test_rsqrt(self):
|
|
class RsqrtModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.rsqrt()
|
|
|
|
x = torch.randn(4, 2, 3, requires_grad=True, dtype=torch.float64)
|
|
self.run_test(RsqrtModel(), x)
|
|
|
|
def test_rsqrt_zeros(self):
|
|
class RsqrtModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.rsqrt()
|
|
x = torch.zeros(4, 2, 3, requires_grad=True, dtype=torch.float64)
|
|
self.run_test(RsqrtModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_unique(self):
|
|
class UniqueModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.unique(x, sorted=True, return_inverse=False, return_counts=True)
|
|
|
|
x = torch.tensor([1, 3, 2, 3], dtype=torch.long)
|
|
self.run_test(UniqueModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_unique_along_dim(self):
|
|
class UniqueModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.unique(x, dim=0, sorted=True, return_inverse=True, return_counts=False)
|
|
|
|
x = torch.tensor([1, 3, 2, 3], dtype=torch.long)
|
|
self.run_test(UniqueModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_cumsum(self):
|
|
class CumSum(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.cumsum(input, dim=0)
|
|
x = torch.randn(2, 3, 4)
|
|
model = CumSum()
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_cumsum_with_cast(self):
|
|
class CumSum(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.cumsum(input, dim=0, dtype=torch.float32)
|
|
|
|
model = CumSum()
|
|
x = torch.tensor([2, 3, 4], dtype=torch.int32)
|
|
self.run_test(model, x)
|
|
x = torch.tensor([False, True, True])
|
|
self.run_test(model, x)
|
|
|
|
@disableScriptTest() # error in propagate as assign input shape
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
@skipIfUnsupportedOpsetVersion([12]) # Due to ONNX Loop shape inference issue
|
|
def test_embedding_bag(self):
|
|
model = torch.nn.EmbeddingBag(10, 5, mode='sum', scale_grad_by_freq=True)
|
|
input = torch.randint(10, (7,))
|
|
offset = torch.tensor([0, 2, 5, 6])
|
|
self.run_test(model, (input, offset))
|
|
|
|
model = torch.nn.EmbeddingBag(10, 5, mode='sum', include_last_offset=True)
|
|
input = torch.randint(10, (7,))
|
|
offset = torch.tensor([0, 2, 5, 6])
|
|
self.run_test(model, (input, offset))
|
|
|
|
model = torch.nn.EmbeddingBag(10, 5, mode='max')
|
|
input = torch.randint(10, (7, 5))
|
|
self.run_test(model, (input))
|
|
|
|
@disableScriptTest() # scripting prim::Uninitialized, prim::dtype, prim::unchecked_cast
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
@skipIfUnsupportedOpsetVersion([12]) # Due to ONNX Loop shape inference issue
|
|
def test_embedding_bag_1d_per_sample_weights(self):
|
|
class EmbeddingModel(torch.nn.Module):
|
|
def forward(self, embedding_matrix, input, offset, weights):
|
|
return torch.nn.functional.embedding_bag(input, embedding_matrix, offsets=offset,
|
|
mode='sum', per_sample_weights=weights)
|
|
|
|
model = EmbeddingModel()
|
|
x = torch.randint(7, (6,))
|
|
w = torch.randn(6, )
|
|
offset = torch.tensor([0, 2, 5])
|
|
embedding_matrix = torch.rand(10, 15)
|
|
self.run_test(model, (embedding_matrix, x, offset, w))
|
|
|
|
@disableScriptTest() # scripting prim::Uninitialized, prim::dtype, prim::unchecked_cast
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
@skipIfUnsupportedOpsetVersion([12]) # Due to ONNX Loop shape inference issue
|
|
def test_embedding_bag_2d_per_sample_weights(self):
|
|
class EmbeddingModel(torch.nn.Module):
|
|
def forward(self, embedding_matrix, input, weights):
|
|
return torch.nn.functional.embedding_bag(input, embedding_matrix,
|
|
mode='sum', per_sample_weights=weights)
|
|
|
|
embedding_matrix = torch.rand(10, 15)
|
|
model = EmbeddingModel()
|
|
x = torch.randint(7, (2, 3))
|
|
w = torch.randn(2, 3)
|
|
self.run_test(model, (embedding_matrix, x, w))
|
|
|
|
@disableScriptTest() # scripting prim::Uninitialized, prim::dtype, prim::unchecked_cast
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@unittest.skip("Due to ONNX Loop shape inference issue.")
|
|
def test_embedding_bag_dynamic_input(self):
|
|
class EmbeddingModel1D(torch.nn.Module):
|
|
def forward(self, embedding_matrix, input, weights, offsets):
|
|
return torch.nn.functional.embedding_bag(input, embedding_matrix, offsets=offsets,
|
|
mode='sum', per_sample_weights=weights)
|
|
|
|
model = EmbeddingModel1D()
|
|
x = torch.randint(7, (6,))
|
|
w = torch.randn(6, )
|
|
offsets = torch.tensor([0, 2, 5], dtype=torch.long)
|
|
embedding_matrix = torch.rand(10, 15)
|
|
x2 = torch.randint(7, (2,))
|
|
w2 = torch.randn(2, )
|
|
embedding_matrix2 = torch.rand(12, 25)
|
|
offsets2 = torch.tensor([0, ], dtype=torch.long)
|
|
self.run_test(model, (embedding_matrix, x, w, offsets),
|
|
test_with_inputs=[(embedding_matrix2, x2, w2, offsets2)],
|
|
input_names=['embedding_matrix', 'x', 'offsets', 'w'],
|
|
dynamic_axes={'embedding_matrix': [0, 1], 'x': [0], 'offsets': [0], 'w': [0]})
|
|
|
|
class EmbeddingModel2D(torch.nn.Module):
|
|
def forward(self, embedding_matrix, input, weights):
|
|
return torch.nn.functional.embedding_bag(input, embedding_matrix,
|
|
mode='sum', per_sample_weights=weights)
|
|
|
|
model = EmbeddingModel2D()
|
|
x = torch.randint(7, (2, 3))
|
|
w = torch.randn(2, 3)
|
|
embedding_matrix = torch.rand(10, 15)
|
|
x2 = torch.randint(7, (3, 5))
|
|
w2 = torch.randn(3, 5)
|
|
embedding_matrix2 = torch.rand(12, 25)
|
|
self.run_test(model, (embedding_matrix, x, w),
|
|
test_with_inputs=[(embedding_matrix2, x2, w2)],
|
|
input_names=['embedding_matrix', 'x', 'w'],
|
|
dynamic_axes={'embedding_matrix': [0, 1], 'x': [0, 1], 'w': [0, 1]})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(8)
|
|
def test_meshgrid(self):
|
|
class Meshgrid(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
output1, output2, output3 = torch.meshgrid(x, y, z)
|
|
return output1, output2, output3
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
y = torch.zeros(4, requires_grad=True)
|
|
z = torch.randn(5, requires_grad=True)
|
|
self.run_test(Meshgrid(), (x, y, z))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(8)
|
|
def test_meshgrid_scalar(self):
|
|
class Meshgrid(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
output1, output2, output3 = torch.meshgrid(x, y, z)
|
|
return output1, output2, output3
|
|
|
|
x = torch.ones(3, requires_grad=True)
|
|
y = torch.zeros(4, requires_grad=True)
|
|
z = torch.tensor(2.0)
|
|
self.run_test(Meshgrid(), (x, y, z))
|
|
|
|
def test_baddbmm(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, input, batch1, batch2):
|
|
return torch.baddbmm(input, batch1, batch2, alpha=torch.tensor(5), beta=3.5)
|
|
x = torch.randn(10, 3, 5)
|
|
batch1 = torch.randn(10, 3, 4)
|
|
batch2 = torch.randn(10, 4, 5)
|
|
model = MyModule()
|
|
self.run_test(model, (x, batch1, batch2))
|
|
|
|
def test_baddbmm_dynamic(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, input, batch1, batch2, alpha, beta):
|
|
return torch.baddbmm(input, batch1, batch2, alpha=alpha, beta=beta)
|
|
x = torch.randn(10, 3, 5)
|
|
batch1 = torch.randn(10, 3, 4)
|
|
batch2 = torch.randn(10, 4, 5)
|
|
alpha = torch.tensor(5)
|
|
beta = torch.tensor(3.5)
|
|
model = MyModule()
|
|
self.run_test(model, (x, batch1, batch2, alpha, beta))
|
|
|
|
def test_numel(self):
|
|
class MyModule(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return input.numel() * input
|
|
|
|
x = torch.randn(2, 3, 5)
|
|
model = MyModule()
|
|
self.run_test(model, (x,))
|
|
|
|
def test_numel_empty(self):
|
|
class MyModule(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return input.numel() * input
|
|
|
|
x = torch.randn(0)
|
|
model = MyModule()
|
|
self.run_test(model, (x,))
|
|
|
|
def test_cast_to(self):
|
|
class MyModule(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input, other):
|
|
return input.to(other) + other
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.tensor([1], dtype=torch.int64)
|
|
model = MyModule()
|
|
self.run_test(model, (x, y))
|
|
|
|
def test_cast_to_bool(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.cat((input.to(other), other), 0)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.zeros([2, 3, 4], dtype=torch.bool)
|
|
model = MyModule()
|
|
self.run_test(model, (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_ones_bool(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, input):
|
|
true = torch.ones(input.shape, dtype=torch.bool)
|
|
return input.to(true) & true
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = MyModule()
|
|
self.run_test(model, x)
|
|
|
|
def test_log(self):
|
|
class Log(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.log(input)
|
|
x = torch.rand(2, 3, 4)
|
|
model = Log()
|
|
self.run_test(model, x)
|
|
|
|
def test_log1p(self):
|
|
class Log1p(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.log1p(input)
|
|
x = torch.rand(2, 3, 4)
|
|
model = Log1p()
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_round(self):
|
|
class Round(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.round(x)
|
|
|
|
x = torch.tensor([0.9920, -1.0362, -1.5000, 3.5000], requires_grad=True)
|
|
self.run_test(Round(), x)
|
|
|
|
def test_constant_pad(self):
|
|
model = torch.nn.ConstantPad1d(2, 3.5)
|
|
x = torch.randn(2, 4, 4)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.ConstantPad2d((3, 0, 2, 1), 3.5)
|
|
x = torch.randn(2, 2, 4, 4)
|
|
self.run_test(model, x)
|
|
|
|
# Dynamic padding is added in opset 11
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest() # Functional module not scriptable
|
|
def test_pad_types(self):
|
|
# Test for different pad integer types
|
|
class Pad(torch.nn.Module):
|
|
def forward(self, x, pad):
|
|
return torch.nn.functional.pad(x, pad)
|
|
|
|
x = torch.randn(2, 2, 4, 4)
|
|
y = pad = (torch.tensor(2, dtype=torch.int32), torch.tensor(4, dtype=torch.int32))
|
|
self.run_test(Pad(), (x, y))
|
|
|
|
y = pad = (torch.tensor(2, dtype=torch.int64), torch.tensor(4, dtype=torch.int64))
|
|
self.run_test(Pad(), (x, y))
|
|
|
|
@skipIfUnsupportedMaxOpsetVersion(10)
|
|
def test_unsupported_pad(self):
|
|
class Pad(torch.nn.Module):
|
|
def forward(self, x, pad):
|
|
return torch.nn.functional.pad(x, pad)
|
|
|
|
def run():
|
|
x = torch.randn(2, 2, 4, 4)
|
|
y = pad = (torch.tensor(2, dtype=torch.int32), torch.tensor(4, dtype=torch.int32))
|
|
p = Pad()
|
|
f = io.BytesIO()
|
|
torch.onnx._export(p, (x, y), f)
|
|
|
|
with self.assertRaises(RuntimeError) as cm:
|
|
run()
|
|
|
|
the_exception = cm.exception
|
|
self.assertEqual('Unsupported: ONNX export of Pad in opset 9. The sizes of the padding must be constant. ' +
|
|
'Please try opset version 11.', the_exception.args[0])
|
|
|
|
@disableScriptTest() # export prim::Uninitialized
|
|
def test_reflection_pad(self):
|
|
model = torch.nn.ReflectionPad1d(2)
|
|
x = torch.randn(2, 4, 4)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.ReflectionPad2d((3, 0, 2, 1))
|
|
x = torch.randn(2, 2, 4, 4)
|
|
self.run_test(model, x)
|
|
|
|
@disableScriptTest() # export prim::Uninitialized
|
|
def test_replication_pad(self):
|
|
model = torch.nn.ReplicationPad1d(2)
|
|
x = torch.randn(2, 4, 4)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.ReplicationPad2d((3, 0, 2, 1))
|
|
x = torch.randn(2, 2, 4, 4)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest() # export prim::Uninitialized
|
|
def test_im2col(self):
|
|
class Unfold(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.nn.functional.unfold(input, kernel_size=(10, 15), dilation=2, padding=5, stride=3), \
|
|
torch.nn.functional.unfold(input, kernel_size=(2, 2), dilation=1, padding=0, stride=3), \
|
|
torch.nn.functional.unfold(input, kernel_size=(1, 1), dilation=5, padding=2, stride=3)
|
|
|
|
x = torch.rand(1, 1, 200, 100)
|
|
self.run_test(Unfold(), x)
|
|
|
|
@skipIfNoLapack
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_det(self):
|
|
class Det(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.det(x)
|
|
|
|
x = torch.randn(2, 3, 5, 5)
|
|
self.run_test(Det(), x)
|
|
|
|
# This test checks output scalar type in the ONNX graph should not be null
|
|
# https://github.com/pytorch/pytorch/issues/28607
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_trace_script(self):
|
|
@torch.jit.script
|
|
def center_slice_helper(input, h_offset):
|
|
return input[:, h_offset:]
|
|
|
|
class CenterCrop(torch.nn.Module):
|
|
def forward(self, input):
|
|
return center_slice_helper(input, torch.tensor(input.shape[1] - 1))
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(CenterCrop(), x)
|
|
|
|
@skipIfNoLapack
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_logdet(self):
|
|
class LogDet(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.logdet(x)
|
|
|
|
x = torch.randn(2, 3, 5, 5)
|
|
self.run_test(LogDet(), x)
|
|
|
|
def test_dim(self):
|
|
class DimModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
out = input * 2
|
|
out *= out.dim()
|
|
return out
|
|
|
|
empty_input = torch.randn(0, requires_grad=True)
|
|
multi_dim_input = torch.randn(1, 2, 3, requires_grad=True)
|
|
self.run_test(DimModel(), empty_input)
|
|
self.run_test(DimModel(), multi_dim_input)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
@disableScriptTest() # variable number of inputs not scriptable
|
|
def test_einsum(self):
|
|
class EinsumModelBatchDiagonal(torch.nn.Module):
|
|
def forward(self, *tensor_list):
|
|
eqn = '...ii ->...i'
|
|
return torch.einsum(eqn, *tensor_list)
|
|
|
|
x = torch.randn(3, 5, 5)
|
|
self.run_test(EinsumModelBatchDiagonal(), input=(x,))
|
|
|
|
class EinsumModelBatchMatmul(torch.nn.Module):
|
|
def forward(self, *tensor_list):
|
|
eqn = 'bij, bjk -> bik'
|
|
return torch.einsum(eqn, *tensor_list)
|
|
|
|
x = torch.randn(5, 2, 3)
|
|
y = torch.randn(5, 3, 4)
|
|
self.run_test(EinsumModelBatchMatmul(), input=(x, y))
|
|
|
|
class EinsumModelInnerProd(torch.nn.Module):
|
|
def forward(self, *tensor_list):
|
|
eqn = 'i,i'
|
|
return torch.einsum(eqn, *tensor_list)
|
|
|
|
x = torch.randn(5)
|
|
y = torch.randn(5)
|
|
self.run_test(EinsumModelInnerProd(), input=(x, y))
|
|
|
|
class EinsumModelTranspose(torch.nn.Module):
|
|
def forward(self, *tensor_list):
|
|
eqn = 'ij->ji'
|
|
return torch.einsum(eqn, *tensor_list)
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(EinsumModelTranspose(), input=(x,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
@disableScriptTest() # shape/type inference
|
|
def test_crossentropyloss(self):
|
|
for ignore_index in [-100, 1]:
|
|
x = torch.randn(3, 5)
|
|
y = torch.empty(3, dtype=torch.long).random_(5)
|
|
y[y == 1] = ignore_index
|
|
|
|
self._crossentropyloss(x, y, ignore_index)
|
|
|
|
x = torch.randn(3, 5, 2)
|
|
y = torch.empty(3, 2, dtype=torch.long).random_(5)
|
|
y[y == 1] = ignore_index
|
|
self._crossentropyloss(x, y, ignore_index)
|
|
|
|
x = torch.randn(3, 5, 2, 7)
|
|
y = torch.empty(3, 2, 7, dtype=torch.long).random_(5)
|
|
y[y == 1] = ignore_index
|
|
self._crossentropyloss(x, y, ignore_index)
|
|
|
|
def _crossentropyloss(self, x, y, ignore_index):
|
|
class CrossEntropyLossNone(torch.nn.Module):
|
|
def __init__(self, ignore_index):
|
|
super(CrossEntropyLossNone, self).__init__()
|
|
if ignore_index == -100:
|
|
self.loss = torch.nn.CrossEntropyLoss(reduction='none')
|
|
else:
|
|
self.loss = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=ignore_index)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(CrossEntropyLossNone(ignore_index), input=(x, y))
|
|
|
|
class CrossEntropyLossNoneWeight(torch.nn.Module):
|
|
def __init__(self, ignore_index):
|
|
super(CrossEntropyLossNoneWeight, self).__init__()
|
|
if ignore_index == -100:
|
|
self.loss = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.randn(5))
|
|
else:
|
|
self.loss = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.randn(5), ignore_index=ignore_index)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(CrossEntropyLossNoneWeight(ignore_index), input=(x, y))
|
|
|
|
class CrossEntropyLossSum(torch.nn.Module):
|
|
def __init__(self, ignore_index):
|
|
super(CrossEntropyLossSum, self).__init__()
|
|
if ignore_index == -100:
|
|
self.loss = torch.nn.CrossEntropyLoss(reduction='sum')
|
|
else:
|
|
self.loss = torch.nn.CrossEntropyLoss(reduction='sum', ignore_index=ignore_index)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(CrossEntropyLossSum(ignore_index), input=(x, y))
|
|
|
|
class CrossEntropyLossSumWeight(torch.nn.Module):
|
|
def __init__(self, ignore_index):
|
|
super(CrossEntropyLossSumWeight, self).__init__()
|
|
if ignore_index == -100:
|
|
self.loss = torch.nn.CrossEntropyLoss(reduction='sum', weight=torch.randn(5))
|
|
else:
|
|
self.loss = torch.nn.CrossEntropyLoss(reduction='sum', weight=torch.randn(5), ignore_index=ignore_index)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(CrossEntropyLossSumWeight(ignore_index), input=(x, y))
|
|
|
|
class CrossEntropyLossMean(torch.nn.Module):
|
|
def __init__(self, ignore_index):
|
|
super(CrossEntropyLossMean, self).__init__()
|
|
if ignore_index == -100:
|
|
self.loss = torch.nn.CrossEntropyLoss()
|
|
else:
|
|
self.loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(CrossEntropyLossMean(ignore_index), input=(x, y))
|
|
|
|
class CrossEntropyLossMeanWeight(torch.nn.Module):
|
|
def __init__(self, ignore_index):
|
|
super(CrossEntropyLossMeanWeight, self).__init__()
|
|
if ignore_index == -100:
|
|
self.loss = torch.nn.CrossEntropyLoss(weight=torch.randn(5))
|
|
else:
|
|
self.loss = torch.nn.CrossEntropyLoss(weight=torch.randn(5), ignore_index=ignore_index)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(CrossEntropyLossMeanWeight(ignore_index), input=(x, y))
|
|
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest() # Output dtype mismatch
|
|
def test_kldiv_loss(self):
|
|
|
|
x = torch.randn(5)
|
|
y = torch.randn(5)
|
|
self._kldiv_loss(x, y)
|
|
|
|
x = torch.randn(2, 3, 5)
|
|
y = torch.randn(2, 3, 5)
|
|
self._kldiv_loss(x, y)
|
|
|
|
x = torch.randn(2, 3, 5, 7)
|
|
y = torch.randn(2, 3, 5, 7)
|
|
self._kldiv_loss(x, y)
|
|
|
|
def _kldiv_loss(self, x, y):
|
|
class KLDivLossNone(torch.nn.Module):
|
|
def __init__(self):
|
|
super(KLDivLossNone, self).__init__()
|
|
self.loss = torch.nn.KLDivLoss(reduction='none', log_target=True)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(KLDivLossNone(), input=(x, y))
|
|
|
|
class KLDivLossMean(torch.nn.Module):
|
|
def __init__(self):
|
|
super(KLDivLossMean, self).__init__()
|
|
self.loss = torch.nn.KLDivLoss(reduction='mean', log_target=False)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(KLDivLossMean(), input=(x, y))
|
|
|
|
class KLDivLossSum(torch.nn.Module):
|
|
def __init__(self):
|
|
super(KLDivLossSum, self).__init__()
|
|
self.loss = torch.nn.KLDivLoss(reduction='sum', log_target=True)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(KLDivLossSum(), input=(x, y))
|
|
|
|
class KLDivLossBatchMean(torch.nn.Module):
|
|
def __init__(self):
|
|
super(KLDivLossBatchMean, self).__init__()
|
|
self.loss = torch.nn.KLDivLoss(reduction='batchmean', log_target=False)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(KLDivLossBatchMean(), input=(x, y))
|
|
|
|
class KLDivLossMiniBatchMean(torch.nn.Module):
|
|
def __init__(self):
|
|
super(KLDivLossMiniBatchMean, self).__init__()
|
|
self.loss = torch.nn.KLDivLoss(reduction='batchmean', size_average=False, log_target=True)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(KLDivLossMiniBatchMean(), input=(x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
@disableScriptTest() # shape/type inference
|
|
def test_nllloss(self):
|
|
class NLLModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NLLModel, self).__init__()
|
|
self.loss = torch.nn.NLLLoss(reduction='none')
|
|
self.m = torch.nn.LogSoftmax(dim=1)
|
|
|
|
def forward(self, input, target):
|
|
output = self.loss(self.m(2 * input), target)
|
|
return output
|
|
|
|
N, C = 5, 4
|
|
input = torch.randn(N, 16)
|
|
target = torch.empty(N, dtype=torch.long).random_(0, C)
|
|
|
|
# using test data containing default ignore_index=-100
|
|
target[target == 1] = -100
|
|
self.run_test(NLLModel(), (input, target))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
@disableScriptTest() # shape/type inference
|
|
def test_nllloss_2d_none(self):
|
|
class NLLModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NLLModel, self).__init__()
|
|
self.loss = torch.nn.NLLLoss(reduction='none')
|
|
self.conv = torch.nn.Conv2d(16, C, (3, 3))
|
|
self.m = torch.nn.LogSoftmax(dim=1)
|
|
|
|
def forward(self, input, target):
|
|
output = self.loss(self.m(self.conv(input)), target)
|
|
return output
|
|
|
|
N, C = 5, 4
|
|
input = torch.randn(N, 16, 10, 10)
|
|
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
|
|
|
|
# using test data containing default ignore_index=-100
|
|
target[target == 1] = -100
|
|
self.run_test(NLLModel(), (input, target))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
@disableScriptTest() # shape/type inference
|
|
def test_nllloss_2d_mean(self):
|
|
class NLLModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NLLModel, self).__init__()
|
|
self.loss = torch.nn.NLLLoss(reduction='mean')
|
|
self.conv = torch.nn.Conv2d(16, C, (3, 3))
|
|
self.m = torch.nn.LogSoftmax(dim=1)
|
|
|
|
def forward(self, input, target):
|
|
output = self.loss(self.m(self.conv(input)), target)
|
|
return output
|
|
|
|
N, C = 5, 4
|
|
input = torch.randn(N, 16, 10, 10)
|
|
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
|
|
|
|
# using test data containing default ignore_index=-100
|
|
target[target == 1] = -100
|
|
self.run_test(NLLModel(), (input, target))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
@disableScriptTest() # shape/type inference
|
|
def test_nllloss_2d_sum(self):
|
|
class NLLModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NLLModel, self).__init__()
|
|
self.loss = torch.nn.NLLLoss(reduction='sum')
|
|
self.conv = torch.nn.Conv2d(16, C, (3, 3))
|
|
self.m = torch.nn.LogSoftmax(dim=1)
|
|
|
|
def forward(self, input, target):
|
|
output = self.loss(self.m(self.conv(input)), target)
|
|
return output
|
|
|
|
N, C = 5, 4
|
|
input = torch.randn(N, 16, 10, 10)
|
|
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
|
|
|
|
# using test data containing default ignore_index=-100
|
|
target[target == 1] = -100
|
|
self.run_test(NLLModel(), (input, target))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
@disableScriptTest() # shape/type inference
|
|
def test_nllloss_2d_mean_weights(self):
|
|
class NLLModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NLLModel, self).__init__()
|
|
self.loss = torch.nn.NLLLoss(reduction='mean', weight=torch.randn(C))
|
|
self.conv = torch.nn.Conv2d(16, C, (3, 3))
|
|
self.m = torch.nn.LogSoftmax(dim=1)
|
|
|
|
def forward(self, input, target):
|
|
output = self.loss(self.m(self.conv(input)), target)
|
|
return output
|
|
|
|
N, C = 5, 4
|
|
input = torch.randn(N, 16, 10, 10)
|
|
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
|
|
|
|
# using test data containing default ignore_index=-100
|
|
target[target == 1] = -100
|
|
self.run_test(NLLModel(), (input, target))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
@disableScriptTest() # shape/type inference
|
|
def test_nllloss_2d_mean_ignore_index(self):
|
|
class NLLModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NLLModel, self).__init__()
|
|
self.loss = torch.nn.NLLLoss(reduction='mean', ignore_index=1)
|
|
self.conv = torch.nn.Conv2d(16, C, (3, 3))
|
|
self.m = torch.nn.LogSoftmax(dim=1)
|
|
|
|
def forward(self, input, target):
|
|
output = self.loss(self.m(self.conv(input)), target)
|
|
return output
|
|
|
|
N, C = 5, 4
|
|
input = torch.randn(N, 16, 10, 10)
|
|
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
|
|
self.run_test(NLLModel(), (input, target))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
@disableScriptTest() # shape/type inference
|
|
def test_nllloss_2d_mean_ignore_index_weights(self):
|
|
class NLLModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NLLModel, self).__init__()
|
|
self.loss = torch.nn.NLLLoss(reduction='mean', weight=torch.randn(C), ignore_index=1)
|
|
self.conv = torch.nn.Conv2d(16, C, (3, 3))
|
|
self.m = torch.nn.LogSoftmax(dim=1)
|
|
|
|
def forward(self, input, target):
|
|
output = self.loss(self.m(self.conv(input)), target)
|
|
return output
|
|
|
|
N, C = 5, 4
|
|
input = torch.randn(N, 16, 10, 10)
|
|
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
|
|
self.run_test(NLLModel(), (input, target))
|
|
|
|
def test_torch_mm(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, mat1, mat2):
|
|
mm = torch.mm(mat1, mat2)
|
|
return mm
|
|
|
|
mat1 = torch.randn(2, 3)
|
|
mat2 = torch.randn(3, 3)
|
|
self.run_test(M(), input=(mat1, mat2))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9) # Because where op is not supported for opset < 9.
|
|
def test_where_with_bool_tensor(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, mat1, mat2):
|
|
out = torch.where(mat1 > 0, mat1, mat2)
|
|
return out
|
|
|
|
mat1 = torch.randn(2, 3)
|
|
mat2 = torch.ones(2, 3)
|
|
self.run_test(M(), input=(mat1, mat2))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9) # Because where op is not supported for opset < 9.
|
|
def test_where_with_byte_tensor(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, cond, mat1, mat2):
|
|
out = torch.where(cond, mat1, mat2)
|
|
return out
|
|
|
|
cond = torch.ones(2, 3, dtype=torch.uint8)
|
|
cond[1, 2] = 0
|
|
mat1 = torch.randn(2, 3)
|
|
mat2 = torch.ones(2, 3)
|
|
self.run_test(M(), input=(cond, mat1, mat2))
|
|
|
|
def test_dropout(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.dropout = torch.nn.Dropout(0.3)
|
|
|
|
def forward(self, x):
|
|
dropout = self.dropout(x)
|
|
return dropout
|
|
|
|
x = torch.randn(10, 3, 53)
|
|
self.run_test(M(), (x))
|
|
|
|
def test_shape_constant_fold(self):
|
|
class ShapeModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(ShapeModule, self).__init__()
|
|
self.register_buffer("weight", torch.ones(5))
|
|
|
|
def forward(self, x):
|
|
shape = self.weight.shape[0]
|
|
return x + shape
|
|
|
|
x = torch.randn(2, 5)
|
|
self.run_test(ShapeModule(), (x,), rtol=1e-3, atol=1e-5)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_celu(self):
|
|
class Celu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Celu, self).__init__()
|
|
self.celu = torch.nn.CELU(alpha=1.0)
|
|
|
|
def forward(self, input):
|
|
return self.celu(input)
|
|
|
|
input = torch.randn(2)
|
|
self.run_test(Celu(), (input,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_celu_default(self):
|
|
class Celu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Celu, self).__init__()
|
|
self.celu = torch.nn.CELU()
|
|
|
|
def forward(self, input):
|
|
return self.celu(input)
|
|
|
|
input = torch.randn(2)
|
|
self.run_test(Celu(), (input,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_celu_alpha(self):
|
|
class Celu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Celu, self).__init__()
|
|
self.celu = torch.nn.CELU(alpha=2.)
|
|
|
|
def forward(self, input):
|
|
return self.celu(input)
|
|
|
|
input = torch.randn(2)
|
|
self.run_test(Celu(), (input,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_celu_cast(self):
|
|
class Celu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Celu, self).__init__()
|
|
self.celu = torch.nn.CELU()
|
|
|
|
def forward(self, input):
|
|
return self.celu(input)
|
|
|
|
input = torch.randn(2, 5, 7, dtype=torch.float64)
|
|
self.run_test(Celu(), (input,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_where(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, cond, input, other):
|
|
return torch.where(cond, input, other)
|
|
|
|
x = torch.randint(0, 1, (2, 3, 4), dtype=torch.bool)
|
|
y = torch.randn(2, 1, 4)
|
|
z = torch.ones(2, 3, 1)
|
|
self.run_test(Model(), (x, y, z))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest() # symbolic update needed for unbind: ONNX export of unbind with dynamic number of outputs
|
|
def test_where_condition(self):
|
|
class Model1(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.stack(torch.where(input > 0.5), dim=1)
|
|
|
|
x = torch.randint(0, 2, (2, 3, 4), dtype=bool)
|
|
self.run_test(Model1(), (x))
|
|
|
|
class Model2(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.stack(torch.where(input > other), dim=1)
|
|
|
|
x = torch.randint(0, 1, (2, 3, 4), dtype=bool)
|
|
y = torch.randint(1, 2, (2, 3, 4), dtype=bool)
|
|
self.run_test(Model2(), (x, y))
|
|
|
|
def test_empty_branch(self):
|
|
class EmptyBranchModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
out = input + 1
|
|
if out.dim() > 2:
|
|
if out.dim() > 3:
|
|
out += 3
|
|
else:
|
|
pass
|
|
else:
|
|
pass
|
|
return out
|
|
|
|
x = torch.randn(1, 2, 3, requires_grad=True)
|
|
self.run_test(EmptyBranchModel(), x)
|
|
|
|
@skipIfONNXShapeInference(False)
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_if_transpose(self):
|
|
class IfModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = x.transpose(0, 1)
|
|
if x.size(0) == 2:
|
|
return x.transpose(0, 1)
|
|
else:
|
|
return x
|
|
|
|
x = torch.randn(2, 3)
|
|
self.run_test(torch.jit.script(IfModel()), x,
|
|
output_names=['output_1'],
|
|
dynamic_axes={'output_1': [0, 1]})
|
|
|
|
def test_onnx_proto_checker(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Model, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return 2 * x
|
|
|
|
x = torch.randn(1, 2, 3, requires_grad=True)
|
|
f = io.BytesIO()
|
|
torch.onnx._export(Model(), x, f)
|
|
model = onnx.load(f)
|
|
model.ir_version = 0
|
|
|
|
def check_proto():
|
|
torch._C._check_onnx_proto(model.SerializeToString())
|
|
|
|
self.assertRaises(RuntimeError, check_proto)
|
|
|
|
@disableScriptTest() # dtype mismatch
|
|
def test_split_tensor_scalar(self):
|
|
class SplitModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.split(x, x.size(1))
|
|
|
|
x = torch.randn(1, 2, 3, requires_grad=True)
|
|
self.run_test(SplitModel(), x)
|
|
|
|
def test_split_tensor_multi(self):
|
|
class SplitModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.split(x, torch.ones(3))
|
|
|
|
x = torch.randn(1, 2, 3, requires_grad=True)
|
|
|
|
def run_model():
|
|
SplitModel(x)
|
|
|
|
self.assertRaises(TypeError, run_model)
|
|
|
|
def _dispatch_rnn_test(self, name, *args, **kwargs):
|
|
if name == 'elman':
|
|
self._elman_rnn_test(*args, **kwargs)
|
|
if name == 'lstm':
|
|
self._lstm_test(*args, **kwargs)
|
|
if name == 'gru':
|
|
self._gru_test(*args, **kwargs)
|
|
|
|
def _elman_rnn_test(self, layers, nonlinearity, bidirectional,
|
|
initial_state, packed_sequence, dropout):
|
|
batch_first = True if packed_sequence == 2 else False
|
|
model = torch.nn.RNN(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers, nonlinearity=nonlinearity,
|
|
bidirectional=bidirectional, dropout=dropout, batch_first=batch_first)
|
|
|
|
if packed_sequence == 1:
|
|
model = RnnModelWithPackedSequence(model, False)
|
|
if packed_sequence == 2:
|
|
model = RnnModelWithPackedSequence(model, True)
|
|
|
|
def make_input(batch_size):
|
|
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
|
|
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
|
|
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
|
|
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
|
|
inputs = [inputs]
|
|
|
|
directions = 2 if bidirectional else 1
|
|
|
|
if initial_state:
|
|
h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
|
|
inputs.append(h0)
|
|
if packed_sequence != 0:
|
|
inputs.append(torch.IntTensor(seq_lengths))
|
|
if len(inputs) == 1:
|
|
input = inputs[0]
|
|
else:
|
|
input = tuple(inputs)
|
|
return input
|
|
|
|
input = make_input(RNN_BATCH_SIZE)
|
|
self.run_test(model, input, batch_size=RNN_BATCH_SIZE)
|
|
|
|
# test that the model still runs with a different batch size
|
|
other_input = make_input(RNN_BATCH_SIZE + 1)
|
|
self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1)
|
|
|
|
def _lstm_test(self, layers, bidirectional, initial_state,
|
|
packed_sequence, dropout):
|
|
batch_first = True if packed_sequence == 2 else False
|
|
model = LstmFlatteningResult(
|
|
RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers,
|
|
bidirectional=bidirectional, dropout=dropout, batch_first=batch_first)
|
|
if packed_sequence == 1:
|
|
model = RnnModelWithPackedSequence(model, False)
|
|
if packed_sequence == 2:
|
|
model = RnnModelWithPackedSequence(model, True)
|
|
|
|
def make_input(batch_size):
|
|
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
|
|
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
|
|
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
|
|
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
|
|
inputs = [inputs]
|
|
|
|
directions = 2 if bidirectional else 1
|
|
|
|
if initial_state:
|
|
h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
|
|
c0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
|
|
inputs.append((h0, c0))
|
|
if packed_sequence != 0:
|
|
inputs.append(torch.IntTensor(seq_lengths))
|
|
if len(inputs) == 1:
|
|
input = inputs[0]
|
|
else:
|
|
input = tuple(inputs)
|
|
return input
|
|
|
|
input = make_input(RNN_BATCH_SIZE)
|
|
self.run_test(model, input, batch_size=RNN_BATCH_SIZE)
|
|
|
|
# test that the model still runs with a different batch size
|
|
other_input = make_input(RNN_BATCH_SIZE + 1)
|
|
self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1)
|
|
|
|
def _gru_test(self, layers, bidirectional, initial_state,
|
|
packed_sequence, dropout):
|
|
batch_first = True if packed_sequence == 2 else False
|
|
model = torch.nn.GRU(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers, bidirectional=bidirectional, dropout=dropout,
|
|
batch_first=batch_first)
|
|
if packed_sequence == 1:
|
|
model = RnnModelWithPackedSequence(model, False)
|
|
if packed_sequence == 2:
|
|
model = RnnModelWithPackedSequence(model, True)
|
|
|
|
def make_input(batch_size):
|
|
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
|
|
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
|
|
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
|
|
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
|
|
inputs = [inputs]
|
|
|
|
directions = 2 if bidirectional else 1
|
|
|
|
if initial_state:
|
|
h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
|
|
inputs.append(h0)
|
|
if packed_sequence != 0:
|
|
inputs.append(torch.IntTensor(seq_lengths))
|
|
if len(inputs) == 1:
|
|
input = inputs[0]
|
|
else:
|
|
input = tuple(inputs)
|
|
return input
|
|
|
|
input = make_input(RNN_BATCH_SIZE)
|
|
self.run_test(model, input, batch_size=RNN_BATCH_SIZE)
|
|
|
|
# test that the model still runs with a different batch size
|
|
other_input = make_input(RNN_BATCH_SIZE + 1)
|
|
self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_fake_quantize_per_tensor(self):
|
|
class FakeQuantizePerTensorModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
scale = 1. / 127
|
|
zero_point = 0
|
|
quant_min = -128
|
|
quant_max = 127
|
|
return torch.fake_quantize_per_tensor_affine(input, scale, zero_point, quant_min, quant_max)
|
|
|
|
x = torch.randn(6, 4, 3, 3)
|
|
self.run_test(FakeQuantizePerTensorModel(), (x))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_dropout_training(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.dropout = torch.nn.Dropout(0.4)
|
|
|
|
def forward(self, x):
|
|
dropout = self.dropout(x)
|
|
return dropout
|
|
|
|
model = MyModule()
|
|
x = torch.randn(10)
|
|
|
|
model.train()
|
|
|
|
ort_sess = convert_to_onnx(model, input=(x,), opset_version=self.opset_version,
|
|
training=torch.onnx.TrainingMode.TRAINING)
|
|
ort_outs = run_ort(ort_sess, input=(x,))
|
|
assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0])))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_dropout_training_zero(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.dropout = torch.nn.Dropout(0.5)
|
|
|
|
def forward(self, x):
|
|
dropout = self.dropout(x)
|
|
return dropout
|
|
|
|
model = MyModule()
|
|
|
|
# ensure there are no zeros in the input
|
|
x = torch.randn(10, 3, 128, 128)
|
|
y = x.numpy()
|
|
y_mask = np.where(y == 0, 1, y)
|
|
input = torch.from_numpy(y_mask)
|
|
nb_elements = torch.numel(input)
|
|
|
|
model.train()
|
|
|
|
ort_sess = convert_to_onnx(model, input=(x,), opset_version=self.opset_version,
|
|
training=torch.onnx.TrainingMode.TRAINING)
|
|
ort_outs = run_ort(ort_sess, input=(x,))
|
|
|
|
y = model(input)
|
|
output = y.cpu().numpy()
|
|
|
|
ort_mask = np.where(ort_outs[0] != 0, 1, 0)
|
|
pyt_mask = np.where(output != 0, 1, 0)
|
|
|
|
ratio_pytorch = np.sum(pyt_mask) / nb_elements
|
|
ratio_ort = np.sum(ort_mask) / nb_elements
|
|
|
|
np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)
|
|
|
|
def test_conv_bn(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True)
|
|
self.bn = torch.nn.BatchNorm2d(16, affine=True)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
bn = self.bn(x)
|
|
return bn
|
|
|
|
model = MyModule()
|
|
x = torch.randn(10, 3, 128, 128)
|
|
ort_sess1 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version,
|
|
training=torch.onnx.TrainingMode.TRAINING)
|
|
ort_outs1 = run_ort(ort_sess1, input=(x,))
|
|
ort_sess2 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version,
|
|
training=torch.onnx.TrainingMode.EVAL)
|
|
ort_outs2 = run_ort(ort_sess2, input=(x,))
|
|
[np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in
|
|
zip(ort_outs1, ort_outs2)]
|
|
|
|
def test_multiple_conv_bn(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
|
self.conv2 = torch.nn.Conv2d(64, 2, kernel_size=1, stride=1, padding=0, bias=False)
|
|
self.conv3 = torch.nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1, bias=False)
|
|
self.bn = torch.nn.BatchNorm2d(64)
|
|
self.bn2 = torch.nn.BatchNorm2d(2)
|
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.bn(x)
|
|
x = self.relu(x)
|
|
x = self.maxpool(x)
|
|
x = self.conv2(x)
|
|
x = self.bn2(x)
|
|
x = self.relu(x)
|
|
x = self.conv3(x)
|
|
x = self.bn2(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
model = MyModule()
|
|
x = torch.randn(2, 3, 224, 224)
|
|
ort_sess1 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version,
|
|
training=torch.onnx.TrainingMode.TRAINING)
|
|
ort_outs1 = run_ort(ort_sess1, input=(x,))
|
|
ort_sess2 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version,
|
|
training=torch.onnx.TrainingMode.EVAL)
|
|
ort_outs2 = run_ort(ort_sess2, input=(x,))
|
|
[np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in
|
|
zip(ort_outs1, ort_outs2)]
|
|
|
|
def make_test(name, base, layer, bidirectional, initial_state,
|
|
variable_length, dropout,
|
|
**extra_kwargs):
|
|
test_name = str('_'.join([
|
|
'test', name, layer[1],
|
|
bidirectional[1], initial_state[1],
|
|
variable_length[1], dropout[1]
|
|
]))
|
|
|
|
# Cannot export with older opsets because of 'ConstantFill' op
|
|
# ConstantFill was a temp op removed at opset 8. This is no longer supported by onnxruntime
|
|
@disableScriptTest() # Test code not scriptable
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def f(self):
|
|
self._dispatch_rnn_test(
|
|
base,
|
|
layers=layer[0],
|
|
bidirectional=bidirectional[0],
|
|
initial_state=initial_state[0],
|
|
packed_sequence=variable_length[0],
|
|
dropout=dropout[0],
|
|
**extra_kwargs)
|
|
|
|
f.__name__ = test_name
|
|
setattr(TestONNXRuntime, f.__name__, f)
|
|
|
|
def setup_rnn_tests():
|
|
layers_opts = [
|
|
(1, 'unilayer'),
|
|
(3, 'trilayer')
|
|
]
|
|
bidirectional_opts = [
|
|
(False, 'forward'),
|
|
(True, 'bidirectional')
|
|
]
|
|
initial_state_opts = [
|
|
(True, 'with_initial_state'),
|
|
(False, 'no_initial_state')
|
|
]
|
|
variable_length_opts = [
|
|
(0, 'without_sequence_lengths'),
|
|
(1, 'with_variable_length_sequences'),
|
|
(2, 'with_batch_first_sequence_lengths')
|
|
]
|
|
dropout_opts = [
|
|
(0.2, 'with_dropout'),
|
|
(0.0, 'without_dropout')
|
|
]
|
|
test_count = 0
|
|
for (layer, bidirectional, initial_state, variable_length, dropout) in \
|
|
itertools.product(
|
|
layers_opts,
|
|
bidirectional_opts,
|
|
initial_state_opts,
|
|
variable_length_opts,
|
|
dropout_opts,):
|
|
|
|
for base, name, extra_kwargs in (
|
|
('elman', 'elman_relu', {'nonlinearity': u'relu'}),
|
|
('elman', 'elman_tanh', {'nonlinearity': u'tanh'}),
|
|
('lstm', 'lstm', {}),
|
|
('gru', 'gru', {})
|
|
):
|
|
make_test(name, base, layer, bidirectional, initial_state,
|
|
variable_length, dropout,
|
|
**extra_kwargs)
|
|
test_count += 1
|
|
|
|
# sanity check that a representative example does exist
|
|
TestONNXRuntime.test_gru_trilayer_forward_with_initial_state_without_sequence_lengths_with_dropout
|
|
|
|
# make sure no one accidentally disables all the tests without
|
|
# noticing
|
|
if test_count != 192:
|
|
raise ValueError('Expected 192 tests but found {}'.format(test_count))
|
|
|
|
setup_rnn_tests()
|
|
|
|
|
|
# opset 7 tests
|
|
TestONNXRuntime_opset7 = type(str("TestONNXRuntime_opset7"),
|
|
(unittest.TestCase,),
|
|
dict(TestONNXRuntime.__dict__, opset_version=7))
|
|
|
|
# opset 8 tests
|
|
TestONNXRuntime_opset8 = type(str("TestONNXRuntime_opset8"),
|
|
(unittest.TestCase,),
|
|
dict(TestONNXRuntime.__dict__, opset_version=8))
|
|
|
|
|
|
# opset 10 tests
|
|
TestONNXRuntime_opset10 = type(str("TestONNXRuntime_opset10"),
|
|
(unittest.TestCase,),
|
|
dict(TestONNXRuntime.__dict__, opset_version=10))
|
|
|
|
# opset 11 tests
|
|
TestONNXRuntime_opset11 = type(str("TestONNXRuntime_opset11"),
|
|
(unittest.TestCase,),
|
|
dict(TestONNXRuntime.__dict__, opset_version=11))
|
|
|
|
# opset 12 tests
|
|
TestONNXRuntime_opset12 = type(str("TestONNXRuntime_opset12"),
|
|
(unittest.TestCase,),
|
|
dict(TestONNXRuntime.__dict__, opset_version=12))
|
|
|
|
# opset 9 tests, with keep_initializers_as_inputs=False for
|
|
# IR version 4 style export.
|
|
TestONNXRuntime_opset9_IRv4 = type(str("TestONNXRuntime_opset9_IRv4"),
|
|
(unittest.TestCase,),
|
|
dict(TestONNXRuntime.__dict__,
|
|
keep_initializers_as_inputs=False))
|
|
|
|
|
|
# opset 10 tests, with keep_initializers_as_inputs=False for
|
|
# IR version 4 style export.
|
|
TestONNXRuntime_opset10_IRv4 = type(str("TestONNXRuntime_opset10_IRv4"),
|
|
(unittest.TestCase,),
|
|
dict(TestONNXRuntime.__dict__, opset_version=10,
|
|
keep_initializers_as_inputs=False))
|
|
|
|
|
|
# opset 11 tests, with keep_initializers_as_inputs=False for
|
|
# IR version 4 style export.
|
|
TestONNXRuntime_opset11_IRv4 = type(str("TestONNXRuntime_opset11_IRv4"),
|
|
(unittest.TestCase,),
|
|
dict(TestONNXRuntime.__dict__, opset_version=11,
|
|
keep_initializers_as_inputs=False))
|
|
|
|
# opset 12 tests, with keep_initializers_as_inputs=False for
|
|
# IR version 4 style export.
|
|
TestONNXRuntime_opset12_IRv4 = type(str("TestONNXRuntime_opset12_IRv4"),
|
|
(unittest.TestCase,),
|
|
dict(TestONNXRuntime.__dict__, opset_version=12,
|
|
keep_initializers_as_inputs=False))
|
|
|
|
|
|
# opset 9 tests, with use_new_jit_passes=True for using new jit API,
|
|
# and with keep_initializers_as_inputs=False for IR version 4 style export.
|
|
TestONNXRuntime_opset9_IRv4_new_jit_API = type(str("TestONNXRuntime_opset9_IRv4_new_jit_API"),
|
|
(unittest.TestCase,),
|
|
dict(TestONNXRuntime.__dict__,
|
|
keep_initializers_as_inputs=False,
|
|
use_new_jit_passes=True,
|
|
onnx_shape_inference=True))
|
|
|
|
|
|
# opset 12 tests, with use_new_jit_passes=True for using new jit API,
|
|
# and keep_initializers_as_inputs=False for IR version 4 style export.
|
|
TestONNXRuntime_opset12_IRv4_new_jit_API = type(str("TestONNXRuntime_opset12_IRv4_new_jit_API"),
|
|
(unittest.TestCase,),
|
|
dict(TestONNXRuntime.__dict__, opset_version=12,
|
|
keep_initializers_as_inputs=False,
|
|
use_new_jit_passes=True,
|
|
onnx_shape_inference=True))
|
|
|
|
|
|
# opset 12 tests, with _onnx_shape_inference=True.
|
|
TestONNXRuntime_opset12_onnx_shape_inference = type(str("TestONNXRuntime_opset12_onnx_shape_inference"),
|
|
(unittest.TestCase,),
|
|
dict(TestONNXRuntime.__dict__, opset_version=12,
|
|
onnx_shape_inference=True))
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|