mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
gloo: use shared Stores (#150230)
Summary: X-link: https://github.com/facebookincubator/gloo/pull/423 This modifies `connectFullMesh` to take in a shared_ptr<IStore> instead of a reference. This is an API breaking change but fairly easy to work around. To have backwards compatibility in PyTorch during the commit phase we add a new ifdef `GLOO_SHARED_STORE` which can provide backwards compatibility until we update the pinned Gloo version in pytorch OSS repo. This also adds a new `wait_get` method to `IStore` which will allow us to do a more efficient operation in PyTorch TCPStore. PyTorch's `Store::get` automatically waits so we want to make sure we can avoid waiting twice to reduce network traffic. This change will land simultaneously in PyTorch and Gloo repos. Test Plan: ``` buck2 test //gloo/... //caffe2/caffe2/contrib/gloo: ``` Differential Revision: D72084111 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150230 Approved by: https://github.com/fduwjj
This commit is contained in:
parent
4934a83347
commit
6aea4d90fb
|
|
@ -785,10 +785,25 @@ ProcessGroupGloo::ProcessGroupGloo(
|
||||||
contexts_.reserve(options_->devices.size());
|
contexts_.reserve(options_->devices.size());
|
||||||
for (const auto i : c10::irange(options_->devices.size())) {
|
for (const auto i : c10::irange(options_->devices.size())) {
|
||||||
auto context = std::make_shared<::gloo::rendezvous::Context>(rank_, size_);
|
auto context = std::make_shared<::gloo::rendezvous::Context>(rank_, size_);
|
||||||
auto store = ::gloo::rendezvous::PrefixStore(std::to_string(i), *store_);
|
|
||||||
|
#ifdef GLOO_SHARED_STORE
|
||||||
|
auto underlyingStore = store_;
|
||||||
|
#else
|
||||||
|
auto& underlyingStore = *store_;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
auto store = std::make_shared<::gloo::rendezvous::PrefixStore>(
|
||||||
|
std::to_string(i), underlyingStore);
|
||||||
|
|
||||||
|
#ifdef GLOO_SHARED_STORE
|
||||||
|
auto connectStore = store;
|
||||||
|
#else
|
||||||
|
auto& connectStore = *store;
|
||||||
|
#endif
|
||||||
|
|
||||||
context->setTimeout(options_->timeout);
|
context->setTimeout(options_->timeout);
|
||||||
try {
|
try {
|
||||||
context->connectFullMesh(store, options_->devices[i]);
|
context->connectFullMesh(connectStore, options_->devices[i]);
|
||||||
} catch (const std::runtime_error& e) {
|
} catch (const std::runtime_error& e) {
|
||||||
auto err = e.what();
|
auto err = e.what();
|
||||||
// TORCH_CHECK to print the cpp stacktrace.
|
// TORCH_CHECK to print the cpp stacktrace.
|
||||||
|
|
|
||||||
|
|
@ -367,7 +367,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
|
||||||
|
|
||||||
void enableCollectivesTiming() override;
|
void enableCollectivesTiming() override;
|
||||||
|
|
||||||
const std::unique_ptr<::gloo::rendezvous::Store>& _getStore() const {
|
const std::shared_ptr<::gloo::rendezvous::Store>& _getStore() const {
|
||||||
return store_;
|
return store_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -393,7 +393,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::unique_ptr<::gloo::rendezvous::Store> store_;
|
std::shared_ptr<::gloo::rendezvous::Store> store_;
|
||||||
const c10::intrusive_ptr<Options> options_;
|
const c10::intrusive_ptr<Options> options_;
|
||||||
|
|
||||||
// Every Gloo context represents a set of connections to its peers.
|
// Every Gloo context represents a set of connections to its peers.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user