pytorch/caffe2/utils/proto_utils.h
Aapo Kyrola 453c60ce28 Threaded dependency-aware RNNExecutor (frontier/diagonal execution).
Summary:
This diff adds dependency-aware concurrent/parallel execution of operators in stepnets. For CPU, we use multi-threaded execution. For CUDA, we use multiple streams and cuda events for parallelism and dependency tracking.

Much of the diff is about computing dependency graph, which was quite tricky because we need to also avoid write-races of multiple operators running in multiple timesteps in parallel. Also, recurrent blobs "change name" when passing over timestep ("_prev"), so that needs to be handled as well.

This diff also restores the link-ops that I unlanded earlier.

The performance gain of this diff is very good for CPU (same perf as with static_dag, even better on forward-only). On CUDA, the gains are modest, at least with the sizes i was testing with.

Reviewed By: salexspb

Differential Revision: D5001637

fbshipit-source-id: 3d0a71593d73a9ff22f4c1a5c9abf2a4a0c633c8
2017-08-15 23:55:15 -07:00

283 lines
9.2 KiB
C++

#ifndef CAFFE2_UTILS_PROTO_UTILS_H_
#define CAFFE2_UTILS_PROTO_UTILS_H_
#include "google/protobuf/message_lite.h"
#ifndef CAFFE2_USE_LITE_PROTO
#include "google/protobuf/message.h"
#endif // !CAFFE2_USE_LITE_PROTO
#include "caffe2/core/logging.h"
#include "caffe2/proto/caffe2.pb.h"
namespace caffe2 {
using std::string;
using ::google::protobuf::MessageLite;
// A wrapper function to return device name string for use in blob serialization
// / deserialization. This should have one to one correspondence with
// caffe2/proto/caffe2.proto: enum DeviceType.
//
// Note that we can't use DeviceType_Name, because that is only available in
// protobuf-full, and some platforms (like mobile) may want to use
// protobuf-lite instead.
std::string DeviceTypeName(const int32_t& d);
// Common interfaces that reads file contents into a string.
bool ReadStringFromFile(const char* filename, string* str);
bool WriteStringToFile(const string& str, const char* filename);
// Common interfaces that are supported by both lite and full protobuf.
bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto);
inline bool ReadProtoFromBinaryFile(const string filename, MessageLite* proto) {
return ReadProtoFromBinaryFile(filename.c_str(), proto);
}
void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename);
inline void WriteProtoToBinaryFile(const MessageLite& proto,
const string& filename) {
return WriteProtoToBinaryFile(proto, filename.c_str());
}
#ifdef CAFFE2_USE_LITE_PROTO
inline string ProtoDebugString(const MessageLite& proto) {
return proto.SerializeAsString();
}
// Text format MessageLite wrappers: these functions do nothing but just
// allowing things to compile. It will produce a runtime error if you are using
// MessageLite but still want text support.
inline bool ReadProtoFromTextFile(
const char* /*filename*/,
MessageLite* /*proto*/) {
LOG(FATAL) << "If you are running lite version, you should not be "
<< "calling any text-format protobuffers.";
return false; // Just to suppress compiler warning.
}
inline bool ReadProtoFromTextFile(const string filename, MessageLite* proto) {
return ReadProtoFromTextFile(filename.c_str(), proto);
}
inline void WriteProtoToTextFile(
const MessageLite& /*proto*/,
const char* /*filename*/) {
LOG(FATAL) << "If you are running lite version, you should not be "
<< "calling any text-format protobuffers.";
}
inline void WriteProtoToTextFile(const MessageLite& proto,
const string& filename) {
return WriteProtoToTextFile(proto, filename.c_str());
}
inline bool ReadProtoFromFile(const char* filename, MessageLite* proto) {
return (ReadProtoFromBinaryFile(filename, proto) ||
ReadProtoFromTextFile(filename, proto));
}
inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) {
return ReadProtoFromFile(filename.c_str(), proto);
}
#else // CAFFE2_USE_LITE_PROTO
using ::google::protobuf::Message;
inline string ProtoDebugString(const Message& proto) {
return proto.ShortDebugString();
}
bool ReadProtoFromTextFile(const char* filename, Message* proto);
inline bool ReadProtoFromTextFile(const string filename, Message* proto) {
return ReadProtoFromTextFile(filename.c_str(), proto);
}
void WriteProtoToTextFile(const Message& proto, const char* filename);
inline void WriteProtoToTextFile(const Message& proto, const string& filename) {
return WriteProtoToTextFile(proto, filename.c_str());
}
// Read Proto from a file, letting the code figure out if it is text or binary.
inline bool ReadProtoFromFile(const char* filename, Message* proto) {
return (ReadProtoFromBinaryFile(filename, proto) ||
ReadProtoFromTextFile(filename, proto));
}
inline bool ReadProtoFromFile(const string& filename, Message* proto) {
return ReadProtoFromFile(filename.c_str(), proto);
}
#endif // CAFFE2_USE_LITE_PROTO
template <
class IterableInputs = std::initializer_list<string>,
class IterableOutputs = std::initializer_list<string>,
class IterableArgs = std::initializer_list<Argument>>
OperatorDef CreateOperatorDef(
const string& type,
const string& name,
const IterableInputs& inputs,
const IterableOutputs& outputs,
const IterableArgs& args,
const DeviceOption& device_option = DeviceOption(),
const string& engine = "") {
OperatorDef def;
def.set_type(type);
def.set_name(name);
for (const string& in : inputs) {
def.add_input(in);
}
for (const string& out : outputs) {
def.add_output(out);
}
for (const Argument& arg : args) {
def.add_arg()->CopyFrom(arg);
}
if (device_option.has_device_type()) {
def.mutable_device_option()->CopyFrom(device_option);
}
if (engine.size()) {
def.set_engine(engine);
}
return def;
}
// A simplified version compared to the full CreateOperator, if you do not need
// to specify args.
template <
class IterableInputs = std::initializer_list<string>,
class IterableOutputs = std::initializer_list<string>>
inline OperatorDef CreateOperatorDef(
const string& type,
const string& name,
const IterableInputs& inputs,
const IterableOutputs& outputs,
const DeviceOption& device_option = DeviceOption(),
const string& engine = "") {
return CreateOperatorDef(
type,
name,
inputs,
outputs,
std::vector<Argument>(),
device_option,
engine);
}
bool HasOutput(const OperatorDef& op, const std::string& output);
bool HasInput(const OperatorDef& op, const std::string& input);
/**
* @brief A helper class to index into arguments.
*
* This helper helps us to more easily index into a set of arguments
* that are present in the operator. To save memory, the argument helper
* does not copy the operator def, so one would need to make sure that the
* lifetime of the OperatorDef object outlives that of the ArgumentHelper.
*/
class ArgumentHelper {
public:
template <typename Def>
static bool HasArgument(const Def& def, const string& name) {
return ArgumentHelper(def).HasArgument(name);
}
template <typename Def, typename T>
static T GetSingleArgument(
const Def& def,
const string& name,
const T& default_value) {
return ArgumentHelper(def).GetSingleArgument<T>(name, default_value);
}
template <typename Def, typename T>
static bool HasSingleArgumentOfType(const Def& def, const string& name) {
return ArgumentHelper(def).HasSingleArgumentOfType<T>(name);
}
template <typename Def, typename T>
static vector<T> GetRepeatedArgument(
const Def& def,
const string& name,
const std::vector<T>& default_value = std::vector<T>()) {
return ArgumentHelper(def).GetRepeatedArgument<T>(name, default_value);
}
template <typename Def, typename MessageType>
static MessageType GetMessageArgument(const Def& def, const string& name) {
return ArgumentHelper(def).GetMessageArgument<MessageType>(name);
}
template <typename Def, typename MessageType>
static vector<MessageType> GetRepeatedMessageArgument(
const Def& def,
const string& name) {
return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name);
}
explicit ArgumentHelper(const OperatorDef& def);
explicit ArgumentHelper(const NetDef& netdef);
bool HasArgument(const string& name) const;
template <typename T>
T GetSingleArgument(const string& name, const T& default_value) const;
template <typename T>
bool HasSingleArgumentOfType(const string& name) const;
template <typename T>
vector<T> GetRepeatedArgument(
const string& name,
const std::vector<T>& default_value = std::vector<T>()) const;
template <typename MessageType>
MessageType GetMessageArgument(const string& name) const {
CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name);
MessageType message;
if (arg_map_.at(name).has_s()) {
CAFFE_ENFORCE(
message.ParseFromString(arg_map_.at(name).s()),
"Faild to parse content from the string");
} else {
VLOG(1) << "Return empty message for parameter " << name;
}
return message;
}
template <typename MessageType>
vector<MessageType> GetRepeatedMessageArgument(const string& name) const {
CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name);
vector<MessageType> messages(arg_map_.at(name).strings_size());
for (int i = 0; i < messages.size(); ++i) {
CAFFE_ENFORCE(
messages[i].ParseFromString(arg_map_.at(name).strings(i)),
"Faild to parse content from the string");
}
return messages;
}
private:
CaffeMap<string, Argument> arg_map_;
};
const Argument& GetArgument(const OperatorDef& def, const string& name);
bool GetFlagArgument(
const OperatorDef& def,
const string& name,
bool def_value = false);
Argument* GetMutableArgument(
const string& name,
const bool create_if_missing,
OperatorDef* def);
template <typename T>
Argument MakeArgument(const string& name, const T& value);
template <typename T>
inline void AddArgument(const string& name, const T& value, OperatorDef* def) {
GetMutableArgument(name, true, def)->CopyFrom(MakeArgument(name, value));
}
} // namespace caffe2
#endif // CAFFE2_UTILS_PROTO_UTILS_H_