mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Differential Revision: D10098106 Original commit changeset: 94ec7fc57c84 fbshipit-source-id: 38f729b0970618f38359797b806cbbcd865f4715
251 lines
10 KiB
Python
251 lines
10 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import convert, workspace
|
|
from caffe2.proto import caffe2_pb2, torch_pb2
|
|
import unittest
|
|
import numpy as np
|
|
|
|
class TestOperator(unittest.TestCase):
|
|
def setUp(self):
|
|
workspace.ResetWorkspace()
|
|
|
|
def testArgument2AttributeProto(self):
|
|
arg_f = caffe2_pb2.Argument()
|
|
arg_f.name = "TestArgF"
|
|
arg_f.f = 10.0
|
|
attr_f = convert.ArgumentToAttributeProto(arg_f)
|
|
self.assertEqual(attr_f.name, arg_f.name)
|
|
self.assertEqual(attr_f.f, arg_f.f)
|
|
|
|
arg_i = caffe2_pb2.Argument()
|
|
arg_i.name = "TestArgI"
|
|
arg_i.i = 100
|
|
attr_i = convert.ArgumentToAttributeProto(arg_i)
|
|
self.assertEqual(attr_i.name, arg_i.name)
|
|
self.assertEqual(attr_i.i, arg_i.i)
|
|
|
|
arg_s = caffe2_pb2.Argument()
|
|
arg_s.name = "TestArgS"
|
|
arg_s.s = "TestS".encode("utf-8")
|
|
attr_s = convert.ArgumentToAttributeProto(arg_s)
|
|
self.assertEqual(attr_s.name, arg_s.name)
|
|
self.assertEqual(attr_s.s, arg_s.s)
|
|
|
|
# TODO: test net arg
|
|
|
|
arg_floats = caffe2_pb2.Argument()
|
|
arg_floats.name = "TestArgFloats"
|
|
arg_floats.floats.extend([10.0, 11.0, 12.0])
|
|
attr_floats = convert.ArgumentToAttributeProto(arg_floats)
|
|
self.assertEqual(attr_floats.name, arg_floats.name)
|
|
self.assertEqual(attr_floats.floats, arg_floats.floats)
|
|
|
|
arg_ints = caffe2_pb2.Argument()
|
|
arg_ints.name = "TestArgInts"
|
|
arg_ints.ints.extend([100, 101, 102])
|
|
attr_ints = convert.ArgumentToAttributeProto(arg_ints)
|
|
self.assertEqual(attr_ints.name, arg_ints.name)
|
|
self.assertEqual(attr_ints.ints, arg_ints.ints)
|
|
|
|
arg_strings = caffe2_pb2.Argument()
|
|
arg_strings.name = "TestArgStrings"
|
|
arg_strings.strings.extend([
|
|
"TestStrings1".encode("utf-8"),
|
|
"TestStrings2".encode("utf-8"),
|
|
])
|
|
attr_strings = convert.ArgumentToAttributeProto(arg_strings)
|
|
self.assertEqual(attr_strings.name, arg_strings.name)
|
|
self.assertEqual(attr_strings.strings, arg_strings.strings)
|
|
|
|
# TODO: test nets arg
|
|
|
|
def testAttributeProto2Argument(self):
|
|
attr_f = torch_pb2.AttributeProto()
|
|
attr_f.type = torch_pb2.AttributeProto.FLOAT
|
|
attr_f.name = "TestAttrF"
|
|
attr_f.f = 10.0
|
|
arg_f = convert.AttributeProtoToArgument(attr_f)
|
|
self.assertEqual(arg_f.name, attr_f.name)
|
|
self.assertEqual(arg_f.f, attr_f.f)
|
|
|
|
attr_i = torch_pb2.AttributeProto()
|
|
attr_i.type = torch_pb2.AttributeProto.INT
|
|
attr_i.name = "TestArgI"
|
|
attr_i.i = 100
|
|
arg_i = convert.AttributeProtoToArgument(attr_i)
|
|
self.assertEqual(arg_i.name, attr_i.name)
|
|
self.assertEqual(arg_i.i, attr_i.i)
|
|
|
|
attr_s = torch_pb2.AttributeProto()
|
|
attr_s.type = torch_pb2.AttributeProto.STRING
|
|
attr_s.name = "TestArgS"
|
|
attr_s.s = "TestS".encode("utf-8")
|
|
arg_s = convert.AttributeProtoToArgument(attr_s)
|
|
self.assertEqual(arg_s.name, attr_s.name)
|
|
self.assertEqual(arg_s.s, attr_s.s)
|
|
|
|
# TODO: test graph attribute
|
|
|
|
attr_floats = torch_pb2.AttributeProto()
|
|
attr_floats.type = torch_pb2.AttributeProto.FLOATS
|
|
attr_floats.name = "TestAttrFloats"
|
|
attr_floats.floats.extend([10.0, 11.0, 12.0])
|
|
arg_floats = convert.AttributeProtoToArgument(attr_floats)
|
|
self.assertEqual(arg_floats.name, attr_floats.name)
|
|
self.assertEqual(arg_floats.floats, attr_floats.floats)
|
|
|
|
attr_ints = torch_pb2.AttributeProto()
|
|
attr_ints.type = torch_pb2.AttributeProto.INTS
|
|
attr_ints.name = "TestArgInts"
|
|
attr_ints.ints.extend([100, 101, 102])
|
|
arg_ints = convert.AttributeProtoToArgument(attr_ints)
|
|
self.assertEqual(arg_ints.name, attr_ints.name)
|
|
self.assertEqual(arg_ints.ints, attr_ints.ints)
|
|
|
|
attr_strings = torch_pb2.AttributeProto()
|
|
attr_strings.type = torch_pb2.AttributeProto.STRINGS
|
|
attr_strings.name = "TestArgStrings"
|
|
attr_strings.strings.extend([
|
|
"TestStrings1".encode("utf-8"),
|
|
"TestStrings2".encode("utf-8"),
|
|
])
|
|
arg_strings = convert.AttributeProtoToArgument(attr_strings)
|
|
self.assertEqual(arg_strings.name, attr_strings.name)
|
|
self.assertEqual(arg_strings.strings, attr_strings.strings)
|
|
|
|
# TODO: test graphs attribute
|
|
|
|
|
|
def testOperatorDef2NodeProto(self):
|
|
op_def = caffe2_pb2.OperatorDef()
|
|
op_def.input.extend(["A", "B", "C"])
|
|
op_def.output.extend(["X", "Y"])
|
|
op_def.name = "TestOpName"
|
|
op_def.type = "TestOp"
|
|
arg1 = caffe2_pb2.Argument()
|
|
arg1.name = "TestArg1"
|
|
arg1.i = 1
|
|
arg2 = caffe2_pb2.Argument()
|
|
arg2.name = "TestArg2"
|
|
arg1.s = "TestInfo".encode("utf-8")
|
|
op_def.arg.extend([arg1, arg2])
|
|
op_def.device_option.CopyFrom(caffe2_pb2.DeviceOption())
|
|
op_def.engine = "TestEngine".encode("utf-8")
|
|
op_def.control_input.extend(["input1", "input2"])
|
|
op_def.is_gradient_op = True
|
|
op_def.debug_info = "TestDebugInfo"
|
|
|
|
node = convert.OperatorDefToNodeProto(op_def)
|
|
|
|
self.assertEqual(node.input, op_def.input)
|
|
self.assertEqual(node.output, op_def.output)
|
|
self.assertEqual(node.name, op_def.name)
|
|
self.assertEqual(node.op_type, op_def.type)
|
|
self.assertEqual(node.attribute[0].name, op_def.arg[0].name)
|
|
self.assertEqual(node.attribute[1].name, op_def.arg[1].name)
|
|
self.assertEqual(node.device_option, op_def.device_option)
|
|
node_engine = [a.s.decode("utf-8") for a in node.annotations if a.name == "engine"][0]
|
|
self.assertEqual(node_engine, op_def.engine)
|
|
node_control_input = [a.strings for a in node.annotations if a.name == "control_input"][0]
|
|
self.assertEqual(len(node_control_input), len(op_def.control_input))
|
|
for x, y in zip(node_control_input, op_def.control_input):
|
|
self.assertEqual(x.decode("utf-8"), y)
|
|
self.assertEqual(node.doc_string, op_def.debug_info)
|
|
node_is_gradient_op = [a.i for a in node.annotations if a.name == "is_gradient_op"][0]
|
|
self.assertEqual(node_is_gradient_op, int(op_def.is_gradient_op))
|
|
|
|
def testNodeProto2OperatorDef(self):
|
|
node = torch_pb2.NodeProto()
|
|
node.input.extend(["A", "B", "C"])
|
|
node.output.extend(["X", "Y"])
|
|
node.name = "TestOpName"
|
|
node.op_type = "TestOp"
|
|
attr1 = torch_pb2.AttributeProto()
|
|
attr1.name = "TestAttr1"
|
|
attr1.type = torch_pb2.AttributeProto.STRING
|
|
attr1.s = "TestInfo".encode("utf-8")
|
|
attr2 = torch_pb2.AttributeProto()
|
|
attr2.name = "TestAttr2"
|
|
attr2.type = torch_pb2.AttributeProto.INT
|
|
attr2.i = 10
|
|
node.attribute.extend([attr1, attr2])
|
|
node.device_option.CopyFrom(caffe2_pb2.DeviceOption())
|
|
anno1 = torch_pb2.AttributeProto()
|
|
anno1.name = "engine"
|
|
anno1.type = torch_pb2.AttributeProto.STRING
|
|
anno1.s = "TestEngine".encode("utf-8")
|
|
anno2 = torch_pb2.AttributeProto()
|
|
anno2.name = "control_input"
|
|
anno2.type = torch_pb2.AttributeProto.STRINGS
|
|
anno2.strings.extend(["input1".encode("utf-8"), "input2".encode("utf-8")])
|
|
anno3 = torch_pb2.AttributeProto()
|
|
anno3.name = "is_gradient_op"
|
|
anno3.type = torch_pb2.AttributeProto.INT
|
|
anno3.i = 1
|
|
node.annotations.extend([anno1, anno2, anno3])
|
|
node.doc_string = "TestDocString".encode("utf-8")
|
|
|
|
op_def = convert.NodeProtoToOperatorDef(node)
|
|
|
|
self.assertEqual(op_def.input, node.input)
|
|
self.assertEqual(op_def.output, node.output)
|
|
self.assertEqual(op_def.name, node.name)
|
|
self.assertEqual(op_def.type, node.op_type)
|
|
self.assertEqual(op_def.arg[0].name, node.attribute[0].name)
|
|
self.assertEqual(op_def.arg[1].name, node.attribute[1].name)
|
|
self.assertEqual(op_def.device_option, node.device_option)
|
|
node_engine = [a.s for a in node.annotations if a.name == "engine"][0]
|
|
self.assertEqual(op_def.engine, node_engine.decode("utf-8"))
|
|
node_control_input = [a.strings for a in node.annotations if a.name == "control_input"][0]
|
|
for x, y in zip(op_def.control_input, node_control_input):
|
|
self.assertEqual(x, y.decode("utf-8"))
|
|
self.assertEqual(op_def.debug_info, node.doc_string)
|
|
node_is_gradient_op = [a.i for a in node.annotations if a.name == "is_gradient_op"][0]
|
|
self.assertEqual(int(op_def.is_gradient_op), node_is_gradient_op)
|
|
|
|
def testEnd2End(self):
|
|
op_def = caffe2_pb2.OperatorDef()
|
|
op_def.type = "Add"
|
|
op_def.input.extend(["input1"])
|
|
op_def.input.extend(["input2"])
|
|
op_def.output.extend(["output1"])
|
|
node = convert.OperatorDefToNodeProto(op_def)
|
|
|
|
input1 = np.random.randn(1, 3, 1, 5).astype(np.float32)
|
|
input2 = np.random.randn(2, 1, 4, 1).astype(np.float32)
|
|
ref_output1 = input1 + input2
|
|
workspace.FeedBlob("input1", input1)
|
|
workspace.FeedBlob("input2", input2)
|
|
self.assertEqual(workspace.RunOperatorOnce(node.SerializeToString(), legacy_proto=False), True)
|
|
|
|
self.assertEqual(workspace.HasBlob("output1"), True)
|
|
fetched_back = workspace.FetchBlob("output1")
|
|
np.testing.assert_array_equal(fetched_back, ref_output1)
|
|
|
|
def testRoundTrip(self):
|
|
op_def = caffe2_pb2.OperatorDef()
|
|
op_def.type = "Add"
|
|
op_def.input.extend(["input1"])
|
|
op_def.input.extend(["input2"])
|
|
op_def.output.extend(["output1"])
|
|
node = convert.OperatorDefToNodeProto(op_def)
|
|
new_op_def = convert.NodeProtoToOperatorDef(node)
|
|
|
|
input1 = np.random.randn(1, 3, 1, 5).astype(np.float32)
|
|
input2 = np.random.randn(2, 1, 4, 1).astype(np.float32)
|
|
ref_output1 = input1 + input2
|
|
workspace.FeedBlob("input1", input1)
|
|
workspace.FeedBlob("input2", input2)
|
|
self.assertEqual(workspace.RunOperatorOnce(new_op_def.SerializeToString()), True)
|
|
|
|
self.assertEqual(workspace.HasBlob("output1"), True)
|
|
fetched_back = workspace.FetchBlob("output1")
|
|
np.testing.assert_array_equal(fetched_back, ref_output1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|