mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Enable models tests (#38791)
Summary: PR to enable model tests which are fixed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/38791 Reviewed By: hl475 Differential Revision: D21732498 Pulled By: houseroad fbshipit-source-id: f417f9d4124ef5a663dc666d5c2ed6ba013b26a4
This commit is contained in:
parent
b789c1790f
commit
7f1c9886cd
|
|
@ -43,10 +43,6 @@ if [[ $PARALLEL == 1 ]]; then
|
|||
args+=("3")
|
||||
fi
|
||||
|
||||
# Skipped tests
|
||||
args+=("-k")
|
||||
args+=('not (TestOperators and test_full_like) and not (TestOperators and test_zeros_like) and not (TestOperators and test_ones_like) and not (TestModels and test_vgg16) and not (TestModels and test_vgg16_bn) and not (TestModels and test_vgg19) and not (TestModels and test_vgg19_bn)')
|
||||
|
||||
# These exclusions are for tests that take a long time / a lot of GPU
|
||||
# memory to run; they should be passing (and you will test them if you
|
||||
# run them locally
|
||||
|
|
|
|||
|
|
@ -70,7 +70,6 @@ class TestModels(TestCase):
|
|||
self.exportTest(toC(SRResNet(rescale_factor=4, n_filters=64, n_blocks=8)), toC(x))
|
||||
|
||||
@skipIfNoLapack
|
||||
@unittest.skip("This model is broken, see https://github.com/pytorch/pytorch/issues/18429")
|
||||
def test_super_resolution(self):
|
||||
x = Variable(
|
||||
torch.randn(BATCH_SIZE, 1, 224, 224).fill_(1.0)
|
||||
|
|
@ -83,26 +82,29 @@ class TestModels(TestCase):
|
|||
)
|
||||
self.exportTest(toC(alexnet()), toC(x))
|
||||
|
||||
@unittest.skip("Waiting for https://github.com/pytorch/pytorch/pull/3100")
|
||||
def test_mnist(self):
|
||||
x = Variable(torch.randn(BATCH_SIZE, 1, 28, 28).fill_(1.0))
|
||||
self.exportTest(toC(MNIST()), toC(x))
|
||||
|
||||
@unittest.skip("This model takes too much memory")
|
||||
def test_vgg16(self):
|
||||
# VGG 16-layer model (configuration "D")
|
||||
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
||||
self.exportTest(toC(vgg16()), toC(x))
|
||||
|
||||
@unittest.skip("This model takes too much memory")
|
||||
def test_vgg16_bn(self):
|
||||
# VGG 16-layer model (configuration "D") with batch normalization
|
||||
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
||||
self.exportTest(toC(vgg16_bn()), toC(x))
|
||||
|
||||
@unittest.skip("This model takes too much memory")
|
||||
def test_vgg19(self):
|
||||
# VGG 19-layer model (configuration "E")
|
||||
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
||||
self.exportTest(toC(vgg19()), toC(x))
|
||||
|
||||
@unittest.skip("This model takes too much memory")
|
||||
def test_vgg19_bn(self):
|
||||
# VGG 19-layer model (configuration 'E') with batch normalization
|
||||
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
||||
|
|
@ -131,11 +133,10 @@ class TestModels(TestCase):
|
|||
sqnet_v1_1 = SqueezeNet(version=1.1)
|
||||
self.exportTest(toC(sqnet_v1_1), toC(x))
|
||||
|
||||
@unittest.skip("Temporary - waiting for https://github.com/onnx/onnx/pull/1773.")
|
||||
def test_densenet(self):
|
||||
# Densenet-121 model
|
||||
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
||||
self.exportTest(toC(densenet121()), toC(x))
|
||||
self.exportTest(toC(densenet121()), toC(x), rtol=1e-2, atol=1e-5)
|
||||
|
||||
def test_dcgan_netD(self):
|
||||
netD = _netD(1)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from test_pytorch_onnx_onnxruntime import run_model_test
|
|||
|
||||
|
||||
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
|
||||
opset_versions = opset_versions if opset_versions else [7, 8, 9, 10]
|
||||
opset_versions = opset_versions if opset_versions else [7, 8, 9, 10, 11, 12]
|
||||
for opset_version in opset_versions:
|
||||
self.opset_version = opset_version
|
||||
run_model_test(self, model, False,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import unittest
|
|||
'''Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]
|
||||
--no-onnx: no onnx python dependence
|
||||
--produce-onnx-test-data: generate onnx test data
|
||||
--accept: accept onnx updates and overwrite models
|
||||
'''
|
||||
|
||||
_onnx_test = False # flag to produce onnx test cases.
|
||||
|
|
|
|||
|
|
@ -11,8 +11,6 @@ from verify import verify
|
|||
|
||||
from test_pytorch_common import TestCase, run_tests
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestVerify(TestCase):
|
||||
maxDiff = None
|
||||
|
|
@ -98,7 +96,6 @@ class TestVerify(TestCase):
|
|||
x = torch.tensor([1, 2])
|
||||
self.assertVerifyExpectFail(MyModel(), x, backend)
|
||||
|
||||
@unittest.skip("Indexing is broken by #3725")
|
||||
def test_embedded_constant_difference(self):
|
||||
class MyModel(Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user