mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: There is a long lasting problem of scoping which was introduced in original python wrappers early in H1. Basically each RNNCell implemented has to manually scope outputs of each of the operators. If somebody forgets, then there could be weird bugs with layers etc. Approach is the following. User has to explicitly specify current scope when using apply_over_sequence function and others if the function is going to be called several times (like for stacking layers). This way we use Caffe2 native scoping approach instead of inventing one extra API people have to use (i.e. passing scope name as an argument to the RNNCell constructor). Closes https://github.com/caffe2/caffe2/pull/1681 Differential Revision: D6777536 Pulled By: salexspb fbshipit-source-id: 73d860b8d4857589e04bdea5a6fcd3080d68427c
316 lines
11 KiB
Python
316 lines
11 KiB
Python
# Copyright (c) 2016-present, Facebook, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
##############################################################################
|
|
|
|
## @package utils
|
|
# Module caffe2.python.utils
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
from caffe2.proto import caffe2_pb2
|
|
from future.utils import viewitems
|
|
from google.protobuf.message import DecodeError, Message
|
|
from google.protobuf import text_format
|
|
import sys
|
|
import collections
|
|
import functools
|
|
import numpy as np
|
|
from six import integer_types, binary_type, text_type
|
|
|
|
|
|
def CaffeBlobToNumpyArray(blob):
|
|
if (blob.num != 0):
|
|
# old style caffe blob.
|
|
return (np.asarray(blob.data, dtype=np.float32)
|
|
.reshape(blob.num, blob.channels, blob.height, blob.width))
|
|
else:
|
|
# new style caffe blob.
|
|
return (np.asarray(blob.data, dtype=np.float32)
|
|
.reshape(blob.shape.dim))
|
|
|
|
|
|
def Caffe2TensorToNumpyArray(tensor):
|
|
if tensor.data_type == caffe2_pb2.TensorProto.FLOAT:
|
|
return np.asarray(
|
|
tensor.float_data, dtype=np.float32).reshape(tensor.dims)
|
|
elif tensor.data_type == caffe2_pb2.TensorProto.DOUBLE:
|
|
return np.asarray(
|
|
tensor.double_data, dtype=np.float64).reshape(tensor.dims)
|
|
elif tensor.data_type == caffe2_pb2.TensorProto.INT32:
|
|
return np.asarray(
|
|
tensor.int32_data, dtype=np.int).reshape(tensor.dims) # pb.INT32=>np.int use int32_data
|
|
elif tensor.data_type == caffe2_pb2.TensorProto.INT16:
|
|
return np.asarray(
|
|
tensor.int32_data, dtype=np.int16).reshape(tensor.dims) # pb.INT16=>np.int16 use int32_data
|
|
elif tensor.data_type == caffe2_pb2.TensorProto.UINT16:
|
|
return np.asarray(
|
|
tensor.int32_data, dtype=np.uint16).reshape(tensor.dims) # pb.UINT16=>np.uint16 use int32_data
|
|
elif tensor.data_type == caffe2_pb2.TensorProto.INT8:
|
|
return np.asarray(
|
|
tensor.int32_data, dtype=np.int8).reshape(tensor.dims) # pb.INT8=>np.int8 use int32_data
|
|
elif tensor.data_type == caffe2_pb2.TensorProto.UINT8:
|
|
return np.asarray(
|
|
tensor.int32_data, dtype=np.uint8).reshape(tensor.dims) # pb.UINT8=>np.uint8 use int32_data
|
|
else:
|
|
# TODO: complete the data type: bool, float16, byte, int64, string
|
|
raise RuntimeError(
|
|
"Tensor data type not supported yet: " + str(tensor.data_type))
|
|
|
|
|
|
def NumpyArrayToCaffe2Tensor(arr, name=None):
|
|
tensor = caffe2_pb2.TensorProto()
|
|
tensor.dims.extend(arr.shape)
|
|
if name:
|
|
tensor.name = name
|
|
if arr.dtype == np.float32:
|
|
tensor.data_type = caffe2_pb2.TensorProto.FLOAT
|
|
tensor.float_data.extend(list(arr.flatten().astype(float)))
|
|
elif arr.dtype == np.float64:
|
|
tensor.data_type = caffe2_pb2.TensorProto.DOUBLE
|
|
tensor.double_data.extend(list(arr.flatten().astype(np.float64)))
|
|
elif arr.dtype == np.int or arr.dtype == np.int32:
|
|
tensor.data_type = caffe2_pb2.TensorProto.INT32
|
|
tensor.int32_data.extend(list(arr.flatten().astype(np.int)))
|
|
elif arr.dtype == np.int16:
|
|
tensor.data_type = caffe2_pb2.TensorProto.INT16
|
|
tensor.int32_data.extend(list(arr.flatten().astype(np.int16))) # np.int16=>pb.INT16 use int32_data
|
|
elif arr.dtype == np.uint16:
|
|
tensor.data_type = caffe2_pb2.TensorProto.UINT16
|
|
tensor.int32_data.extend(list(arr.flatten().astype(np.uint16))) # np.uint16=>pb.UNIT16 use int32_data
|
|
elif arr.dtype == np.int8:
|
|
tensor.data_type = caffe2_pb2.TensorProto.INT8
|
|
tensor.int32_data.extend(list(arr.flatten().astype(np.int8))) # np.int8=>pb.INT8 use int32_data
|
|
elif arr.dtype == np.uint8:
|
|
tensor.data_type = caffe2_pb2.TensorProto.UINT8
|
|
tensor.int32_data.extend(list(arr.flatten().astype(np.uint8))) # np.uint8=>pb.UNIT8 use int32_data
|
|
else:
|
|
# TODO: complete the data type: bool, float16, byte, int64, string
|
|
raise RuntimeError(
|
|
"Numpy data type not supported yet: " + str(arr.dtype))
|
|
return tensor
|
|
|
|
|
|
def MakeArgument(key, value):
|
|
"""Makes an argument based on the value type."""
|
|
argument = caffe2_pb2.Argument()
|
|
argument.name = key
|
|
iterable = isinstance(value, collections.Iterable)
|
|
|
|
# Fast tracking common use case where a float32 array of tensor parameters
|
|
# needs to be serialized. The entire array is guaranteed to have the same
|
|
# dtype, so no per-element checking necessary and no need to convert each
|
|
# element separately.
|
|
if isinstance(value, np.ndarray) and value.dtype.type is np.float32:
|
|
argument.floats.extend(value.flatten().tolist())
|
|
return argument
|
|
|
|
if isinstance(value, np.ndarray):
|
|
value = value.flatten().tolist()
|
|
elif isinstance(value, np.generic):
|
|
# convert numpy scalar to native python type
|
|
value = np.asscalar(value)
|
|
|
|
if type(value) is float:
|
|
argument.f = value
|
|
elif type(value) in integer_types or type(value) is bool:
|
|
# We make a relaxation that a boolean variable will also be stored as
|
|
# int.
|
|
argument.i = value
|
|
elif isinstance(value, binary_type):
|
|
argument.s = value
|
|
elif isinstance(value, text_type):
|
|
argument.s = value.encode('utf-8')
|
|
elif isinstance(value, caffe2_pb2.NetDef):
|
|
argument.n.CopyFrom(value)
|
|
elif isinstance(value, Message):
|
|
argument.s = value.SerializeToString()
|
|
elif iterable and all(type(v) in [float, np.float_] for v in value):
|
|
argument.floats.extend(
|
|
v.item() if type(v) is np.float_ else v for v in value
|
|
)
|
|
elif iterable and all(
|
|
type(v) in integer_types or type(v) in [bool, np.int_] for v in value
|
|
):
|
|
argument.ints.extend(
|
|
v.item() if type(v) is np.int_ else v for v in value
|
|
)
|
|
elif iterable and all(
|
|
isinstance(v, binary_type) or isinstance(v, text_type) for v in value
|
|
):
|
|
argument.strings.extend(
|
|
v.encode('utf-8') if isinstance(v, text_type) else v
|
|
for v in value
|
|
)
|
|
elif iterable and all(isinstance(v, caffe2_pb2.NetDef) for v in value):
|
|
argument.nets.extend(value)
|
|
elif iterable and all(isinstance(v, Message) for v in value):
|
|
argument.strings.extend(v.SerializeToString() for v in value)
|
|
else:
|
|
if iterable:
|
|
raise ValueError(
|
|
"Unknown iterable argument type: key={} value={}, value "
|
|
"type={}[{}]".format(
|
|
key, value, type(value), set(type(v) for v in value)
|
|
)
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"Unknown argument type: key={} value={}, value type={}".format(
|
|
key, value, type(value)
|
|
)
|
|
)
|
|
return argument
|
|
|
|
|
|
def TryReadProtoWithClass(cls, s):
|
|
"""Reads a protobuffer with the given proto class.
|
|
|
|
Inputs:
|
|
cls: a protobuffer class.
|
|
s: a string of either binary or text protobuffer content.
|
|
|
|
Outputs:
|
|
proto: the protobuffer of cls
|
|
|
|
Throws:
|
|
google.protobuf.message.DecodeError: if we cannot decode the message.
|
|
"""
|
|
obj = cls()
|
|
try:
|
|
text_format.Parse(s, obj)
|
|
return obj
|
|
except text_format.ParseError:
|
|
obj.ParseFromString(s)
|
|
return obj
|
|
|
|
|
|
def GetContentFromProto(obj, function_map):
|
|
"""Gets a specific field from a protocol buffer that matches the given class
|
|
"""
|
|
for cls, func in viewitems(function_map):
|
|
if type(obj) is cls:
|
|
return func(obj)
|
|
|
|
|
|
def GetContentFromProtoString(s, function_map):
|
|
for cls, func in viewitems(function_map):
|
|
try:
|
|
obj = TryReadProtoWithClass(cls, s)
|
|
return func(obj)
|
|
except DecodeError:
|
|
continue
|
|
else:
|
|
raise DecodeError("Cannot find a fit protobuffer class.")
|
|
|
|
|
|
def ConvertProtoToBinary(proto_class, filename, out_filename):
|
|
"""Convert a text file of the given protobuf class to binary."""
|
|
proto = TryReadProtoWithClass(proto_class, open(filename).read())
|
|
with open(out_filename, 'w') as fid:
|
|
fid.write(proto.SerializeToString())
|
|
|
|
|
|
def GetGPUMemoryUsageStats():
|
|
"""Get GPU memory usage stats from CUDAContext. This requires flag
|
|
--caffe2_gpu_memory_tracking to be enabled"""
|
|
from caffe2.python import workspace, core
|
|
workspace.RunOperatorOnce(
|
|
core.CreateOperator(
|
|
"GetGPUMemoryUsage",
|
|
[],
|
|
["____mem____"],
|
|
device_option=core.DeviceOption(caffe2_pb2.CUDA, 0),
|
|
),
|
|
)
|
|
b = workspace.FetchBlob("____mem____")
|
|
return {
|
|
'total_by_gpu': b[0, :],
|
|
'max_by_gpu': b[1, :],
|
|
'total': np.sum(b[0, :]),
|
|
'max_total': np.sum(b[1, :])
|
|
}
|
|
|
|
|
|
def ResetBlobs(blobs):
|
|
from caffe2.python import workspace, core
|
|
workspace.RunOperatorOnce(
|
|
core.CreateOperator(
|
|
"Free",
|
|
list(blobs),
|
|
list(blobs),
|
|
device_option=core.DeviceOption(caffe2_pb2.CPU),
|
|
),
|
|
)
|
|
|
|
|
|
class DebugMode(object):
|
|
'''
|
|
This class allows to drop you into an interactive debugger
|
|
if there is an unhandled exception in your python script
|
|
|
|
Example of usage:
|
|
|
|
def main():
|
|
# your code here
|
|
pass
|
|
|
|
if __name__ == '__main__':
|
|
from caffe2.python.utils import DebugMode
|
|
DebugMode.run(main)
|
|
'''
|
|
|
|
@classmethod
|
|
def run(cls, func):
|
|
try:
|
|
return func()
|
|
except KeyboardInterrupt:
|
|
raise
|
|
except Exception:
|
|
import pdb
|
|
|
|
print(
|
|
'Entering interactive debugger. Type "bt" to print '
|
|
'the full stacktrace. Type "help" to see command listing.')
|
|
print(sys.exc_info()[1])
|
|
print
|
|
|
|
pdb.post_mortem()
|
|
sys.exit(1)
|
|
raise
|
|
|
|
def raiseIfNotEqual(a, b, msg):
|
|
if a != b:
|
|
raise Exception("{}. {} != {}".format(msg, a, b))
|
|
|
|
def debug(f):
|
|
'''
|
|
Use this method to decorate your function with DebugMode's functionality
|
|
|
|
Example:
|
|
|
|
@debug
|
|
def test_foo(self):
|
|
raise Exception("Bar")
|
|
|
|
'''
|
|
|
|
@functools.wraps(f)
|
|
def wrapper(*args, **kwargs):
|
|
def func():
|
|
return f(*args, **kwargs)
|
|
DebugMode.run(func)
|
|
|
|
return wrapper
|