gloo: use shared Stores (#150230)

Summary:
X-link: https://github.com/facebookincubator/gloo/pull/423

This modifies `connectFullMesh` to take in a shared_ptr<IStore> instead of a reference. This is an API breaking change but fairly easy to work around.

To have backwards compatibility in PyTorch during the commit phase we add a new ifdef `GLOO_SHARED_STORE` which can provide backwards compatibility until we update the pinned Gloo version in pytorch OSS repo.

This also adds a new `wait_get` method to `IStore` which will allow us to do a more efficient operation in PyTorch TCPStore. PyTorch's `Store::get` automatically waits so we want to make sure we can avoid waiting twice to reduce network traffic.

This change will land simultaneously in PyTorch and Gloo repos.

Test Plan:
```
buck2 test //gloo/... //caffe2/caffe2/contrib/gloo:
```

Differential Revision: D72084111

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150230
Approved by: https://github.com/fduwjj
This commit is contained in:
Tristan Rice 2025-04-01 23:37:25 +00:00 committed by PyTorch MergeBot
parent 4934a83347
commit 6aea4d90fb
2 changed files with 19 additions and 4 deletions

View File

@ -785,10 +785,25 @@ ProcessGroupGloo::ProcessGroupGloo(
contexts_.reserve(options_->devices.size());
for (const auto i : c10::irange(options_->devices.size())) {
auto context = std::make_shared<::gloo::rendezvous::Context>(rank_, size_);
auto store = ::gloo::rendezvous::PrefixStore(std::to_string(i), *store_);
#ifdef GLOO_SHARED_STORE
auto underlyingStore = store_;
#else
auto& underlyingStore = *store_;
#endif
auto store = std::make_shared<::gloo::rendezvous::PrefixStore>(
std::to_string(i), underlyingStore);
#ifdef GLOO_SHARED_STORE
auto connectStore = store;
#else
auto& connectStore = *store;
#endif
context->setTimeout(options_->timeout);
try {
context->connectFullMesh(store, options_->devices[i]);
context->connectFullMesh(connectStore, options_->devices[i]);
} catch (const std::runtime_error& e) {
auto err = e.what();
// TORCH_CHECK to print the cpp stacktrace.

View File

@ -367,7 +367,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
void enableCollectivesTiming() override;
const std::unique_ptr<::gloo::rendezvous::Store>& _getStore() const {
const std::shared_ptr<::gloo::rendezvous::Store>& _getStore() const {
return store_;
}
@ -393,7 +393,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
}
protected:
std::unique_ptr<::gloo::rendezvous::Store> store_;
std::shared_ptr<::gloo::rendezvous::Store> store_;
const c10::intrusive_ptr<Options> options_;
// Every Gloo context represents a set of connections to its peers.