pytorch/test/onnx/test_pytorch_onnx_onnxruntime.py
BowenBao 3da4cea658 [ONNX] Add dim_param support in export with onnx shape inference (#44920)
Summary:
* Support propagating `dim_param` in ONNX by encoding as `ShapeSymbol` in `SymbolicShape` of outputs. If export is called with `dynamic_axes` provided, shape inference will start with these axes set as dynamic.
* Add new test file `test_pytorch_onnx_shape_inference.py`, reusing all test cases from `test_pytorch_onnx_onnxruntime.py`, but focus on validating shape for all nodes in graph. Currently this is not enabled in the CI, since there are still quite some existing issues and corner cases to fix. The test is default to run only at opset 12.
* Bug fixes, such as div, _len, and peephole.cpp passes for PackPadded, and LogSoftmaxCrossEntropy.
* This PR depends on existing PR such as 44332.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44920

Reviewed By: eellison

Differential Revision: D23958398

Pulled By: bzinodev

fbshipit-source-id: 00479d9bd19c867d526769a15ba97ec16d56e51d
2020-09-30 21:56:24 -07:00

4818 lines
180 KiB
Python

import unittest
import onnxruntime # noqa
import torch
import numpy as np
import io
import itertools
import copy
from torch.nn.utils import rnn as rnn_utils
from model_defs.lstm_flattening_result import LstmFlatteningResult
from model_defs.rnn_model_with_packed_sequence import RnnModelWithPackedSequence
from test_pytorch_common import (skipIfUnsupportedMinOpsetVersion, disableScriptTest,
skipIfUnsupportedOpsetVersion, skipIfNoLapack,
skipIfUnsupportedMaxOpsetVersion, skipIfONNXShapeInference)
from test_pytorch_common import BATCH_SIZE
from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
from typing import List
import model_defs.word_language_model as word_language_model
import torchvision
import onnx
def to_numpy(tensor):
if tensor.requires_grad:
return tensor.detach().cpu().numpy()
else:
return tensor.cpu().numpy()
def convert_to_onnx(model, input=None, opset_version=9, example_outputs=None,
do_constant_folding=True, keep_initializers_as_inputs=True,
dynamic_axes=None, input_names=None, output_names=None,
fixed_batch_size=False, training=None,
onnx_shape_inference=False,
use_new_jit_passes=False):
# export the model to ONNX
f = io.BytesIO()
input_copy = copy.deepcopy(input)
torch.onnx._export(model, input_copy, f,
opset_version=opset_version,
example_outputs=example_outputs,
do_constant_folding=do_constant_folding,
keep_initializers_as_inputs=keep_initializers_as_inputs,
dynamic_axes=dynamic_axes,
input_names=input_names, output_names=output_names,
fixed_batch_size=fixed_batch_size, training=training,
onnx_shape_inference=onnx_shape_inference,
use_new_jit_passes=use_new_jit_passes)
# compute onnxruntime output prediction
ort_sess = onnxruntime.InferenceSession(f.getvalue())
return ort_sess
def run_ort(ort_sess, input):
input_copy = copy.deepcopy(input)
input, _ = torch.jit._flatten(input_copy)
inputs = list(map(to_numpy, input))
ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs))
ort_outs = ort_sess.run(None, ort_inputs)
return ort_outs
def ort_compare_with_pytorch(ort_outs, output, rtol, atol):
output, _ = torch.jit._flatten(output)
outputs = list(map(to_numpy, output))
# compare onnxruntime and PyTorch results
assert len(outputs) == len(ort_outs), "number of outputs differ"
# compare onnxruntime and PyTorch results
[np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)]
def run_model_test(self, model, batch_size=2, state_dict=None,
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
example_outputs=None, do_constant_folding=True,
dynamic_axes=None, test_with_inputs=None,
input_names=None, output_names=None,
fixed_batch_size=False):
model.eval()
if input is None:
input = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
with torch.no_grad():
if isinstance(input, torch.Tensor):
input = (input,)
# In-place operators will update input tensor data as well.
# Thus inputs are replicated before every forward call.
input_copy = copy.deepcopy(input)
output = model(*input_copy)
if isinstance(output, torch.Tensor):
output = (output,)
ort_sess = convert_to_onnx(model, input=input, opset_version=self.opset_version,
example_outputs=output, do_constant_folding=do_constant_folding,
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
dynamic_axes=dynamic_axes, input_names=input_names,
output_names=output_names, fixed_batch_size=fixed_batch_size, training=None,
onnx_shape_inference=self.onnx_shape_inference,
use_new_jit_passes=self.use_new_jit_passes)
ort_outs = run_ort(ort_sess, input)
ort_compare_with_pytorch(ort_outs, output, rtol, atol)
# if additional test inputs are provided run the onnx
# model with these inputs and check the outputs
if test_with_inputs is not None:
for test_input in test_with_inputs:
if isinstance(test_input, torch.Tensor):
test_input = (test_input,)
test_input_copy = copy.deepcopy(test_input)
output = model(*test_input_copy)
if isinstance(output, torch.Tensor):
output = (output,)
ort_outs = run_ort(ort_sess, test_input)
ort_compare_with_pytorch(ort_outs, output, rtol, atol)
class TestONNXRuntime(unittest.TestCase):
from torch.onnx.symbolic_helper import _export_onnx_opset_version
opset_version = _export_onnx_opset_version
keep_initializers_as_inputs = True # For IR version 3 type export.
use_new_jit_passes = False # For testing main code-path
onnx_shape_inference = False
def setUp(self):
torch.manual_seed(0)
onnxruntime.set_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
np.random.seed(seed=0)
self.is_script_test_enabled = True
def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=True,
batch_size=2, use_gpu=True, dynamic_axes=None, test_with_inputs=None,
input_names=None, output_names=None, fixed_batch_size=False):
def _run_test(m):
return run_model_test(self, m, batch_size=batch_size,
input=input, use_gpu=use_gpu, rtol=rtol, atol=atol,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes, test_with_inputs=test_with_inputs,
input_names=input_names, output_names=output_names,
fixed_batch_size=fixed_batch_size)
if self.is_script_test_enabled and self.use_new_jit_passes:
script_model = torch.jit.script(model)
_run_test(script_model)
_run_test(model)
def run_model_test_with_external_data(self, model, input, rtol=0.001, atol=1e-7,
example_outputs=None, do_constant_folding=True,
dynamic_axes=None, input_names=None, output_names=None,
ort_optim_on=True):
import os
import tempfile
model.eval()
with torch.no_grad():
if isinstance(input, torch.Tensor):
input = (input,)
# In-place operators will update input tensor data as well.
# Thus inputs are replicated before every forward call.
input_copy = copy.deepcopy(input)
output = model(*input_copy)
if isinstance(output, torch.Tensor):
output = (output,)
# export the model to ONNX
with tempfile.TemporaryDirectory() as tmpdirname:
model_file_name = os.path.join(tmpdirname, 'model.onnx')
input_copy = copy.deepcopy(input)
torch.onnx.export(model, input_copy, model_file_name,
opset_version=self.opset_version,
example_outputs=output,
verbose=False,
do_constant_folding=do_constant_folding,
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
dynamic_axes=dynamic_axes,
input_names=input_names, output_names=output_names,
use_external_data_format=True)
# compute onnxruntime output prediction
ort_sess_opt = onnxruntime.SessionOptions()
ort_sess_opt.graph_optimization_level = \
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED if ort_optim_on else \
onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
ort_sess = onnxruntime.InferenceSession(model_file_name, sess_options=ort_sess_opt)
input_copy = copy.deepcopy(input)
ort_outs = run_ort(ort_sess, input_copy)
ort_compare_with_pytorch(ort_outs, output, rtol, atol)
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
def test_embedding_model_with_external_data(self):
class LargeModel(torch.nn.Module):
def __init__(self):
super(LargeModel, self).__init__()
dim = 15
n = 4 * 100
self.emb = torch.nn.Embedding(n, dim)
self.lin1 = torch.nn.Linear(dim, 1)
self.seq = torch.nn.Sequential(
self.emb,
self.lin1,
)
def forward(self, input):
return self.seq(input)
model = LargeModel()
x = torch.tensor([2], dtype=torch.long)
self.run_model_test_with_external_data(model, x)
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
def test_mobilenet_v2_with_external_data(self):
model = torchvision.models.mobilenet_v2(pretrained=True)
x = torch.randn(2, 3, 224, 224, requires_grad=True)
# We are turning off Onnx Runtime optimization off in this test,
# because external data format is not supported to in ORT optimizer.
# Once that support is added, we can set ort_optim_on=True (default).
self.run_model_test_with_external_data(model, x, rtol=1e-3, atol=1e-5,
ort_optim_on=False)
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
def test_attribute_with_external_data(self):
class LargeModel(torch.nn.Module):
def forward(self, x):
return x + torch.ones(2, 1024)
x = torch.randn(2, 1)
self.run_model_test_with_external_data(LargeModel(), x)
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
@unittest.skip("Enable this once large model with subgraph is supported in ORT")
def test_subgraph_with_external_data(self):
class LargeModel(torch.nn.Module):
def forward(self, x):
for i in range(x.size(0)):
x = x + torch.ones(2, 1024)
return x
x = torch.randn(2, 1)
self.run_model_test_with_external_data(torch.jit.script(LargeModel()), x)
def test_fuse_conv_bn1d(self):
class Fuse(torch.nn.Module):
def __init__(self):
super(Fuse, self).__init__()
self.conv = torch.nn.Conv1d(16, 33, 3, stride=2)
self.bn = torch.nn.BatchNorm1d(33)
def forward(self, x):
out = self.conv(x)
return self.bn(out)
model = Fuse()
x = torch.randn(20, 16, 50, requires_grad=True)
self.run_test(model, (x,))
def test_fuse_conv_bn2d(self):
class Fuse(torch.nn.Module):
def __init__(self):
super(Fuse, self).__init__()
self.conv = torch.nn.Conv2d(3, 2, kernel_size=1, stride=2, padding=3, bias=False)
self.bn = torch.nn.BatchNorm2d(2)
def forward(self, x):
out = self.conv(x)
return self.bn(out)
model = Fuse()
x = torch.randn(2, 3, 2, 2, requires_grad=True)
self.run_test(model, (x,))
def test_fuse_conv_bn3d(self):
class Fuse(torch.nn.Module):
def __init__(self):
super(Fuse, self).__init__()
self.conv = torch.nn.Conv3d(3, 2, (3, 5, 2), stride=(2, 1, 1), padding=(3, 2, 0), bias=False)
self.bn = torch.nn.BatchNorm3d(2)
def forward(self, x):
out = self.conv(x)
return self.bn(out)
model = Fuse()
x = torch.randn(2, 3, 10, 50, 100, requires_grad=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-6)
def test_reshape_constant_fold(self):
class Reshape(torch.nn.Module):
def __init__(self, ):
super(Reshape, self).__init__()
self.register_buffer("weight", torch.ones(5))
def forward(self, x):
scale_1 = self.weight.reshape(1, -1, 1, 1)
return x * scale_1
x = torch.randn(4, 5)
self.run_test(Reshape(), (x,), rtol=1e-3, atol=1e-5)
def run_word_language_model(self, model_name):
ntokens = 50
emsize = 5
nhid = 5
nlayers = 5
dropout = 0.2
tied = False
batchsize = 5
model = word_language_model.RNNModel(model_name, ntokens, emsize,
nhid, nlayers, dropout, tied,
batchsize)
x = torch.arange(0, ntokens).long().view(-1, batchsize)
# Only support CPU version, since tracer is not working in GPU RNN.
self.run_test(model, (x, model.hidden))
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest() # Faster RCNN model is not scriptable
def test_faster_rcnn(self):
model = torchvision.models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200,
max_size=300)
model.eval()
x = torch.randn(2, 3, 200, 300, requires_grad=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
def get_image_from_url(self, url):
import os
from urllib.parse import urlsplit
from urllib import request
from PIL import Image
from torchvision import transforms
from torch._utils_internal import get_writable_path
filename = os.path.basename(urlsplit(url)[2])
data_dir = get_writable_path(os.path.join(os.path.dirname(__file__)))
path = os.path.join(data_dir, filename)
data = request.urlopen(url, timeout=15).read()
with open(path, 'wb') as f:
f.write(data)
image = Image.open(path).convert("RGB")
image = image.resize((300, 200), Image.BILINEAR)
to_tensor = transforms.ToTensor()
return to_tensor(image)
def get_test_images(self):
image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
image = self.get_image_from_url(url=image_url)
images = [image]
return images
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
def test_mask_rcnn(self):
model = torchvision.models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200,
max_size=300)
images = self.get_test_images()
self.run_test(model, (images,), rtol=1e-3, atol=1e-5)
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
def test_keypoint_rcnn(self):
model = torchvision.models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200,
max_size=300)
images = self.get_test_images()
self.run_test(model, (images,), rtol=1e-3, atol=1e-5)
@disableScriptTest()
def test_word_language_model_RNN_TANH(self):
self.run_word_language_model("RNN_TANH")
@disableScriptTest()
def test_word_language_model_RNN_RELU(self):
self.run_word_language_model("RNN_RELU")
@disableScriptTest()
def test_word_language_model_LSTM(self):
self.run_word_language_model("LSTM")
@disableScriptTest()
def test_word_language_model_GRU(self):
self.run_word_language_model("GRU")
def test_index_1d(self):
class MyModel(torch.nn.Module):
def forward(self, input):
return input[0]
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), m1)
def test_index_2d_1dimslice(self):
class MyModel(torch.nn.Module):
def forward(self, input):
return input[0:1, :]
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), m1)
def test_index_2d_sliceint(self):
class MyModel(torch.nn.Module):
def forward(self, input):
return input[1, :]
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), m1)
def test_index_2d_neg_slice(self):
class MyModel(torch.nn.Module):
def forward(self, input):
return input[0:-1, :]
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), m1)
@skipIfUnsupportedMinOpsetVersion(9)
def test_index_mask(self):
class MyModel(torch.nn.Module):
def forward(self, input):
return input[torch.tensor([0, 1, 0], dtype=torch.uint8)]
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), m1)
class MyModel(torch.nn.Module):
def forward(self, input):
return input[torch.tensor([0, 1, 0], dtype=torch.bool)]
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), m1)
@disableScriptTest()
def test_dict(self):
class MyModel(torch.nn.Module):
def forward(self, x_in):
x_out = {}
x_out["test_key_out"] = torch.add(x_in[list(x_in.keys())[0]], list(x_in.keys())[0])
return x_out
x = {torch.tensor(1.): torch.randn(1, 2, 3)}
self.run_test(MyModel(), (x,))
@disableScriptTest()
def test_dict_str(self):
class MyModel(torch.nn.Module):
def forward(self, x_in):
x_out = {}
x_out["test_key_out"] = torch.add(x_in["test_key_in"], 2.)
return x_out
x = {"test_key_in": torch.randn(1, 2, 3)}
self.run_test(MyModel(), (x,))
@skipIfUnsupportedMinOpsetVersion(9)
def test_cste_script(self):
class MyModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return torch.zeros(x.size(0)), torch.ones((x.size(1), x.size(0)), dtype=torch.int64)
x = torch.randn(3, 4)
self.run_test(MyModel(), x)
def test_scalar_tensor(self):
class test(torch.nn.Module):
def forward(self, input):
return torch.scalar_tensor(input.size(0)), \
torch.scalar_tensor(input.size(1), dtype=torch.int64)
x = torch.randn(2, 3, 4)
y = torch.randn(7, 8, 9)
model = test()
self.run_test(model, x, test_with_inputs=[y],
input_names=['input_1'],
dynamic_axes={'input_1': [0, 1, 2]})
def test_tensor(self):
class ScalarInputModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
return torch.tensor(input.shape[1])
x = torch.randn(3, 4)
self.run_test(ScalarInputModel(), x)
class TensorInputModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
return torch.tensor([input.shape[0], input.shape[1]])
x = torch.randn(3, 4)
self.run_test(TensorInputModel(), x)
class FloatInputModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
return torch.tensor([float(input)])
x = torch.randn(1)
self.run_test(FloatInputModel(), x)
class InputWithDtypeModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
return torch.tensor(input.shape[1], dtype=torch.long)
x = torch.randn(3, 4)
self.run_test(InputWithDtypeModel(), x)
class MixedInputModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
return torch.tensor([input.shape[0], int(input)])
x = torch.randn(1)
self.run_test(MixedInputModel(), x)
def test_hardtanh(self):
model = torch.nn.Hardtanh(-1.5, 2.5)
x = torch.arange(-5, 5).to(dtype=torch.float32)
self.run_test(model, x)
def test_hardtanh_script_with_default_values(self):
class MyModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return torch.nn.functional.hardtanh(x)
x = torch.arange(-5, 5).to(dtype=torch.float32)
self.run_test(MyModel(), x)
def test_clamp(self):
class ClampModel(torch.nn.Module):
def forward(self, x):
return x.clamp(-0.5, 0.5)
x = torch.randn(3, 4)
self.run_test(ClampModel(), x)
class ClampMinModel(torch.nn.Module):
def forward(self, x):
return x.clamp(min=-0.5)
x = torch.randn(3, 4)
self.run_test(ClampMinModel(), x)
class ClampMaxModel(torch.nn.Module):
def forward(self, x):
return x.clamp(max=0.5)
x = torch.randn(3, 4)
self.run_test(ClampMaxModel(), x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_clamp_dyn(self):
class ClampMaxModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return x.clamp(None, x.size(0))
x = torch.arange(16).view(4, 4).float()
self.run_test(ClampMaxModel(), x)
class ClampMinModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return x.clamp(x.size(0), None)
x = torch.arange(16).view(4, 4).float()
self.run_test(ClampMinModel(), x)
class ClampMinMaxModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return x.clamp(x.size(0), x.size(1))
x = torch.arange(16).view(2, 8).float()
self.run_test(ClampMinMaxModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_full_trace(self):
class FullModel(torch.nn.Module):
def forward(self, x):
return torch.full((3, 4), x, dtype=torch.long)
x = torch.tensor(12)
self.run_test(FullModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_full_script(self):
class FullModelScripting(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return torch.full((3, 4), x, dtype=torch.long)
x = torch.tensor(12)
self.run_test(FullModelScripting(), x)
def test_fuse_addmm(self):
class AddmmModel(torch.nn.Module):
def forward(self, x):
return torch.mm(x, x) + x
x = torch.ones(3, 3)
self.run_test(AddmmModel(), x)
def test_maxpool(self):
model = torch.nn.MaxPool1d(2, stride=1)
x = torch.randn(20, 16, 50)
self.run_test(model, x)
def test_conv(self):
class TraceModel(torch.nn.Module):
def __init__(self):
super(TraceModel, self).__init__()
self.conv1 = torch.nn.Conv1d(16, 33, 3, stride=2)
self.conv2 = torch.nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
self.conv3 = torch.nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
def forward(self, input1, input2, input3):
return self.conv1(input1), self.conv2(input2), self.conv3(input3)
class ScriptModel(torch.jit.ScriptModule):
def __init__(self):
super(ScriptModel, self).__init__()
self.conv1 = torch.nn.Conv1d(16, 33, 3, stride=2)
self.conv2 = torch.nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
self.conv3 = torch.nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
@torch.jit.script_method
def forward(self, input1, input2, input3):
return self.conv1(input1), self.conv2(input2), self.conv3(input3)
x1 = torch.randn(20, 16, 50)
x2 = torch.randn(20, 16, 50, 100)
x3 = torch.randn(20, 16, 10, 50, 100)
self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
self.run_test(ScriptModel(), (x1, x2, x3), atol=10e-5)
def test_conv_shape_inference(self):
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv2 = torch.nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
def forward(self, input):
return self.conv2(input) + 2
x = torch.randn(20, 16, 50, 100)
self.run_test(Model(), x, atol=10e-5,
input_names=['x'],
dynamic_axes={'x': [0]})
def test_conv_transpose(self):
class TraceModel(torch.nn.Module):
def __init__(self):
super(TraceModel, self).__init__()
self.conv1 = torch.nn.ConvTranspose1d(16, 33, 3, stride=2)
self.conv2 = torch.nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
self.conv3 = torch.nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
def forward(self, input1, input2, input3):
return self.conv1(input1), self.conv2(input2), self.conv3(input3)
class ScriptModel(torch.jit.ScriptModule):
def __init__(self):
super(ScriptModel, self).__init__()
self.conv1 = torch.nn.ConvTranspose1d(16, 33, 3, stride=2)
self.conv2 = torch.nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
self.conv3 = torch.nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
@torch.jit.script_method
def forward(self, input1, input2, input3):
return self.conv1(input1), self.conv2(input2), self.conv3(input3)
x1 = torch.randn(20, 16, 50)
x2 = torch.randn(20, 16, 50, 100)
x3 = torch.randn(20, 16, 10, 50, 100)
self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
self.run_test(ScriptModel(), (x1, x2, x3), atol=10e-5)
# Conversion of Transpose depends on input shape to be known.
# The following test only works when onnx shape inference is enabled.
@skipIfONNXShapeInference(False)
def test_transpose_infer_shape(self):
class TransposeModule(torch.jit.ScriptModule):
def __init__(self):
super(TransposeModule, self).__init__()
self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)
@torch.jit.script_method
def forward(self, x):
x = self.conv(x)
return x.transpose(0, 1)
x = torch.randn(32, 3, 64, 64)
self.run_test(TransposeModule(), x)
def squeeze_model_tests(self, d, x1, x2):
class Squeeze(torch.nn.Module):
def __init__(self, d):
super(Squeeze, self).__init__()
self.d = d
def forward(self, x):
if self.d is not None:
return torch.squeeze(x, dim=self.d)
else:
return torch.squeeze(x)
x2 = [] if x2 is None else [x2]
self.run_test(Squeeze(d), x1, input_names=['input'], dynamic_axes={'input': {0: '0', 1: '1', 2: '2'}}, test_with_inputs=x2)
def test_squeeze_without_no_op(self):
x = torch.randn(2, 1, 4)
self.squeeze_model_tests(1, x, None)
@skipIfUnsupportedMinOpsetVersion(11)
def test_squeeze(self):
x_squeeze = torch.randn(2, 1, 4)
x_noop = torch.randn(2, 2, 3)
self.squeeze_model_tests(1, x_squeeze, x_noop)
def test_squeeze_neg_without_no_op(self):
x = torch.randn(2, 1, 4)
self.squeeze_model_tests(-2, x, None)
@skipIfUnsupportedMinOpsetVersion(11)
def test_squeeze_neg(self):
x_squeeze = torch.randn(2, 1, 4)
x_noop = torch.randn(2, 2, 3)
self.squeeze_model_tests(-2, x_squeeze, x_noop)
def test_squeeze_all_dims(self):
x_squeeze = torch.randn(2, 1, 4)
x_noop = torch.randn(2, 2, 3)
self.squeeze_model_tests(None, x_squeeze, x_noop)
@skipIfUnsupportedMinOpsetVersion(11)
def test_squeeze_no_op(self):
x_noop = torch.randn(2, 1, 4)
x_squeeze = torch.randn(2, 2, 1)
self.squeeze_model_tests(2, x_noop, x_squeeze)
def test_squeeze_no_op_without_additional_inputs(self):
x_noop = torch.randn(2, 1, 4)
self.squeeze_model_tests(2, x_noop, None)
@skipIfUnsupportedMinOpsetVersion(11)
def test_squeeze_runtime_dim(self):
class Squeeze(torch.nn.Module):
def forward(self, d1, d2):
t = torch.zeros(d1[0], d2[0])
return t.squeeze(0)
d1 = torch.tensor([1])
d3 = torch.tensor([3])
d4 = torch.tensor([4])
self.run_test(Squeeze(), (d1, d4), test_with_inputs=[(d3, d4)])
self.run_test(Squeeze(), (d3, d4), test_with_inputs=[(d1, d3)])
def test_unsqueeze(self):
class Unsqueeze(torch.nn.Module):
def forward(self, x):
return torch.unsqueeze(x, dim=-2)
x = torch.randn(2, 3, 4)
self.run_test(Unsqueeze(), x)
def test_maxpool_default_stride(self):
class MaxPoolModel(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.max_pool2d(x, 2)
model = MaxPoolModel()
x = torch.randn(10, 20, 16, 50)
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(8)
def test_maxpool_adaptive(self):
model = torch.nn.AdaptiveMaxPool1d((5), return_indices=False)
x = torch.randn(20, 16, 50, requires_grad=True)
self.run_test(model, x)
def test_maxpool_2d(self):
model = torch.nn.MaxPool2d(5, padding=(1, 2))
x = torch.randn(1, 20, 16, 50, requires_grad=True)
self.run_test(model, x)
def test_maxpool_1d_ceil(self):
model = torch.nn.MaxPool1d(3, 2, ceil_mode=True)
x = torch.randn(20, 16, 50)
self.run_test(model, x)
def test_maxpool_2d_ceil(self):
model = torch.nn.MaxPool2d(3, 2, ceil_mode=True)
x = torch.randn(20, 16, 50, 32)
self.run_test(model, x)
def test_maxpool_3d_ceil(self):
model = torch.nn.MaxPool3d(3, 2, ceil_mode=True)
x = torch.randn(20, 16, 50, 44, 31)
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(8)
@disableScriptTest() # Functional module not scriptable
def test_maxpool_with_indices(self):
model = torch.nn.MaxPool1d(2, stride=1, return_indices=True)
x = torch.randn(20, 16, 50)
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(10)
def test_maxpool_dilation(self):
model = torch.nn.MaxPool1d(2, stride=1, dilation=2)
x = torch.randn(20, 16, 50)
self.run_test(model, x)
def test_avgpool_default_stride(self):
class AvgPoolModel(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.avg_pool2d(x, 2)
model = AvgPoolModel()
x = torch.randn(10, 20, 16, 50)
self.run_test(model, x)
def test_avgpool(self):
model = torch.nn.AvgPool1d(2, stride=1)
x = torch.randn(20, 16, 50)
self.run_test(model, x)
def test_avgpool_1d_ceil(self):
model = torch.nn.AvgPool1d(3, 2, ceil_mode=True)
x = torch.randn(1, 1, 7)
self.run_test(model, x)
def test_avgpool_2d_ceil(self):
model = torch.nn.AvgPool2d(3, 2, ceil_mode=True)
x = torch.randn(20, 16, 50, 32)
self.run_test(model, x)
def test_avgpool_3d_ceil(self):
model = torch.nn.AvgPool3d(3, 2, ceil_mode=True)
x = torch.randn(20, 16, 50, 44, 31)
self.run_test(model, x)
def test_arithmetic(self):
class ArithmeticModule(torch.nn.Module):
def forward(self, x):
x = x + 2
x = x - 4
x = x * 6
x = x / 8
return x
x = torch.randn(2, 3, 4)
self.run_test(ArithmeticModule(), x)
# In scripting the first transpose node do not carry shape and dtype info.
# The following test only works when onnx shape inference is enabled.
@skipIfONNXShapeInference(False)
def test_arithmetic_infer_dtype(self):
class ArithmeticModule(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
x = x.t()
x = x + 2
x = x - 4
x = x * 6
x = x / 8
return x
x = torch.randn(2, 3)
self.run_test(ArithmeticModule(), x)
def test_floor_div(self):
class FloorDivModule(torch.nn.Module):
def forward(self, x, y):
return x // 3, x // 2., \
x.to(dtype=torch.float64) // 3, x.to(dtype=torch.float64) // 2., \
x.to(dtype=torch.int64) // 3, x.to(dtype=torch.int64) // 2., \
x // (y + 1.).to(dtype=torch.int64), x // y, \
x.to(dtype=torch.float64) // y.to(dtype=torch.int64), x.to(dtype=torch.float64) // y.to(dtype=torch.float64), \
x.to(dtype=torch.int64) // y.to(dtype=torch.int64), x.to(dtype=torch.int64) // y
x = torch.randn(2, 3, 4)
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4)
self.run_test(FloorDivModule(), (x, y))
def test_floor_div_script(self):
class FloorDivModule(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, y):
return x // 3, x // 2., x // y
x = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 4)
self.run_test(FloorDivModule(), (x, y))
@skipIfUnsupportedMinOpsetVersion(9)
def test_floordiv(self):
class FloordivModule(torch.nn.Module):
def forward(self, x):
return x.new_zeros(x.size(2) // x.size(1))
x = torch.randn(2, 3, 4)
self.run_test(FloordivModule(), (x,))
def test_div(self):
class DivModule(torch.nn.Module):
def forward(self, x, y):
return x / y, torch.true_divide(x, y)
x = torch.randn(2, 3, 4).to(torch.int)
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
self.run_test(DivModule(), (x, y))
self.run_test(DivModule(), (x.float(), y.float()))
# Note: div cannot (generally) be exported via scripting
# since its type promotion logic is dependent on knowing the scalar types
# of the input tensors. That is, the ONNX graph is dependent on the
# data type of the inputs. This makes it appropriate for tracing only.
def test_div_promotion_trace(self):
class DivModule(torch.nn.Module):
def forward(self, x, y):
return x / y, torch.true_divide(x, y)
x = torch.randn(2, 3, 4).to(torch.int)
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
prev_default = torch.get_default_dtype()
torch.set_default_dtype(torch.float)
self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
torch.set_default_dtype(torch.double)
self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
torch.set_default_dtype(prev_default)
# In scripting x, y do not carry shape and dtype info.
# The following test only works when onnx shape inference is enabled.
@skipIfONNXShapeInference(False)
def test_div_promotion_script(self):
class DivModule(torch.nn.Module):
def forward(self, x, y):
# Add transpose to hide shape/type information
# Otherwise shape and type are still avaiable from input.
x = x.transpose(1, 2)
y = y.transpose(1, 2)
return x / y, torch.true_divide(x, y)
x = torch.randn(2, 3, 4).to(torch.int)
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
prev_default = torch.get_default_dtype()
# 1. x,y are int, and output is float.
# This can be handled by the default case, where both are cast to float.
# It works even if type of x, y are unknown.
torch.set_default_dtype(torch.float)
self.run_test(torch.jit.script(DivModule()), (x, y))
# 2. x,y are int, and output is double.
# This can be handled by the default case, where both are cast to double.
# It works even if type of x, y are unknown.
torch.set_default_dtype(torch.double)
self.run_test(torch.jit.script(DivModule()), (x, y))
# 3. x is int, y is double, and output is double.
# This can only be handled when both type of x and y are known.
torch.set_default_dtype(prev_default)
x = torch.randn(2, 3, 4).to(torch.int)
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.double)
self.run_test(torch.jit.script(DivModule()), (x, y))
def test_slice_trace(self):
class MyModule(torch.nn.Module):
def forward(self, x):
return x[0:1]
x = torch.randn(3)
self.run_test(MyModule(), x)
def test_slice_neg(self):
class NegSlice(torch.nn.Module):
def forward(self, x):
return x[-1:]
x = torch.randn(3, 4, 5)
self.run_test(NegSlice(), x)
def test_slice_neg_large(self):
class NegSlice(torch.nn.Module):
def forward(self, x):
return x[:, :, -3:-1, :, -1]
x = torch.randn(3, 4, 5, 6, 7)
self.run_test(NegSlice(), x)
def test_slice_neg_large_negone(self):
class NegSlice(torch.nn.Module):
def forward(self, x):
return x[:, :, :, :, -1]
x = torch.randn(3, 4, 5, 6, 7)
self.run_test(NegSlice(), x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_slice_with_input_index(self):
class InputIndexSlice(torch.nn.Module):
def forward(self, x, y):
x[:y.size(0), 0, :] = y
return x
x = torch.zeros((56, 6, 256))
y = torch.rand((22, 256))
self.run_test(InputIndexSlice(), (x, y))
@skipIfUnsupportedMinOpsetVersion(10)
@disableScriptTest() # scripting tuple/list append
def test_slice_dynamic(self):
class DynamicSliceExportMod(torch.nn.Module):
def forward(self, x):
results = []
for i in range(4):
results.append(x[:x.size(0) - i, i:x.size(2), i:3])
return tuple(results)
x = torch.rand(5, 5, 5)
y = torch.randn(6, 7, 8)
self.run_test(DynamicSliceExportMod(), x, test_with_inputs=[y],
input_names=['input_1'],
output_names=['output_1'],
dynamic_axes={'input_1': [0, 1, 2],
'output_1': [0, 1, 2]})
@skipIfUnsupportedMinOpsetVersion(10)
def test_slice_dynamic_script(self):
class DynamicSliceModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return x[1:x.size(1)]
x = torch.rand(1, 2)
self.run_test(DynamicSliceModel(), x)
@skipIfUnsupportedMinOpsetVersion(10)
def test_slice_dynamic_shape_script(self):
class DynamicSliceModel(torch.nn.Module):
def forward(self, x):
return x.new_zeros(x.shape[1:x.size(2)])
x = torch.rand(1, 2, 3, 4)
self.run_test(DynamicSliceModel(), x)
@skipIfUnsupportedMinOpsetVersion(10)
@disableScriptTest() # scripting tuple/list append
def test_slice_dynamic_to_end(self):
class DynamicSliceExportMod(torch.nn.Module):
def forward(self, x):
results = []
for i in range(4):
results.append(x[:, i:, x.size(2) - 5])
return tuple(results)
x = torch.rand(5, 5, 5)
self.run_test(DynamicSliceExportMod(), x,
dynamic_axes={'input_1': [0, 1, 2],
'output_1': [0, 1, 2]})
@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_dynamic(self):
class ArangeModel(torch.nn.Module):
def forward(self, input):
return torch.arange(input.shape[0]), \
torch.arange(12), \
torch.arange(start=input.shape[0], end=input.shape[0] + 5)
x = torch.randn(5, 3, 2)
y = torch.randn(8, 3, 2)
self.run_test(ArangeModel(), x, test_with_inputs=[y],
input_names=['input_1'],
output_names=['output_1', 'output_2', 'output_3'],
dynamic_axes={'input_1': [0],
'output_1': [0]})
self.run_test(torch.jit.script(ArangeModel()), x,
test_with_inputs=[y], input_names=['input_1'],
output_names=['output_1', 'output_2', 'output_3'],
dynamic_axes={'input_1': [0],
'output_1': [0]})
@skipIfUnsupportedMinOpsetVersion(9)
def test_dynamic_arange_out(self):
class ArangeOutModel(torch.nn.Module):
def forward(self, end):
out_t = torch.tensor([1], dtype=torch.int64)
return torch.arange(end, out=out_t)
x = torch.tensor(8)
self.run_test(ArangeOutModel(), (x))
@skipIfUnsupportedMinOpsetVersion(9)
def test_dynamic_arange_start_out(self):
class ArangeStartOutModel(torch.nn.Module):
def forward(self, start, end):
out_t = torch.tensor([1], dtype=torch.int64)
return torch.arange(start.size(0), end, out=out_t)
x = torch.randn(2, 3, 4)
y = torch.tensor(8)
self.run_test(ArangeStartOutModel(), (x, y))
@skipIfUnsupportedMinOpsetVersion(11)
def test_arange(self):
class ArangeModel(torch.nn.Module):
def forward(self, start, end):
return torch.arange(start.size(0), end, 1.5, dtype=torch.int64)
x = torch.randn(2, 3, 4)
y = torch.tensor(8.5, dtype=torch.float)
self.run_test(ArangeModel(), (x, y))
@skipIfUnsupportedMinOpsetVersion(11)
def test_arange_out(self):
class ArangeOutModel(torch.nn.Module):
def forward(self, end):
out_t = torch.tensor([1], dtype=torch.float)
return torch.arange(end, out=out_t)
x = torch.tensor(8.5, dtype=torch.float)
self.run_test(ArangeOutModel(), (x))
@skipIfUnsupportedMinOpsetVersion(11)
def test_arange_start_out(self):
class ArangeStartOutModel(torch.nn.Module):
def forward(self, start, end):
out_t = torch.tensor([1], dtype=torch.float)
return torch.arange(start.size(0), end, out=out_t)
x = torch.randn(2, 3, 4)
y = torch.tensor(8.5, dtype=torch.float)
self.run_test(ArangeStartOutModel(), (x, y))
@skipIfUnsupportedMinOpsetVersion(11)
def test_arange_no_type(self):
class ArangeModel(torch.nn.Module):
def forward(self, end):
return torch.arange(end), \
torch.arange(0, end)
x = torch.tensor(6.2, dtype=torch.float)
self.run_test(ArangeModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_size(self):
class SizeModel(torch.nn.Module):
def forward(self, input):
return torch.arange(input.size(0)), torch.arange(input.size(-1)), torch.ones(input.shape)
x = torch.randn(5, 3, 2)
self.run_test(SizeModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
@disableScriptTest() # x.stride() not scriptable
def test_as_strided(self):
class Model(torch.nn.Module):
def forward(self, x):
chunk_size = list(x.size())
chunk_size[1] = chunk_size[1] * 2 - 1
chunk_stride = list(x.stride())
chunk_stride[1] = chunk_stride[1] // 2
return x.as_strided((3, 3, 3), (1, 4, 2), storage_offset=2), x.as_strided(chunk_size, chunk_stride)
x = torch.randn(5, 8, 7)
self.run_test(Model(), x)
@disableScriptTest() # Ellipses followed by tensor indexing not scriptable
def test_tensor_index_advanced_indexing_ellipsis(self):
class MyModel(torch.nn.Module):
def forward(self, input):
return input[..., torch.tensor([2, 1]), torch.tensor([0, 3])]
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), (m1,))
def test_tensor_index_advanced_indexing(self):
class MyModel(torch.nn.Module):
def forward(self, input):
return input[:, torch.tensor([[0, 2], [1, 1]]), :, torch.tensor([2, 1]), torch.tensor([0, 3])]
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), (m1,))
class MyModel(torch.nn.Module):
def forward(self, input):
return input[:, torch.tensor([0, 2]), None, 2:4, torch.tensor([[1, 3], [4, 0]])]
self.run_test(MyModel(), (m1,))
class MyModel(torch.nn.Module):
def forward(self, input):
return input[:, torch.tensor([0, 2]), torch.tensor([1]), 2:4, torch.tensor([[1], [4]])]
self.run_test(MyModel(), (m1,))
def test_tensor_index_advanced_indexing_consecutive(self):
class MyModel(torch.nn.Module):
def forward(self, input):
return input[:, torch.tensor([0, 2]), torch.tensor([[1, 3], [4, 0]]), None]
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), (m1,))
@skipIfUnsupportedMinOpsetVersion(11)
def test_index_put(self):
class IndexPutModel(torch.nn.Module):
def forward(self, x, ind, update):
x[ind] = update
return x
x = torch.randn(3, 4)
ind = torch.tensor([1], dtype=torch.long)
update = torch.ones(4)
self.run_test(IndexPutModel(), (x, ind, update))
@skipIfUnsupportedMinOpsetVersion(11)
def test_index_put_accumulate(self):
class IndexPutModel(torch.nn.Module):
def forward(self, x, ind, update):
return x.index_put((ind, ), update, accumulate=True)
x = torch.randn(3, 4)
ind = torch.tensor([2], dtype=torch.long)
update = torch.ones(4)
self.run_test(IndexPutModel(), (x, ind, update))
@skipIfUnsupportedMinOpsetVersion(11)
def test_index_put_slice_index(self):
class IndexPutModel(torch.nn.Module):
def forward(self, x, update):
x[1:2, 1:3, torch.tensor([1])] += update
return x
x = torch.randn(3, 4, 5)
update = torch.tensor([10, 15]).view(1, 2, 1)
self.run_test(IndexPutModel(), (x, update))
class IndexPutModel2(torch.nn.Module):
def forward(self, x, update):
x[torch.tensor([0, 2]), torch.tensor([1, 2])] += update
return x
x = torch.randn(3, 4, 5)
update = torch.randn(2, 5)
self.run_test(IndexPutModel2(), (x, update))
class IndexPutModel3(torch.nn.Module):
def forward(self, x, update):
x[torch.tensor([0, 2]), 1:2] += update
return x
x = torch.randn(3, 4, 5)
update = torch.tensor([10, 15]).view(2, 1, 1)
self.run_test(IndexPutModel3(), (x, update))
class IndexPutModel4(torch.nn.Module):
def forward(self, x, update):
x[torch.tensor([0, 2]), 2] += update
return x
x = torch.randn(3, 4, 5)
update = torch.tensor([10, 15]).view(2, 1)
self.run_test(IndexPutModel4(), (x, update))
class IndexPutModel5(torch.nn.Module):
def forward(self, x, update):
x[1:3, torch.tensor([0, 2]), 2] += update
return x
x = torch.randn(3, 4, 5)
update = torch.tensor([10, 15]).view(2, 1)
self.run_test(IndexPutModel5(), (x, update))
class IndexPutModel6(torch.nn.Module):
def forward(self, x, update):
x[1:3, 0] = update
return x
x = torch.randn(3, 4, 5)
update = torch.arange(2 * 5).to(torch.float).view(2, 5)
self.run_test(IndexPutModel6(), (x, update))
class IndexPutModel7(torch.nn.Module):
def forward(self, x, update):
x[1:, 0] = update
return x
x = torch.randn(3, 4, 5)
update = torch.arange(2 * 5).to(torch.float).view(2, 5)
self.run_test(IndexPutModel7(), (x, update))
class IndexPutModel8(torch.nn.Module):
def forward(self, x, update):
x[:3, 0] = update
return x
x = torch.randn(3, 4, 5)
update = torch.arange(3 * 5).to(torch.float).view(3, 5)
self.run_test(IndexPutModel8(), (x, update))
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest() # Ellipses followed by tensor indexing not scriptable
def test_index_put_ellipsis(self):
class IndexPutModel(torch.nn.Module):
def forward(self, x, update):
x[..., torch.tensor([2, 1, 3]), 2:4] += update
return x
x = torch.randn(3, 4, 5, 6, 7)
update = torch.randn(3, 1, 1, 3, 2)
self.run_test(IndexPutModel(), (x, update))
class IndexPutModel2(torch.nn.Module):
def forward(self, x, update):
x[2, ..., torch.tensor([2, 1, 3]), 2:4] += update
return x
x = torch.randn(3, 4, 5, 6, 7)
update = torch.randn(4, 1, 3, 2)
self.run_test(IndexPutModel2(), (x, update))
@skipIfUnsupportedMinOpsetVersion(11)
def test_copy_(self):
class CopyModel(torch.nn.Module):
def forward(self, x, data):
x[1:3] = data
return x
x = torch.randn(3, 4)
update = torch.randn(2, 4)
self.run_test(CopyModel(), (x, update))
# mixed slice and select
class CopyModel2(torch.nn.Module):
def forward(self, x, data):
x[1:3, 0] = data
return x
x = torch.randn(3, 4)
update = torch.tensor([0], dtype=torch.float32)
self.run_test(CopyModel2(), (x, update))
update = torch.tensor([2, 3], dtype=torch.float32)
self.run_test(CopyModel2(), (x, update))
update = torch.randn(2)
self.run_test(CopyModel2(), (x, update))
class CopyModel3(torch.nn.Module):
def forward(self, x, data):
x[1, 1:3] = data
return x
x = torch.randn(3, 4)
update = torch.tensor([0], dtype=torch.float32)
self.run_test(CopyModel3(), (x, update))
update = torch.tensor([2, 3], dtype=torch.float32)
self.run_test(CopyModel3(), (x, update))
update = torch.randn(2)
self.run_test(CopyModel3(), (x, update))
class CopyModel4(torch.nn.Module):
def forward(self, x, ind, data):
x[ind] = data
return x
x = torch.randn(3, 4)
ind = torch.tensor(2)
data = torch.randn(4)
self.run_test(CopyModel4(), (x, ind, data))
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest() # Model not scriptable (output with shape doesn't match the broadcast shape)
def test_copy_tracing(self):
class CopyModel(torch.nn.Module):
def forward(self, x, data):
x[1, 1:3] = data
return x
x = torch.randn(3, 4)
update = torch.randn(1, 2)
self.run_test(CopyModel(), (x, update))
@skipIfUnsupportedMinOpsetVersion(11)
def test_copy_ellipsis(self):
class CopyModel(torch.nn.Module):
def forward(self, x, update):
x[..., 1] = update
return x
x = torch.randn(2, 3, 4)
update = torch.ones(1)
self.run_test(CopyModel(), (x, update))
x = torch.randn(2, 3, 4, 5, 6)
update = torch.ones(1)
self.run_test(CopyModel(), (x, update))
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest() # Missing input size (with ellipsis indexing)
def test_copy_ellipsis_tracing(self):
class CopyModel(torch.nn.Module):
def forward(self, x, update):
x[2, ..., 1:3] = update
return x
x = torch.randn(3, 4, 5, 6)
update = torch.ones(1)
self.run_test(CopyModel(), (x, update))
@skipIfUnsupportedMinOpsetVersion(10)
def test_flip(self):
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.flip(x, dims=[0])
x = torch.tensor(np.arange(6.0).reshape(2, 3))
self.run_test(MyModule(), x)
def test_random(self):
class RandN(torch.nn.Module):
def forward(self, x):
return torch.mul(x, (torch.randn(2, 3, 4) + x).size(0))
x = torch.randn(2, 3, 4)
self.run_test(RandN(), x)
class Rand(torch.nn.Module):
def forward(self, x):
return torch.mul(x, (torch.rand(2, 3, 4) + x).size(0))
x = torch.randn(2, 3, 4)
self.run_test(Rand(), x)
@skipIfUnsupportedMinOpsetVersion(9)
@disableScriptTest() # symbolic update for randn
def test_random_dynamic_size(self):
class RandN(torch.nn.Module):
def forward(self, x):
return torch.mul(x, torch.randn(x.size()).size(1))
x = torch.randn(2, 3, 4)
self.run_test(RandN(), x)
class Rand(torch.nn.Module):
def forward(self, x):
return torch.mul(x, torch.rand(x.size()).size(1))
x = torch.randn(2, 3, 4)
self.run_test(Rand(), x)
def test_random_like(self):
class RandNLike(torch.nn.Module):
def forward(self, x):
return torch.mul(x, torch.randn_like(x).size(0))
x = torch.randn(2, 3, 4)
self.run_test(RandNLike(), x)
self.run_test(torch.jit.script(RandNLike()), x)
class RandLike(torch.nn.Module):
def forward(self, x):
return torch.mul(x, torch.rand_like(x).size(0))
x = torch.randn(2, 3, 4)
self.run_test(RandLike(), x)
self.run_test(torch.jit.script(RandLike()), x)
def test_random_like_dtype(self):
class RandNLike(torch.nn.Module):
def forward(self, x):
return torch.mul(x.to(torch.double), torch.randn_like(x, dtype=torch.double).size(0))
x = torch.randn(2, 3, 4)
self.run_test(RandNLike(), x)
class RandLike(torch.nn.Module):
def forward(self, x):
return torch.mul(x.to(torch.double), torch.rand_like(x, dtype=torch.double).size(0))
x = torch.randn(2, 3, 4)
self.run_test(RandLike(), x)
def _interpolate(self, x, mode, use_size, is_upsample, align_corners=False):
class MyModel(torch.nn.Module):
def forward(self, x):
scale = 2.3 if is_upsample else 0.5
if len(x.size()) == 3:
scale_array = 2.3
if len(x.size()) == 4:
scale_array = [2.3, 5.1]
if len(x.size()) == 5:
scale_array = [3.3, 2.3, 5.1]
if use_size:
size_array = [int(float(v) * scale) for v in x.size()[2:]]
if align_corners:
return torch.nn.functional.interpolate(x, mode=mode, size=size_array[0], align_corners=True), \
torch.nn.functional.interpolate(x, mode=mode, size=size_array, align_corners=True)
return torch.nn.functional.interpolate(x, mode=mode, size=size_array[0]), \
torch.nn.functional.interpolate(x, mode=mode, size=size_array)
if align_corners:
return torch.nn.functional.interpolate(x, mode=mode, scale_factor=scale,
align_corners=True, recompute_scale_factor=False), \
torch.nn.functional.interpolate(x, mode=mode, scale_factor=scale_array,
align_corners=True, recompute_scale_factor=False)
return torch.nn.functional.interpolate(x, mode=mode,
scale_factor=scale, recompute_scale_factor=False), \
torch.nn.functional.interpolate(x, mode=mode,
scale_factor=scale_array, recompute_scale_factor=False)
self.run_test(MyModel(), x)
def _interpolate_script(self, x, mode, use_size, is_upsample, align_corners=False):
class MyModel(torch.jit.ScriptModule):
__constants__ = ['mode', 'use_size', 'is_upsample', 'size', 'scale', 'size_array', 'scale_array', 'align_corners']
def __init__(self, mode, use_size, is_upsample, align_corners):
super(MyModel, self).__init__()
self.mode = mode
self.use_size = use_size
self.is_upsample = is_upsample
self.align_corners = align_corners
self.scale = 2.0 if self.is_upsample else 0.5
self.size = 24 if self.is_upsample else 2
if x.dim() == 3:
self.scale_array = [2.3]
self.size_array = [16]
elif x.dim() == 4:
self.scale_array = [2.3, 3.1]
self.size_array = [16, 32]
else:
self.scale_array = [2.3, 3.1, 4.6]
self.size_array = [16, 32, 64]
@torch.jit.script_method
def forward(self, x):
if self.use_size:
if self.align_corners:
return torch.nn.functional.interpolate(x, mode=self.mode, size=self.size, align_corners=True), \
torch.nn.functional.interpolate(x, mode=self.mode, size=self.size_array, align_corners=True)
return torch.nn.functional.interpolate(x, mode=self.mode, size=self.size), \
torch.nn.functional.interpolate(x, mode=self.mode, size=self.size_array)
if self.align_corners:
return torch.nn.functional.interpolate(x, mode=self.mode,
scale_factor=self.scale, recompute_scale_factor=False), \
torch.nn.functional.interpolate(x, mode=self.mode,
scale_factor=self.scale_array, recompute_scale_factor=False)
return torch.nn.functional.interpolate(x, mode=self.mode,
scale_factor=self.scale, recompute_scale_factor=False), \
torch.nn.functional.interpolate(x, mode=self.mode,
scale_factor=self.scale_array, recompute_scale_factor=False)
model = MyModel(mode, use_size, is_upsample, align_corners)
self.run_test(model, x, atol=1e-6)
def _interpolate_tests(self, is_upsample):
# - cubic mode is not supported for opsets below 11;
# - linear mode does not match for opsets below 11;
modes = ["nearest", "linear", "bicubic"]
if self.opset_version < 11:
modes = ["nearest"]
x = [torch.randn(1, 2, 6, requires_grad=True),
torch.randn(1, 2, 4, 6, requires_grad=True),
torch.randn(1, 2, 4, 4, 6, requires_grad=True)]
for mode in modes:
for xi in x:
mode_i = mode
# TODO: enable bicubic downsample when ORT precision loss fixed
if mode == "bicubic" and xi.dim() != 4:
continue
elif mode == "linear":
if xi.dim() == 3:
# TODO : enable when linear mode is implemented for 1d inputs in ORT
continue
elif xi.dim() == 4:
mode_i = "bilinear"
elif xi.dim() == 5:
# TODO : enable when linear mode is implemented for 3d inputs in ORT
mode_i = "trilinear"
continue
self._interpolate(xi, mode_i, True, is_upsample)
# test with align_corners if supported
if mode != 'nearest':
self._interpolate(xi, mode_i, True, is_upsample, True)
self._interpolate_script(xi, mode_i, True, is_upsample, True)
# the following cases, require dynamic sizes/scales,
# which which is not supported for opset_version < 9
if self.opset_version >= 9:
self._interpolate_script(xi, mode_i, True, is_upsample)
self._interpolate(xi, mode_i, False, is_upsample)
# test with align_corners if supported
if mode != 'nearest':
self._interpolate(xi, mode_i, False, is_upsample, True)
self._interpolate_script(xi, mode_i, False, is_upsample, True)
self._interpolate_script(xi, mode_i, False, is_upsample)
@disableScriptTest()
def test_interpolate_upsample(self):
self._interpolate_tests(True)
@disableScriptTest()
@skipIfUnsupportedMinOpsetVersion(9)
def test_interpolate_function_substitution(self):
class ScriptModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return torch.nn.functional.interpolate(x, mode="nearest", scale_factor=2.)
class ScriptModule(torch.jit.ScriptModule):
def __init__(self):
super(ScriptModule, self).__init__()
self.submodule = ScriptModel()
@torch.jit.script_method
def forward(self, input):
return self.submodule(input)
x = torch.randn(1, 2, 4, 4, 6)
self.run_test(ScriptModule(), (x,))
@torch.jit.script
def script_method(x):
return torch.nn.functional.interpolate(x, mode="nearest", scale_factor=2.)
class TracingModule(torch.nn.Module):
def forward(self, x):
return script_method(x)
self.run_test(TracingModule(), (x,))
@skipIfUnsupportedMinOpsetVersion(10)
@disableScriptTest()
def test_interpolate_downsample(self):
self._interpolate_tests(False)
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
def test_interpolate_no_shape(self):
class MyModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, y):
x = torch.add(x, x)
out1 = torch.nn.functional.interpolate(x, mode="bilinear", size=(16, 16), align_corners=False)
out2 = torch.nn.functional.interpolate(x, mode="nearest", size=(int(y.size(0)), int(y.size(1))))
return out1, out2
x = torch.randn(1, 2, 4, 4, requires_grad=True)
y = torch.randn(16, 16, requires_grad=True)
self.run_test(MyModel(), (x, y))
@disableScriptTest()
def test_groupnorm(self):
model = torch.nn.GroupNorm(3, 6, 0.002)
x = torch.randn(4, 6, 180, 180, 180)
self.run_test(model, x)
model = torch.nn.GroupNorm(1, 6, 0.002)
x = torch.randn(4, 6, 180, 180)
self.run_test(model, x)
model = torch.nn.GroupNorm(6, 6, 0.002)
x = torch.randn(4, 6, 180, 180)
self.run_test(model, x)
@disableScriptTest()
def test_groupnorm_noaffine(self):
model = torch.nn.GroupNorm(4, 8, 0.002, affine=False)
x = torch.randn(3, 8, 224, 224)
self.run_test(model, x)
model = torch.nn.GroupNorm(1, 6, 0.002, affine=False)
x = torch.randn(4, 6, 180, 180)
self.run_test(model, x)
model = torch.nn.GroupNorm(6, 6, 0.002, affine=False)
x = torch.randn(4, 6, 180, 180)
self.run_test(model, x)
def test_std(self):
class StandardDeviation(torch.nn.Module):
def forward(self, input):
return torch.std(input, unbiased=False)
x = torch.randn(2, 3, 4)
model = StandardDeviation()
self.run_test(model, x)
def test_pow(self):
class PowModule(torch.nn.Module):
def forward(self, x, y):
return x.pow(y)
x = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 4)
self.run_test(PowModule(), (x, y))
x = torch.randint(10, (2, 3, 4))
y = torch.randint(10, (2, 3, 4)).to(dtype=torch.int32)
self.run_test(PowModule(), (x, y))
x = torch.randint(10, (2, 3, 4))
y = torch.randint(10, (2, 3, 4))
self.run_test(PowModule(), (x, y))
x = torch.randn(2, 3, 4).to(dtype=torch.float64)
y = torch.randint(10, (2, 3, 4))
self.run_test(PowModule(), (x, y))
def test_std_along_dims(self):
class StandardDeviation(torch.nn.Module):
def forward(self, input):
return torch.std(input, dim=(0, 1), unbiased=False)
x = torch.randn(2, 3, 4)
model = StandardDeviation()
self.run_test(model, x)
def test_std_keepdim(self):
class StandardDeviation(torch.nn.Module):
def forward(self, input):
return torch.std(input, dim=(0, 1), unbiased=False, keepdim=True)
x = torch.randn(2, 3, 4)
model = StandardDeviation()
self.run_test(model, x)
def test_bitshift(self):
class BitshiftModel(torch.nn.Module):
def forward(self, input, input2):
return input >> 1, input << 3.1, \
input2 >> torch.tensor([1, 2]), input2 << 4.2
input = torch.arange(24, dtype=torch.float32).reshape(3, 4, 2)
input2 = torch.arange(24, dtype=torch.int64).reshape(3, 4, 2)
self.run_test(BitshiftModel(), (input, input2))
def test_bitshift_other_fp(self):
class BitshiftModel(torch.nn.Module):
def forward(self, input):
return input << 2.4
input = torch.arange(24, dtype=torch.int64).reshape(3, 4, 2)
self.run_test(BitshiftModel(), input)
# uint8 not implemented in ORT for Mul used in
# exporting bitshift for opset_version < 10
@skipIfUnsupportedMinOpsetVersion(11)
def test_bitshift_uint8(self):
class BitshiftModel(torch.nn.Module):
def forward(self, input, input2):
return input >> 1, input << 3., \
input2 >> torch.tensor([1, 2], dtype=torch.uint8), input2 << 4.
input = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
self.run_test(BitshiftModel(), (input, input2))
def test_narrow(self):
class NarrowModel(torch.nn.Module):
def forward(self, input):
return torch.narrow(input, 0, 0, 2)
x = torch.randn(3, 3, requires_grad=True)
self.run_test(NarrowModel(), x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_narrow_dynamic(self):
class NarrowModel(torch.nn.Module):
def forward(self, input):
return torch.narrow(input, 0, 0, input.shape[0] - 1)
x = torch.randn(3, 3, requires_grad=True)
self.run_test(NarrowModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_index_fill(self):
class IndexFillModel(torch.nn.Module):
def forward(self, input):
index = torch.tensor([2, 0])
return input.index_fill(2, index, -1)
x = torch.randn(3, 4, 5, requires_grad=True)
self.run_test(IndexFillModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_index_copy(self):
class IndexCopyModel(torch.nn.Module):
def forward(self, input):
index = torch.tensor([2, 0])
source = torch.ones(3, 2, 5)
return input.index_copy(1, index, source)
x = torch.randn(3, 4, 5, requires_grad=True)
self.run_test(IndexCopyModel(), x)
def test_select(self):
class Select(torch.nn.Module):
def forward(self, x):
return x[:, 1]
x = torch.randn(3, 4)
self.run_test(Select(), x)
def test_select_negative_index(self):
class Select(torch.nn.Module):
def forward(self, x):
return x[:, -1]
x = torch.randn(3, 4)
self.run_test(Select(), x)
# TODO: enable for opset 10 when ONNXRuntime version will be updated
def test_index_select_constant_scaler_index(self):
class IndexSelectScalerIndexModel(torch.nn.Module):
def forward(self, x):
index = 2
return torch.index_select(x, 1, torch.tensor(index))
x = torch.randn(3, 4)
self.run_test(IndexSelectScalerIndexModel(), x)
def test_index_select_scaler_index(self):
class IndexSelectScalerIndexModel(torch.nn.Module):
def __init__(self, index_base):
super(IndexSelectScalerIndexModel, self).__init__()
self.index_base = torch.tensor(index_base)
def forward(self, x, index_offset):
index = self.index_base + index_offset
return torch.index_select(x, 1, index)
x = torch.randn(3, 4)
offset = 2
index_offset = torch.tensor(offset)
base = 1
self.run_test(IndexSelectScalerIndexModel(base), (x, index_offset))
def test_take(self):
class TakeModel(torch.nn.Module):
def forward(self, x, y):
return torch.take(x, y)
x = torch.randn(6, 4, 3, 3)
y = torch.tensor([4, 1, 7, 15, 63])
self.run_test(TakeModel(), (x, y))
def test_topk(self):
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.topk(x, 3)
x = torch.arange(1., 6., requires_grad=True)
self.run_test(MyModule(), x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_topk_smallest_unsorted(self):
class MyModule(torch.nn.Module):
def forward(self, x, k):
# When sorted=False, order of elements in the outout tensors
# are not expected to match between PyTorch and ORT
topk_unsorted = torch.topk(x, k, largest=False, sorted=False)
topk_sorted = torch.topk(x, k, largest=False, sorted=True)
return topk_sorted, torch.sort(topk_unsorted.values).values
x = torch.arange(1., 6., requires_grad=True)
k = torch.tensor(3)
self.run_test(MyModule(), (x, k))
@skipIfUnsupportedMinOpsetVersion(10)
def test_topk_script(self):
class MyModuleDynamic(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, k):
return torch.topk(x, k)
x = torch.arange(1., 6., requires_grad=True)
k = torch.tensor(3)
self.run_test(MyModuleDynamic(), [x, k])
@skipIfUnsupportedOpsetVersion([7])
def test_normalize(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.normalize(x)
x = torch.randn(3, 3)
self.run_test(Model(), x)
def test_layer_norm(self):
model = torch.nn.LayerNorm([10, 10])
x = torch.randn(20, 5, 10, 10)
self.run_test(model, x)
def test_batchnorm1d(self):
x = torch.randn(10, 10)
model = torch.nn.BatchNorm1d(10, affine=True)
self.run_test(model, x)
x = torch.randn(10, 10, 128)
self.run_test(model, x)
def test_batchnorm1d_noaffine(self):
x = torch.randn(10, 10)
model = torch.nn.BatchNorm1d(10, affine=False)
self.run_test(model, x)
x = torch.randn(10, 10, 128)
self.run_test(model, x)
def test_batchnorm2d(self):
x = torch.randn(10, 3, 128, 128)
model = torch.nn.BatchNorm2d(3, affine=True)
self.run_test(model, x)
def test_batchnorm2d_noaffine(self):
x = torch.randn(10, 3, 128, 128)
model = torch.nn.BatchNorm2d(3, affine=False)
self.run_test(model, x)
def test_batchnorm3d(self):
x = torch.randn(10, 3, 128, 128, 128)
model = torch.nn.BatchNorm3d(3, affine=True)
self.run_test(model, x)
def test_batchnorm3d_noaffine(self):
x = torch.randn(10, 3, 128, 128, 128)
model = torch.nn.BatchNorm3d(3, affine=False)
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_scatter_with_scalar(self):
class ScatterModel(torch.nn.Module):
def forward(self, input, indices):
values = 1.0
return input.scatter(1, indices, values)
input = torch.tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=torch.float64)
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
self.run_test(ScatterModel(), input=(input, indices))
@skipIfUnsupportedMinOpsetVersion(9)
def test_scatter_with_scalar_different_types(self):
# Tests the case when scalar src (updates values) type is different
# from self type. Happens only with scalar src - PyTorch does not
# allow this when src is a tensor.
class ScatterModel(torch.nn.Module):
def forward(self, input, indices):
values = 1.0
return input.scatter(1, indices, values)
input = torch.tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=torch.float32)
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
self.run_test(ScatterModel(), input=(input, indices))
@skipIfUnsupportedMinOpsetVersion(9)
def test_scatter(self):
class ScatterModel(torch.nn.Module):
def forward(self, input, indices, values):
return input.scatter(1, indices, values)
input = torch.tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]])
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
self.run_test(ScatterModel(), input=(input, indices, values))
input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
indices = torch.tensor([[1, 0], [0, 2], [0, 1]], dtype=torch.int64)
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
self.run_test(ScatterModel(), (input, indices, values))
input = torch.zeros(3, 4, 5, 6)
indices = torch.tensor([[1, 0], [0, 2], [0, 1]], dtype=torch.int64)
indices = indices.view(3, 2, 1, 1).expand(3, 2, 5, 6)
values = torch.arange(3 * 2 * 5 * 6, dtype=torch.float32).view(3, 2, 5, 6)
self.run_test(ScatterModel(), (input, indices, values))
input = torch.zeros(3, 4, 2)
indices = torch.tensor([[[1, 0], [0, 2]], [[1, 1], [0, 1]], [[2, 1], [2, 2]]])
values = torch.arange(3 * 2 * 2, dtype=torch.float32).view(3, 2, 2)
self.run_test(ScatterModel(), (input, indices, values))
@skipIfUnsupportedMinOpsetVersion(9)
def test_scatter_add(self):
class ScatterModel(torch.nn.Module):
def forward(self, input, indices, values):
return input.scatter_add(1, indices, values)
input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
self.run_test(ScatterModel(), input=(input, indices, values))
@skipIfUnsupportedMinOpsetVersion(9)
def test_one_hot(self):
class OneHot(torch.nn.Module):
def __init__(self, num_classes):
super().__init__()
self.num_classes = num_classes
def forward(self, x):
return torch.nn.functional.one_hot(x, self.num_classes)
x = torch.arange(10)
self.run_test(OneHot(15), (x))
@skipIfUnsupportedMinOpsetVersion(9)
def test_gather(self):
class GatherModel(torch.nn.Module):
def forward(self, input, indices):
return input.gather(1, indices)
input = torch.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
self.run_test(GatherModel(), input=(input, indices))
@skipIfUnsupportedMinOpsetVersion(9)
def test_expand(self):
class ExpandModel(torch.nn.Module):
def forward(self, input):
return input.expand(2, 3, -1)
input = torch.randn(2, 1, 4)
self.run_test(ExpandModel(), input=(input))
class ExpandInferDimModel(torch.nn.Module):
def forward(self, input):
return input.expand(-1, input.size(0))
input = torch.randn(3, 1)
self.run_test(ExpandInferDimModel(), input=(input))
class ExpandTensorSizeModel(torch.nn.Module):
def forward(self, input, size):
return input.expand(size)
input = torch.randn(3,)
size = torch.tensor(-1)
self.run_test(ExpandTensorSizeModel(), input=(input, size))
def test_multinomial(self):
class Multinomial(torch.nn.Module):
def forward(self, weight):
return torch.multinomial(weight, 3, replacement=True)
class MultinomialNoReplacement(torch.nn.Module):
def forward(self, weight):
return torch.multinomial(weight, 1)
weight = torch.tensor([[0, 10, 0, 0], [0, 0, 100, 0]], dtype=torch.float)
self.run_test(Multinomial(), (weight,))
self.run_test(MultinomialNoReplacement(), (weight,))
def _test_reduced_ops(self, op):
class ReducedOpModule(torch.nn.Module):
def forward(self, input):
return op(input, dim=-1)
if op != torch.mean: # torch.mean only supports float types
x = torch.randint(10, (4, 4), dtype=torch.uint8)
self.run_test(ReducedOpModule(), x)
x = torch.randint(10, (4, 4), dtype=torch.int8)
self.run_test(ReducedOpModule(), x)
x = torch.randint(10, (4, 4), dtype=torch.int16)
self.run_test(ReducedOpModule(), x)
x = torch.randint(10, (4, 4), dtype=torch.int32)
self.run_test(ReducedOpModule(), x)
x = torch.randint(10, (4, 4), dtype=torch.int64)
self.run_test(ReducedOpModule(), x)
# torch.mean only supports float types
# ORT does not support double ReduceProd for double
if op != torch.prod and op != torch.mean:
x = torch.randn(4, 5, dtype=torch.double)
self.run_test(ReducedOpModule(), x)
if op != torch.prod: # torch.prod not implemented for Half
x = torch.randn(4, 4, dtype=torch.half)
self.run_test(ReducedOpModule(), x)
x = torch.randn(4, 5, dtype=torch.float)
self.run_test(ReducedOpModule(), x)
def test_reduced_sum(self):
return self._test_reduced_ops(op=torch.sum)
def test_reduced_mean(self):
return self._test_reduced_ops(op=torch.mean)
def test_reduced_prod(self):
return self._test_reduced_ops(op=torch.prod)
def test_reduced_min_max(self):
class ReducedMinMaxModule(torch.nn.Module):
def forward(self, input):
return torch.min(input, dim=-1)[0], torch.max(input, dim=0)[0]
x = torch.randint(10, (4, 4), dtype=torch.int32)
self.run_test(ReducedMinMaxModule(), x)
x = torch.randint(10, (4, 4), dtype=torch.int64)
self.run_test(ReducedMinMaxModule(), x)
x = torch.randn(4, 5, dtype=torch.float)
self.run_test(ReducedMinMaxModule(), x)
def test_reduce_log_sum_exp(self):
class ReduceLogSumExpModel(torch.nn.Module):
def forward(self, input):
a = torch.logsumexp(input, dim=0)
b = torch.logsumexp(input, dim=(0, 1))
return a + b
x = torch.randn(4, 4, requires_grad=True)
self.run_test(ReduceLogSumExpModel(), x)
def test_softmax(self):
for i in range(-4, 3):
model = torch.nn.Softmax(dim=i)
input = torch.randn(3, 4, 5, 6)
self.run_test(model, input)
class SoftmaxUnknownRank(torch.nn.Module):
def __init__(self, i):
super().__init__()
self.softmax = torch.nn.Softmax(dim=i)
def forward(self, x):
return self.softmax(x.reshape(3, 4, 5, 6))
model = torch.jit.script(SoftmaxUnknownRank(i))
self.run_test(model, input)
def test_softmax_large_values(self):
input = torch.tensor([[-1e12, -1e12, -1e12], [1e12, 0.0, -5.0], [3.0, 4.0, 5.0]])
for i in range(-2, 1):
model = torch.nn.Softmax(dim=i)
self.run_test(model, input)
class SoftmaxUnknownRank(torch.nn.Module):
def __init__(self, i):
super().__init__()
self.softmax = torch.nn.Softmax(dim=i)
def forward(self, x):
return self.softmax(x.reshape(3, 3))
model = torch.jit.script(SoftmaxUnknownRank(i))
self.run_test(model, input)
def test_logsoftmax(self):
for i in range(7)[2:]:
model = torch.nn.LogSoftmax(dim=i - 1)
dims = [2] * (i - 2) + [3, 4]
input = torch.ones(*dims, requires_grad=True)
self.run_test(model, input)
def test_logsoftmax_dim(self):
for i in range(-4, 3):
model = torch.nn.LogSoftmax(dim=i)
input = torch.randn(3, 4, 5, 6)
self.run_test(model, input)
@skipIfUnsupportedMinOpsetVersion(9)
@disableScriptTest() # scripting prim_dtype
def test_lstm_no_hidden(self):
class LSTMModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.rnn = torch.nn.LSTM(input_size=16, hidden_size=16)
def forward(self, x):
return self.rnn(x)
input = torch.randn((10, 16, 16))
self.run_test(LSTMModel(), (input,))
@skipIfUnsupportedMinOpsetVersion(9)
@disableScriptTest()
def test_lstm(self):
model = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
self.run_test(model, (input, (h0, c0)))
@skipIfUnsupportedMinOpsetVersion(9)
@disableScriptTest()
def test_lstm_default_init_state(self):
model = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
self.run_test(model, input)
@skipIfUnsupportedMinOpsetVersion(9)
@disableScriptTest() # LSTMModel model not scriptable
def test_lstm_fixed_batch_size(self):
class LSTMModel(torch.nn.Module):
def __init__(self):
super(LSTMModel, self).__init__()
self.lstm = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)
def forward(self, input):
batch_size = input.size()[1]
h0_np = np.ones([1, batch_size, RNN_HIDDEN_SIZE]).astype(np.float32)
c0_np = np.ones([1, batch_size, RNN_HIDDEN_SIZE]).astype(np.float32)
h0 = torch.from_numpy(h0_np)
c0 = torch.from_numpy(c0_np)
return self.lstm(input, (h0, c0))
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
# verify with different input of same batch size
input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
self.run_test(LSTMModel(), input, fixed_batch_size=True, test_with_inputs=[input2])
@skipIfUnsupportedMinOpsetVersion(9)
@disableScriptTest()
def test_lstm_post_fix_init_state(self):
class LSTMModel(torch.nn.Module):
def __init__(self):
super(LSTMModel, self).__init__()
self.lstm = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE,
1, bidirectional=False)
def forward(self, input):
batch_size = input.size()[1]
h0_np = np.ones([1, batch_size, RNN_HIDDEN_SIZE]).astype(np.float32)
c0_np = np.ones([1, batch_size, RNN_HIDDEN_SIZE]).astype(np.float32)
h0 = torch.from_numpy(h0_np)
c0 = torch.from_numpy(c0_np)
return self.lstm(input, (h0, c0))
model = LSTMModel()
input = torch.randn(RNN_SEQUENCE_LENGTH, 1, RNN_INPUT_SIZE)
# verify with different input of different batch size
input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
self.run_test(model, input, dynamic_axes={'input' : {0 : 'seq', 1 : 'batch'}},
test_with_inputs=[input2])
@disableScriptTest()
def test_lstm_constant_folding(self):
class LstmNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_layers, bidirectional):
super(LstmNet, self).__init__()
self.lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional)
def forward(self, input, initial_state):
return self.lstm(input, initial_state)
def get_LstmNet_model_and_inputs(input_size, hidden_size, num_layers, batch_size,
seq_len, bidirectional):
num_directions = 2 if bidirectional else 1
model = LstmNet(input_size, hidden_size, num_layers, bidirectional)
input = torch.randn(seq_len, batch_size, input_size)
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
c0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
return model, (input, (h0, c0))
batch_size1 = 3
model1, input1 = get_LstmNet_model_and_inputs(7, 3, 2, batch_size1, 5, True)
self.run_test(model1, input1, do_constant_folding=True)
batch_size2 = 4
model2, input2 = get_LstmNet_model_and_inputs(5, 4, 3, batch_size2, 7, False)
self.run_test(model2, input2, do_constant_folding=True)
@skipIfUnsupportedMinOpsetVersion(9)
@disableScriptTest()
def test_lstm_no_bias(self):
class LstmNet(torch.nn.Module):
def __init__(self, num_layers, bidirectional):
super(LstmNet, self).__init__()
self.lstm = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, num_layers, bias=False, bidirectional=bidirectional)
def forward(self, input, initial_state):
return self.lstm(input, initial_state)
def get_LstmNet_model_and_inputs(num_layers, bidirectional):
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
num_directions = 2 if bidirectional else 1
model = LstmNet(num_layers, bidirectional)
h0 = torch.randn(num_layers * num_directions, BATCH_SIZE, RNN_HIDDEN_SIZE)
c0 = torch.randn(num_layers * num_directions, BATCH_SIZE, RNN_HIDDEN_SIZE)
return model, (input, (h0, c0))
num_layers = [1, 1, 2, 3]
bidirectional = [True, False, True, False]
models_and_inputs = [get_LstmNet_model_and_inputs(n, b) for n, b in zip(num_layers, bidirectional)]
for model, input in models_and_inputs:
self.run_test(model, input)
@disableScriptTest()
def test_rnn_no_bias(self):
def make_model(layers, packed_sequence):
batch_first = True if packed_sequence == 2 else False
model = torch.nn.RNN(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers, bidirectional=False,
batch_first=batch_first, bias=False)
if packed_sequence == 1:
model = RnnModelWithPackedSequence(model, False)
if packed_sequence == 2:
model = RnnModelWithPackedSequence(model, True)
return model
def make_input(batch_size, layers, packed_sequence):
batch_first = True if packed_sequence == 2 else False
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
inputs = [inputs]
h0 = torch.randn(layers, batch_size, RNN_HIDDEN_SIZE)
inputs.append(h0)
if packed_sequence != 0:
inputs.append(torch.IntTensor(seq_lengths))
if len(inputs) == 1:
input = inputs[0]
else:
input = tuple(inputs)
return input
layers = [1, 3, 1, 3, 1, 3]
packed_sequence = [0, 0, 1, 1, 2, 2]
models = [make_model(l, p) for l, p in zip(layers, packed_sequence)]
inputs = [make_input(RNN_BATCH_SIZE, l, p) for l, p in zip(layers, packed_sequence)]
for model, input in zip(models, inputs):
self.run_test(model, input, batch_size=RNN_BATCH_SIZE)
def test_gru_no_bias(self):
class GruNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_layers, bidirectional):
super(GruNet, self).__init__()
self.mygru = torch.nn.GRU(input_size, hidden_size, num_layers, bidirectional=bidirectional, bias=False)
def forward(self, input, initial_state):
out = self.mygru(input, initial_state)
return out
def get_GruNet_model_and_inputs(input_size, hidden_size, num_layers, batch_size,
seq_len, bidirectional):
num_directions = 2 if bidirectional else 1
model = GruNet(input_size, hidden_size, num_layers, bidirectional)
input = torch.randn(seq_len, batch_size, input_size)
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
return model, (input, h0)
input_size = [7, 5]
hidden_size = [3, 4]
num_layers = [2, 3]
batch_size = [3, 4]
seq_len = [5, 7]
bidirectional = [True, False]
models_and_inputs = [get_GruNet_model_and_inputs(i, h, n, b, s, bi)
for i, h, n, b, s, bi in zip(input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional)]
for model, input in models_and_inputs:
self.run_test(model, input, do_constant_folding=True)
def test_gru_constant_folding(self):
class GruNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_layers, bidirectional):
super(GruNet, self).__init__()
self.mygru = torch.nn.GRU(input_size, hidden_size, num_layers, bidirectional=bidirectional)
def forward(self, input, initial_state):
out = self.mygru(input, initial_state)
return out
def get_GruNet_model_and_inputs(input_size, hidden_size, num_layers, batch_size,
seq_len, bidirectional):
num_directions = 2 if bidirectional else 1
model = GruNet(input_size, hidden_size, num_layers, bidirectional)
input = torch.randn(seq_len, batch_size, input_size)
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
return model, (input, h0)
batch_size1 = 3
model1, input1 = get_GruNet_model_and_inputs(7, 3, 2, batch_size1, 5, True)
self.run_test(model1, input1, do_constant_folding=True)
batch_size2 = 4
model2, input2 = get_GruNet_model_and_inputs(5, 4, 3, batch_size2, 7, False)
self.run_test(model2, input2, do_constant_folding=True)
@skipIfUnsupportedMinOpsetVersion(8)
def test_max_tensors(self):
class MaxModel(torch.nn.Module):
def forward(self, input, other):
return torch.max(input, other)
model = MaxModel()
x = torch.randn(4, 4, requires_grad=True)
y = torch.randn(4, 1, requires_grad=True)
self.run_test(model, (x, y))
@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_end(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_test(ArangeScript(), x)
class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
self.run_test(ArangeModel(), x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_arange_end_notype(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(a.size(0))
x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_test(ArangeScript(), x)
class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(a.size(0))
self.run_test(ArangeModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_start_end(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
x = torch.randn(3, 4, requires_grad=True)
self.run_test(ArangeScript(), x)
class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
self.run_test(ArangeModel(), x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_arange_start_end_notype(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(2.7, a.size(0) + 2).view(-1, 1) + a
x = torch.randn(3, 4, requires_grad=True)
self.run_test(ArangeScript(), x)
class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(2.7, a.size(0) + 2).view(-1, 1) + a
self.run_test(ArangeModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_start_end_step(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
x = torch.randn(3, 4, requires_grad=True)
self.run_test(ArangeScript(), x)
class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
self.run_test(ArangeModel(), x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_arange_start_end_step_notype(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(2.7, a.size(0) * a.size(1) + 2, a.size(1)).view(-1, 1) + a
x = torch.randn(3, 4, requires_grad=True)
self.run_test(ArangeScript(), x)
class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(2.7, a.size(0) * a.size(1) + 2, a.size(1)).view(-1, 1) + a
self.run_test(ArangeModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test__dim_arange(self):
class DimArange(torch.nn.Module):
def forward(self, input):
return torch._dim_arange(input, 1)
x = torch.ones(5, 6)
self.run_test(DimArange(), x)
def _test_compare_ops(self, model, num_inputs):
x_float = torch.randn(1, 2, 3, 4, requires_grad=True)
x_int = torch.randint(10, (3, 4), dtype=torch.int32)
if num_inputs > 1:
y_float = torch.randn(1, 2, 3, 4, requires_grad=True)
y_int = torch.randint(10, (3, 4), dtype=torch.int32)
self.run_test(model, (x_float, y_float))
self.run_test(model, (x_float, y_int))
self.run_test(model, (x_int, y_float))
self.run_test(model, (x_int, y_int))
else:
self.run_test(model, x_float)
self.run_test(model, x_int)
def test_gt(self):
class GreaterModel(torch.nn.Module):
def forward(self, input, other):
return input > other
self._test_compare_ops(GreaterModel(), 2)
@skipIfUnsupportedMinOpsetVersion(9)
def test_ge(self):
class GreaterOrEqualModel(torch.nn.Module):
def forward(self, input, other):
return input >= other
self._test_compare_ops(GreaterOrEqualModel(), 2)
def test_gt_scalar(self):
class GreaterModel(torch.nn.Module):
def forward(self, input):
return input > 1
self._test_compare_ops(GreaterModel(), 1)
@skipIfUnsupportedMinOpsetVersion(9)
def test_ge_scalar(self):
class GreaterOrEqualModel(torch.nn.Module):
def forward(self, input):
return input >= 1
self._test_compare_ops(GreaterOrEqualModel(), 1)
def test_lt(self):
class LessModel(torch.nn.Module):
def forward(self, input, other):
return input > other
self._test_compare_ops(LessModel(), 2)
@skipIfUnsupportedMinOpsetVersion(9)
def test_le(self):
class LessOrEqualModel(torch.nn.Module):
def forward(self, input, other):
return input <= other
self._test_compare_ops(LessOrEqualModel(), 2)
def test_lt_scalar(self):
class LessModel(torch.nn.Module):
def forward(self, input):
return input < 1
self._test_compare_ops(LessModel(), 1)
@skipIfUnsupportedMinOpsetVersion(9)
def test_le_scalar(self):
class LessOrEqualModel(torch.nn.Module):
def forward(self, input):
return input <= 1
self._test_compare_ops(LessOrEqualModel(), 1)
def test_matmul(self):
class MatmulModel(torch.nn.Module):
def forward(self, input, other):
return torch.matmul(input, other)
x = torch.randn(3, 4, requires_grad=True)
y = torch.randn(4, 5, requires_grad=True)
self.run_test(MatmulModel(), (x, y))
x = torch.randint(10, (3, 4))
y = torch.randint(10, (4, 5))
self.run_test(MatmulModel(), (x, y))
def test_matmul_batch(self):
class MatmulModel(torch.nn.Module):
def forward(self, input, other):
return torch.matmul(input, other)
x = torch.randn(2, 3, 4, requires_grad=True)
y = torch.randn(2, 4, 5, requires_grad=True)
self.run_test(MatmulModel(), (x, y))
x = torch.randint(10, (2, 3, 4))
y = torch.randint(10, (2, 4, 5))
self.run_test(MatmulModel(), (x, y))
def _argmin_argmax_model(self, input):
class ArgminArgmaxModel(torch.nn.Module):
def forward(self, input):
return torch.argmin(input), \
torch.argmax(input), \
torch.argmin(input, keepdim=True), \
torch.argmax(input, keepdim=True)
self.run_test(ArgminArgmaxModel(), input)
def test_argmin_argmax(self):
input = torch.randn(7, 3, 5)
self._argmin_argmax_model(input)
# Argmin and Argmax with "select_last_index" is not supprted before opset 12
# "select_last_index" was added in opset 12 to deal with corner case where the
# same value appears multiple times in the tensor
@skipIfUnsupportedMinOpsetVersion(12)
def test_argmin_argmax_select_last_index(self):
input = torch.tensor([[1., 2., 3.],
[1., 1., 2.]])
self._argmin_argmax_model(input)
input = torch.ones(7, 3, 5)
self._argmin_argmax_model(input)
def test_repeat(self):
class RepeatModel(torch.nn.Module):
def forward(self, x, y):
x2 = x.repeat(y.shape[0], 1)
y1 = y.view(-1, 1)
return x2 + y1
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 8, 9])
self.run_test(RepeatModel(), (x, y))
def test_view(self):
class ViewModel(torch.nn.Module):
def forward(self, input):
return input.view(4, 24)
x = torch.randint(10, (4, 2, 3, 4), dtype=torch.int32)
self.run_test(ViewModel(), x)
def test_view_dynamic(self):
class ViewModel(torch.nn.Module):
def forward(self, input, other):
return input.view(other.shape)
x = torch.randn(2, 3, 4)
shape = torch.randn(6, 4)
self.run_test(ViewModel(), (x, shape))
def test_view_dynamic_zero_dim(self):
class ViewModel(torch.nn.Module):
def forward(self, input):
input = input.view(-1, 2)
return input.view(1, -1)
x = torch.ones(2)
another_x = torch.empty((0,))
self.run_test(ViewModel(), x, test_with_inputs=[another_x],
input_names=['input_1'], dynamic_axes={'input_1': [0, ]})
def test_view_as(self):
class ViewModel(torch.nn.Module):
def forward(self, input, other):
return input.view_as(other)
x = torch.randn(2, 3, 4)
y = torch.randn(6, 4)
self.run_test(ViewModel(), (x, y))
@disableScriptTest() # ONNX Shape inference failure in if/else block for Gemm
def test_weight_norm(self):
model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=1)
x = torch.randn(3, 4, 5, requires_grad=True)
self.run_test(model, x)
model = torch.nn.utils.weight_norm(torch.nn.Conv1d(1, 1, 3))
x = torch.randn(1, 1, 5, requires_grad=True)
self.run_test(model, x)
model = torch.nn.utils.weight_norm(torch.nn.Conv1d(1, 1, 3), dim=-2)
x = torch.randn(1, 1, 5, requires_grad=True)
self.run_test(model, x)
model = torch.nn.utils.weight_norm(torch.nn.Conv1d(3, 6, 3), name='weight')
x = torch.randn(3, 3, 5, requires_grad=True)
self.run_test(model, x)
@disableScriptTest() # ONNX Shape inference failure in if/else block for Gemm
def test_weight_norm_nodim(self):
model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=None)
x = torch.randn(3, 4, 5, requires_grad=True)
self.run_test(model, x)
def test_flatten(self):
class FlattenModel(torch.nn.Module):
def forward(self, input):
return torch.flatten(input)
x = torch.randint(10, (1, 2, 3, 4))
self.run_test(FlattenModel(), x)
def test_flatten2d(self):
class FlattenModel(torch.nn.Module):
def forward(self, input):
return torch.flatten(input, 1)
x = torch.randint(10, (1, 2, 3, 4))
self.run_test(FlattenModel(), x)
def test_flatten2d_neg(self):
class FlattenModel(torch.nn.Module):
def forward(self, x):
return torch.flatten(x, 1, -1), torch.flatten(x, 0, -2), torch.flatten(x, 1, -2)
x = torch.randint(10, (1, 2, 3, 4))
self.run_test(FlattenModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_flatten_dynamic_axes(self):
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.flatten(x, start_dim=2, end_dim=3)
batch_size = 3
x = torch.randn(batch_size, 5, 4, 5)
y = torch.randn(5, 5, 4, 5)
model = MyModule()
self.run_test(model, x, test_with_inputs=[y],
input_names=['input'],
output_names=['output'],
dynamic_axes={'input' : {0 : 'batch_size'},
'output' : {0 : 'batch_size'}})
@skipIfUnsupportedMinOpsetVersion(11)
def test_getitem(self):
class GetItemModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, y, z, ind):
# this will create prim::ListConstruct(x, y, z) + aten::__getitem__
arr = [x, y, z]
return arr[ind]
x = torch.randn(3, 4, 5)
y = torch.randn(1, 4, 5)
z = torch.randn(2, 4, 5)
ind = torch.tensor(1, dtype=torch.long)
self.run_test(GetItemModel(), (x, y, z, ind))
ind = torch.tensor(-2, dtype=torch.long)
self.run_test(GetItemModel(), (x, y, z, ind))
def test_unbind(self):
class UnbindModel(torch.nn.Module):
def forward(self, input):
_, out, _ = input.unbind()
return out
x = torch.randn(3, 4, 5)
self.run_test(UnbindModel(), x)
class UnbindModel2(torch.nn.Module):
def forward(self, input):
_, out, _, _ = input.unbind(1)
return out
x = torch.randn(3, 4, 5)
self.run_test(UnbindModel2(), x)
class UnbindModel3(torch.nn.Module):
def forward(self, input):
_, out, _, _ = input.unbind(-2)
return out
x = torch.randn(3, 4, 5)
self.run_test(UnbindModel3(), x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_len(self):
class LenModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
return len(input.unbind()) + input
x = torch.randn(4, 5)
self.run_test(LenModel(), x, input_names=['input'], dynamic_axes={'input': {0: 'seq'}},
test_with_inputs=(torch.randn(5, 5),))
@skipIfUnsupportedMinOpsetVersion(9)
def test_len_list(self):
class LenListModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
return torch.ones(len(input.shape))
x = torch.randn(4, 5)
self.run_test(LenListModel(), x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_unbind_dynamic(self):
class UnbindModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
return input.unbind()[1]
x = torch.randn(3, 4, 5)
self.run_test(UnbindModel(), x)
class UnbindModel2(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
return input.unbind(-1)[1]
x = torch.randn(3, 4, 5)
self.run_test(UnbindModel2(), x)
def test_split(self):
class SplitModel(torch.nn.Module):
def forward(self, input):
out1, out2, out3 = input.split([2, 1, 2])
return out1, out2, out3
x = torch.randn(5, 4, 3)
self.run_test(SplitModel(), x)
class SplitModel2(torch.nn.Module):
def forward(self, input):
out1, out2, out3 = input.split([2, 1, 1], -2)
return out1, out2, out3
x = torch.randn(5, 4, 3)
self.run_test(SplitModel2(), x)
class SplitModel3(torch.nn.Module):
def forward(self, input):
out1, out2, out3 = input.split([2, 1, 2])
return out3, out1
x = torch.randn(5, 4, 3)
self.run_test(torch.jit.script(SplitModel3()), x)
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
def test_split_size_as_list(self):
class SplitModel(torch.nn.Module):
def forward(self, input, split_sizes: List[int]):
out = []
split_list: List[torch.Tensor] = input.split(split_sizes)
for ob in split_list:
out.append(ob)
return torch.cat(out, dim=0)
x = torch.randn(6, 4, 3)
split_sizes = [torch.tensor(2), torch.tensor(4)]
self.run_test(SplitModel(), (x, split_sizes))
@skipIfUnsupportedMinOpsetVersion(11)
def test_split_size_with_slice(self):
class SplitModule(torch.nn.Module):
def forward(self, x, y, t):
splits = (x.size(1), y.size(1))
out, out2 = torch.split(t, splits, dim=1)
return out, out2
x = torch.randn(2, 3)
y = torch.randn(2, 4)
t = torch.randn(2, 7)
self.run_test(SplitModule(), (x, y, t))
@skipIfUnsupportedMinOpsetVersion(11)
def test_split_dynamic(self):
class SplitModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
return input.split(2)[1]
x = torch.randn(5, 4, 3)
self.run_test(SplitModel(), x)
class SplitModel2(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
return input.split(2, -3)[1]
x = torch.randn(5, 4, 3)
self.run_test(SplitModel2(), x)
def test_concat(self):
class ConcatModel(torch.nn.Module):
def forward(self, x, y, z):
return torch.cat((x, y, z))
x = torch.randn(3, 4, 5)
y = torch.randn(1, 4, 5)
z = torch.randn(2, 4, 5)
self.run_test(ConcatModel(), (x, y, z))
@skipIfUnsupportedMinOpsetVersion(11)
def test_concat_dynamic(self):
class ConcatDynamicModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return torch.cat(x.unbind())
x = torch.randn(4, 5, 6)
self.run_test(ConcatDynamicModel(), x)
def test_stack(self):
class StackModel(torch.nn.Module):
def forward(self, x, y, z):
return torch.stack((x, y, z), 1)
x = torch.randn(3, 4, 5)
y = torch.randn(3, 4, 5)
z = torch.randn(3, 4, 5)
self.run_test(StackModel(), (x, y, z))
@skipIfUnsupportedMinOpsetVersion(11)
def test_stack_dynamic(self):
class StackDynamicModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return torch.stack(x.unbind(), 1)
x = torch.randn(4, 5, 6)
self.run_test(StackDynamicModel(), x)
def test_loop_dynamic(self):
class LoopModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
for i in range(x.size(2)):
x = x + i
return x
model = LoopModel()
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
self.run_test(model, inputs)
@skipIfUnsupportedMinOpsetVersion(9)
def test_loop_nested(self):
class NestedLoopsModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
for i in range(5):
a = 0
while a < 4:
a += 1
x = x + a
return x
model = NestedLoopsModel()
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
self.run_test(model, inputs)
@skipIfUnsupportedMinOpsetVersion(11)
def test_loop_with_list(self):
class ListLoopModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
res = []
res1 = []
arr = x.split([3, 4, 1, 1, 2, 3, 2], 0)
res2 = torch.zeros(3, 4, dtype=torch.long)
res3 = []
res4 = []
for i in range(len(arr)):
res = res.append(arr[i].sum(0, False))
res1 = res1.append(arr[-1 - i].sum(0, False))
res2 += 1
res3 = res3 + [arr[i].sum(0, False)]
res4 += [arr[-1 - i].sum(0, False)]
return torch.stack(res), torch.stack(res1), res2, torch.stack(res3), torch.stack(res4)
model = ListLoopModel()
inputs = torch.randn(16)
self.run_test(model, inputs)
@skipIfONNXShapeInference(False)
@skipIfUnsupportedMinOpsetVersion(11)
def test_loop_transpose(self):
class LoopModel(torch.nn.Module):
def forward(self, x):
res = torch.zeros_like(x[0])
for i in range(x.size(0)):
res += x[0].transpose(0, 1)
return res
model = torch.jit.script(LoopModel())
x = torch.randn(5, 3, 3)
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_list(self):
class ListModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
tensors = x.unbind()
res = []
res.append(tensors[0])
res.append(tensors[1])
res.pop(1)
res.insert(0, tensors[1])
res.append(tensors[2])
res += [tensors[3], tensors[4]]
res = res + [tensors[5]]
return torch.ones(len(res))
model = ListModel()
inputs = torch.randn(16, 1)
self.run_test(model, inputs)
@skipIfUnsupportedMinOpsetVersion(9)
def test_tensor_factories(self):
class TensorFactory(torch.nn.Module):
def forward(self, x):
return torch.zeros(x.size()) + torch.ones(x.size())
x = torch.randn(2, 3, 4)
self.run_test(TensorFactory(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_tensor_factories_script(self):
class TensorFactory(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return torch.zeros(x.shape, dtype=torch.float) + torch.ones(x.shape, dtype=torch.float)
x = torch.randn(2, 3, 4)
self.run_test(TensorFactory(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_tensor_like_factories_script(self):
class TensorFactory(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
zeros = torch.zeros_like(x, dtype=torch.float, layout=torch.strided, device=torch.device('cpu'))
ones = torch.ones_like(x, dtype=torch.float, layout=torch.strided, device=torch.device('cpu'))
return zeros + ones
x = torch.randn(2, 3, 4)
self.run_test(TensorFactory(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_eye(self):
class TensorFactory(torch.nn.Module):
def forward(self, x):
return torch.eye(x.size()[1], 3), torch.eye(4, 4, dtype=torch.long), torch.eye(x.size()[1], 2, dtype=torch.long)
x = torch.randn(2, 3, 4)
another_x = torch.randn(5, 6, 7)
self.run_test(TensorFactory(), x, test_with_inputs=[another_x],
input_names=['input_1'], dynamic_axes={'input_1': [0, 1, 2]})
@skipIfUnsupportedMinOpsetVersion(9)
def test_inplace_zero(self):
class Zero_(torch.nn.Module):
def forward(self, x):
return x.zero_(), x
x = torch.randn(2, 3, 4)
self.run_test(Zero_(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_new_zeros(self):
class Zero_(torch.nn.Module):
def forward(self, x):
return x.new_zeros(x.shape[1:2]), x.new_zeros(x.shape[2:], dtype=torch.long)
x = torch.randn(2, 3, 4)
self.run_test(Zero_(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_list_pass(self):
class Slice(torch.nn.Module):
def forward(self, x, y):
return x.new_zeros(x.shape[2:] + y.shape[1:])
x = torch.randn(2, 3, 4, 5)
y = torch.randn(1, 2, 3, 4)
self.run_test(Slice(), (x, y))
class Size(torch.nn.Module):
def forward(self, x, y):
return x.new_zeros(x.shape + y.shape)
x = torch.randn(2, 3, 4)
y = torch.randn(1, 2, 3)
self.run_test(Size(), (x, y))
class Array(torch.nn.Module):
def forward(self, x, y):
arr1 = [x.shape[0], x.shape[1], 2]
arr2 = [y.shape[0], y.shape[1]]
return x.new_zeros(arr1 + arr2)
x = torch.randn(2, 3, 4)
y = torch.randn(1, 2, 3)
self.run_test(Array(), (x, y))
class List(torch.nn.Module):
def forward(self, x, y):
l1 = list(x.shape)
l2 = list(y.shape)
return x.new_zeros(l1 + l2)
x = torch.randn(2, 3, 4)
y = torch.randn(1, 2, 3)
self.run_test(List(), (x, y))
@skipIfUnsupportedMinOpsetVersion(9)
def test_new_empty(self):
class Emtpy(torch.nn.Module):
def forward(self, x):
return x.new_empty(x.shape[0]).fill_(0), x.new_empty(x.shape[0], dtype=torch.long) * 0
x = torch.randn(2, 3, 4)
self.run_test(Emtpy(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_new_full(self):
class Full(torch.nn.Module):
def forward(self, x):
return x.new_full(x.shape[1:2], 5), x.new_full(x.shape[0:1], 1.3, dtype=torch.long)
x = torch.randn(2, 3, 4)
self.run_test(Full(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_inplace_list(self):
class Arithmetic(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, y):
return torch.cat([x.add_(3), y.fill_(0)])
x = torch.randn(2, 3)
y = torch.randn(2, 3)
self.run_test(Arithmetic(), (x, y))
@skipIfUnsupportedMinOpsetVersion(9)
def test_inplace_fill(self):
class Fill_(torch.nn.Module):
def forward(self, x):
return x.fill_(3), x
x = torch.randn(2, 3, 4)
self.run_test(Fill_(), x)
def test_inplace_arithmetic(self):
class Arithmetic(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, y):
x.add_(3)
y.mul_(x)
return x, y
x = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 4)
self.run_test(Arithmetic(), (x, y))
@disableScriptTest()
def test_sort(self):
class SortModel(torch.nn.Module):
def forward(self, x):
out = []
for i in range(-2, 2):
out.append(torch.sort(x, dim=i, descending=True))
return out
x = torch.randn(3, 4)
self.run_test(SortModel(), x)
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
def test_sort_ascending(self):
class SortModel(torch.nn.Module):
def forward(self, x):
out = []
for i in range(-2, 2):
out.append(torch.sort(x, dim=i, descending=False))
return out
x = torch.randn(3, 4)
self.run_test(SortModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_masked_fill(self):
class MaskedFillModel(torch.nn.Module):
def forward(self, x):
mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.uint8)
return x.masked_fill(mask, 2)
x = torch.zeros(4, 2, 3, requires_grad=True)
self.run_test(MaskedFillModel(), x)
class MaskedFillModel2(torch.nn.Module):
def forward(self, x):
return x.masked_fill(x > 3, -1)
x = torch.arange(16).view(2, 2, 4).to(torch.float32)
self.run_test(MaskedFillModel2(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_masked_fill_inplace(self):
class MaskedFillModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.uint8)
x.masked_fill_(mask, 2)
return x
x = torch.zeros(4, 2, 3, requires_grad=True)
self.run_test(MaskedFillModel(), x)
class MaskedFillModel2(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
x.masked_fill_(x > 3, -1)
return x
x = torch.arange(16).view(2, 2, 4).to(torch.float32)
self.run_test(MaskedFillModel2(), x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_masked_scatter(self):
class MaskedScatterModel(torch.nn.Module):
def forward(self, x):
return torch.masked_scatter(x, x.ge(0.5), torch.ones(100, 100) * 5)
x = torch.randn(3, 4, 5, requires_grad=True)
self.run_test(MaskedScatterModel(), x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_masked_select(self):
class MaskedSelectModel(torch.nn.Module):
def forward(self, x):
return torch.masked_select(x, x.ge(0.5))
x = torch.randn(3, 4, 5, requires_grad=True)
self.run_test(MaskedSelectModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_pixel_shuffle(self):
class PixelShuffle(torch.nn.Module):
def forward(self, x):
return torch.pixel_shuffle(x, upscale_factor=2)
x = torch.randn(2, 16, 4, 3, requires_grad=True)
self.run_test(PixelShuffle(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_scalar_type(self):
class ArithmeticModel(torch.nn.Module):
def forward(self, x):
return x.size(0) * 2 * x
x = torch.ones(2, 3, dtype=torch.float32)
self.run_test(ArithmeticModel(), x)
class ReciprocalModel(torch.nn.Module):
def forward(self, x):
return torch.reciprocal(x)
x = torch.tensor([2.0, 4.0], dtype=torch.double)
self.run_test(ReciprocalModel(), x)
class ComparisonModel(torch.nn.Module):
def forward(self, x, y):
a = torch.tensor([12.0])
return x.lt(1.5) & y.le(2) & x.le(1), x.gt(y), x.lt(y), a.ge(x.size(0))
x = torch.ones(2, 3, dtype=torch.int32)
y = torch.ones(2, 3, dtype=torch.float32)
self.run_test(ComparisonModel(), (x, y))
class MatMulModel(torch.nn.Module):
def forward(self, x):
return (torch.mm(x, x) + x + torch.mm(x, x) + x)
x = torch.ones(3, 3)
self.run_test(MatMulModel(), x)
class AddMMModel(torch.nn.Module):
def forward(self, x):
return torch.mm(x, x) + x
x = torch.ones(3, 3)
self.run_test(AddMMModel(), x)
class FullModel(torch.nn.Module):
# add is used for exporting full
def forward(self, x):
return torch.full((3, 4), x)
x = torch.tensor(12.)
self.run_test(FullModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
@disableScriptTest() # dtype mismatch
def test_full_like(self):
class FullLikeModel(torch.nn.Module):
def forward(self, x):
return torch.full_like(x, 4)
x = torch.tensor(12)
self.run_test(FullLikeModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
@disableScriptTest() # dtype mismatch
def test_full_like_value(self):
class FullLikeModel(torch.nn.Module):
def forward(self, x, y):
out = y + 2
return torch.full_like(x, out)
x = torch.tensor(12)
y = torch.tensor(2)
self.run_test(FullLikeModel(), (x, y))
def test_l1_norm(self):
class NormModel(torch.nn.Module):
def forward(self, x):
return torch.norm(x, p=1, dim=-1, keepdim=False)
x = torch.randn(4, 2, 3, requires_grad=True)
self.run_test(NormModel(), x)
def test_l2_norm(self):
class NormModel(torch.nn.Module):
def forward(self, x):
return torch.norm(x, p=2, dim=-2, keepdim=False)
x = torch.randn(4, 2, 3, requires_grad=True)
self.run_test(NormModel(), x)
def test_frobenius_norm(self):
class NormModel(torch.nn.Module):
def forward(self, x):
return torch.norm(x, p="fro", dim=0, keepdim=False)
x = torch.randn(4, 2, 3, requires_grad=True)
self.run_test(NormModel(), x)
def test_frobenius_norm_keepdim(self):
class NormModel(torch.nn.Module):
def forward(self, x):
return torch.norm(x, p="fro", dim=(0, 1), keepdim=True)
x = torch.randn(4, 2, 3, requires_grad=True)
self.run_test(NormModel(), x)
def test_unfold(self):
class UnfoldModel(torch.nn.Module):
def forward(self, x):
return x.unfold(dimension=2, size=2, step=2)
x = torch.randn(4, 2, 3, requires_grad=True)
self.run_test(UnfoldModel(), x)
@skipIfONNXShapeInference(False)
def test_unfold_infer_shape(self):
class UnfoldModule(torch.jit.ScriptModule):
def __init__(self):
super(UnfoldModule, self).__init__()
self.conv = torch.nn.Conv1d(3, 1, 3, stride=2)
@torch.jit.script_method
def forward(self, x):
x = self.conv(x)
return x.unfold(dimension=2, size=2, step=2)
x = torch.randn(32, 3, 64)
self.run_test(UnfoldModule(), x)
def test_remainder(self):
class RemainderModel(torch.nn.Module):
def forward(self, input, other):
return torch.remainder(input, other)
x = torch.randn(4, 2, 3)
y = torch.randn(1, 2, 1)
self.run_test(RemainderModel(), (x, y))
def test_remainder_scalar(self):
class RemainderModel(torch.nn.Module):
def forward(self, input):
return torch.remainder(input, 2.55)
x = torch.randint(10, (2, 3))
self.run_test(RemainderModel(), x)
@skipIfUnsupportedMinOpsetVersion(10)
def test_fmod(self):
class FModModel(torch.nn.Module):
def forward(self, input, other):
return torch.fmod(input, other)
x = torch.randn(4, 2, 3)
y = torch.randn(1, 2, 1)
self.run_test(FModModel(), (x, y))
@skipIfUnsupportedMinOpsetVersion(10)
def test_fmod_scalar(self):
class FModModel(torch.nn.Module):
def forward(self, input):
return torch.fmod(input, 2.55)
x = torch.randint(10, (2, 3))
self.run_test(FModModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_gelu(self):
class GeluModel(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.gelu(x)
x = torch.randn(2, 4, 5, 6, requires_grad=True)
self.run_test(GeluModel(), x)
def test_add_inplace(self):
class InplaceAddModel(torch.nn.Module):
def forward(self, x):
x += 12
return x
x = torch.randn(4, 2, 3, requires_grad=True)
self.run_test(InplaceAddModel(), x)
def test_rsqrt(self):
class RsqrtModel(torch.nn.Module):
def forward(self, x):
return x.rsqrt()
x = torch.randn(4, 2, 3, requires_grad=True, dtype=torch.float64)
self.run_test(RsqrtModel(), x)
def test_rsqrt_zeros(self):
class RsqrtModel(torch.nn.Module):
def forward(self, x):
return x.rsqrt()
x = torch.zeros(4, 2, 3, requires_grad=True, dtype=torch.float64)
self.run_test(RsqrtModel(), x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_unique(self):
class UniqueModel(torch.nn.Module):
def forward(self, x):
return torch.unique(x, sorted=True, return_inverse=False, return_counts=True)
x = torch.tensor([1, 3, 2, 3], dtype=torch.long)
self.run_test(UniqueModel(), x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_unique_along_dim(self):
class UniqueModel(torch.nn.Module):
def forward(self, x):
return torch.unique(x, dim=0, sorted=True, return_inverse=True, return_counts=False)
x = torch.tensor([1, 3, 2, 3], dtype=torch.long)
self.run_test(UniqueModel(), x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_cumsum(self):
class CumSum(torch.nn.Module):
def forward(self, input):
return torch.cumsum(input, dim=0)
x = torch.randn(2, 3, 4)
model = CumSum()
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_cumsum_with_cast(self):
class CumSum(torch.nn.Module):
def forward(self, input):
return torch.cumsum(input, dim=0, dtype=torch.float32)
model = CumSum()
x = torch.tensor([2, 3, 4], dtype=torch.int32)
self.run_test(model, x)
x = torch.tensor([False, True, True])
self.run_test(model, x)
@disableScriptTest() # error in propagate as assign input shape
@skipIfUnsupportedMinOpsetVersion(10)
@skipIfUnsupportedOpsetVersion([12]) # Due to ONNX Loop shape inference issue
def test_embedding_bag(self):
model = torch.nn.EmbeddingBag(10, 5, mode='sum', scale_grad_by_freq=True)
input = torch.randint(10, (7,))
offset = torch.tensor([0, 2, 5, 6])
self.run_test(model, (input, offset))
model = torch.nn.EmbeddingBag(10, 5, mode='sum', include_last_offset=True)
input = torch.randint(10, (7,))
offset = torch.tensor([0, 2, 5, 6])
self.run_test(model, (input, offset))
model = torch.nn.EmbeddingBag(10, 5, mode='max')
input = torch.randint(10, (7, 5))
self.run_test(model, (input))
@disableScriptTest() # scripting prim::Uninitialized, prim::dtype, prim::unchecked_cast
@skipIfUnsupportedMinOpsetVersion(10)
@skipIfUnsupportedOpsetVersion([12]) # Due to ONNX Loop shape inference issue
def test_embedding_bag_1d_per_sample_weights(self):
class EmbeddingModel(torch.nn.Module):
def forward(self, embedding_matrix, input, offset, weights):
return torch.nn.functional.embedding_bag(input, embedding_matrix, offsets=offset,
mode='sum', per_sample_weights=weights)
model = EmbeddingModel()
x = torch.randint(7, (6,))
w = torch.randn(6, )
offset = torch.tensor([0, 2, 5])
embedding_matrix = torch.rand(10, 15)
self.run_test(model, (embedding_matrix, x, offset, w))
@disableScriptTest() # scripting prim::Uninitialized, prim::dtype, prim::unchecked_cast
@skipIfUnsupportedMinOpsetVersion(10)
@skipIfUnsupportedOpsetVersion([12]) # Due to ONNX Loop shape inference issue
def test_embedding_bag_2d_per_sample_weights(self):
class EmbeddingModel(torch.nn.Module):
def forward(self, embedding_matrix, input, weights):
return torch.nn.functional.embedding_bag(input, embedding_matrix,
mode='sum', per_sample_weights=weights)
embedding_matrix = torch.rand(10, 15)
model = EmbeddingModel()
x = torch.randint(7, (2, 3))
w = torch.randn(2, 3)
self.run_test(model, (embedding_matrix, x, w))
@disableScriptTest() # scripting prim::Uninitialized, prim::dtype, prim::unchecked_cast
@skipIfUnsupportedMinOpsetVersion(11)
@unittest.skip("Due to ONNX Loop shape inference issue.")
def test_embedding_bag_dynamic_input(self):
class EmbeddingModel1D(torch.nn.Module):
def forward(self, embedding_matrix, input, weights, offsets):
return torch.nn.functional.embedding_bag(input, embedding_matrix, offsets=offsets,
mode='sum', per_sample_weights=weights)
model = EmbeddingModel1D()
x = torch.randint(7, (6,))
w = torch.randn(6, )
offsets = torch.tensor([0, 2, 5], dtype=torch.long)
embedding_matrix = torch.rand(10, 15)
x2 = torch.randint(7, (2,))
w2 = torch.randn(2, )
embedding_matrix2 = torch.rand(12, 25)
offsets2 = torch.tensor([0, ], dtype=torch.long)
self.run_test(model, (embedding_matrix, x, w, offsets),
test_with_inputs=[(embedding_matrix2, x2, w2, offsets2)],
input_names=['embedding_matrix', 'x', 'offsets', 'w'],
dynamic_axes={'embedding_matrix': [0, 1], 'x': [0], 'offsets': [0], 'w': [0]})
class EmbeddingModel2D(torch.nn.Module):
def forward(self, embedding_matrix, input, weights):
return torch.nn.functional.embedding_bag(input, embedding_matrix,
mode='sum', per_sample_weights=weights)
model = EmbeddingModel2D()
x = torch.randint(7, (2, 3))
w = torch.randn(2, 3)
embedding_matrix = torch.rand(10, 15)
x2 = torch.randint(7, (3, 5))
w2 = torch.randn(3, 5)
embedding_matrix2 = torch.rand(12, 25)
self.run_test(model, (embedding_matrix, x, w),
test_with_inputs=[(embedding_matrix2, x2, w2)],
input_names=['embedding_matrix', 'x', 'w'],
dynamic_axes={'embedding_matrix': [0, 1], 'x': [0, 1], 'w': [0, 1]})
@skipIfUnsupportedMinOpsetVersion(8)
def test_meshgrid(self):
class Meshgrid(torch.nn.Module):
def forward(self, x, y, z):
output1, output2, output3 = torch.meshgrid(x, y, z)
return output1, output2, output3
x = torch.randn(3, requires_grad=True)
y = torch.zeros(4, requires_grad=True)
z = torch.randn(5, requires_grad=True)
self.run_test(Meshgrid(), (x, y, z))
@skipIfUnsupportedMinOpsetVersion(8)
def test_meshgrid_scalar(self):
class Meshgrid(torch.nn.Module):
def forward(self, x, y, z):
output1, output2, output3 = torch.meshgrid(x, y, z)
return output1, output2, output3
x = torch.ones(3, requires_grad=True)
y = torch.zeros(4, requires_grad=True)
z = torch.tensor(2.0)
self.run_test(Meshgrid(), (x, y, z))
def test_baddbmm(self):
class MyModule(torch.nn.Module):
def forward(self, input, batch1, batch2):
return torch.baddbmm(input, batch1, batch2, alpha=torch.tensor(5), beta=3.5)
x = torch.randn(10, 3, 5)
batch1 = torch.randn(10, 3, 4)
batch2 = torch.randn(10, 4, 5)
model = MyModule()
self.run_test(model, (x, batch1, batch2))
def test_baddbmm_dynamic(self):
class MyModule(torch.nn.Module):
def forward(self, input, batch1, batch2, alpha, beta):
return torch.baddbmm(input, batch1, batch2, alpha=alpha, beta=beta)
x = torch.randn(10, 3, 5)
batch1 = torch.randn(10, 3, 4)
batch2 = torch.randn(10, 4, 5)
alpha = torch.tensor(5)
beta = torch.tensor(3.5)
model = MyModule()
self.run_test(model, (x, batch1, batch2, alpha, beta))
def test_numel(self):
class MyModule(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
return input.numel() * input
x = torch.randn(2, 3, 5)
model = MyModule()
self.run_test(model, (x,))
def test_numel_empty(self):
class MyModule(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
return input.numel() * input
x = torch.randn(0)
model = MyModule()
self.run_test(model, (x,))
def test_cast_to(self):
class MyModule(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input, other):
return input.to(other) + other
x = torch.randn(2, 3, 4)
y = torch.tensor([1], dtype=torch.int64)
model = MyModule()
self.run_test(model, (x, y))
def test_cast_to_bool(self):
class MyModule(torch.nn.Module):
def forward(self, input, other):
return torch.cat((input.to(other), other), 0)
x = torch.randn(2, 3, 4)
y = torch.zeros([2, 3, 4], dtype=torch.bool)
model = MyModule()
self.run_test(model, (x, y))
@skipIfUnsupportedMinOpsetVersion(9)
def test_ones_bool(self):
class MyModule(torch.nn.Module):
def forward(self, input):
true = torch.ones(input.shape, dtype=torch.bool)
return input.to(true) & true
x = torch.randn(2, 3, 4)
model = MyModule()
self.run_test(model, x)
def test_log(self):
class Log(torch.nn.Module):
def forward(self, input):
return torch.log(input)
x = torch.rand(2, 3, 4)
model = Log()
self.run_test(model, x)
def test_log1p(self):
class Log1p(torch.nn.Module):
def forward(self, input):
return torch.log1p(input)
x = torch.rand(2, 3, 4)
model = Log1p()
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_round(self):
class Round(torch.nn.Module):
def forward(self, x):
return torch.round(x)
x = torch.tensor([0.9920, -1.0362, -1.5000, 3.5000], requires_grad=True)
self.run_test(Round(), x)
def test_constant_pad(self):
model = torch.nn.ConstantPad1d(2, 3.5)
x = torch.randn(2, 4, 4)
self.run_test(model, x)
model = torch.nn.ConstantPad2d((3, 0, 2, 1), 3.5)
x = torch.randn(2, 2, 4, 4)
self.run_test(model, x)
# Dynamic padding is added in opset 11
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest() # Functional module not scriptable
def test_pad_types(self):
# Test for different pad integer types
class Pad(torch.nn.Module):
def forward(self, x, pad):
return torch.nn.functional.pad(x, pad)
x = torch.randn(2, 2, 4, 4)
y = pad = (torch.tensor(2, dtype=torch.int32), torch.tensor(4, dtype=torch.int32))
self.run_test(Pad(), (x, y))
y = pad = (torch.tensor(2, dtype=torch.int64), torch.tensor(4, dtype=torch.int64))
self.run_test(Pad(), (x, y))
@skipIfUnsupportedMaxOpsetVersion(10)
def test_unsupported_pad(self):
class Pad(torch.nn.Module):
def forward(self, x, pad):
return torch.nn.functional.pad(x, pad)
def run():
x = torch.randn(2, 2, 4, 4)
y = pad = (torch.tensor(2, dtype=torch.int32), torch.tensor(4, dtype=torch.int32))
p = Pad()
f = io.BytesIO()
torch.onnx._export(p, (x, y), f)
with self.assertRaises(RuntimeError) as cm:
run()
the_exception = cm.exception
self.assertEqual('Unsupported: ONNX export of Pad in opset 9. The sizes of the padding must be constant. ' +
'Please try opset version 11.', the_exception.args[0])
@disableScriptTest() # export prim::Uninitialized
def test_reflection_pad(self):
model = torch.nn.ReflectionPad1d(2)
x = torch.randn(2, 4, 4)
self.run_test(model, x)
model = torch.nn.ReflectionPad2d((3, 0, 2, 1))
x = torch.randn(2, 2, 4, 4)
self.run_test(model, x)
@disableScriptTest() # export prim::Uninitialized
def test_replication_pad(self):
model = torch.nn.ReplicationPad1d(2)
x = torch.randn(2, 4, 4)
self.run_test(model, x)
model = torch.nn.ReplicationPad2d((3, 0, 2, 1))
x = torch.randn(2, 2, 4, 4)
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest() # export prim::Uninitialized
def test_im2col(self):
class Unfold(torch.nn.Module):
def forward(self, input):
return torch.nn.functional.unfold(input, kernel_size=(10, 15), dilation=2, padding=5, stride=3), \
torch.nn.functional.unfold(input, kernel_size=(2, 2), dilation=1, padding=0, stride=3), \
torch.nn.functional.unfold(input, kernel_size=(1, 1), dilation=5, padding=2, stride=3)
x = torch.rand(1, 1, 200, 100)
self.run_test(Unfold(), x)
@skipIfNoLapack
@skipIfUnsupportedMinOpsetVersion(11)
def test_det(self):
class Det(torch.nn.Module):
def forward(self, x):
return torch.det(x)
x = torch.randn(2, 3, 5, 5)
self.run_test(Det(), x)
# This test checks output scalar type in the ONNX graph should not be null
# https://github.com/pytorch/pytorch/issues/28607
@skipIfUnsupportedMinOpsetVersion(10)
def test_trace_script(self):
@torch.jit.script
def center_slice_helper(input, h_offset):
return input[:, h_offset:]
class CenterCrop(torch.nn.Module):
def forward(self, input):
return center_slice_helper(input, torch.tensor(input.shape[1] - 1))
x = torch.randn(3, 4)
self.run_test(CenterCrop(), x)
@skipIfNoLapack
@skipIfUnsupportedMinOpsetVersion(11)
def test_logdet(self):
class LogDet(torch.nn.Module):
def forward(self, x):
return torch.logdet(x)
x = torch.randn(2, 3, 5, 5)
self.run_test(LogDet(), x)
def test_dim(self):
class DimModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
out = input * 2
out *= out.dim()
return out
empty_input = torch.randn(0, requires_grad=True)
multi_dim_input = torch.randn(1, 2, 3, requires_grad=True)
self.run_test(DimModel(), empty_input)
self.run_test(DimModel(), multi_dim_input)
@skipIfUnsupportedMinOpsetVersion(12)
@disableScriptTest() # variable number of inputs not scriptable
def test_einsum(self):
class EinsumModelBatchDiagonal(torch.nn.Module):
def forward(self, *tensor_list):
eqn = '...ii ->...i'
return torch.einsum(eqn, *tensor_list)
x = torch.randn(3, 5, 5)
self.run_test(EinsumModelBatchDiagonal(), input=(x,))
class EinsumModelBatchMatmul(torch.nn.Module):
def forward(self, *tensor_list):
eqn = 'bij, bjk -> bik'
return torch.einsum(eqn, *tensor_list)
x = torch.randn(5, 2, 3)
y = torch.randn(5, 3, 4)
self.run_test(EinsumModelBatchMatmul(), input=(x, y))
class EinsumModelInnerProd(torch.nn.Module):
def forward(self, *tensor_list):
eqn = 'i,i'
return torch.einsum(eqn, *tensor_list)
x = torch.randn(5)
y = torch.randn(5)
self.run_test(EinsumModelInnerProd(), input=(x, y))
class EinsumModelTranspose(torch.nn.Module):
def forward(self, *tensor_list):
eqn = 'ij->ji'
return torch.einsum(eqn, *tensor_list)
x = torch.randn(3, 4)
self.run_test(EinsumModelTranspose(), input=(x,))
@skipIfUnsupportedMinOpsetVersion(12)
@disableScriptTest() # shape/type inference
def test_crossentropyloss(self):
for ignore_index in [-100, 1]:
x = torch.randn(3, 5)
y = torch.empty(3, dtype=torch.long).random_(5)
y[y == 1] = ignore_index
self._crossentropyloss(x, y, ignore_index)
x = torch.randn(3, 5, 2)
y = torch.empty(3, 2, dtype=torch.long).random_(5)
y[y == 1] = ignore_index
self._crossentropyloss(x, y, ignore_index)
x = torch.randn(3, 5, 2, 7)
y = torch.empty(3, 2, 7, dtype=torch.long).random_(5)
y[y == 1] = ignore_index
self._crossentropyloss(x, y, ignore_index)
def _crossentropyloss(self, x, y, ignore_index):
class CrossEntropyLossNone(torch.nn.Module):
def __init__(self, ignore_index):
super(CrossEntropyLossNone, self).__init__()
if ignore_index == -100:
self.loss = torch.nn.CrossEntropyLoss(reduction='none')
else:
self.loss = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=ignore_index)
def forward(self, input, target):
return self.loss(input, target)
self.run_test(CrossEntropyLossNone(ignore_index), input=(x, y))
class CrossEntropyLossNoneWeight(torch.nn.Module):
def __init__(self, ignore_index):
super(CrossEntropyLossNoneWeight, self).__init__()
if ignore_index == -100:
self.loss = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.randn(5))
else:
self.loss = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.randn(5), ignore_index=ignore_index)
def forward(self, input, target):
return self.loss(input, target)
self.run_test(CrossEntropyLossNoneWeight(ignore_index), input=(x, y))
class CrossEntropyLossSum(torch.nn.Module):
def __init__(self, ignore_index):
super(CrossEntropyLossSum, self).__init__()
if ignore_index == -100:
self.loss = torch.nn.CrossEntropyLoss(reduction='sum')
else:
self.loss = torch.nn.CrossEntropyLoss(reduction='sum', ignore_index=ignore_index)
def forward(self, input, target):
return self.loss(input, target)
self.run_test(CrossEntropyLossSum(ignore_index), input=(x, y))
class CrossEntropyLossSumWeight(torch.nn.Module):
def __init__(self, ignore_index):
super(CrossEntropyLossSumWeight, self).__init__()
if ignore_index == -100:
self.loss = torch.nn.CrossEntropyLoss(reduction='sum', weight=torch.randn(5))
else:
self.loss = torch.nn.CrossEntropyLoss(reduction='sum', weight=torch.randn(5), ignore_index=ignore_index)
def forward(self, input, target):
return self.loss(input, target)
self.run_test(CrossEntropyLossSumWeight(ignore_index), input=(x, y))
class CrossEntropyLossMean(torch.nn.Module):
def __init__(self, ignore_index):
super(CrossEntropyLossMean, self).__init__()
if ignore_index == -100:
self.loss = torch.nn.CrossEntropyLoss()
else:
self.loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
def forward(self, input, target):
return self.loss(input, target)
self.run_test(CrossEntropyLossMean(ignore_index), input=(x, y))
class CrossEntropyLossMeanWeight(torch.nn.Module):
def __init__(self, ignore_index):
super(CrossEntropyLossMeanWeight, self).__init__()
if ignore_index == -100:
self.loss = torch.nn.CrossEntropyLoss(weight=torch.randn(5))
else:
self.loss = torch.nn.CrossEntropyLoss(weight=torch.randn(5), ignore_index=ignore_index)
def forward(self, input, target):
return self.loss(input, target)
self.run_test(CrossEntropyLossMeanWeight(ignore_index), input=(x, y))
@skipIfUnsupportedMinOpsetVersion(9)
@disableScriptTest() # Output dtype mismatch
def test_kldiv_loss(self):
x = torch.randn(5)
y = torch.randn(5)
self._kldiv_loss(x, y)
x = torch.randn(2, 3, 5)
y = torch.randn(2, 3, 5)
self._kldiv_loss(x, y)
x = torch.randn(2, 3, 5, 7)
y = torch.randn(2, 3, 5, 7)
self._kldiv_loss(x, y)
def _kldiv_loss(self, x, y):
class KLDivLossNone(torch.nn.Module):
def __init__(self):
super(KLDivLossNone, self).__init__()
self.loss = torch.nn.KLDivLoss(reduction='none', log_target=True)
def forward(self, input, target):
return self.loss(input, target)
self.run_test(KLDivLossNone(), input=(x, y))
class KLDivLossMean(torch.nn.Module):
def __init__(self):
super(KLDivLossMean, self).__init__()
self.loss = torch.nn.KLDivLoss(reduction='mean', log_target=False)
def forward(self, input, target):
return self.loss(input, target)
self.run_test(KLDivLossMean(), input=(x, y))
class KLDivLossSum(torch.nn.Module):
def __init__(self):
super(KLDivLossSum, self).__init__()
self.loss = torch.nn.KLDivLoss(reduction='sum', log_target=True)
def forward(self, input, target):
return self.loss(input, target)
self.run_test(KLDivLossSum(), input=(x, y))
class KLDivLossBatchMean(torch.nn.Module):
def __init__(self):
super(KLDivLossBatchMean, self).__init__()
self.loss = torch.nn.KLDivLoss(reduction='batchmean', log_target=False)
def forward(self, input, target):
return self.loss(input, target)
self.run_test(KLDivLossBatchMean(), input=(x, y))
class KLDivLossMiniBatchMean(torch.nn.Module):
def __init__(self):
super(KLDivLossMiniBatchMean, self).__init__()
self.loss = torch.nn.KLDivLoss(reduction='batchmean', size_average=False, log_target=True)
def forward(self, input, target):
return self.loss(input, target)
self.run_test(KLDivLossMiniBatchMean(), input=(x, y))
@skipIfUnsupportedMinOpsetVersion(12)
@disableScriptTest() # shape/type inference
def test_nllloss(self):
class NLLModel(torch.nn.Module):
def __init__(self):
super(NLLModel, self).__init__()
self.loss = torch.nn.NLLLoss(reduction='none')
self.m = torch.nn.LogSoftmax(dim=1)
def forward(self, input, target):
output = self.loss(self.m(2 * input), target)
return output
N, C = 5, 4
input = torch.randn(N, 16)
target = torch.empty(N, dtype=torch.long).random_(0, C)
# using test data containing default ignore_index=-100
target[target == 1] = -100
self.run_test(NLLModel(), (input, target))
@skipIfUnsupportedMinOpsetVersion(12)
@disableScriptTest() # shape/type inference
def test_nllloss_2d_none(self):
class NLLModel(torch.nn.Module):
def __init__(self):
super(NLLModel, self).__init__()
self.loss = torch.nn.NLLLoss(reduction='none')
self.conv = torch.nn.Conv2d(16, C, (3, 3))
self.m = torch.nn.LogSoftmax(dim=1)
def forward(self, input, target):
output = self.loss(self.m(self.conv(input)), target)
return output
N, C = 5, 4
input = torch.randn(N, 16, 10, 10)
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
# using test data containing default ignore_index=-100
target[target == 1] = -100
self.run_test(NLLModel(), (input, target))
@skipIfUnsupportedMinOpsetVersion(12)
@disableScriptTest() # shape/type inference
def test_nllloss_2d_mean(self):
class NLLModel(torch.nn.Module):
def __init__(self):
super(NLLModel, self).__init__()
self.loss = torch.nn.NLLLoss(reduction='mean')
self.conv = torch.nn.Conv2d(16, C, (3, 3))
self.m = torch.nn.LogSoftmax(dim=1)
def forward(self, input, target):
output = self.loss(self.m(self.conv(input)), target)
return output
N, C = 5, 4
input = torch.randn(N, 16, 10, 10)
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
# using test data containing default ignore_index=-100
target[target == 1] = -100
self.run_test(NLLModel(), (input, target))
@skipIfUnsupportedMinOpsetVersion(12)
@disableScriptTest() # shape/type inference
def test_nllloss_2d_sum(self):
class NLLModel(torch.nn.Module):
def __init__(self):
super(NLLModel, self).__init__()
self.loss = torch.nn.NLLLoss(reduction='sum')
self.conv = torch.nn.Conv2d(16, C, (3, 3))
self.m = torch.nn.LogSoftmax(dim=1)
def forward(self, input, target):
output = self.loss(self.m(self.conv(input)), target)
return output
N, C = 5, 4
input = torch.randn(N, 16, 10, 10)
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
# using test data containing default ignore_index=-100
target[target == 1] = -100
self.run_test(NLLModel(), (input, target))
@skipIfUnsupportedMinOpsetVersion(12)
@disableScriptTest() # shape/type inference
def test_nllloss_2d_mean_weights(self):
class NLLModel(torch.nn.Module):
def __init__(self):
super(NLLModel, self).__init__()
self.loss = torch.nn.NLLLoss(reduction='mean', weight=torch.randn(C))
self.conv = torch.nn.Conv2d(16, C, (3, 3))
self.m = torch.nn.LogSoftmax(dim=1)
def forward(self, input, target):
output = self.loss(self.m(self.conv(input)), target)
return output
N, C = 5, 4
input = torch.randn(N, 16, 10, 10)
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
# using test data containing default ignore_index=-100
target[target == 1] = -100
self.run_test(NLLModel(), (input, target))
@skipIfUnsupportedMinOpsetVersion(12)
@disableScriptTest() # shape/type inference
def test_nllloss_2d_mean_ignore_index(self):
class NLLModel(torch.nn.Module):
def __init__(self):
super(NLLModel, self).__init__()
self.loss = torch.nn.NLLLoss(reduction='mean', ignore_index=1)
self.conv = torch.nn.Conv2d(16, C, (3, 3))
self.m = torch.nn.LogSoftmax(dim=1)
def forward(self, input, target):
output = self.loss(self.m(self.conv(input)), target)
return output
N, C = 5, 4
input = torch.randn(N, 16, 10, 10)
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
self.run_test(NLLModel(), (input, target))
@skipIfUnsupportedMinOpsetVersion(12)
@disableScriptTest() # shape/type inference
def test_nllloss_2d_mean_ignore_index_weights(self):
class NLLModel(torch.nn.Module):
def __init__(self):
super(NLLModel, self).__init__()
self.loss = torch.nn.NLLLoss(reduction='mean', weight=torch.randn(C), ignore_index=1)
self.conv = torch.nn.Conv2d(16, C, (3, 3))
self.m = torch.nn.LogSoftmax(dim=1)
def forward(self, input, target):
output = self.loss(self.m(self.conv(input)), target)
return output
N, C = 5, 4
input = torch.randn(N, 16, 10, 10)
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
self.run_test(NLLModel(), (input, target))
def test_torch_mm(self):
class M(torch.nn.Module):
def forward(self, mat1, mat2):
mm = torch.mm(mat1, mat2)
return mm
mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
self.run_test(M(), input=(mat1, mat2))
@skipIfUnsupportedMinOpsetVersion(9) # Because where op is not supported for opset < 9.
def test_where_with_bool_tensor(self):
class M(torch.nn.Module):
def forward(self, mat1, mat2):
out = torch.where(mat1 > 0, mat1, mat2)
return out
mat1 = torch.randn(2, 3)
mat2 = torch.ones(2, 3)
self.run_test(M(), input=(mat1, mat2))
@skipIfUnsupportedMinOpsetVersion(9) # Because where op is not supported for opset < 9.
def test_where_with_byte_tensor(self):
class M(torch.nn.Module):
def forward(self, cond, mat1, mat2):
out = torch.where(cond, mat1, mat2)
return out
cond = torch.ones(2, 3, dtype=torch.uint8)
cond[1, 2] = 0
mat1 = torch.randn(2, 3)
mat2 = torch.ones(2, 3)
self.run_test(M(), input=(cond, mat1, mat2))
def test_dropout(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.dropout = torch.nn.Dropout(0.3)
def forward(self, x):
dropout = self.dropout(x)
return dropout
x = torch.randn(10, 3, 53)
self.run_test(M(), (x))
def test_shape_constant_fold(self):
class ShapeModule(torch.nn.Module):
def __init__(self):
super(ShapeModule, self).__init__()
self.register_buffer("weight", torch.ones(5))
def forward(self, x):
shape = self.weight.shape[0]
return x + shape
x = torch.randn(2, 5)
self.run_test(ShapeModule(), (x,), rtol=1e-3, atol=1e-5)
@skipIfUnsupportedMinOpsetVersion(12)
def test_celu(self):
class Celu(torch.nn.Module):
def __init__(self):
super(Celu, self).__init__()
self.celu = torch.nn.CELU(alpha=1.0)
def forward(self, input):
return self.celu(input)
input = torch.randn(2)
self.run_test(Celu(), (input,))
@skipIfUnsupportedMinOpsetVersion(12)
def test_celu_default(self):
class Celu(torch.nn.Module):
def __init__(self):
super(Celu, self).__init__()
self.celu = torch.nn.CELU()
def forward(self, input):
return self.celu(input)
input = torch.randn(2)
self.run_test(Celu(), (input,))
@skipIfUnsupportedMinOpsetVersion(12)
def test_celu_alpha(self):
class Celu(torch.nn.Module):
def __init__(self):
super(Celu, self).__init__()
self.celu = torch.nn.CELU(alpha=2.)
def forward(self, input):
return self.celu(input)
input = torch.randn(2)
self.run_test(Celu(), (input,))
@skipIfUnsupportedMinOpsetVersion(12)
def test_celu_cast(self):
class Celu(torch.nn.Module):
def __init__(self):
super(Celu, self).__init__()
self.celu = torch.nn.CELU()
def forward(self, input):
return self.celu(input)
input = torch.randn(2, 5, 7, dtype=torch.float64)
self.run_test(Celu(), (input,))
@skipIfUnsupportedMinOpsetVersion(9)
def test_where(self):
class Model(torch.nn.Module):
def forward(self, cond, input, other):
return torch.where(cond, input, other)
x = torch.randint(0, 1, (2, 3, 4), dtype=torch.bool)
y = torch.randn(2, 1, 4)
z = torch.ones(2, 3, 1)
self.run_test(Model(), (x, y, z))
@skipIfUnsupportedMinOpsetVersion(9)
@disableScriptTest() # symbolic update needed for unbind: ONNX export of unbind with dynamic number of outputs
def test_where_condition(self):
class Model1(torch.nn.Module):
def forward(self, input):
return torch.stack(torch.where(input > 0.5), dim=1)
x = torch.randint(0, 2, (2, 3, 4), dtype=bool)
self.run_test(Model1(), (x))
class Model2(torch.nn.Module):
def forward(self, input, other):
return torch.stack(torch.where(input > other), dim=1)
x = torch.randint(0, 1, (2, 3, 4), dtype=bool)
y = torch.randint(1, 2, (2, 3, 4), dtype=bool)
self.run_test(Model2(), (x, y))
def test_empty_branch(self):
class EmptyBranchModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
out = input + 1
if out.dim() > 2:
if out.dim() > 3:
out += 3
else:
pass
else:
pass
return out
x = torch.randn(1, 2, 3, requires_grad=True)
self.run_test(EmptyBranchModel(), x)
@skipIfONNXShapeInference(False)
@skipIfUnsupportedMinOpsetVersion(11)
def test_if_transpose(self):
class IfModel(torch.nn.Module):
def forward(self, x):
x = x.transpose(0, 1)
if x.size(0) == 2:
return x.transpose(0, 1)
else:
return x
x = torch.randn(2, 3)
self.run_test(torch.jit.script(IfModel()), x,
output_names=['output_1'],
dynamic_axes={'output_1': [0, 1]})
def test_onnx_proto_checker(self):
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
return 2 * x
x = torch.randn(1, 2, 3, requires_grad=True)
f = io.BytesIO()
torch.onnx._export(Model(), x, f)
model = onnx.load(f)
model.ir_version = 0
def check_proto():
torch._C._check_onnx_proto(model.SerializeToString())
self.assertRaises(RuntimeError, check_proto)
@disableScriptTest() # dtype mismatch
def test_split_tensor_scalar(self):
class SplitModel(torch.nn.Module):
def forward(self, x):
return torch.split(x, x.size(1))
x = torch.randn(1, 2, 3, requires_grad=True)
self.run_test(SplitModel(), x)
def test_split_tensor_multi(self):
class SplitModel(torch.nn.Module):
def forward(self, x):
return torch.split(x, torch.ones(3))
x = torch.randn(1, 2, 3, requires_grad=True)
def run_model():
SplitModel(x)
self.assertRaises(TypeError, run_model)
def _dispatch_rnn_test(self, name, *args, **kwargs):
if name == 'elman':
self._elman_rnn_test(*args, **kwargs)
if name == 'lstm':
self._lstm_test(*args, **kwargs)
if name == 'gru':
self._gru_test(*args, **kwargs)
def _elman_rnn_test(self, layers, nonlinearity, bidirectional,
initial_state, packed_sequence, dropout):
batch_first = True if packed_sequence == 2 else False
model = torch.nn.RNN(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers, nonlinearity=nonlinearity,
bidirectional=bidirectional, dropout=dropout, batch_first=batch_first)
if packed_sequence == 1:
model = RnnModelWithPackedSequence(model, False)
if packed_sequence == 2:
model = RnnModelWithPackedSequence(model, True)
def make_input(batch_size):
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
inputs = [inputs]
directions = 2 if bidirectional else 1
if initial_state:
h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
inputs.append(h0)
if packed_sequence != 0:
inputs.append(torch.IntTensor(seq_lengths))
if len(inputs) == 1:
input = inputs[0]
else:
input = tuple(inputs)
return input
input = make_input(RNN_BATCH_SIZE)
self.run_test(model, input, batch_size=RNN_BATCH_SIZE)
# test that the model still runs with a different batch size
other_input = make_input(RNN_BATCH_SIZE + 1)
self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1)
def _lstm_test(self, layers, bidirectional, initial_state,
packed_sequence, dropout):
batch_first = True if packed_sequence == 2 else False
model = LstmFlatteningResult(
RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers,
bidirectional=bidirectional, dropout=dropout, batch_first=batch_first)
if packed_sequence == 1:
model = RnnModelWithPackedSequence(model, False)
if packed_sequence == 2:
model = RnnModelWithPackedSequence(model, True)
def make_input(batch_size):
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
inputs = [inputs]
directions = 2 if bidirectional else 1
if initial_state:
h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
c0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
inputs.append((h0, c0))
if packed_sequence != 0:
inputs.append(torch.IntTensor(seq_lengths))
if len(inputs) == 1:
input = inputs[0]
else:
input = tuple(inputs)
return input
input = make_input(RNN_BATCH_SIZE)
self.run_test(model, input, batch_size=RNN_BATCH_SIZE)
# test that the model still runs with a different batch size
other_input = make_input(RNN_BATCH_SIZE + 1)
self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1)
def _gru_test(self, layers, bidirectional, initial_state,
packed_sequence, dropout):
batch_first = True if packed_sequence == 2 else False
model = torch.nn.GRU(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers, bidirectional=bidirectional, dropout=dropout,
batch_first=batch_first)
if packed_sequence == 1:
model = RnnModelWithPackedSequence(model, False)
if packed_sequence == 2:
model = RnnModelWithPackedSequence(model, True)
def make_input(batch_size):
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
inputs = [inputs]
directions = 2 if bidirectional else 1
if initial_state:
h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
inputs.append(h0)
if packed_sequence != 0:
inputs.append(torch.IntTensor(seq_lengths))
if len(inputs) == 1:
input = inputs[0]
else:
input = tuple(inputs)
return input
input = make_input(RNN_BATCH_SIZE)
self.run_test(model, input, batch_size=RNN_BATCH_SIZE)
# test that the model still runs with a different batch size
other_input = make_input(RNN_BATCH_SIZE + 1)
self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1)
@skipIfUnsupportedMinOpsetVersion(10)
def test_fake_quantize_per_tensor(self):
class FakeQuantizePerTensorModel(torch.nn.Module):
def forward(self, input):
scale = 1. / 127
zero_point = 0
quant_min = -128
quant_max = 127
return torch.fake_quantize_per_tensor_affine(input, scale, zero_point, quant_min, quant_max)
x = torch.randn(6, 4, 3, 3)
self.run_test(FakeQuantizePerTensorModel(), (x))
@skipIfUnsupportedMinOpsetVersion(12)
def test_dropout_training(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.dropout = torch.nn.Dropout(0.4)
def forward(self, x):
dropout = self.dropout(x)
return dropout
model = MyModule()
x = torch.randn(10)
model.train()
ort_sess = convert_to_onnx(model, input=(x,), opset_version=self.opset_version,
training=torch.onnx.TrainingMode.TRAINING)
ort_outs = run_ort(ort_sess, input=(x,))
assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0])))
@skipIfUnsupportedMinOpsetVersion(12)
def test_dropout_training_zero(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.dropout = torch.nn.Dropout(0.5)
def forward(self, x):
dropout = self.dropout(x)
return dropout
model = MyModule()
# ensure there are no zeros in the input
x = torch.randn(10, 3, 128, 128)
y = x.numpy()
y_mask = np.where(y == 0, 1, y)
input = torch.from_numpy(y_mask)
nb_elements = torch.numel(input)
model.train()
ort_sess = convert_to_onnx(model, input=(x,), opset_version=self.opset_version,
training=torch.onnx.TrainingMode.TRAINING)
ort_outs = run_ort(ort_sess, input=(x,))
y = model(input)
output = y.cpu().numpy()
ort_mask = np.where(ort_outs[0] != 0, 1, 0)
pyt_mask = np.where(output != 0, 1, 0)
ratio_pytorch = np.sum(pyt_mask) / nb_elements
ratio_ort = np.sum(ort_mask) / nb_elements
np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)
def test_conv_bn(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True)
self.bn = torch.nn.BatchNorm2d(16, affine=True)
def forward(self, x):
x = self.conv(x)
bn = self.bn(x)
return bn
model = MyModule()
x = torch.randn(10, 3, 128, 128)
ort_sess1 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version,
training=torch.onnx.TrainingMode.TRAINING)
ort_outs1 = run_ort(ort_sess1, input=(x,))
ort_sess2 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version,
training=torch.onnx.TrainingMode.EVAL)
ort_outs2 = run_ort(ort_sess2, input=(x,))
[np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in
zip(ort_outs1, ort_outs2)]
def test_multiple_conv_bn(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.conv2 = torch.nn.Conv2d(64, 2, kernel_size=1, stride=1, padding=0, bias=False)
self.conv3 = torch.nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1, bias=False)
self.bn = torch.nn.BatchNorm2d(64)
self.bn2 = torch.nn.BatchNorm2d(2)
self.relu = torch.nn.ReLU(inplace=True)
self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def forward(self, x):
x = self.conv1(x)
x = self.bn(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn2(x)
x = self.relu(x)
return x
model = MyModule()
x = torch.randn(2, 3, 224, 224)
ort_sess1 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version,
training=torch.onnx.TrainingMode.TRAINING)
ort_outs1 = run_ort(ort_sess1, input=(x,))
ort_sess2 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version,
training=torch.onnx.TrainingMode.EVAL)
ort_outs2 = run_ort(ort_sess2, input=(x,))
[np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in
zip(ort_outs1, ort_outs2)]
def make_test(name, base, layer, bidirectional, initial_state,
variable_length, dropout,
**extra_kwargs):
test_name = str('_'.join([
'test', name, layer[1],
bidirectional[1], initial_state[1],
variable_length[1], dropout[1]
]))
# Cannot export with older opsets because of 'ConstantFill' op
# ConstantFill was a temp op removed at opset 8. This is no longer supported by onnxruntime
@disableScriptTest() # Test code not scriptable
@skipIfUnsupportedMinOpsetVersion(9)
def f(self):
self._dispatch_rnn_test(
base,
layers=layer[0],
bidirectional=bidirectional[0],
initial_state=initial_state[0],
packed_sequence=variable_length[0],
dropout=dropout[0],
**extra_kwargs)
f.__name__ = test_name
setattr(TestONNXRuntime, f.__name__, f)
def setup_rnn_tests():
layers_opts = [
(1, 'unilayer'),
(3, 'trilayer')
]
bidirectional_opts = [
(False, 'forward'),
(True, 'bidirectional')
]
initial_state_opts = [
(True, 'with_initial_state'),
(False, 'no_initial_state')
]
variable_length_opts = [
(0, 'without_sequence_lengths'),
(1, 'with_variable_length_sequences'),
(2, 'with_batch_first_sequence_lengths')
]
dropout_opts = [
(0.2, 'with_dropout'),
(0.0, 'without_dropout')
]
test_count = 0
for (layer, bidirectional, initial_state, variable_length, dropout) in \
itertools.product(
layers_opts,
bidirectional_opts,
initial_state_opts,
variable_length_opts,
dropout_opts,):
for base, name, extra_kwargs in (
('elman', 'elman_relu', {'nonlinearity': u'relu'}),
('elman', 'elman_tanh', {'nonlinearity': u'tanh'}),
('lstm', 'lstm', {}),
('gru', 'gru', {})
):
make_test(name, base, layer, bidirectional, initial_state,
variable_length, dropout,
**extra_kwargs)
test_count += 1
# sanity check that a representative example does exist
TestONNXRuntime.test_gru_trilayer_forward_with_initial_state_without_sequence_lengths_with_dropout
# make sure no one accidentally disables all the tests without
# noticing
if test_count != 192:
raise ValueError('Expected 192 tests but found {}'.format(test_count))
setup_rnn_tests()
# opset 7 tests
TestONNXRuntime_opset7 = type(str("TestONNXRuntime_opset7"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=7))
# opset 8 tests
TestONNXRuntime_opset8 = type(str("TestONNXRuntime_opset8"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=8))
# opset 10 tests
TestONNXRuntime_opset10 = type(str("TestONNXRuntime_opset10"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=10))
# opset 11 tests
TestONNXRuntime_opset11 = type(str("TestONNXRuntime_opset11"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=11))
# opset 12 tests
TestONNXRuntime_opset12 = type(str("TestONNXRuntime_opset12"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=12))
# opset 9 tests, with keep_initializers_as_inputs=False for
# IR version 4 style export.
TestONNXRuntime_opset9_IRv4 = type(str("TestONNXRuntime_opset9_IRv4"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__,
keep_initializers_as_inputs=False))
# opset 10 tests, with keep_initializers_as_inputs=False for
# IR version 4 style export.
TestONNXRuntime_opset10_IRv4 = type(str("TestONNXRuntime_opset10_IRv4"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=10,
keep_initializers_as_inputs=False))
# opset 11 tests, with keep_initializers_as_inputs=False for
# IR version 4 style export.
TestONNXRuntime_opset11_IRv4 = type(str("TestONNXRuntime_opset11_IRv4"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=11,
keep_initializers_as_inputs=False))
# opset 12 tests, with keep_initializers_as_inputs=False for
# IR version 4 style export.
TestONNXRuntime_opset12_IRv4 = type(str("TestONNXRuntime_opset12_IRv4"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=12,
keep_initializers_as_inputs=False))
# opset 9 tests, with use_new_jit_passes=True for using new jit API,
# and with keep_initializers_as_inputs=False for IR version 4 style export.
TestONNXRuntime_opset9_IRv4_new_jit_API = type(str("TestONNXRuntime_opset9_IRv4_new_jit_API"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__,
keep_initializers_as_inputs=False,
use_new_jit_passes=True,
onnx_shape_inference=True))
# opset 12 tests, with use_new_jit_passes=True for using new jit API,
# and keep_initializers_as_inputs=False for IR version 4 style export.
TestONNXRuntime_opset12_IRv4_new_jit_API = type(str("TestONNXRuntime_opset12_IRv4_new_jit_API"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=12,
keep_initializers_as_inputs=False,
use_new_jit_passes=True,
onnx_shape_inference=True))
# opset 12 tests, with _onnx_shape_inference=True.
TestONNXRuntime_opset12_onnx_shape_inference = type(str("TestONNXRuntime_opset12_onnx_shape_inference"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=12,
onnx_shape_inference=True))
if __name__ == '__main__':
unittest.main()