pytorch/caffe2/python/convert.py
Dmytro Dzhulgakov 1d3f650ce4 Revert D10098106: [pytorch][PR] [WIP] New version of PT1 model format
Differential Revision:
D10098106

Original commit changeset: 94ec7fc57c84

fbshipit-source-id: 38f729b0970618f38359797b806cbbcd865f4715
2018-10-02 00:43:40 -07:00

67 lines
2.5 KiB
Python

## @package workspace
# Module caffe2.python.workspace
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.proto import caffe2_pb2, torch_pb2
import caffe2.python._import_c_extension as C
def ArgumentToAttributeProto(arg):
serialized_arg = None
if hasattr(arg, 'SerializeToString') and callable(arg.SerializeToString):
serialized_arg = arg.SerializeToString()
elif isinstance(arg, bytes):
serialized_arg = arg
else:
raise ValueError('No SerializeToString method is detected. '
'neither arg is bytes.\ntype is {}'.format(type(arg)))
attr = torch_pb2.AttributeProto()
attr.ParseFromString(C.argument_to_attribute_proto(serialized_arg))
return attr
def AttributeProtoToArgument(attr):
serialized_attr = None
if hasattr(attr, 'SerializeToString') and callable(attr.SerializeToString):
serialized_attr = attr.SerializeToString()
elif isinstance(attr, bytes):
serialized_attr = attr
else:
raise ValueError('No SerializeToString method is detected. '
'neither attr is bytes.\ntype is {}'.format(type(attr)))
arg = caffe2_pb2.Argument()
arg.ParseFromString(C.attribute_proto_to_argument(serialized_attr))
return arg
def OperatorDefToNodeProto(op_def):
serialized_op_def = None
if hasattr(op_def, 'SerializeToString') and callable(op_def.SerializeToString):
serialized_op_def = op_def.SerializeToString()
elif isinstance(op_def, bytes):
serialized_op_def = op_def
else:
raise ValueError('No SerializeToString method is detected. '
'neither op_def is bytes.\ntype is {}'.format(type(op_def)))
node = torch_pb2.NodeProto()
node.ParseFromString(C.operator_def_to_node_proto(serialized_op_def))
return node
def NodeProtoToOperatorDef(node_proto):
serialized_node_proto = None
if hasattr(node_proto, 'SerializeToString') and callable(node_proto.SerializeToString):
serialized_node_proto = node_proto.SerializeToString()
elif isinstance(node_proto, bytes):
serialized_node_proto = node_proto
else:
raise ValueError('No SerializeToString method is detected. '
'neither node_proto is bytes.\ntype is {}'.format(type(node_proto)))
op_def = caffe2_pb2.OperatorDef()
op_def.ParseFromString(C.node_proto_to_operator_def(serialized_node_proto))
return op_def