pytorch/caffe2/operators/load_save_op_gpu.cc
Junjie Bai f54ab540af Rename cuda_gpu_id to device_id in DeviceOption (#12456)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12456

codemod with 'Yes to all'
codemod -d . --extensions h,cc,cpp,cu,py,proto,pbtxt,pb.txt,config cuda_gpu_id device_id

Overload TextFormat::ParseFromString to do string replace when parsing from protobuf format

Reviewed By: Yangqing

Differential Revision: D10240535

fbshipit-source-id: 5e6992bec961214be8dbe26f16f5794154a22b25
2018-10-09 15:54:04 -07:00

20 lines
632 B
C++

#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/load_save_op.h"
namespace caffe2 {
template <>
void LoadOp<CUDAContext>::SetCurrentDevice(BlobProto* proto) {
if (proto->has_tensor()) {
proto->mutable_tensor()->clear_device_detail();
auto* device_detail = proto->mutable_tensor()->mutable_device_detail();
device_detail->set_device_type(PROTO_CUDA);
device_detail->set_device_id(CaffeCudaGetDevice());
}
}
REGISTER_CUDA_OPERATOR(Load, LoadOp<CUDAContext>);
REGISTER_CUDA_OPERATOR(Save, SaveOp<CUDAContext>);
REGISTER_CUDA_OPERATOR(Checkpoint, CheckpointOp<CUDAContext>);
} // namespace caffe2