mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
This adds a new `clone()` method to Store which will return a new Store instance that can be used from a different thread.
This is intended to better support multiple threads with stores such as when ProcessGroupNCCL needs a store to do error propagation.
Related issue: https://github.com/pytorch/pytorch/issues/150943
Approved by: https://github.com/fduwjj
Test Plan:
contbuild & OSS CI, see 205881ea4a
Test plan from GitHub:
```
pytest test/distributed/test_store.py -k PythonStore
pytest test/distributed/test_store.py -k clone
```
Differential Revision: D72789690
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151045
Approved by: https://github.com/XilunWu, https://github.com/fduwjj
This commit is contained in:
parent
75162aa7de
commit
8b5e717601
|
|
@ -172,6 +172,15 @@ class StoreTestBase:
|
|||
def test_multi_get(self):
|
||||
self._test_multi_get(self._create_store())
|
||||
|
||||
def test_clone(self):
|
||||
a = self._create_store()
|
||||
b = a.clone()
|
||||
|
||||
self.assertIsInstance(b, dist.Store)
|
||||
|
||||
a.set("foo", "bar")
|
||||
self.assertEqual(b.get("foo"), b"bar")
|
||||
|
||||
# This is the number of keys used in test_set_get. Adding this as a class
|
||||
# property instead of hardcoding in the test since some Store
|
||||
# implementations will have differing number of keys. In the base case,
|
||||
|
|
@ -628,6 +637,9 @@ class MyPythonStore(dist.Store):
|
|||
val = self.store[key] = newValue
|
||||
return val
|
||||
|
||||
def clone(self) -> "MyPythonStore":
|
||||
return self
|
||||
|
||||
|
||||
class PythonStoreTest(TestCase):
|
||||
def test_set_get(self):
|
||||
|
|
|
|||
|
|
@ -297,6 +297,10 @@ FileStore::FileStore(std::string path, int numWorkers)
|
|||
addHelper(refCountKey_, 1);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Store> FileStore::clone() {
|
||||
return c10::make_intrusive<FileStore>(path_, numWorkers_);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-exception-escape)
|
||||
FileStore::~FileStore() {
|
||||
// If the file does not exist - exit.
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ class TORCH_API FileStore : public Store {
|
|||
public:
|
||||
explicit FileStore(std::string path, int numWorkers);
|
||||
|
||||
c10::intrusive_ptr<Store> clone() override;
|
||||
|
||||
~FileStore() override;
|
||||
|
||||
void set(const std::string& key, const std::vector<uint8_t>& value) override;
|
||||
|
|
|
|||
|
|
@ -9,6 +9,10 @@
|
|||
|
||||
namespace c10d {
|
||||
|
||||
c10::intrusive_ptr<Store> HashStore::clone() {
|
||||
return c10::intrusive_ptr<Store>::unsafe_reclaim_from_nonowning(this);
|
||||
}
|
||||
|
||||
void HashStore::set(const std::string& key, const std::vector<uint8_t>& data) {
|
||||
std::unique_lock<std::mutex> lock(m_);
|
||||
map_[key] = data;
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ namespace c10d {
|
|||
|
||||
class TORCH_API HashStore : public Store {
|
||||
public:
|
||||
c10::intrusive_ptr<Store> clone() override;
|
||||
|
||||
~HashStore() override = default;
|
||||
|
||||
void set(const std::string& key, const std::vector<uint8_t>& data) override;
|
||||
|
|
|
|||
|
|
@ -6,6 +6,10 @@ namespace c10d {
|
|||
PrefixStore::PrefixStore(std::string prefix, c10::intrusive_ptr<Store> store)
|
||||
: prefix_(std::move(prefix)), store_(std::move(store)) {}
|
||||
|
||||
c10::intrusive_ptr<Store> PrefixStore::clone() {
|
||||
return c10::make_intrusive<PrefixStore>(prefix_, store_->clone());
|
||||
}
|
||||
|
||||
std::string PrefixStore::joinKey(const std::string& key) {
|
||||
return prefix_ + "/" + key;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ class TORCH_API PrefixStore : public Store {
|
|||
public:
|
||||
explicit PrefixStore(std::string prefix, c10::intrusive_ptr<Store> store);
|
||||
|
||||
c10::intrusive_ptr<Store> clone() override;
|
||||
|
||||
using Store::set;
|
||||
void set(const std::string& key, const std::vector<uint8_t>& value) override;
|
||||
|
||||
|
|
|
|||
|
|
@ -32,6 +32,10 @@ class TORCH_API Store : public torch::CustomClassHolder {
|
|||
|
||||
~Store() override = default;
|
||||
|
||||
// Clone a thread safe copy of this store object that points to the same
|
||||
// underlying store.
|
||||
virtual c10::intrusive_ptr<Store> clone() = 0;
|
||||
|
||||
void set(const std::string& key, const std::string& value);
|
||||
|
||||
virtual void set(
|
||||
|
|
|
|||
|
|
@ -350,6 +350,17 @@ TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts)
|
|||
|
||||
TCPStore::~TCPStore() = default;
|
||||
|
||||
c10::intrusive_ptr<Store> TCPStore::clone() {
|
||||
TCPStoreOptions opts;
|
||||
opts.port = addr_.port;
|
||||
opts.isServer = false;
|
||||
opts.waitWorkers = false;
|
||||
opts.timeout = timeout_;
|
||||
opts.useLibUV = usingLibUv_;
|
||||
|
||||
return c10::make_intrusive<TCPStore>(addr_.host, opts);
|
||||
}
|
||||
|
||||
void TCPStore::waitForWorkers() {
|
||||
STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__waitForWorkers);
|
||||
if (!numWorkers_.has_value()) {
|
||||
|
|
|
|||
|
|
@ -77,6 +77,8 @@ class TORCH_API TCPStore : public Store {
|
|||
|
||||
~TCPStore() override;
|
||||
|
||||
c10::intrusive_ptr<Store> clone() override;
|
||||
|
||||
void set(const std::string& key, const std::vector<uint8_t>& value) override;
|
||||
|
||||
std::vector<uint8_t> compareSet(
|
||||
|
|
|
|||
|
|
@ -274,6 +274,10 @@ class PythonStore : public ::c10d::Store {
|
|||
PYBIND11_OVERLOAD_PURE(void, ::c10d::Store, wait, keys, timeout);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Store> clone() override {
|
||||
PYBIND11_OVERLOAD_PURE(c10::intrusive_ptr<Store>, ::c10d::Store, clone);
|
||||
}
|
||||
|
||||
// Note: this function manually calls the Python-side overload
|
||||
// for this function instead of using the PYBIND11_OVERLOAD_XYZ
|
||||
// macros. This is done so that we can call the Python-side
|
||||
|
|
@ -1208,6 +1212,16 @@ and :class:`~torch.distributed.HashStore`).
|
|||
)")
|
||||
// Default constructor.
|
||||
.def(py::init<>())
|
||||
.def(
|
||||
"clone",
|
||||
&::c10d::Store::clone,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(
|
||||
Clones the store and returns a new object that points to the same underlying
|
||||
store. The returned store can be used concurrently with the original object.
|
||||
This is intended to provide a safe way to use a store from multiple threads by
|
||||
cloning one store per thread.
|
||||
)")
|
||||
// Convert from std::string to std::vector<uint8>.
|
||||
.def(
|
||||
"set",
|
||||
|
|
@ -3572,6 +3586,14 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
|||
if (get("key3") != "15") {
|
||||
TORCH_CHECK(false, "assertion failed");
|
||||
}
|
||||
|
||||
auto cloned = store->clone();
|
||||
store->set("foo", "bar");
|
||||
|
||||
auto ret = cloned->get("foo");
|
||||
TORCH_CHECK(
|
||||
std::string(ret.begin(), ret.end()) == "bar",
|
||||
"checked clone behavior");
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user