#include #include #include #include namespace c10d { c10::intrusive_ptr HashStore::clone() { return c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this); } void HashStore::set(const std::string& key, const std::vector& data) { std::unique_lock lock(m_); map_[key] = data; cv_.notify_all(); } std::vector HashStore::compareSet( const std::string& key, const std::vector& expectedValue, const std::vector& desiredValue) { std::unique_lock lock(m_); auto it = map_.find(key); if ((it == map_.end() && expectedValue.empty()) || (it != map_.end() && it->second == expectedValue)) { // if the key does not exist and currentValue arg is empty or // the key does exist and current value is what is expected, then set it map_[key] = desiredValue; cv_.notify_all(); return desiredValue; } else if (it == map_.end()) { // if the key does not exist return expectedValue; } // key exists but current value is not expected return it->second; } std::vector HashStore::get(const std::string& key) { std::unique_lock lock(m_); auto it = map_.find(key); if (it != map_.end()) { return it->second; } // Slow path: wait up to any timeout_. auto pred = [&]() { return map_.find(key) != map_.end(); }; if (timeout_ == kNoTimeout) { cv_.wait(lock, pred); } else { if (!cv_.wait_for(lock, timeout_, pred)) { C10_THROW_ERROR(DistStoreError, "Wait timeout"); } } return map_[key]; } void HashStore::wait( const std::vector& keys, const std::chrono::milliseconds& timeout) { std::unique_lock lock(m_); waitLocked(lock, keys, timeout); } void HashStore::waitLocked( std::unique_lock& lock, const std::vector& keys, const std::chrono::milliseconds& timeout) { const auto end = std::chrono::steady_clock::now() + timeout; auto pred = [&]() { return checkLocked(lock, keys); }; if (timeout == kNoTimeout) { cv_.wait(lock, pred); } else { if (!cv_.wait_until(lock, end, pred)) { C10_THROW_ERROR(DistStoreError, "Wait timeout"); } } } int64_t HashStore::add(const std::string& key, int64_t i) { std::unique_lock lock(m_); const auto& value = map_[key]; int64_t ti = i; if (!value.empty()) { auto buf = reinterpret_cast(value.data()); auto len = value.size(); ti += std::stoll(std::string(buf, len)); } auto str = std::to_string(ti); const uint8_t* strB = reinterpret_cast(str.c_str()); map_[key] = std::vector(strB, strB + str.size()); return ti; } int64_t HashStore::getNumKeys() { std::unique_lock lock(m_); return static_cast(map_.size()); } bool HashStore::deleteKey(const std::string& key) { std::unique_lock lock(m_); auto numDeleted = map_.erase(key); return (numDeleted == 1); } bool HashStore::check(const std::vector& keys) { std::unique_lock lock(m_); return checkLocked(lock, keys); } bool HashStore::checkLocked( const std::unique_lock& lock, const std::vector& keys) { for (const auto& key : keys) { auto foundKV = map_.find(key) != map_.end(); auto foundQueue = queues_.find(key) != queues_.end() && !queues_[key].empty(); if (!foundKV && !foundQueue) { return false; } } return true; } void HashStore::append( const std::string& key, const std::vector& value) { std::unique_lock lock(m_); auto it = map_.find(key); if (it == map_.end()) { map_[key] = value; } else { it->second.insert(it->second.end(), value.begin(), value.end()); } cv_.notify_all(); } std::vector> HashStore::multiGet( const std::vector& keys) { std::unique_lock lock(m_); auto deadline = std::chrono::steady_clock::now() + timeout_; std::vector> res; res.reserve(keys.size()); for (auto& key : keys) { auto it = map_.find(key); if (it != map_.end()) { res.emplace_back(it->second); } else { auto pred = [&]() { return map_.find(key) != map_.end(); }; if (timeout_ == kNoTimeout) { cv_.wait(lock, pred); } else { if (!cv_.wait_until(lock, deadline, pred)) { C10_THROW_ERROR(DistStoreError, "Wait timeout"); } } res.emplace_back(map_[key]); } } return res; } void HashStore::multiSet( const std::vector& keys, const std::vector>& values) { std::unique_lock lock(m_); for (auto i : ::c10::irange(keys.size())) { map_[keys[i]] = values[i]; } cv_.notify_all(); } bool HashStore::hasExtendedApi() const { return true; } void HashStore::queuePush( const std::string& key, const std::vector& value) { std::unique_lock lock(m_); queues_[key].push_back(value); cv_.notify_one(); } std::vector HashStore::queuePop(const std::string& key, bool block) { std::unique_lock lock(m_); if (block) { waitLocked(lock, {key}, timeout_); } auto& queue = queues_[key]; TORCH_CHECK_WITH(DistQueueEmptyError, !queue.empty(), "queue is empty"); auto val = queue.front(); queue.pop_front(); return val; } int64_t HashStore::queueLen(const std::string& key) { std::unique_lock lock(m_); auto it = queues_.find(key); if (it == queues_.end()) { return 0; } return static_cast(it->second.size()); } } // namespace c10d