mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18598 ghimport-source-id: c74597e5e7437e94a43c163cee0639b20d0d0c6a Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18598 Turn on F401: Unused import warning.** This was requested by someone at Facebook; this lint is turned on for Facebook by default. "Sure, why not." I had to noqa a number of imports in __init__. Hypothetically we're supposed to use __all__ in this case, but I was too lazy to fix it. Left for future work. Be careful! flake8-2 and flake8-3 behave differently with respect to import resolution for # type: comments. flake8-3 will report an import unused; flake8-2 will not. For now, I just noqa'd all these sites. All the changes were done by hand. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: D14687478 fbshipit-source-id: 30d532381e914091aadfa0d2a5a89404819663e3
152 lines
5.2 KiB
Python
152 lines
5.2 KiB
Python
from torchvision.models.alexnet import alexnet
|
|
from torchvision.models.inception import inception_v3
|
|
from torchvision.models.densenet import densenet121
|
|
from torchvision.models.resnet import resnet50
|
|
from torchvision.models.vgg import vgg16, vgg16_bn, vgg19, vgg19_bn
|
|
|
|
from model_defs.mnist import MNIST
|
|
from model_defs.squeezenet import SqueezeNet
|
|
from model_defs.super_resolution import SuperResolutionNet
|
|
from model_defs.srresnet import SRResNet
|
|
from model_defs.dcgan import _netD, _netG, weights_init, bsz, imgsz, nz
|
|
from model_defs.op_test import DummyNet, ConcatNet, PermuteNet, PReluNet
|
|
|
|
from test_pytorch_common import TestCase, run_tests, skipIfNoLapack
|
|
|
|
import torch
|
|
import torch.onnx
|
|
import torch.onnx.utils
|
|
from torch.autograd import Variable
|
|
from torch.onnx import OperatorExportTypes
|
|
|
|
import unittest
|
|
|
|
import caffe2.python.onnx.backend as backend
|
|
|
|
from verify import verify
|
|
|
|
if torch.cuda.is_available():
|
|
def toC(x):
|
|
return x.cuda()
|
|
else:
|
|
def toC(x):
|
|
return x
|
|
|
|
BATCH_SIZE = 2
|
|
|
|
|
|
class TestModels(TestCase):
|
|
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7):
|
|
trace = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX)
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
verify(model, inputs, backend, rtol=rtol, atol=atol)
|
|
|
|
def test_ops(self):
|
|
x = Variable(
|
|
torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
|
|
)
|
|
self.exportTest(toC(DummyNet()), toC(x))
|
|
|
|
def test_prelu(self):
|
|
x = Variable(
|
|
torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
|
|
)
|
|
self.exportTest(PReluNet(), x)
|
|
|
|
def test_concat(self):
|
|
input_a = Variable(torch.randn(BATCH_SIZE, 3))
|
|
input_b = Variable(torch.randn(BATCH_SIZE, 3))
|
|
inputs = ((toC(input_a), toC(input_b)), )
|
|
self.exportTest(toC(ConcatNet()), inputs)
|
|
|
|
def test_permute(self):
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 10, 12))
|
|
self.exportTest(PermuteNet(), x)
|
|
|
|
@unittest.skip("This model takes too much memory")
|
|
def test_srresnet(self):
|
|
x = Variable(torch.randn(1, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(SRResNet(rescale_factor=4, n_filters=64, n_blocks=8)), toC(x))
|
|
|
|
@skipIfNoLapack
|
|
def test_super_resolution(self):
|
|
x = Variable(
|
|
torch.randn(BATCH_SIZE, 1, 224, 224).fill_(1.0)
|
|
)
|
|
self.exportTest(toC(SuperResolutionNet(upscale_factor=3)), toC(x), atol=1e-6)
|
|
|
|
def test_alexnet(self):
|
|
x = Variable(
|
|
torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
|
|
)
|
|
self.exportTest(toC(alexnet()), toC(x))
|
|
|
|
@unittest.skip("Waiting for https://github.com/pytorch/pytorch/pull/3100")
|
|
def test_mnist(self):
|
|
x = Variable(torch.randn(BATCH_SIZE, 1, 28, 28).fill_(1.0))
|
|
self.exportTest(toC(MNIST()), toC(x))
|
|
|
|
def test_vgg16(self):
|
|
# VGG 16-layer model (configuration "D")
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(vgg16()), toC(x))
|
|
|
|
def test_vgg16_bn(self):
|
|
# VGG 16-layer model (configuration "D") with batch normalization
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(vgg16_bn()), toC(x))
|
|
|
|
def test_vgg19(self):
|
|
# VGG 19-layer model (configuration "E")
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(vgg19()), toC(x))
|
|
|
|
def test_vgg19_bn(self):
|
|
# VGG 19-layer model (configuration 'E') with batch normalization
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(vgg19_bn()), toC(x))
|
|
|
|
def test_resnet(self):
|
|
# ResNet50 model
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(resnet50()), toC(x), atol=1e-6)
|
|
|
|
def test_inception(self):
|
|
x = Variable(
|
|
torch.randn(BATCH_SIZE, 3, 299, 299) + 1.)
|
|
self.exportTest(toC(inception_v3()), toC(x))
|
|
|
|
def test_squeezenet(self):
|
|
# SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and
|
|
# <0.5MB model size
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
sqnet_v1_0 = SqueezeNet(version=1.1)
|
|
self.exportTest(toC(sqnet_v1_0), toC(x))
|
|
|
|
# SqueezeNet 1.1 has 2.4x less computation and slightly fewer params
|
|
# than SqueezeNet 1.0, without sacrificing accuracy.
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
sqnet_v1_1 = SqueezeNet(version=1.1)
|
|
self.exportTest(toC(sqnet_v1_1), toC(x))
|
|
|
|
@unittest.skip("Temporary - waiting for https://github.com/onnx/onnx/pull/1773.")
|
|
def test_densenet(self):
|
|
# Densenet-121 model
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(densenet121()), toC(x))
|
|
|
|
def test_dcgan_netD(self):
|
|
netD = _netD(1)
|
|
netD.apply(weights_init)
|
|
input = Variable(torch.Tensor(bsz, 3, imgsz, imgsz).normal_(0, 1))
|
|
self.exportTest(toC(netD), toC(input))
|
|
|
|
def test_dcgan_netG(self):
|
|
netG = _netG(1)
|
|
netG.apply(weights_init)
|
|
input = Variable(torch.Tensor(bsz, nz, 1, 1).normal_(0, 1))
|
|
self.exportTest(toC(netG), toC(input))
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|