mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add node_name to DeviceOption
Summary: Allow for generalizing net transforms. Reviewed By: Yangqing Differential Revision: D5812140 fbshipit-source-id: e3f30acad362ae1f0614ee218d331b525710b88e
This commit is contained in:
parent
37af6566e1
commit
68f358452b
|
|
@ -28,6 +28,15 @@ NetBase::NetBase(
|
|||
def->external_output().begin(),
|
||||
def->external_output().end()),
|
||||
name_(def->name()) {
|
||||
// Check that node_name is empty for all ops
|
||||
for (const OperatorDef& op : def->op()) {
|
||||
if (op.has_device_option()) {
|
||||
CAFFE_ENFORCE(
|
||||
!op.device_option().has_node_name(),
|
||||
"node_name must be empty for all operators at execution time.");
|
||||
}
|
||||
}
|
||||
|
||||
// Go through the operators and make sure that blobs are correctly made.
|
||||
std::set<string> known_blobs(
|
||||
external_input_.begin(), external_input_.end());
|
||||
|
|
|
|||
|
|
@ -132,6 +132,9 @@ message DeviceOption {
|
|||
optional int32 cuda_gpu_id = 2;
|
||||
// [general] The random seed to start the device random number generator with.
|
||||
optional uint32 random_seed = 3;
|
||||
// [general] What node this op should execute on.
|
||||
// Used for net transformation purposes. Must be empty at execution time.
|
||||
optional string node_name = 4;
|
||||
}
|
||||
|
||||
// Operator Definition.
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@
|
|||
#include <cerrno>
|
||||
#include <fstream>
|
||||
|
||||
#include "google/protobuf/io/coded_stream.h"
|
||||
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||
#include <google/protobuf/io/coded_stream.h>
|
||||
#include <google/protobuf/io/zero_copy_stream_impl.h>
|
||||
|
||||
#ifndef CAFFE2_USE_LITE_PROTO
|
||||
#include "google/protobuf/text_format.h"
|
||||
#include <google/protobuf/text_format.h>
|
||||
#endif // !CAFFE2_USE_LITE_PROTO
|
||||
|
||||
#include "caffe2/core/logging.h"
|
||||
|
|
@ -42,7 +42,8 @@ std::string DeviceTypeName(const int32_t& d) {
|
|||
bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs) {
|
||||
return (
|
||||
lhs.device_type() == rhs.device_type() &&
|
||||
lhs.cuda_gpu_id() == rhs.cuda_gpu_id());
|
||||
lhs.cuda_gpu_id() == rhs.cuda_gpu_id() &&
|
||||
lhs.node_name() == rhs.node_name());
|
||||
}
|
||||
|
||||
bool ReadStringFromFile(const char* filename, string* str) {
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
#ifndef CAFFE2_UTILS_PROTO_UTILS_H_
|
||||
#define CAFFE2_UTILS_PROTO_UTILS_H_
|
||||
|
||||
#include "google/protobuf/message_lite.h"
|
||||
#ifndef CAFFE2_USE_LITE_PROTO
|
||||
#include "google/protobuf/message.h"
|
||||
#ifdef CAFFE2_USE_LITE_PROTO
|
||||
#include <google/protobuf/message_lite.h>
|
||||
#else // CAFFE2_USE_LITE_PROTO
|
||||
#include <google/protobuf/message.h>
|
||||
#endif // !CAFFE2_USE_LITE_PROTO
|
||||
|
||||
#include "caffe2/core/logging.h"
|
||||
|
|
|
|||
|
|
@ -3,6 +3,23 @@
|
|||
|
||||
namespace caffe2 {
|
||||
|
||||
TEST(ProtoUtilsTest, IsSameDevice) {
|
||||
DeviceOption a;
|
||||
DeviceOption b;
|
||||
EXPECT_TRUE(IsSameDevice(a, b));
|
||||
a.set_node_name("my_node");
|
||||
EXPECT_FALSE(IsSameDevice(a, b));
|
||||
b.set_node_name("my_node");
|
||||
EXPECT_TRUE(IsSameDevice(a, b));
|
||||
b.set_cuda_gpu_id(2);
|
||||
EXPECT_FALSE(IsSameDevice(a, b));
|
||||
a.set_cuda_gpu_id(2);
|
||||
EXPECT_TRUE(IsSameDevice(a, b));
|
||||
a.set_device_type(DeviceType::CUDA);
|
||||
b.set_device_type(DeviceType::CPU);
|
||||
EXPECT_FALSE(IsSameDevice(a, b));
|
||||
}
|
||||
|
||||
TEST(ProtoUtilsTest, SimpleReadWrite) {
|
||||
string content("The quick brown fox jumps over the lazy dog.");
|
||||
string name = std::tmpnam(nullptr);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user