mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
c10d/ProcessGroup: cleanup abort and shutdown (#148798)
This adds `abort` and `shutdown` to `Backend` and `ProcessGroup` objects. This simplifies the logic in `distributed_c10d.py` by having a default noop implementation for all PGs. This will be useful for torchft and upcoming versions of NCCL which will handle abort correctly. Currently `torchft` would have to call internal methods `_abort` on the PGNCCL object directly but with this change we can now just call `.abort()` and have it work for any PG implementation. Test plan: ``` pytest distributed/test_backends.py distributed/test_c10d_common.py distributed/test_c10d_pypg.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/148798 Approved by: https://github.com/kwen2501
This commit is contained in:
parent
9841f0ddcf
commit
7ffadff286
|
|
@ -1559,6 +1559,11 @@ class DummyWork(dist._Work):
|
|||
|
||||
|
||||
class DummyProcessGroup(dist.ProcessGroup):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._aborted = False
|
||||
self._shutdown = False
|
||||
|
||||
def getBackendName(self):
|
||||
return "Dummy"
|
||||
|
||||
|
|
@ -1622,6 +1627,12 @@ class DummyProcessGroup(dist.ProcessGroup):
|
|||
|
||||
return DummyWork()
|
||||
|
||||
def abort(self) -> None:
|
||||
self._aborted = True
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self._shutdown = True
|
||||
|
||||
|
||||
class PythonProcessGroupExtensionTest(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -1794,6 +1805,36 @@ class PythonProcessGroupExtensionTest(MultiProcessTestCase):
|
|||
# intentionally not calling into `destroy_process_group` as not all
|
||||
# user applications would explicitly that.
|
||||
|
||||
def test_shutdown(self) -> None:
|
||||
dist.Backend.register_backend(
|
||||
"dummy", PythonProcessGroupExtensionTest.create_dummy
|
||||
)
|
||||
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "6789"
|
||||
dist.init_process_group("dummy", rank=self.rank, world_size=self.world_size)
|
||||
|
||||
pg = c10d._get_default_group()
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
self.assertTrue(pg._shutdown)
|
||||
|
||||
def test_abort(self) -> None:
|
||||
dist.Backend.register_backend(
|
||||
"dummy", PythonProcessGroupExtensionTest.create_dummy
|
||||
)
|
||||
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "6789"
|
||||
dist.init_process_group("dummy", rank=self.rank, world_size=self.world_size)
|
||||
|
||||
pg = c10d._get_default_group()
|
||||
|
||||
c10d._abort_process_group()
|
||||
|
||||
self.assertTrue(pg._aborted)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(CommonDistributedDataParallelTest)
|
||||
|
||||
|
|
|
|||
|
|
@ -191,6 +191,12 @@ class TestPyProcessGroup(TestCase):
|
|||
pg._set_group_desc("desc")
|
||||
self.assertEqual(pg.group_desc, "py:desc")
|
||||
|
||||
def test_abort_shutdown(self) -> None:
|
||||
# verify this are noops
|
||||
pg = DummyAttrProcessGroup(0, 1)
|
||||
pg.abort()
|
||||
pg.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -299,6 +299,8 @@ class Backend:
|
|||
def options(self) -> Options: ...
|
||||
def rank(self) -> int: ...
|
||||
def size(self) -> int: ...
|
||||
def abort(self) -> None: ...
|
||||
def shutdown(self) -> None: ...
|
||||
def eager_connect_single_device(self, device: torch.device | None) -> None: ...
|
||||
def _set_sequence_number_for_group(self) -> None: ...
|
||||
def _set_default_timeout(self, timeout: timedelta) -> None: ...
|
||||
|
|
@ -324,6 +326,8 @@ class ProcessGroup:
|
|||
) -> None: ...
|
||||
def rank(self) -> int: ...
|
||||
def size(self) -> int: ...
|
||||
def abort(self) -> None: ...
|
||||
def shutdown(self) -> None: ...
|
||||
@overload
|
||||
def broadcast(
|
||||
self,
|
||||
|
|
@ -600,7 +604,6 @@ class ProcessGroupNCCL(Backend):
|
|||
def _group_start(self) -> None: ...
|
||||
def _group_end(self) -> None: ...
|
||||
def _set_default_timeout(self, timeout) -> None: ...
|
||||
def _shutdown(self) -> None: ...
|
||||
def perform_nocolor_split(self, device: torch.device) -> None: ...
|
||||
def register_mem_pool(self, pool: torch.cuda.MemPool) -> None: ...
|
||||
def deregister_mem_pool(self, pool: torch.cuda.MemPool) -> None: ...
|
||||
|
|
|
|||
|
|
@ -436,6 +436,14 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Aborts all pending operations and connections in the backend if the backend
|
||||
// supports it.
|
||||
virtual void abort() {}
|
||||
|
||||
// Shutdown the backend if the backend supports it. This should be used for
|
||||
// normal shutdown.
|
||||
virtual void shutdown() {}
|
||||
|
||||
protected:
|
||||
// Implementations of this interface need to call this to setup
|
||||
// appropriate logging etc.
|
||||
|
|
|
|||
|
|
@ -874,6 +874,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
|||
getDefaultBackend()->waitForPendingWorks();
|
||||
}
|
||||
|
||||
virtual void shutdown() {
|
||||
for (auto& backend : backendTypeToBackend_) {
|
||||
backend.second->shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
virtual void abort() {
|
||||
for (auto& backend : backendTypeToBackend_) {
|
||||
backend.second->abort();
|
||||
}
|
||||
}
|
||||
|
||||
bool hasHooks() const {
|
||||
// `getDefaultBackend` will throw today if the backend is set to `undefined`
|
||||
// (in case of `init_process_group(nothing)`)
|
||||
|
|
|
|||
|
|
@ -144,7 +144,7 @@ class TORCH_API ProcessGroupMPI : public Backend {
|
|||
~ProcessGroupMPI() override;
|
||||
|
||||
// Abort the MPI program, needs to be called when exception is detected
|
||||
void abort();
|
||||
void abort() override;
|
||||
|
||||
const std::string getBackendName() const override {
|
||||
return std::string(MPI_BACKEND_NAME);
|
||||
|
|
|
|||
|
|
@ -761,11 +761,11 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
|||
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> initIntraNodeComm();
|
||||
|
||||
// Destroy (shutdown) this backend -- normal exit.
|
||||
void shutdown();
|
||||
void shutdown() override;
|
||||
|
||||
// Provides an API to abort the ProcessGroup (similar to ncclCommAbort)
|
||||
// instead of relying on ProcessGroupNCCL destructor.
|
||||
void abort();
|
||||
void abort() override;
|
||||
|
||||
void eagerConnectSingleDevice(at::Device device) override;
|
||||
|
||||
|
|
|
|||
|
|
@ -111,6 +111,14 @@ class PyProcessGroup : public ProcessGroup {
|
|||
);
|
||||
}
|
||||
|
||||
void abort() override {
|
||||
PYBIND11_OVERRIDE(
|
||||
void, /* Return type */
|
||||
ProcessGroup, /* Parent class */
|
||||
abort, /* Name of function in C++ */
|
||||
);
|
||||
}
|
||||
|
||||
const std::string& getGroupName() const override {
|
||||
PYBIND11_OVERRIDE(
|
||||
const std::string&, /* Return type */
|
||||
|
|
|
|||
|
|
@ -1961,6 +1961,16 @@ communication mechanism.
|
|||
.def("rank", &::c10d::ProcessGroup::getRank, R"(Get the rank of this process group.)")
|
||||
.def("size", &::c10d::ProcessGroup::getSize, R"(Get the size of this process group.)")
|
||||
.def("name", &::c10d::ProcessGroup::getBackendName, R"(Get the name of this process group.)")
|
||||
.def(
|
||||
"abort",
|
||||
&::c10d::ProcessGroup::abort,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
"abort all operations and connections if supported by the backend")
|
||||
.def(
|
||||
"shutdown",
|
||||
&::c10d::ProcessGroup::shutdown,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
"shutdown the process group")
|
||||
.def("_id", &::c10d::ProcessGroup::getID)
|
||||
.def(
|
||||
"_backend_id",
|
||||
|
|
@ -2478,6 +2488,16 @@ Arguments:
|
|||
.def("rank", &::c10d::Backend::getRank)
|
||||
.def("size", &::c10d::Backend::getSize)
|
||||
.def("name", &::c10d::Backend::getBackendName)
|
||||
.def(
|
||||
"abort",
|
||||
&::c10d::Backend::abort,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
"abort all operations and connections if supported by the backend")
|
||||
.def(
|
||||
"shutdown",
|
||||
&::c10d::Backend::shutdown,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
"shutdown the backend")
|
||||
.def_property_readonly(
|
||||
"supports_splitting",
|
||||
&::c10d::Backend::supportsSplitting,
|
||||
|
|
@ -2972,12 +2992,6 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
|||
py::arg("size"),
|
||||
py::arg("timeout") = ::c10d::kProcessGroupNCCLDefaultTimeout,
|
||||
R"(Create a new ProcessGroupNCCL instance.)")
|
||||
.def(
|
||||
"_shutdown",
|
||||
[](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) {
|
||||
return self->shutdown();
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("_group_start", &::c10d::ProcessGroupNCCL::groupStart)
|
||||
.def("_group_end", &::c10d::ProcessGroupNCCL::groupEnd)
|
||||
.def(
|
||||
|
|
@ -3025,11 +3039,6 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
|||
.def(
|
||||
"deregister_mem_pool",
|
||||
&::c10d::ProcessGroupNCCL::deregisterMemPool)
|
||||
.def(
|
||||
"abort",
|
||||
&::c10d::ProcessGroupNCCL::abort,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(Abort the process group.)")
|
||||
.def(
|
||||
"_is_initialized",
|
||||
&::c10d::ProcessGroupNCCL::isInitialized,
|
||||
|
|
|
|||
|
|
@ -1798,36 +1798,6 @@ def _get_split_source(pg):
|
|||
return split_from
|
||||
|
||||
|
||||
def _shutdown_backend(pg):
|
||||
"""
|
||||
Try to shut down the backend of a process group.
|
||||
Currently, only ProcessGroupNCCL backend is supported.
|
||||
No op for other backends.
|
||||
"""
|
||||
backend = None
|
||||
try:
|
||||
backend = pg._get_backend(torch.device("cuda"))
|
||||
except RuntimeError:
|
||||
pass
|
||||
if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
|
||||
# explicitly call shutdown to ensure that NCCL resources are released
|
||||
backend._shutdown()
|
||||
|
||||
|
||||
def _abort_backend(pg: ProcessGroup):
|
||||
"""
|
||||
Abort the backend of a process group.
|
||||
Currently, only ProcessGroupNCCL backend is supported.
|
||||
No op for other backends.
|
||||
"""
|
||||
try:
|
||||
backend = pg._get_backend(torch.device("cuda"))
|
||||
except RuntimeError:
|
||||
backend = None
|
||||
if isinstance(backend, ProcessGroupNCCL):
|
||||
backend.abort()
|
||||
|
||||
|
||||
def _new_process_group_helper(
|
||||
group_size,
|
||||
group_rank,
|
||||
|
|
@ -2162,7 +2132,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
|
|||
for pg_to_shutdown in sorted(
|
||||
_world.pg_names, key=lambda x: _world.pg_names[x], reverse=True
|
||||
):
|
||||
_shutdown_backend(pg_to_shutdown)
|
||||
pg_to_shutdown.shutdown()
|
||||
|
||||
_update_default_pg(None)
|
||||
_world.pg_map.clear()
|
||||
|
|
@ -2184,7 +2154,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
|
|||
# process group is in good state, we aren't dealing with failures.
|
||||
_world.group_count = 0
|
||||
else:
|
||||
_shutdown_backend(pg)
|
||||
pg.shutdown()
|
||||
del _world.pg_map[pg]
|
||||
del _world.pg_names[pg]
|
||||
del _world.pg_group_ranks[pg]
|
||||
|
|
@ -2240,24 +2210,19 @@ def _abort_process_group(group: Optional[ProcessGroup] = None):
|
|||
except RuntimeError:
|
||||
backend = None
|
||||
|
||||
if not isinstance(backend, ProcessGroupNCCL):
|
||||
logger.warning(
|
||||
"`abort_process_group` currently only has implementation for ProcessGroupNCCL; "
|
||||
"however, no NCCL backend is found. This call will be a no-op."
|
||||
)
|
||||
return
|
||||
|
||||
if group == GroupMember.WORLD:
|
||||
if group is None or group == GroupMember.WORLD:
|
||||
# Abort all backends within a ncclGroupStart|End semantic.
|
||||
# This ensures that different NCCL communicators' abort calls won't
|
||||
# deadlock each other.
|
||||
# For details, please see: https://github.com/pytorch/pytorch/issues/119797
|
||||
backend._group_start()
|
||||
if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
|
||||
backend._group_start()
|
||||
for pg_to_abort in sorted(
|
||||
_world.pg_names, key=lambda x: _world.pg_names[x], reverse=True
|
||||
):
|
||||
_abort_backend(pg_to_abort)
|
||||
backend._group_end()
|
||||
pg_to_abort.abort()
|
||||
if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
|
||||
backend._group_end()
|
||||
|
||||
_update_default_pg(None)
|
||||
_world.pg_map.clear()
|
||||
|
|
@ -2279,7 +2244,7 @@ def _abort_process_group(group: Optional[ProcessGroup] = None):
|
|||
# process group is in good state, we aren't dealing with failures.
|
||||
_world.group_count = 0
|
||||
else:
|
||||
_abort_backend(pg)
|
||||
pg.abort()
|
||||
del _world.pg_map[pg]
|
||||
del _world.pg_names[pg]
|
||||
del _world.pg_group_ranks[pg]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user