c10d/Store: add clone feature (#150966) (#150966) (#151045)

Summary:
This adds a new `clone()` method to Store which will return a new Store instance that can be used from a different thread.

This is intended to better support multiple threads with stores such as when ProcessGroupNCCL needs a store to do error propagation.

Related issue: https://github.com/pytorch/pytorch/issues/150943

Approved by: https://github.com/fduwjj

Test Plan:
contbuild & OSS CI, see 205881ea4a

Test plan from GitHub:
```
pytest test/distributed/test_store.py -k PythonStore
pytest test/distributed/test_store.py -k clone
```

Differential Revision: D72789690

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151045
Approved by: https://github.com/XilunWu, https://github.com/fduwjj
This commit is contained in:
Tristan Rice 2025-04-11 04:00:23 +00:00 committed by PyTorch MergeBot
parent 75162aa7de
commit 8b5e717601
11 changed files with 69 additions and 0 deletions

View File

@ -172,6 +172,15 @@ class StoreTestBase:
def test_multi_get(self):
self._test_multi_get(self._create_store())
def test_clone(self):
a = self._create_store()
b = a.clone()
self.assertIsInstance(b, dist.Store)
a.set("foo", "bar")
self.assertEqual(b.get("foo"), b"bar")
# This is the number of keys used in test_set_get. Adding this as a class
# property instead of hardcoding in the test since some Store
# implementations will have differing number of keys. In the base case,
@ -628,6 +637,9 @@ class MyPythonStore(dist.Store):
val = self.store[key] = newValue
return val
def clone(self) -> "MyPythonStore":
return self
class PythonStoreTest(TestCase):
def test_set_get(self):

View File

@ -297,6 +297,10 @@ FileStore::FileStore(std::string path, int numWorkers)
addHelper(refCountKey_, 1);
}
c10::intrusive_ptr<Store> FileStore::clone() {
return c10::make_intrusive<FileStore>(path_, numWorkers_);
}
// NOLINTNEXTLINE(bugprone-exception-escape)
FileStore::~FileStore() {
// If the file does not exist - exit.

View File

@ -13,6 +13,8 @@ class TORCH_API FileStore : public Store {
public:
explicit FileStore(std::string path, int numWorkers);
c10::intrusive_ptr<Store> clone() override;
~FileStore() override;
void set(const std::string& key, const std::vector<uint8_t>& value) override;

View File

@ -9,6 +9,10 @@
namespace c10d {
c10::intrusive_ptr<Store> HashStore::clone() {
return c10::intrusive_ptr<Store>::unsafe_reclaim_from_nonowning(this);
}
void HashStore::set(const std::string& key, const std::vector<uint8_t>& data) {
std::unique_lock<std::mutex> lock(m_);
map_[key] = data;

View File

@ -10,6 +10,8 @@ namespace c10d {
class TORCH_API HashStore : public Store {
public:
c10::intrusive_ptr<Store> clone() override;
~HashStore() override = default;
void set(const std::string& key, const std::vector<uint8_t>& data) override;

View File

@ -6,6 +6,10 @@ namespace c10d {
PrefixStore::PrefixStore(std::string prefix, c10::intrusive_ptr<Store> store)
: prefix_(std::move(prefix)), store_(std::move(store)) {}
c10::intrusive_ptr<Store> PrefixStore::clone() {
return c10::make_intrusive<PrefixStore>(prefix_, store_->clone());
}
std::string PrefixStore::joinKey(const std::string& key) {
return prefix_ + "/" + key;
}

View File

@ -8,6 +8,8 @@ class TORCH_API PrefixStore : public Store {
public:
explicit PrefixStore(std::string prefix, c10::intrusive_ptr<Store> store);
c10::intrusive_ptr<Store> clone() override;
using Store::set;
void set(const std::string& key, const std::vector<uint8_t>& value) override;

View File

@ -32,6 +32,10 @@ class TORCH_API Store : public torch::CustomClassHolder {
~Store() override = default;
// Clone a thread safe copy of this store object that points to the same
// underlying store.
virtual c10::intrusive_ptr<Store> clone() = 0;
void set(const std::string& key, const std::string& value);
virtual void set(

View File

@ -350,6 +350,17 @@ TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts)
TCPStore::~TCPStore() = default;
c10::intrusive_ptr<Store> TCPStore::clone() {
TCPStoreOptions opts;
opts.port = addr_.port;
opts.isServer = false;
opts.waitWorkers = false;
opts.timeout = timeout_;
opts.useLibUV = usingLibUv_;
return c10::make_intrusive<TCPStore>(addr_.host, opts);
}
void TCPStore::waitForWorkers() {
STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__waitForWorkers);
if (!numWorkers_.has_value()) {

View File

@ -77,6 +77,8 @@ class TORCH_API TCPStore : public Store {
~TCPStore() override;
c10::intrusive_ptr<Store> clone() override;
void set(const std::string& key, const std::vector<uint8_t>& value) override;
std::vector<uint8_t> compareSet(

View File

@ -274,6 +274,10 @@ class PythonStore : public ::c10d::Store {
PYBIND11_OVERLOAD_PURE(void, ::c10d::Store, wait, keys, timeout);
}
c10::intrusive_ptr<Store> clone() override {
PYBIND11_OVERLOAD_PURE(c10::intrusive_ptr<Store>, ::c10d::Store, clone);
}
// Note: this function manually calls the Python-side overload
// for this function instead of using the PYBIND11_OVERLOAD_XYZ
// macros. This is done so that we can call the Python-side
@ -1208,6 +1212,16 @@ and :class:`~torch.distributed.HashStore`).
)")
// Default constructor.
.def(py::init<>())
.def(
"clone",
&::c10d::Store::clone,
py::call_guard<py::gil_scoped_release>(),
R"(
Clones the store and returns a new object that points to the same underlying
store. The returned store can be used concurrently with the original object.
This is intended to provide a safe way to use a store from multiple threads by
cloning one store per thread.
)")
// Convert from std::string to std::vector<uint8>.
.def(
"set",
@ -3572,6 +3586,14 @@ such as `dist.all_reduce(tensor, async_op=True)`.
if (get("key3") != "15") {
TORCH_CHECK(false, "assertion failed");
}
auto cloned = store->clone();
store->set("foo", "bar");
auto ret = cloned->get("foo");
TORCH_CHECK(
std::string(ret.begin(), ret.end()) == "bar",
"checked clone behavior");
},
py::call_guard<py::gil_scoped_release>());