mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15316 This starts cleaning up the files in c10 according to the module structure we decided on. Move to c10/util: - Half.h, Half-inl.h, Half.cpp, bitcasts.h Move to c10/core: - Device.h, Device.cpp - DeviceType.h, DeviceType.cpp i-am-not-moving-c2-to-c10 Reviewed By: dzhulgakov Differential Revision: D13498493 fbshipit-source-id: dfcf1c490474a12ab950c72ca686b8ad86428f63
143 lines
4.4 KiB
C++
143 lines
4.4 KiB
C++
#pragma once
|
|
#include <c10/core/Device.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <caffe2/proto/caffe2.pb.h>
|
|
|
|
namespace caffe2 {
|
|
|
|
using DeviceType = at::DeviceType;
|
|
constexpr DeviceType CPU = DeviceType::CPU;
|
|
constexpr DeviceType CUDA = DeviceType::CUDA;
|
|
constexpr DeviceType OPENGL = DeviceType::OPENGL;
|
|
constexpr DeviceType OPENCL = DeviceType::OPENCL;
|
|
constexpr DeviceType MKLDNN = DeviceType::MKLDNN;
|
|
constexpr DeviceType IDEEP = DeviceType::IDEEP;
|
|
constexpr DeviceType HIP = DeviceType::HIP;
|
|
constexpr DeviceType COMPILE_TIME_MAX_DEVICE_TYPES =
|
|
DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;
|
|
constexpr DeviceType ONLY_FOR_TEST = DeviceType::ONLY_FOR_TEST;
|
|
|
|
inline CAFFE2_API DeviceType ProtoToType(const caffe2::DeviceTypeProto p) {
|
|
switch (p) {
|
|
case caffe2::PROTO_CPU:
|
|
return DeviceType::CPU;
|
|
case caffe2::PROTO_CUDA:
|
|
return DeviceType::CUDA;
|
|
case caffe2::PROTO_OPENGL:
|
|
return DeviceType::OPENGL;
|
|
case caffe2::PROTO_OPENCL:
|
|
return DeviceType::OPENCL;
|
|
case caffe2::PROTO_MKLDNN:
|
|
return DeviceType::MKLDNN;
|
|
case caffe2::PROTO_IDEEP:
|
|
return DeviceType::IDEEP;
|
|
case caffe2::PROTO_HIP:
|
|
return DeviceType::HIP;
|
|
case caffe2::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES:
|
|
return DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;
|
|
case caffe2::PROTO_ONLY_FOR_TEST:
|
|
return DeviceType::ONLY_FOR_TEST;
|
|
default:
|
|
AT_ERROR(
|
|
"Unknown device:",
|
|
static_cast<int32_t>(p),
|
|
". If you have recently updated the caffe2.proto file to add a new "
|
|
"device type, did you forget to update the ProtoToType() and TypeToProto"
|
|
"function to reflect such recent changes?");
|
|
}
|
|
}
|
|
|
|
inline CAFFE2_API DeviceType ProtoToType(int p) {
|
|
return ProtoToType(static_cast<caffe2::DeviceTypeProto>(p));
|
|
}
|
|
|
|
inline CAFFE2_API DeviceTypeProto TypeToProto(const DeviceType& t) {
|
|
switch (t) {
|
|
case DeviceType::CPU:
|
|
return caffe2::PROTO_CPU;
|
|
case DeviceType::CUDA:
|
|
return caffe2::PROTO_CUDA;
|
|
case DeviceType::OPENGL:
|
|
return caffe2::PROTO_OPENGL;
|
|
case DeviceType::OPENCL:
|
|
return caffe2::PROTO_OPENCL;
|
|
case DeviceType::MKLDNN:
|
|
return caffe2::PROTO_MKLDNN;
|
|
case DeviceType::IDEEP:
|
|
return caffe2::PROTO_IDEEP;
|
|
case DeviceType::HIP:
|
|
return caffe2::PROTO_HIP;
|
|
case DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES:
|
|
return caffe2::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES;
|
|
case DeviceType::ONLY_FOR_TEST:
|
|
return caffe2::PROTO_ONLY_FOR_TEST;
|
|
default:
|
|
AT_ERROR(
|
|
"Unknown device:",
|
|
static_cast<int32_t>(t),
|
|
". If you have recently updated the caffe2.proto file to add a new "
|
|
"device type, did you forget to update the ProtoToType() and TypeToProto"
|
|
"function to reflect such recent changes?");
|
|
}
|
|
}
|
|
|
|
inline CAFFE2_API caffe2::DeviceOption DeviceToOption(
|
|
const at::Device& device) {
|
|
caffe2::DeviceOption option;
|
|
auto type = device.type();
|
|
option.set_device_type(TypeToProto(type));
|
|
|
|
switch (type) {
|
|
case DeviceType::CPU:
|
|
if (device.index() != -1) {
|
|
option.set_numa_node_id(device.index());
|
|
}
|
|
break;
|
|
case DeviceType::CUDA:
|
|
case DeviceType::HIP:
|
|
option.set_device_id(device.index());
|
|
break;
|
|
case DeviceType::OPENGL:
|
|
case DeviceType::OPENCL:
|
|
case DeviceType::MKLDNN:
|
|
case DeviceType::IDEEP:
|
|
case DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES:
|
|
case DeviceType::ONLY_FOR_TEST:
|
|
break;
|
|
default:
|
|
AT_ERROR(
|
|
"Unknown device:",
|
|
static_cast<int32_t>(type),
|
|
". If you have recently updated the caffe2.proto file to add a new "
|
|
"device type, did you forget to update the ProtoToType() and TypeToProto"
|
|
"function to reflect such recent changes?");
|
|
}
|
|
return option;
|
|
}
|
|
|
|
inline CAFFE2_API at::Device OptionToDevice(const caffe2::DeviceOption option) {
|
|
auto type = option.device_type();
|
|
int32_t id = -1;
|
|
switch (type) {
|
|
case caffe2::PROTO_CPU:
|
|
if (option.has_numa_node_id()) {
|
|
id = option.numa_node_id();
|
|
}
|
|
break;
|
|
case caffe2::PROTO_CUDA:
|
|
case caffe2::PROTO_HIP:
|
|
id = option.device_id();
|
|
break;
|
|
}
|
|
return at::Device(ProtoToType(type), id);
|
|
}
|
|
|
|
inline void ExtractDeviceOption(
|
|
DeviceOption* device_option,
|
|
const at::Device& device) {
|
|
AT_ASSERT(device_option);
|
|
device_option->CopyFrom(DeviceToOption(device));
|
|
}
|
|
|
|
} // namespace caffe2
|