#include #include #include #include #include namespace c10d { void HashStore::set(const std::string& key, const std::vector& data) { std::unique_lock lock(m_); map_[key] = data; cv_.notify_all(); } std::vector HashStore::compareSet( const std::string& key, const std::vector& expectedValue, const std::vector& desiredValue) { std::unique_lock lock(m_); auto it = map_.find(key); if ((it == map_.end() && expectedValue.empty()) || (it != map_.end() && it->second == expectedValue)) { // if the key does not exist and currentValue arg is empty or // the key does exist and current value is what is expected, then set it map_[key] = desiredValue; cv_.notify_all(); return desiredValue; } else if (it == map_.end()) { // if the key does not exist return expectedValue; } // key exists but current value is not expected return it->second; } std::vector HashStore::get(const std::string& key) { std::unique_lock lock(m_); auto it = map_.find(key); if (it != map_.end()) { return it->second; } // Slow path: wait up to any timeout_. auto pred = [&]() { return map_.find(key) != map_.end(); }; if (timeout_ == kNoTimeout) { cv_.wait(lock, pred); } else { if (!cv_.wait_for(lock, timeout_, pred)) { C10_THROW_ERROR(DistStoreError, "Wait timeout"); } } return map_[key]; } void HashStore::wait( const std::vector& keys, const std::chrono::milliseconds& timeout) { const auto end = std::chrono::steady_clock::now() + timeout; auto pred = [&]() { auto done = true; for (const auto& key : keys) { if (map_.find(key) == map_.end()) { done = false; break; } } return done; }; std::unique_lock lock(m_); if (timeout == kNoTimeout) { cv_.wait(lock, pred); } else { if (!cv_.wait_until(lock, end, pred)) { C10_THROW_ERROR(DistStoreError, "Wait timeout"); } } } int64_t HashStore::add(const std::string& key, int64_t i) { std::unique_lock lock(m_); const auto& value = map_[key]; int64_t ti = i; if (!value.empty()) { auto buf = reinterpret_cast(value.data()); auto len = value.size(); ti += std::stoll(std::string(buf, len)); } auto str = std::to_string(ti); const uint8_t* strB = reinterpret_cast(str.c_str()); map_[key] = std::vector(strB, strB + str.size()); return ti; } int64_t HashStore::getNumKeys() { std::unique_lock lock(m_); return static_cast(map_.size()); } bool HashStore::deleteKey(const std::string& key) { std::unique_lock lock(m_); auto numDeleted = map_.erase(key); return (numDeleted == 1); } bool HashStore::check(const std::vector& keys) { std::unique_lock lock(m_); for (const auto& key : keys) { if (map_.find(key) == map_.end()) { return false; } } return true; } void HashStore::append( const std::string& key, const std::vector& value) { std::unique_lock lock(m_); auto it = map_.find(key); if (it == map_.end()) { map_[key] = value; } else { it->second.insert(it->second.end(), value.begin(), value.end()); } cv_.notify_all(); } std::vector> HashStore::multiGet( const std::vector& keys) { std::unique_lock lock(m_); auto deadline = std::chrono::steady_clock::now() + timeout_; std::vector> res; res.reserve(keys.size()); for (auto& key : keys) { auto it = map_.find(key); if (it != map_.end()) { res.emplace_back(it->second); } else { auto pred = [&]() { return map_.find(key) != map_.end(); }; if (timeout_ == kNoTimeout) { cv_.wait(lock, pred); } else { if (!cv_.wait_until(lock, deadline, pred)) { C10_THROW_ERROR(DistStoreError, "Wait timeout"); } } res.emplace_back(map_[key]); } } return res; } void HashStore::multiSet( const std::vector& keys, const std::vector>& values) { std::unique_lock lock(m_); for (auto i : ::c10::irange(keys.size())) { map_[keys[i]] = values[i]; } cv_.notify_all(); } bool HashStore::hasExtendedApi() const { return true; } } // namespace c10d