pytorch/c10/core/Device.cpp
Bel H 30cb6ac53c Introduce mlc device (ML Compute device) to PyTorch's device list (#50634)
Summary:
Apple recently announced ML Compute, a new framework available in macOS Big Sur, which enables users to accelerate the training of neural networks on Mac hardware. This PR is the first on a series of PRs that will enable the integration with ML Compute. Most of the integration code will live on a separate subrepo named `mlc`.
The integration with `mlc` (ML Compute) will be very similar to that of xla. We rely on registering our ops through:

TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
 m.impl_UNBOXED(<op_schema_name>, &customized_op_kernel)
 ...
}

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50634

Reviewed By: malfet

Differential Revision: D26614213

Pulled By: smessmer

fbshipit-source-id: 3b492b346c61cc3950ac880ac01a82fbdddbc07b
2021-02-24 22:39:11 -08:00

103 lines
3.1 KiB
C++

#include <c10/core/Device.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <algorithm>
#include <array>
#include <exception>
#include <ostream>
#include <string>
#include <tuple>
#include <vector>
#include <regex>
// Check if compiler has working std::regex implementation
//
// Test below is adapted from https://stackoverflow.com/a/41186162
#if defined(_MSVC_LANG) && _MSVC_LANG >= 201103L
// Compiler has working regex. MSVC has erroneous __cplusplus.
#elif __cplusplus >= 201103L && \
(!defined(__GLIBCXX__) || (__cplusplus >= 201402L) || \
(defined(_GLIBCXX_REGEX_DFS_QUANTIFIERS_LIMIT) || \
defined(_GLIBCXX_REGEX_STATE_LIMIT) || \
(defined(_GLIBCXX_RELEASE) && \
_GLIBCXX_RELEASE > 4)))
// Compiler has working regex.
#else
static_assert(false, "Compiler does not have proper regex support.");
#endif
namespace c10 {
namespace {
DeviceType parse_type(const std::string& device_string) {
static const std::array<
std::pair<std::string, DeviceType>,
static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)>
types = {{
{"cpu", DeviceType::CPU},
{"cuda", DeviceType::CUDA},
{"xpu", DeviceType::XPU},
{"mkldnn", DeviceType::MKLDNN},
{"opengl", DeviceType::OPENGL},
{"opencl", DeviceType::OPENCL},
{"ideep", DeviceType::IDEEP},
{"hip", DeviceType::HIP},
{"fpga", DeviceType::FPGA},
{"msnpu", DeviceType::MSNPU},
{"xla", DeviceType::XLA},
{"vulkan", DeviceType::Vulkan},
{"mlc", DeviceType::MLC},
}};
auto device = std::find_if(
types.begin(),
types.end(),
[device_string](const std::pair<std::string, DeviceType>& p) {
return p.first == device_string;
});
if (device != types.end()) {
return device->second;
}
AT_ERROR(
"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan device type at start of device string: ",
device_string);
}
} // namespace
Device::Device(const std::string& device_string) : Device(Type::CPU) {
TORCH_CHECK(!device_string.empty(), "Device string must not be empty");
// We assume gcc 5+, so we can use proper regex.
static const std::regex regex("([a-zA-Z_]+)(?::([1-9]\\d*|0))?");
std::smatch match;
TORCH_CHECK(
std::regex_match(device_string, match, regex),
"Invalid device string: '", device_string, "'");
type_ = parse_type(match[1].str());
if (match[2].matched) {
try {
index_ = c10::stoi(match[2].str());
} catch (const std::exception &) {
AT_ERROR(
"Could not parse device index '", match[2].str(),
"' in device string '", device_string, "'");
}
}
validate();
}
std::string Device::str() const {
std::string str = DeviceTypeName(type(), /* lower case */ true);
if (has_index()) {
str.push_back(':');
str.append(to_string(index()));
}
return str;
}
std::ostream& operator<<(std::ostream& stream, const Device& device) {
stream << device.str();
return stream;
}
} // namespace c10