Add node_name to DeviceOption

Summary: Allow for generalizing net transforms.

Reviewed By: Yangqing

Differential Revision: D5812140

fbshipit-source-id: e3f30acad362ae1f0614ee218d331b525710b88e
This commit is contained in:
Alisson Gusatti Azzolini 2017-09-13 15:49:55 -07:00 committed by Facebook Github Bot
parent 37af6566e1
commit 68f358452b
5 changed files with 38 additions and 7 deletions

View File

@ -28,6 +28,15 @@ NetBase::NetBase(
def->external_output().begin(),
def->external_output().end()),
name_(def->name()) {
// Check that node_name is empty for all ops
for (const OperatorDef& op : def->op()) {
if (op.has_device_option()) {
CAFFE_ENFORCE(
!op.device_option().has_node_name(),
"node_name must be empty for all operators at execution time.");
}
}
// Go through the operators and make sure that blobs are correctly made.
std::set<string> known_blobs(
external_input_.begin(), external_input_.end());

View File

@ -132,6 +132,9 @@ message DeviceOption {
optional int32 cuda_gpu_id = 2;
// [general] The random seed to start the device random number generator with.
optional uint32 random_seed = 3;
// [general] What node this op should execute on.
// Used for net transformation purposes. Must be empty at execution time.
optional string node_name = 4;
}
// Operator Definition.

View File

@ -4,11 +4,11 @@
#include <cerrno>
#include <fstream>
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#ifndef CAFFE2_USE_LITE_PROTO
#include "google/protobuf/text_format.h"
#include <google/protobuf/text_format.h>
#endif // !CAFFE2_USE_LITE_PROTO
#include "caffe2/core/logging.h"
@ -42,7 +42,8 @@ std::string DeviceTypeName(const int32_t& d) {
bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs) {
return (
lhs.device_type() == rhs.device_type() &&
lhs.cuda_gpu_id() == rhs.cuda_gpu_id());
lhs.cuda_gpu_id() == rhs.cuda_gpu_id() &&
lhs.node_name() == rhs.node_name());
}
bool ReadStringFromFile(const char* filename, string* str) {

View File

@ -1,9 +1,10 @@
#ifndef CAFFE2_UTILS_PROTO_UTILS_H_
#define CAFFE2_UTILS_PROTO_UTILS_H_
#include "google/protobuf/message_lite.h"
#ifndef CAFFE2_USE_LITE_PROTO
#include "google/protobuf/message.h"
#ifdef CAFFE2_USE_LITE_PROTO
#include <google/protobuf/message_lite.h>
#else // CAFFE2_USE_LITE_PROTO
#include <google/protobuf/message.h>
#endif // !CAFFE2_USE_LITE_PROTO
#include "caffe2/core/logging.h"

View File

@ -3,6 +3,23 @@
namespace caffe2 {
TEST(ProtoUtilsTest, IsSameDevice) {
DeviceOption a;
DeviceOption b;
EXPECT_TRUE(IsSameDevice(a, b));
a.set_node_name("my_node");
EXPECT_FALSE(IsSameDevice(a, b));
b.set_node_name("my_node");
EXPECT_TRUE(IsSameDevice(a, b));
b.set_cuda_gpu_id(2);
EXPECT_FALSE(IsSameDevice(a, b));
a.set_cuda_gpu_id(2);
EXPECT_TRUE(IsSameDevice(a, b));
a.set_device_type(DeviceType::CUDA);
b.set_device_type(DeviceType::CPU);
EXPECT_FALSE(IsSameDevice(a, b));
}
TEST(ProtoUtilsTest, SimpleReadWrite) {
string content("The quick brown fox jumps over the lazy dog.");
string name = std::tmpnam(nullptr);