from caffe2.proto import caffe2_pb2 from google.protobuf.message import DecodeError, Message from google.protobuf import text_format import collections import functools import numpy as np import sys if sys.version_info > (3,): # This is python 3. We will define a few stuff that we used. basestring = str long = int def CaffeBlobToNumpyArray(blob): if (blob.num != 0): # old style caffe blob. return (np.asarray(blob.data, dtype=np.float32) .reshape(blob.num, blob.channels, blob.height, blob.width)) else: # new style caffe blob. return (np.asarray(blob.data, dtype=np.float32) .reshape(blob.shape.dim)) def Caffe2TensorToNumpyArray(tensor): if tensor.data_type == caffe2_pb2.TensorProto.FLOAT: return np.asarray( tensor.float_data, dtype=np.float32).reshape(tensor.dims) elif tensor.data_type == caffe2_pb2.TensorProto.DOUBLE: return np.asarray( tensor.double_data, dtype=np.float64).reshape(tensor.dims) elif tensor.data_type == caffe2_pb2.TensorProto.INT32: return np.asarray( tensor.double_data, dtype=np.int).reshape(tensor.dims) else: # TODO: complete the data type. raise RuntimeError( "Tensor data type not supported yet: " + str(tensor.data_type)) def NumpyArrayToCaffe2Tensor(arr, name=None): tensor = caffe2_pb2.TensorProto() tensor.dims.extend(arr.shape) if name: tensor.name = name if arr.dtype == np.float32: tensor.data_type = caffe2_pb2.TensorProto.FLOAT tensor.float_data.extend(list(arr.flatten().astype(float))) elif arr.dtype == np.float64: tensor.data_type = caffe2_pb2.TensorProto.DOUBLE tensor.double_data.extend(list(arr.flatten().astype(np.float64))) elif arr.dtype == np.int: tensor.data_type = caffe2_pb2.TensorProto.INT32 tensor.int32_data.extend(list(arr.flatten().astype(np.int))) else: # TODO: complete the data type. raise RuntimeError( "Numpy data type not supported yet: " + str(arr.dtype)) return tensor def MakeArgument(key, value): """Makes an argument based on the value type.""" argument = caffe2_pb2.Argument() argument.name = key iterable = isinstance(value, collections.Iterable) if isinstance(value, np.ndarray): value = value.flatten().tolist() elif isinstance(value, np.generic): # convert numpy scalar to native python type value = np.asscalar(value) if type(value) is float: argument.f = value elif type(value) is int or type(value) is bool or type(value) is long: # We make a relaxation that a boolean variable will also be stored as # int. argument.i = value elif isinstance(value, basestring): argument.s = (value if type(value) is bytes else value.encode('utf-8')) elif isinstance(value, Message): argument.s = value.SerializeToString() elif iterable and all(type(v) in [float, np.float_] for v in value): argument.floats.extend(value) elif iterable and all(type(v) in [int, bool, long, np.int_] for v in value): argument.ints.extend(value) elif iterable and all(isinstance(v, basestring) for v in value): argument.strings.extend([ (v if type(v) is bytes else v.encode('utf-8')) for v in value]) elif iterable and all(isinstance(v, Message) for v in value): argument.strings.extend([v.SerializeToString() for v in value]) else: raise ValueError( "Unknown argument type: key=%s value=%s, value type=%s" % (key, str(value), str(type(value))) ) return argument def TryReadProtoWithClass(cls, s): """Reads a protobuffer with the given proto class. Inputs: cls: a protobuffer class. s: a string of either binary or text protobuffer content. Outputs: proto: the protobuffer of cls Throws: google.protobuf.message.DecodeError: if we cannot decode the message. """ obj = cls() try: text_format.Parse(s, obj) return obj except text_format.ParseError: obj.ParseFromString(s) return obj def GetContentFromProto(obj, function_map): """Gets a specific field from a protocol buffer that matches the given class """ for cls, func in function_map.items(): if type(obj) is cls: return func(obj) def GetContentFromProtoString(s, function_map): for cls, func in function_map.items(): try: obj = TryReadProtoWithClass(cls, s) return func(obj) except DecodeError: continue else: raise DecodeError("Cannot find a fit protobuffer class.") def ConvertProtoToBinary(proto_class, filename, out_filename): """Convert a text file of the given protobuf class to binary.""" proto = TryReadProtoWithClass(proto_class, open(filename).read()) with open(out_filename, 'w') as fid: fid.write(proto.SerializeToString()) class DebugMode(object): ''' This class allows to drop you into an interactive debugger if there is an unhandled exception in your python script Example of usage: def main(): # your code here pass if __name__ == '__main__': from caffe2.python.utils import DebugMode DebugMode.run(main) ''' @classmethod def run(cls, func): try: return func() except KeyboardInterrupt: raise except Exception: import pdb print( 'Entering interactive debugger. Type "bt" to print ' 'the full stacktrace. Type "help" to see command listing.') print(sys.exc_info()[1]) print pdb.post_mortem() sys.exit(1) raise def debug(f): ''' Use this method to decorate your function with DebugMode's functionality Example: @debug def test_foo(self): raise Exception("Bar") ''' @functools.wraps(f) def wrapper(*args, **kwargs): def func(): return f(*args, **kwargs) DebugMode.run(func) return wrapper