#include #include #include #include #include #include #include #include namespace c10 { namespace { DeviceType parse_type(const std::string& device_string) { static const std::array< std::pair, static_cast(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)> types = {{ {"cpu", DeviceType::CPU}, {"cuda", DeviceType::CUDA}, {"ipu", DeviceType::IPU}, {"xpu", DeviceType::XPU}, {"mkldnn", DeviceType::MKLDNN}, {"opengl", DeviceType::OPENGL}, {"opencl", DeviceType::OPENCL}, {"ideep", DeviceType::IDEEP}, {"hip", DeviceType::HIP}, {"ve", DeviceType::VE}, {"fpga", DeviceType::FPGA}, {"maia", DeviceType::MAIA}, {"xla", DeviceType::XLA}, {"lazy", DeviceType::Lazy}, {"vulkan", DeviceType::Vulkan}, {"mps", DeviceType::MPS}, {"meta", DeviceType::Meta}, {"hpu", DeviceType::HPU}, {"mtia", DeviceType::MTIA}, {"privateuseone", DeviceType::PrivateUse1}, }}; if (device_string == "mkldnn") { TORCH_WARN_ONCE( "'mkldnn' is no longer used as device type. So torch.device('mkldnn') will be " "deprecated and removed in the future. Please use other valid device types instead."); } if (device_string == get_privateuse1_backend()) { return DeviceType::PrivateUse1; } auto device = std::find_if( types.begin(), types.end(), [&device_string](const std::pair& p) { return p.first && p.first == device_string; }); if (device != types.end()) { return device->second; } std::vector device_names; for (const auto& it : types) { if (it.first) { device_names.push_back(it.first); } } TORCH_CHECK( false, "Expected one of ", c10::Join(", ", device_names), " device type at start of device string: ", device_string); } enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR }; } // namespace Device::Device(const std::string& device_string) : Device(Type::CPU) { TORCH_CHECK(!device_string.empty(), "Device string must not be empty"); std::string device_name, device_index_str; DeviceStringParsingState pstate = DeviceStringParsingState::START; // The code below tries to match the string in the variable // device_string against the regular expression: // ([a-zA-Z_]+)(?::([1-9]\\d*|0))? for (size_t i = 0; pstate != DeviceStringParsingState::ERROR && i < device_string.size(); ++i) { const char ch = device_string.at(i); const unsigned char uch = static_cast(ch); switch (pstate) { case DeviceStringParsingState::START: if (ch != ':') { if (std::isalpha(uch) || ch == '_') { device_name.push_back(ch); } else { pstate = DeviceStringParsingState::ERROR; } } else { pstate = DeviceStringParsingState::INDEX_START; } break; case DeviceStringParsingState::INDEX_START: if (std::isdigit(uch)) { device_index_str.push_back(ch); pstate = DeviceStringParsingState::INDEX_REST; } else { pstate = DeviceStringParsingState::ERROR; } break; case DeviceStringParsingState::INDEX_REST: if (device_index_str.at(0) == '0') { pstate = DeviceStringParsingState::ERROR; break; } if (std::isdigit(uch)) { device_index_str.push_back(ch); } else { pstate = DeviceStringParsingState::ERROR; } break; case DeviceStringParsingState::ERROR: // Execution won't reach here. break; } } const bool has_error = device_name.empty() || pstate == DeviceStringParsingState::ERROR || (pstate == DeviceStringParsingState::INDEX_START && device_index_str.empty()); TORCH_CHECK(!has_error, "Invalid device string: '", device_string, "'"); try { if (!device_index_str.empty()) { index_ = static_cast(std::stoi(device_index_str)); } } catch (const std::exception&) { TORCH_CHECK( false, "Could not parse device index '", device_index_str, "' in device string '", device_string, "'"); } type_ = parse_type(device_name); validate(); } std::string Device::str() const { std::string str = DeviceTypeName(type(), /* lower case */ true); if (has_index()) { str.push_back(':'); str.append(std::to_string(index())); } return str; } std::ostream& operator<<(std::ostream& stream, const Device& device) { stream << device.str(); return stream; } } // namespace c10