pytorch/caffe2/python/utils.py
Alexander Sidorov e0e124e617 Fix RNN scoping situation
Summary:
There is a long lasting problem of scoping which was introduced in original python wrappers early in H1. Basically each RNNCell implemented has to manually scope outputs of each of the operators. If somebody forgets, then there could be weird bugs with layers etc.

Approach is the following. User has to explicitly specify current scope when using  apply_over_sequence function and others if the function is going to be called several times (like for stacking layers). This way we use Caffe2 native scoping approach instead of inventing one extra API people have to use (i.e. passing scope name as an argument to the RNNCell constructor).
Closes https://github.com/caffe2/caffe2/pull/1681

Differential Revision: D6777536

Pulled By: salexspb

fbshipit-source-id: 73d860b8d4857589e04bdea5a6fcd3080d68427c
2018-02-07 17:35:29 -08:00

316 lines
11 KiB
Python

# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
## @package utils
# Module caffe2.python.utils
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.proto import caffe2_pb2
from future.utils import viewitems
from google.protobuf.message import DecodeError, Message
from google.protobuf import text_format
import sys
import collections
import functools
import numpy as np
from six import integer_types, binary_type, text_type
def CaffeBlobToNumpyArray(blob):
if (blob.num != 0):
# old style caffe blob.
return (np.asarray(blob.data, dtype=np.float32)
.reshape(blob.num, blob.channels, blob.height, blob.width))
else:
# new style caffe blob.
return (np.asarray(blob.data, dtype=np.float32)
.reshape(blob.shape.dim))
def Caffe2TensorToNumpyArray(tensor):
if tensor.data_type == caffe2_pb2.TensorProto.FLOAT:
return np.asarray(
tensor.float_data, dtype=np.float32).reshape(tensor.dims)
elif tensor.data_type == caffe2_pb2.TensorProto.DOUBLE:
return np.asarray(
tensor.double_data, dtype=np.float64).reshape(tensor.dims)
elif tensor.data_type == caffe2_pb2.TensorProto.INT32:
return np.asarray(
tensor.int32_data, dtype=np.int).reshape(tensor.dims) # pb.INT32=>np.int use int32_data
elif tensor.data_type == caffe2_pb2.TensorProto.INT16:
return np.asarray(
tensor.int32_data, dtype=np.int16).reshape(tensor.dims) # pb.INT16=>np.int16 use int32_data
elif tensor.data_type == caffe2_pb2.TensorProto.UINT16:
return np.asarray(
tensor.int32_data, dtype=np.uint16).reshape(tensor.dims) # pb.UINT16=>np.uint16 use int32_data
elif tensor.data_type == caffe2_pb2.TensorProto.INT8:
return np.asarray(
tensor.int32_data, dtype=np.int8).reshape(tensor.dims) # pb.INT8=>np.int8 use int32_data
elif tensor.data_type == caffe2_pb2.TensorProto.UINT8:
return np.asarray(
tensor.int32_data, dtype=np.uint8).reshape(tensor.dims) # pb.UINT8=>np.uint8 use int32_data
else:
# TODO: complete the data type: bool, float16, byte, int64, string
raise RuntimeError(
"Tensor data type not supported yet: " + str(tensor.data_type))
def NumpyArrayToCaffe2Tensor(arr, name=None):
tensor = caffe2_pb2.TensorProto()
tensor.dims.extend(arr.shape)
if name:
tensor.name = name
if arr.dtype == np.float32:
tensor.data_type = caffe2_pb2.TensorProto.FLOAT
tensor.float_data.extend(list(arr.flatten().astype(float)))
elif arr.dtype == np.float64:
tensor.data_type = caffe2_pb2.TensorProto.DOUBLE
tensor.double_data.extend(list(arr.flatten().astype(np.float64)))
elif arr.dtype == np.int or arr.dtype == np.int32:
tensor.data_type = caffe2_pb2.TensorProto.INT32
tensor.int32_data.extend(list(arr.flatten().astype(np.int)))
elif arr.dtype == np.int16:
tensor.data_type = caffe2_pb2.TensorProto.INT16
tensor.int32_data.extend(list(arr.flatten().astype(np.int16))) # np.int16=>pb.INT16 use int32_data
elif arr.dtype == np.uint16:
tensor.data_type = caffe2_pb2.TensorProto.UINT16
tensor.int32_data.extend(list(arr.flatten().astype(np.uint16))) # np.uint16=>pb.UNIT16 use int32_data
elif arr.dtype == np.int8:
tensor.data_type = caffe2_pb2.TensorProto.INT8
tensor.int32_data.extend(list(arr.flatten().astype(np.int8))) # np.int8=>pb.INT8 use int32_data
elif arr.dtype == np.uint8:
tensor.data_type = caffe2_pb2.TensorProto.UINT8
tensor.int32_data.extend(list(arr.flatten().astype(np.uint8))) # np.uint8=>pb.UNIT8 use int32_data
else:
# TODO: complete the data type: bool, float16, byte, int64, string
raise RuntimeError(
"Numpy data type not supported yet: " + str(arr.dtype))
return tensor
def MakeArgument(key, value):
"""Makes an argument based on the value type."""
argument = caffe2_pb2.Argument()
argument.name = key
iterable = isinstance(value, collections.Iterable)
# Fast tracking common use case where a float32 array of tensor parameters
# needs to be serialized. The entire array is guaranteed to have the same
# dtype, so no per-element checking necessary and no need to convert each
# element separately.
if isinstance(value, np.ndarray) and value.dtype.type is np.float32:
argument.floats.extend(value.flatten().tolist())
return argument
if isinstance(value, np.ndarray):
value = value.flatten().tolist()
elif isinstance(value, np.generic):
# convert numpy scalar to native python type
value = np.asscalar(value)
if type(value) is float:
argument.f = value
elif type(value) in integer_types or type(value) is bool:
# We make a relaxation that a boolean variable will also be stored as
# int.
argument.i = value
elif isinstance(value, binary_type):
argument.s = value
elif isinstance(value, text_type):
argument.s = value.encode('utf-8')
elif isinstance(value, caffe2_pb2.NetDef):
argument.n.CopyFrom(value)
elif isinstance(value, Message):
argument.s = value.SerializeToString()
elif iterable and all(type(v) in [float, np.float_] for v in value):
argument.floats.extend(
v.item() if type(v) is np.float_ else v for v in value
)
elif iterable and all(
type(v) in integer_types or type(v) in [bool, np.int_] for v in value
):
argument.ints.extend(
v.item() if type(v) is np.int_ else v for v in value
)
elif iterable and all(
isinstance(v, binary_type) or isinstance(v, text_type) for v in value
):
argument.strings.extend(
v.encode('utf-8') if isinstance(v, text_type) else v
for v in value
)
elif iterable and all(isinstance(v, caffe2_pb2.NetDef) for v in value):
argument.nets.extend(value)
elif iterable and all(isinstance(v, Message) for v in value):
argument.strings.extend(v.SerializeToString() for v in value)
else:
if iterable:
raise ValueError(
"Unknown iterable argument type: key={} value={}, value "
"type={}[{}]".format(
key, value, type(value), set(type(v) for v in value)
)
)
else:
raise ValueError(
"Unknown argument type: key={} value={}, value type={}".format(
key, value, type(value)
)
)
return argument
def TryReadProtoWithClass(cls, s):
"""Reads a protobuffer with the given proto class.
Inputs:
cls: a protobuffer class.
s: a string of either binary or text protobuffer content.
Outputs:
proto: the protobuffer of cls
Throws:
google.protobuf.message.DecodeError: if we cannot decode the message.
"""
obj = cls()
try:
text_format.Parse(s, obj)
return obj
except text_format.ParseError:
obj.ParseFromString(s)
return obj
def GetContentFromProto(obj, function_map):
"""Gets a specific field from a protocol buffer that matches the given class
"""
for cls, func in viewitems(function_map):
if type(obj) is cls:
return func(obj)
def GetContentFromProtoString(s, function_map):
for cls, func in viewitems(function_map):
try:
obj = TryReadProtoWithClass(cls, s)
return func(obj)
except DecodeError:
continue
else:
raise DecodeError("Cannot find a fit protobuffer class.")
def ConvertProtoToBinary(proto_class, filename, out_filename):
"""Convert a text file of the given protobuf class to binary."""
proto = TryReadProtoWithClass(proto_class, open(filename).read())
with open(out_filename, 'w') as fid:
fid.write(proto.SerializeToString())
def GetGPUMemoryUsageStats():
"""Get GPU memory usage stats from CUDAContext. This requires flag
--caffe2_gpu_memory_tracking to be enabled"""
from caffe2.python import workspace, core
workspace.RunOperatorOnce(
core.CreateOperator(
"GetGPUMemoryUsage",
[],
["____mem____"],
device_option=core.DeviceOption(caffe2_pb2.CUDA, 0),
),
)
b = workspace.FetchBlob("____mem____")
return {
'total_by_gpu': b[0, :],
'max_by_gpu': b[1, :],
'total': np.sum(b[0, :]),
'max_total': np.sum(b[1, :])
}
def ResetBlobs(blobs):
from caffe2.python import workspace, core
workspace.RunOperatorOnce(
core.CreateOperator(
"Free",
list(blobs),
list(blobs),
device_option=core.DeviceOption(caffe2_pb2.CPU),
),
)
class DebugMode(object):
'''
This class allows to drop you into an interactive debugger
if there is an unhandled exception in your python script
Example of usage:
def main():
# your code here
pass
if __name__ == '__main__':
from caffe2.python.utils import DebugMode
DebugMode.run(main)
'''
@classmethod
def run(cls, func):
try:
return func()
except KeyboardInterrupt:
raise
except Exception:
import pdb
print(
'Entering interactive debugger. Type "bt" to print '
'the full stacktrace. Type "help" to see command listing.')
print(sys.exc_info()[1])
print
pdb.post_mortem()
sys.exit(1)
raise
def raiseIfNotEqual(a, b, msg):
if a != b:
raise Exception("{}. {} != {}".format(msg, a, b))
def debug(f):
'''
Use this method to decorate your function with DebugMode's functionality
Example:
@debug
def test_foo(self):
raise Exception("Bar")
'''
@functools.wraps(f)
def wrapper(*args, **kwargs):
def func():
return f(*args, **kwargs)
DebugMode.run(func)
return wrapper