pytorch/torch/csrc/distributed/c10d/PrefixStore.cpp
fduwjj 180425df9b [c10d] Add a recursive method to get the inner most store (#117074)
In c10d PG initialization, we wrap TCPStore with multiple layers of PrefixStore which adds layers of prefix.

One example is:
"default_pg/0//cuda//timeout_dump"
When initialized the default PG, because there is no store passed. We first add the prefix "default_pg" to the TCPStore returned from rendezvous:

bdeaaad70c/torch/distributed/distributed_c10d.py (L1240)

We then add pg_name (aka 0) bdeaaad70c/torch/distributed/distributed_c10d.py (L1376) and device (aka cuda) bdeaaad70c/torch/distributed/distributed_c10d.py (L1387)

to the prefix. Then when we call store_->set("timeout_dump"). The actual key used for writing into TCPStore is "default_pg/0//cuda//timeout_dump".

For sub-PG, things get even interesting, we put the store wrapped with default pg name to a cache:
bdeaaad70c/torch/distributed/distributed_c10d.py (L1517)

And when creating each subPG, it is append its PG name right after the cached store. The example keys are:
'default_pg/0//10//cuda//timeout_dump', 'default_pg/0//12//cuda//timeout_dump', 'default_pg/0//38//cuda//timeout_dump', 'default_pg/0//39//cuda//timeout_dump'. (10, 12, 38 and 39 are all PG names of each subPG created)

The reason why the number in the name is bumped up so high is because for each subPG creation, all ranks have to call the API together and the global variable used for PG name will be bumped up monolithically:

bdeaaad70c/torch/distributed/distributed_c10d.py (L3666)

Similar things happen for using hashing for PG names.

This has a potential issue, because each sub-PG has an instance of ProcessGroupNCCL, and if we want to set something global to notify all sub-PGs (and all ranks). This added prefix causes bugs. For example, if on sub-PG 1, we set a value to TCPStore with key ('default_pg/0//1//cuda//timeout_dump'), while we use the default PG instances to check the TCPStore, which are using the key ('default_pg/0//cuda//timeout_dump'), default PG instances will never get the notified signals. So in this PR, we added a new API in PrefixStore which we get the innermost non-PrefixStore for set and check. The next PR will make changes in NCCL watchdog.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117074
Approved by: https://github.com/wconstab, https://github.com/H-Huang
2024-01-10 20:22:55 +00:00

130 lines
3.4 KiB
C++

#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <utility>
namespace c10d {
PrefixStore::PrefixStore(std::string prefix, c10::intrusive_ptr<Store> store)
: prefix_(std::move(prefix)), store_(std::move(store)) {}
std::string PrefixStore::joinKey(const std::string& key) {
return prefix_ + "/" + key;
}
std::vector<std::string> PrefixStore::joinKeys(
const std::vector<std::string>& keys) {
std::vector<std::string> joinedKeys;
joinedKeys.reserve(keys.size());
for (const auto& key : keys) {
joinedKeys.emplace_back(joinKey(key));
}
return joinedKeys;
}
void PrefixStore::set(
const std::string& key,
const std::vector<uint8_t>& value) {
store_->set(joinKey(key), value);
}
std::vector<uint8_t> PrefixStore::compareSet(
const std::string& key,
const std::vector<uint8_t>& expectedValue,
const std::vector<uint8_t>& desiredValue) {
return store_->compareSet(joinKey(key), expectedValue, desiredValue);
}
std::vector<uint8_t> PrefixStore::get(const std::string& key) {
return store_->get(joinKey(key));
}
int64_t PrefixStore::add(const std::string& key, int64_t value) {
return store_->add(joinKey(key), value);
}
bool PrefixStore::deleteKey(const std::string& key) {
return store_->deleteKey(joinKey(key));
}
int64_t PrefixStore::getNumKeys() {
return store_->getNumKeys();
}
bool PrefixStore::check(const std::vector<std::string>& keys) {
auto joinedKeys = joinKeys(keys);
return store_->check(joinedKeys);
}
void PrefixStore::wait(const std::vector<std::string>& keys) {
auto joinedKeys = joinKeys(keys);
store_->wait(joinedKeys);
}
void PrefixStore::wait(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) {
auto joinedKeys = joinKeys(keys);
store_->wait(joinedKeys, timeout);
}
const std::chrono::milliseconds& PrefixStore::getTimeout() const noexcept {
return store_->getTimeout();
}
void PrefixStore::setTimeout(const std::chrono::milliseconds& timeout) {
store_->setTimeout(timeout);
}
void PrefixStore::append(
const std::string& key,
const std::vector<uint8_t>& value) {
store_->append(joinKey(key), value);
}
std::vector<std::vector<uint8_t>> PrefixStore::multiGet(
const std::vector<std::string>& keys) {
std::vector<std::string> prefixed_keys;
for (auto& key : keys) {
prefixed_keys.push_back(joinKey(key));
}
return store_->multiGet(prefixed_keys);
}
void PrefixStore::multiSet(
const std::vector<std::string>& keys,
const std::vector<std::vector<uint8_t>>& values) {
std::vector<std::string> prefixed_keys;
for (auto& key : keys) {
prefixed_keys.push_back(joinKey(key));
}
store_->multiSet(prefixed_keys, values);
}
// Returns true if this store support append, multiGet and multiSet
bool PrefixStore::hasExtendedApi() const {
return store_->hasExtendedApi();
}
c10::intrusive_ptr<Store> PrefixStore::getUnderlyingStore() {
return store_;
}
c10::intrusive_ptr<Store> PrefixStore::getUnderlyingNonPrefixStore() {
c10::intrusive_ptr<Store> store = store_;
while (store) {
// Attempt to dynamically cast to PrefixStore
PrefixStore* asPrefixStore = dynamic_cast<PrefixStore*>(store.get());
if (asPrefixStore) {
store = asPrefixStore->getUnderlyingStore();
} else {
break; // We've reached a non-PrefixStore
}
}
TORCH_CHECK(
store != nullptr, "Underlying Non-PrefixStore shouldn't be null.");
return store;
}
} // namespace c10d