mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: It helps to develop scripts locally (when working outside of Flow). One doesn't have to rerun the script in order to catch exception in the debugger / add a print statement. (Flow does this kind of thing automatically) Usage example: ``` if __name__ == '__main__': workspace.GlobalInit(['caffe2', '--caffe2_log_level=2']) from caffe2.python.utils import DebugMode DebugMode.enable() DebugMode.run(main) ``` Reviewed By: Yangqing Differential Revision: D4424096 fbshipit-source-id: 73f418c80f581820e70139df7e166981e4d8c55f
173 lines
5.0 KiB
Python
173 lines
5.0 KiB
Python
from caffe2.proto import caffe2_pb2
|
|
from google.protobuf.message import DecodeError, Message
|
|
from google.protobuf import text_format
|
|
import collections
|
|
import numpy as np
|
|
import sys
|
|
|
|
|
|
if sys.version_info > (3,):
|
|
# This is python 3. We will define a few stuff that we used.
|
|
basestring = str
|
|
long = int
|
|
|
|
|
|
def CaffeBlobToNumpyArray(blob):
|
|
if (blob.num != 0):
|
|
# old style caffe blob.
|
|
return (np.asarray(blob.data, dtype=np.float32)
|
|
.reshape(blob.num, blob.channels, blob.height, blob.width))
|
|
else:
|
|
# new style caffe blob.
|
|
return (np.asarray(blob.data, dtype=np.float32)
|
|
.reshape(blob.shape.dim))
|
|
|
|
|
|
def Caffe2TensorToNumpyArray(tensor):
|
|
return np.asarray(tensor.float_data, dtype=np.float32).reshape(tensor.dims)
|
|
|
|
|
|
def NumpyArrayToCaffe2Tensor(arr, name):
|
|
tensor = caffe2_pb2.TensorProto()
|
|
tensor.data_type = caffe2_pb2.TensorProto.FLOAT
|
|
tensor.name = name
|
|
tensor.dims.extend(arr.shape)
|
|
tensor.float_data.extend(list(arr.flatten().astype(float)))
|
|
return tensor
|
|
|
|
|
|
def MakeArgument(key, value):
|
|
"""Makes an argument based on the value type."""
|
|
argument = caffe2_pb2.Argument()
|
|
argument.name = key
|
|
iterable = isinstance(value, collections.Iterable)
|
|
|
|
if isinstance(value, np.ndarray):
|
|
value = value.flatten().tolist()
|
|
elif isinstance(value, np.generic):
|
|
# convert numpy scalar to native python type
|
|
value = np.asscalar(value)
|
|
|
|
if type(value) is float:
|
|
argument.f = value
|
|
elif type(value) is int or type(value) is bool or type(value) is long:
|
|
# We make a relaxation that a boolean variable will also be stored as
|
|
# int.
|
|
argument.i = value
|
|
elif isinstance(value, basestring):
|
|
argument.s = (value if type(value) is bytes
|
|
else value.encode('utf-8'))
|
|
elif isinstance(value, Message):
|
|
argument.s = value.SerializeToString()
|
|
elif iterable and all(type(v) in [float, np.float_] for v in value):
|
|
argument.floats.extend(value)
|
|
elif iterable and all(type(v) in [int, bool, long, np.int_] for v in value):
|
|
argument.ints.extend(value)
|
|
elif iterable and all(isinstance(v, basestring) for v in value):
|
|
argument.strings.extend([
|
|
(v if type(v) is bytes else v.encode('utf-8')) for v in value])
|
|
elif iterable and all(isinstance(v, Message) for v in value):
|
|
argument.strings.extend([v.SerializeToString() for v in value])
|
|
else:
|
|
raise ValueError(
|
|
"Unknown argument type: key=%s value=%s, value type=%s" %
|
|
(key, str(value), str(type(value)))
|
|
)
|
|
return argument
|
|
|
|
|
|
def TryReadProtoWithClass(cls, s):
|
|
"""Reads a protobuffer with the given proto class.
|
|
|
|
Inputs:
|
|
cls: a protobuffer class.
|
|
s: a string of either binary or text protobuffer content.
|
|
|
|
Outputs:
|
|
proto: the protobuffer of cls
|
|
|
|
Throws:
|
|
google.protobuf.message.DecodeError: if we cannot decode the message.
|
|
"""
|
|
obj = cls()
|
|
try:
|
|
text_format.Parse(s, obj)
|
|
return obj
|
|
except text_format.ParseError:
|
|
obj.ParseFromString(s)
|
|
return obj
|
|
|
|
|
|
def GetContentFromProto(obj, function_map):
|
|
"""Gets a specific field from a protocol buffer that matches the given class
|
|
"""
|
|
for cls, func in function_map.items():
|
|
if type(obj) is cls:
|
|
return func(obj)
|
|
|
|
|
|
def GetContentFromProtoString(s, function_map):
|
|
for cls, func in function_map.items():
|
|
try:
|
|
obj = TryReadProtoWithClass(cls, s)
|
|
return func(obj)
|
|
except DecodeError:
|
|
continue
|
|
else:
|
|
raise DecodeError("Cannot find a fit protobuffer class.")
|
|
|
|
|
|
def ConvertProtoToBinary(proto_class, filename, out_filename):
|
|
"""Convert a text file of the given protobuf class to binary."""
|
|
proto = TryReadProtoWithClass(proto_class, open(filename).read())
|
|
with open(out_filename, 'w') as fid:
|
|
fid.write(proto.SerializeToString())
|
|
|
|
|
|
class DebugMode(object):
|
|
'''
|
|
This class allows to drop you into an interactive debugger
|
|
if there is an unhandled exception in your python script
|
|
|
|
Example of usage:
|
|
|
|
def main():
|
|
# your code here
|
|
pass
|
|
|
|
if __name__ == '__main__':
|
|
from caffe2.python.utils import DebugMode
|
|
DebugMode.enable()
|
|
DebugMode.run(main)
|
|
'''
|
|
|
|
_enabled = False
|
|
|
|
@classmethod
|
|
def enable(cls):
|
|
cls._enabled = True
|
|
|
|
@classmethod
|
|
def disable(cls):
|
|
cls._enabled = False
|
|
|
|
@classmethod
|
|
def run(cls, func):
|
|
try:
|
|
return func()
|
|
except KeyboardInterrupt:
|
|
raise
|
|
except Exception:
|
|
if cls._enabled:
|
|
import pdb
|
|
|
|
print(
|
|
'Entering interactive debugger. Type "bt" to print '
|
|
'the full stacktrace. Type "help" to see command listing.')
|
|
print(sys.exc_info()[1])
|
|
print
|
|
|
|
pdb.post_mortem()
|
|
sys.exit(1)
|
|
raise
|