mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This PR fixes some clang-tidy warnings in distributed code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122884 Approved by: https://github.com/kwen2501
132 lines
3.5 KiB
C++
132 lines
3.5 KiB
C++
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
|
#include <utility>
|
|
|
|
namespace c10d {
|
|
|
|
PrefixStore::PrefixStore(std::string prefix, c10::intrusive_ptr<Store> store)
|
|
: prefix_(std::move(prefix)), store_(std::move(store)) {}
|
|
|
|
std::string PrefixStore::joinKey(const std::string& key) {
|
|
return prefix_ + "/" + key;
|
|
}
|
|
|
|
std::vector<std::string> PrefixStore::joinKeys(
|
|
const std::vector<std::string>& keys) {
|
|
std::vector<std::string> joinedKeys;
|
|
joinedKeys.reserve(keys.size());
|
|
for (const auto& key : keys) {
|
|
joinedKeys.emplace_back(joinKey(key));
|
|
}
|
|
return joinedKeys;
|
|
}
|
|
|
|
void PrefixStore::set(
|
|
const std::string& key,
|
|
const std::vector<uint8_t>& value) {
|
|
store_->set(joinKey(key), value);
|
|
}
|
|
|
|
std::vector<uint8_t> PrefixStore::compareSet(
|
|
const std::string& key,
|
|
const std::vector<uint8_t>& expectedValue,
|
|
const std::vector<uint8_t>& desiredValue) {
|
|
return store_->compareSet(joinKey(key), expectedValue, desiredValue);
|
|
}
|
|
|
|
std::vector<uint8_t> PrefixStore::get(const std::string& key) {
|
|
return store_->get(joinKey(key));
|
|
}
|
|
|
|
int64_t PrefixStore::add(const std::string& key, int64_t value) {
|
|
return store_->add(joinKey(key), value);
|
|
}
|
|
|
|
bool PrefixStore::deleteKey(const std::string& key) {
|
|
return store_->deleteKey(joinKey(key));
|
|
}
|
|
|
|
int64_t PrefixStore::getNumKeys() {
|
|
return store_->getNumKeys();
|
|
}
|
|
|
|
bool PrefixStore::check(const std::vector<std::string>& keys) {
|
|
auto joinedKeys = joinKeys(keys);
|
|
return store_->check(joinedKeys);
|
|
}
|
|
|
|
void PrefixStore::wait(const std::vector<std::string>& keys) {
|
|
auto joinedKeys = joinKeys(keys);
|
|
store_->wait(joinedKeys);
|
|
}
|
|
|
|
void PrefixStore::wait(
|
|
const std::vector<std::string>& keys,
|
|
const std::chrono::milliseconds& timeout) {
|
|
auto joinedKeys = joinKeys(keys);
|
|
store_->wait(joinedKeys, timeout);
|
|
}
|
|
|
|
const std::chrono::milliseconds& PrefixStore::getTimeout() const noexcept {
|
|
return store_->getTimeout();
|
|
}
|
|
|
|
void PrefixStore::setTimeout(const std::chrono::milliseconds& timeout) {
|
|
store_->setTimeout(timeout);
|
|
}
|
|
|
|
void PrefixStore::append(
|
|
const std::string& key,
|
|
const std::vector<uint8_t>& value) {
|
|
store_->append(joinKey(key), value);
|
|
}
|
|
|
|
std::vector<std::vector<uint8_t>> PrefixStore::multiGet(
|
|
const std::vector<std::string>& keys) {
|
|
std::vector<std::string> prefixed_keys;
|
|
prefixed_keys.reserve(keys.size());
|
|
for (auto& key : keys) {
|
|
prefixed_keys.push_back(joinKey(key));
|
|
}
|
|
return store_->multiGet(prefixed_keys);
|
|
}
|
|
|
|
void PrefixStore::multiSet(
|
|
const std::vector<std::string>& keys,
|
|
const std::vector<std::vector<uint8_t>>& values) {
|
|
std::vector<std::string> prefixed_keys;
|
|
prefixed_keys.reserve(keys.size());
|
|
for (auto& key : keys) {
|
|
prefixed_keys.push_back(joinKey(key));
|
|
}
|
|
store_->multiSet(prefixed_keys, values);
|
|
}
|
|
|
|
// Returns true if this store support append, multiGet and multiSet
|
|
bool PrefixStore::hasExtendedApi() const {
|
|
return store_->hasExtendedApi();
|
|
}
|
|
|
|
c10::intrusive_ptr<Store> PrefixStore::getUnderlyingStore() {
|
|
return store_;
|
|
}
|
|
|
|
c10::intrusive_ptr<Store> PrefixStore::getUnderlyingNonPrefixStore() {
|
|
c10::intrusive_ptr<Store> store = store_;
|
|
|
|
while (store) {
|
|
// Attempt to dynamically cast to PrefixStore
|
|
PrefixStore* asPrefixStore = dynamic_cast<PrefixStore*>(store.get());
|
|
if (asPrefixStore) {
|
|
store = asPrefixStore->getUnderlyingStore();
|
|
} else {
|
|
break; // We've reached a non-PrefixStore
|
|
}
|
|
}
|
|
|
|
TORCH_CHECK(
|
|
store != nullptr, "Underlying Non-PrefixStore shouldn't be null.");
|
|
return store;
|
|
}
|
|
|
|
} // namespace c10d
|