from caffe2.python import core, workspace from caffe2.python import test_util as tu import caffe2.python.nomnigraph as ng from caffe2.python.nomnigraph_transformations import transpose_network import numpy as np from hypothesis import given import hypothesis.strategies as st class TestNomnigraphTransformations(tu.TestCase): def test_simple_replace(self): net = core.Net("name") net.FC(["X", "W"], ["Y"]) nn = ng.NNModule(net) fc = nn.controlFlow[0] add = nn.createNode(core.CreateOperator("Add", ["X"], ["Y"], engine="CUDNN")) nn.replaceNode(fc, add) nn.deleteNode(fc) # Test it out new_netdef = nn.convertToCaffe2Proto() workspace.ResetWorkspace() workspace.FeedBlob("X", np.array([1, 2, 3])) workspace.FeedBlob("W", np.array([1, 2, 3])) workspace.RunNetOnce(new_netdef) out = workspace.FetchBlob("Y") expected_out = np.array([2, 4, 6]) np.testing.assert_almost_equal(out, expected_out) def test_simple_rewire(self): net = core.Net("name") # Rewire this so that we get # c = Add(a, d) # e = Mul(c, b) # # if a = 1, b = 2, d = 3 # we get 8: (1 + 3) * 2 # as opposed to 7: 1 + (3 * 2) net.Mul(["a", "b"], ["c"]) net.Add(["c", "d"], ["e"]) nn = ng.NNModule(net) mul = nn.controlFlow[0] add = nn.controlFlow[1] a = mul.inputs[0] b = mul.inputs[1] c = mul.outputs[0] d = add.inputs[1] e = add.outputs[0] nn.deleteEdge(a, mul) nn.deleteEdge(b, mul) nn.deleteEdge(mul, c) nn.deleteEdge(c, add) nn.deleteEdge(d, add) nn.deleteEdge(add, e) nn.createEdge(a, add) nn.createEdge(d, add) nn.createEdge(add, c) nn.createEdge(c, mul) nn.createEdge(b, mul) nn.createEdge(mul, e) # Test it out new_netdef = nn.convertToCaffe2Proto() workspace.ResetWorkspace() workspace.FeedBlob("a", np.array([1, 1, 1])) workspace.FeedBlob("b", np.array([2, 2, 2])) workspace.FeedBlob("d", np.array([3, 3, 3])) workspace.RunNetOnce(new_netdef) out = workspace.FetchBlob("e") expected_out = np.array([8, 8, 8]) np.testing.assert_almost_equal(out, expected_out) @given( batch_size=st.integers(16, 20), channels=st.integers(1, 10), height=st.integers(10, 15), width=st.integers(10, 15), seed=st.integers(0, 65535), kernel=st.integers(3, 5), ) def test_transpose_network(self, batch_size, channels, height, width, seed, kernel): net = core.Net("net") net.Conv(["X", "w1", "b1"], ["c1"], stride=1, pad=0, kernel=kernel) net.Conv(["X", "w2", "b2"], ["c2"], stride=1, pad=0, kernel=kernel) # c1 and c2: batch_size, 2*channels, height - kernel + 1, width - kernel + 1 net.Conv(["c1", "w3", "b3"], ["c3"], stride=1, pad=0, kernel=kernel) net.Conv(["c1", "w4", "b4"], ["c4"], stride=1, pad=0, kernel=kernel) # c3 and c4: batch_size, 2*channels, height - 2*kernel + 2, width - 2*kernel + 2 net.Flatten(["c3"], "c3f") net.Flatten(["c4"], "c4f") net.Flatten(["X"], "Xf") net.Concat(["c3f", "c4f", "Xf"], ["out", "split_info"], axis=1, add_axis=0) np.random.seed(seed) workspace.ResetWorkspace() tu.randBlobFloat32("X", batch_size, channels, height, width) tu.randBlobsFloat32(["w1", "w2"], 2 * channels, channels, kernel, kernel) tu.randBlobsFloat32(["b1", "b2"], 2 * channels) tu.randBlobsFloat32(["w3", "w4"], 4 * channels, 2 * channels, kernel, kernel) tu.randBlobsFloat32(["b3", "b4"], 4 * channels) all_inp_names = ["X", "w1", "w2", "b1", "b2", "w3", "w4", "b3", "b4"] all_input = workspace.FetchBlobs(all_inp_names) workspace.RunNetOnce(net) preTransformC1 = workspace.FetchBlob("c1") preTransformC3 = workspace.FetchBlob("c3") preTransformOut = workspace.FetchBlob("out") nn = ng.NNModule(net) preTransformNumOperators = len(nn.operators) preTransformNumTensors = len(nn.tensors) transpose_network(nn) new_netdef = nn.convertToCaffe2Proto() postTransformNumOperators = len(nn.operators) postTransformNumTensors = len(nn.tensors) # The minimal number of additional operators and tensors is at least one # NCHW2NHWC operator and tensor for each channel-based input tensor # and a NHWC2NCHW operator and tensor for the output of each convolution # X, w1, w2, w3, w4 are channel-based inputs # c1, c2, c3, c4 are the outputs of convolutions # i.e. a total of 9. self.assertEqual(postTransformNumOperators, preTransformNumOperators + 9, "expected 9 additional operators") self.assertEqual(postTransformNumTensors, preTransformNumTensors + 9, "expected 9 additional tensors") workspace.ResetWorkspace() for name, val in zip(all_inp_names, all_input): workspace.FeedBlob(name, val) workspace.RunNetOnce(new_netdef) postTransformC1 = workspace.FetchBlob("c1") postTransformC3 = workspace.FetchBlob("c3") postTransformOut = workspace.FetchBlob("out") np.testing.assert_almost_equal(postTransformC1, preTransformC1, 1) np.testing.assert_almost_equal(postTransformC3, preTransformC3, 1) np.testing.assert_almost_equal(postTransformOut, preTransformOut, 1)