mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Differential Revision: D10098106 Original commit changeset: 94ec7fc57c84 fbshipit-source-id: 38f729b0970618f38359797b806cbbcd865f4715
67 lines
2.5 KiB
Python
67 lines
2.5 KiB
Python
## @package workspace
|
|
# Module caffe2.python.workspace
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.proto import caffe2_pb2, torch_pb2
|
|
|
|
import caffe2.python._import_c_extension as C
|
|
|
|
|
|
def ArgumentToAttributeProto(arg):
|
|
serialized_arg = None
|
|
if hasattr(arg, 'SerializeToString') and callable(arg.SerializeToString):
|
|
serialized_arg = arg.SerializeToString()
|
|
elif isinstance(arg, bytes):
|
|
serialized_arg = arg
|
|
else:
|
|
raise ValueError('No SerializeToString method is detected. '
|
|
'neither arg is bytes.\ntype is {}'.format(type(arg)))
|
|
attr = torch_pb2.AttributeProto()
|
|
attr.ParseFromString(C.argument_to_attribute_proto(serialized_arg))
|
|
return attr
|
|
|
|
|
|
def AttributeProtoToArgument(attr):
|
|
serialized_attr = None
|
|
if hasattr(attr, 'SerializeToString') and callable(attr.SerializeToString):
|
|
serialized_attr = attr.SerializeToString()
|
|
elif isinstance(attr, bytes):
|
|
serialized_attr = attr
|
|
else:
|
|
raise ValueError('No SerializeToString method is detected. '
|
|
'neither attr is bytes.\ntype is {}'.format(type(attr)))
|
|
arg = caffe2_pb2.Argument()
|
|
arg.ParseFromString(C.attribute_proto_to_argument(serialized_attr))
|
|
return arg
|
|
|
|
|
|
def OperatorDefToNodeProto(op_def):
|
|
serialized_op_def = None
|
|
if hasattr(op_def, 'SerializeToString') and callable(op_def.SerializeToString):
|
|
serialized_op_def = op_def.SerializeToString()
|
|
elif isinstance(op_def, bytes):
|
|
serialized_op_def = op_def
|
|
else:
|
|
raise ValueError('No SerializeToString method is detected. '
|
|
'neither op_def is bytes.\ntype is {}'.format(type(op_def)))
|
|
node = torch_pb2.NodeProto()
|
|
node.ParseFromString(C.operator_def_to_node_proto(serialized_op_def))
|
|
return node
|
|
|
|
|
|
def NodeProtoToOperatorDef(node_proto):
|
|
serialized_node_proto = None
|
|
if hasattr(node_proto, 'SerializeToString') and callable(node_proto.SerializeToString):
|
|
serialized_node_proto = node_proto.SerializeToString()
|
|
elif isinstance(node_proto, bytes):
|
|
serialized_node_proto = node_proto
|
|
else:
|
|
raise ValueError('No SerializeToString method is detected. '
|
|
'neither node_proto is bytes.\ntype is {}'.format(type(node_proto)))
|
|
op_def = caffe2_pb2.OperatorDef()
|
|
op_def.ParseFromString(C.node_proto_to_operator_def(serialized_node_proto))
|
|
return op_def
|