import numpy as np import unittest from caffe2.proto import caffe2_pb2 from pycaffe2 import core, workspace class TestWorkspace(unittest.TestCase): def setUp(self): self.net = core.Net("test-net") self.net.ConstantFill([], "testblob", shape=[1, 2, 3, 4], value=1.0) workspace.ResetWorkspace() def testRootFolder(self): self.assertEqual(workspace.ResetWorkspace(), True) self.assertEqual(workspace.RootFolder(), ".") self.assertEqual(workspace.ResetWorkspace("/tmp/caffe-workspace-test"), True) self.assertEqual(workspace.RootFolder(), "/tmp/caffe-workspace-test") def testWorkspaceHasBlobWithNonexistingName(self): self.assertEqual(workspace.HasBlob("non-existing"), False) def testRunOperatorOnce(self): self.assertEqual( workspace.RunOperatorOnce( self.net.Proto().op[0].SerializeToString()), True) self.assertEqual(workspace.HasBlob("testblob"), True) blobs = workspace.Blobs() self.assertEqual(len(blobs), 1) self.assertEqual(blobs[0], "testblob") def testRunNetOnce(self): self.assertEqual(workspace.RunNetOnce(self.net.Proto().SerializeToString()), True) self.assertEqual(workspace.HasBlob("testblob"), True) def testRunPlan(self): plan = core.Plan("test-plan") plan.AddNets([self.net]) plan.AddStep(core.ExecutionStep("test-step", self.net)) self.assertEqual(workspace.RunPlan(plan.Proto().SerializeToString()), True); self.assertEqual(workspace.HasBlob("testblob"), True) def testResetWorkspace(self): self.assertEqual(workspace.RunNetOnce(self.net.Proto().SerializeToString()), True) self.assertEqual(workspace.HasBlob("testblob"), True) self.assertEqual(workspace.ResetWorkspace(), True) self.assertEqual(workspace.HasBlob("testblob"), False) def testFetchFeedBlob(self): self.assertEqual(workspace.RunNetOnce(self.net.Proto().SerializeToString()), True) fetched = workspace.FetchBlob("testblob") # check if fetched is correct. self.assertEqual(fetched.shape, (1, 2, 3, 4)) np.testing.assert_array_equal(fetched, 1.0) fetched[:] = 2.0 self.assertEqual(workspace.FeedBlob("testblob", fetched), True) fetched_again = workspace.FetchBlob("testblob") self.assertEqual(fetched_again.shape, (1, 2, 3, 4)) np.testing.assert_array_equal(fetched_again, 2.0) class TestMultiWorkspaces(unittest.TestCase): def setUp(self): workspace.SwitchWorkspace("default") workspace.ResetWorkspace() def testCreateWorkspace(self): workspaces = workspace.Workspaces() self.assertEqual(len(workspaces), 1) self.assertEqual(workspaces[0], "default") self.net = core.Net("test-net") self.net.ConstantFill([], "testblob", shape=[1, 2, 3, 4], value=1.0) self.assertEqual( workspace.RunNetOnce(self.net.Proto().SerializeToString()), True) self.assertEqual(workspace.HasBlob("testblob"), True) self.assertEqual(workspace.SwitchWorkspace("test", True), True) self.assertEqual(workspace.HasBlob("testblob"), False) self.assertEqual(workspace.SwitchWorkspace("default"), True) self.assertEqual(workspace.HasBlob("testblob"), True) try: # The following should raise an error. workspace.SwitchWorkspace("non-existing") # so this should never happen. self.assertEqual(True, False) except RuntimeError: pass workspaces = workspace.Workspaces() self.assertEqual(len(workspaces), 2) workspaces.sort() self.assertEqual(workspaces[0], "default") self.assertEqual(workspaces[1], "test") if __name__ == '__main__': unittest.main()