mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44595 Reviewed By: seemethere Differential Revision: D23670280 Pulled By: walterddr fbshipit-source-id: b32633912f6c8b4606be36b90f901e636567b355
70 lines
2.5 KiB
Python
70 lines
2.5 KiB
Python
# Some standard imports
|
|
import numpy as np
|
|
from torch import nn
|
|
import torch.onnx
|
|
import torch.nn.init as init
|
|
from caffe2.python.model_helper import ModelHelper
|
|
from pytorch_helper import PyTorchModule
|
|
import unittest
|
|
from caffe2.python.core import workspace
|
|
|
|
from test_pytorch_common import skipIfNoLapack
|
|
|
|
|
|
class TestCaffe2Backend(unittest.TestCase):
|
|
|
|
@skipIfNoLapack
|
|
@unittest.skip("test broken because Lapack was always missing.")
|
|
def test_helper(self):
|
|
|
|
class SuperResolutionNet(nn.Module):
|
|
def __init__(self, upscale_factor, inplace=False):
|
|
super(SuperResolutionNet, self).__init__()
|
|
|
|
self.relu = nn.ReLU(inplace=inplace)
|
|
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
|
|
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
|
|
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
|
|
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
|
|
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
|
|
|
self._initialize_weights()
|
|
|
|
def forward(self, x):
|
|
x = self.relu(self.conv1(x))
|
|
x = self.relu(self.conv2(x))
|
|
x = self.relu(self.conv3(x))
|
|
x = self.pixel_shuffle(self.conv4(x))
|
|
return x
|
|
|
|
def _initialize_weights(self):
|
|
init.orthogonal(self.conv1.weight, init.calculate_gain('relu'))
|
|
init.orthogonal(self.conv2.weight, init.calculate_gain('relu'))
|
|
init.orthogonal(self.conv3.weight, init.calculate_gain('relu'))
|
|
init.orthogonal(self.conv4.weight)
|
|
|
|
torch_model = SuperResolutionNet(upscale_factor=3)
|
|
|
|
fake_input = torch.randn(1, 1, 224, 224, requires_grad=True)
|
|
|
|
# use ModelHelper to create a C2 net
|
|
helper = ModelHelper(name="test_model")
|
|
start = helper.Sigmoid(['the_input'])
|
|
# Embed the ONNX-converted pytorch net inside it
|
|
toutput, = PyTorchModule(helper, torch_model, (fake_input,), [start])
|
|
output = helper.Sigmoid(toutput)
|
|
|
|
workspace.RunNetOnce(helper.InitProto())
|
|
workspace.FeedBlob('the_input', fake_input.data.numpy())
|
|
# print([ k for k in workspace.blobs ])
|
|
workspace.RunNetOnce(helper.Proto())
|
|
c2_out = workspace.FetchBlob(str(output))
|
|
|
|
torch_out = torch.sigmoid(torch_model(torch.sigmoid(fake_input)))
|
|
|
|
np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), c2_out, decimal=3)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|