pytorch/test/onnx/test_verify.py
Justin Chu 773d80747c [ONNX] Clean up unit tests, rename files and improve import style (#81141)
- Rename `test_pytorch_common` -> `pytorch_test_common`, `test_onnx_common` -> `onnx_test_common`, removing the test_ prefix to show that the files are not test cases
- Remove import * in `test_pytorch_common` and adjust to import from `testing._internal.common_utils` (where functions are actually defined) instead
- Import modules only in `test_pytorch_onnx_onnxruntime` (too many to handle in a single PR in other tests) (The skips are exceptions)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81141
Approved by: https://github.com/BowenBao
2022-07-12 00:00:49 +00:00

110 lines
3.3 KiB
Python

# Owner(s): ["module: onnx"]
from verify import verify
import caffe2.python.onnx.backend as backend
import torch
from torch.autograd import Function
from torch.nn import Module, Parameter
from torch.testing._internal import common_utils
class TestVerify(common_utils.TestCase):
maxDiff = None
def assertVerifyExpectFail(self, *args, **kwargs):
try:
verify(*args, **kwargs)
except AssertionError as e:
if str(e):
# substring a small piece of string because the exact message
# depends on system's formatting settings
# self.assertExpected(str(e)[:60])
# NB: why we comment out the above check? because numpy keeps
# changing the error format, and we have to keep updating the
# expect files let's relax this constraint
return
else:
raise
# Don't put this in the try block; the AssertionError will catch it
self.assertTrue(False, msg="verify() did not fail when expected to")
def test_result_different(self):
class BrokenAdd(Function):
@staticmethod
def symbolic(g, a, b):
return g.op("Add", a, b)
@staticmethod
def forward(ctx, a, b):
return a.sub(b) # yahaha! you found me!
class MyModel(Module):
def forward(self, x, y):
return BrokenAdd().apply(x, y)
x = torch.tensor([1, 2])
y = torch.tensor([3, 4])
self.assertVerifyExpectFail(MyModel(), (x, y), backend)
def test_jumbled_params(self):
class MyModel(Module):
def __init__(self):
super().__init__()
def forward(self, x):
y = x * x
self.param = Parameter(torch.tensor([2.0]))
return y
x = torch.tensor([1, 2])
with self.assertRaisesRegex(RuntimeError, "state_dict changed"):
verify(MyModel(), x, backend)
def test_dynamic_model_structure(self):
class MyModel(Module):
def __init__(self):
super().__init__()
self.iters = 0
def forward(self, x):
if self.iters % 2 == 0:
r = x * x
else:
r = x + x
self.iters += 1
return r
x = torch.tensor([1, 2])
self.assertVerifyExpectFail(MyModel(), x, backend)
def test_embedded_constant_difference(self):
class MyModel(Module):
def __init__(self):
super().__init__()
self.iters = 0
def forward(self, x):
r = x[self.iters % 2]
self.iters += 1
return r
x = torch.tensor([[1, 2], [3, 4]])
self.assertVerifyExpectFail(MyModel(), x, backend)
def test_explicit_test_args(self):
class MyModel(Module):
def forward(self, x):
if x.data.sum() == 1.0:
return x + x
else:
return x * x
x = torch.tensor([[6, 2]])
y = torch.tensor([[2, -1]])
self.assertVerifyExpectFail(MyModel(), x, backend, test_args=[(y,)])
if __name__ == "__main__":
common_utils.run_tests()