c10d/ProcessGroup: cleanup abort and shutdown (#148798)

This adds `abort` and `shutdown` to `Backend` and `ProcessGroup` objects. This simplifies the logic in `distributed_c10d.py` by having a default noop implementation for all PGs.

This will be useful for torchft and upcoming versions of NCCL which will handle abort correctly. Currently `torchft` would have to call internal methods `_abort` on the PGNCCL object directly but with this change we can now just call `.abort()` and have it work for any PG implementation.

Test plan:

```
pytest distributed/test_backends.py distributed/test_c10d_common.py distributed/test_c10d_pypg.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148798
Approved by: https://github.com/kwen2501
This commit is contained in:
Tristan Rice 2025-03-08 18:33:14 +00:00 committed by PyTorch MergeBot
parent 9841f0ddcf
commit 7ffadff286
10 changed files with 111 additions and 59 deletions

View File

@ -1559,6 +1559,11 @@ class DummyWork(dist._Work):
class DummyProcessGroup(dist.ProcessGroup):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._aborted = False
self._shutdown = False
def getBackendName(self):
return "Dummy"
@ -1622,6 +1627,12 @@ class DummyProcessGroup(dist.ProcessGroup):
return DummyWork()
def abort(self) -> None:
self._aborted = True
def shutdown(self) -> None:
self._shutdown = True
class PythonProcessGroupExtensionTest(MultiProcessTestCase):
def setUp(self):
@ -1794,6 +1805,36 @@ class PythonProcessGroupExtensionTest(MultiProcessTestCase):
# intentionally not calling into `destroy_process_group` as not all
# user applications would explicitly that.
def test_shutdown(self) -> None:
dist.Backend.register_backend(
"dummy", PythonProcessGroupExtensionTest.create_dummy
)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "6789"
dist.init_process_group("dummy", rank=self.rank, world_size=self.world_size)
pg = c10d._get_default_group()
dist.destroy_process_group()
self.assertTrue(pg._shutdown)
def test_abort(self) -> None:
dist.Backend.register_backend(
"dummy", PythonProcessGroupExtensionTest.create_dummy
)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "6789"
dist.init_process_group("dummy", rank=self.rank, world_size=self.world_size)
pg = c10d._get_default_group()
c10d._abort_process_group()
self.assertTrue(pg._aborted)
instantiate_parametrized_tests(CommonDistributedDataParallelTest)

View File

@ -191,6 +191,12 @@ class TestPyProcessGroup(TestCase):
pg._set_group_desc("desc")
self.assertEqual(pg.group_desc, "py:desc")
def test_abort_shutdown(self) -> None:
# verify this are noops
pg = DummyAttrProcessGroup(0, 1)
pg.abort()
pg.shutdown()
if __name__ == "__main__":
run_tests()

View File

@ -299,6 +299,8 @@ class Backend:
def options(self) -> Options: ...
def rank(self) -> int: ...
def size(self) -> int: ...
def abort(self) -> None: ...
def shutdown(self) -> None: ...
def eager_connect_single_device(self, device: torch.device | None) -> None: ...
def _set_sequence_number_for_group(self) -> None: ...
def _set_default_timeout(self, timeout: timedelta) -> None: ...
@ -324,6 +326,8 @@ class ProcessGroup:
) -> None: ...
def rank(self) -> int: ...
def size(self) -> int: ...
def abort(self) -> None: ...
def shutdown(self) -> None: ...
@overload
def broadcast(
self,
@ -600,7 +604,6 @@ class ProcessGroupNCCL(Backend):
def _group_start(self) -> None: ...
def _group_end(self) -> None: ...
def _set_default_timeout(self, timeout) -> None: ...
def _shutdown(self) -> None: ...
def perform_nocolor_split(self, device: torch.device) -> None: ...
def register_mem_pool(self, pool: torch.cuda.MemPool) -> None: ...
def deregister_mem_pool(self, pool: torch.cuda.MemPool) -> None: ...

View File

@ -436,6 +436,14 @@ class TORCH_API Backend : public torch::CustomClassHolder {
return false;
}
// Aborts all pending operations and connections in the backend if the backend
// supports it.
virtual void abort() {}
// Shutdown the backend if the backend supports it. This should be used for
// normal shutdown.
virtual void shutdown() {}
protected:
// Implementations of this interface need to call this to setup
// appropriate logging etc.

View File

@ -874,6 +874,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
getDefaultBackend()->waitForPendingWorks();
}
virtual void shutdown() {
for (auto& backend : backendTypeToBackend_) {
backend.second->shutdown();
}
}
virtual void abort() {
for (auto& backend : backendTypeToBackend_) {
backend.second->abort();
}
}
bool hasHooks() const {
// `getDefaultBackend` will throw today if the backend is set to `undefined`
// (in case of `init_process_group(nothing)`)

View File

@ -144,7 +144,7 @@ class TORCH_API ProcessGroupMPI : public Backend {
~ProcessGroupMPI() override;
// Abort the MPI program, needs to be called when exception is detected
void abort();
void abort() override;
const std::string getBackendName() const override {
return std::string(MPI_BACKEND_NAME);

View File

@ -761,11 +761,11 @@ class TORCH_API ProcessGroupNCCL : public Backend {
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> initIntraNodeComm();
// Destroy (shutdown) this backend -- normal exit.
void shutdown();
void shutdown() override;
// Provides an API to abort the ProcessGroup (similar to ncclCommAbort)
// instead of relying on ProcessGroupNCCL destructor.
void abort();
void abort() override;
void eagerConnectSingleDevice(at::Device device) override;

View File

@ -111,6 +111,14 @@ class PyProcessGroup : public ProcessGroup {
);
}
void abort() override {
PYBIND11_OVERRIDE(
void, /* Return type */
ProcessGroup, /* Parent class */
abort, /* Name of function in C++ */
);
}
const std::string& getGroupName() const override {
PYBIND11_OVERRIDE(
const std::string&, /* Return type */

View File

@ -1961,6 +1961,16 @@ communication mechanism.
.def("rank", &::c10d::ProcessGroup::getRank, R"(Get the rank of this process group.)")
.def("size", &::c10d::ProcessGroup::getSize, R"(Get the size of this process group.)")
.def("name", &::c10d::ProcessGroup::getBackendName, R"(Get the name of this process group.)")
.def(
"abort",
&::c10d::ProcessGroup::abort,
py::call_guard<py::gil_scoped_release>(),
"abort all operations and connections if supported by the backend")
.def(
"shutdown",
&::c10d::ProcessGroup::shutdown,
py::call_guard<py::gil_scoped_release>(),
"shutdown the process group")
.def("_id", &::c10d::ProcessGroup::getID)
.def(
"_backend_id",
@ -2478,6 +2488,16 @@ Arguments:
.def("rank", &::c10d::Backend::getRank)
.def("size", &::c10d::Backend::getSize)
.def("name", &::c10d::Backend::getBackendName)
.def(
"abort",
&::c10d::Backend::abort,
py::call_guard<py::gil_scoped_release>(),
"abort all operations and connections if supported by the backend")
.def(
"shutdown",
&::c10d::Backend::shutdown,
py::call_guard<py::gil_scoped_release>(),
"shutdown the backend")
.def_property_readonly(
"supports_splitting",
&::c10d::Backend::supportsSplitting,
@ -2972,12 +2992,6 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
py::arg("size"),
py::arg("timeout") = ::c10d::kProcessGroupNCCLDefaultTimeout,
R"(Create a new ProcessGroupNCCL instance.)")
.def(
"_shutdown",
[](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) {
return self->shutdown();
},
py::call_guard<py::gil_scoped_release>())
.def("_group_start", &::c10d::ProcessGroupNCCL::groupStart)
.def("_group_end", &::c10d::ProcessGroupNCCL::groupEnd)
.def(
@ -3025,11 +3039,6 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
.def(
"deregister_mem_pool",
&::c10d::ProcessGroupNCCL::deregisterMemPool)
.def(
"abort",
&::c10d::ProcessGroupNCCL::abort,
py::call_guard<py::gil_scoped_release>(),
R"(Abort the process group.)")
.def(
"_is_initialized",
&::c10d::ProcessGroupNCCL::isInitialized,

View File

@ -1798,36 +1798,6 @@ def _get_split_source(pg):
return split_from
def _shutdown_backend(pg):
"""
Try to shut down the backend of a process group.
Currently, only ProcessGroupNCCL backend is supported.
No op for other backends.
"""
backend = None
try:
backend = pg._get_backend(torch.device("cuda"))
except RuntimeError:
pass
if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
# explicitly call shutdown to ensure that NCCL resources are released
backend._shutdown()
def _abort_backend(pg: ProcessGroup):
"""
Abort the backend of a process group.
Currently, only ProcessGroupNCCL backend is supported.
No op for other backends.
"""
try:
backend = pg._get_backend(torch.device("cuda"))
except RuntimeError:
backend = None
if isinstance(backend, ProcessGroupNCCL):
backend.abort()
def _new_process_group_helper(
group_size,
group_rank,
@ -2162,7 +2132,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
for pg_to_shutdown in sorted(
_world.pg_names, key=lambda x: _world.pg_names[x], reverse=True
):
_shutdown_backend(pg_to_shutdown)
pg_to_shutdown.shutdown()
_update_default_pg(None)
_world.pg_map.clear()
@ -2184,7 +2154,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
# process group is in good state, we aren't dealing with failures.
_world.group_count = 0
else:
_shutdown_backend(pg)
pg.shutdown()
del _world.pg_map[pg]
del _world.pg_names[pg]
del _world.pg_group_ranks[pg]
@ -2240,24 +2210,19 @@ def _abort_process_group(group: Optional[ProcessGroup] = None):
except RuntimeError:
backend = None
if not isinstance(backend, ProcessGroupNCCL):
logger.warning(
"`abort_process_group` currently only has implementation for ProcessGroupNCCL; "
"however, no NCCL backend is found. This call will be a no-op."
)
return
if group == GroupMember.WORLD:
if group is None or group == GroupMember.WORLD:
# Abort all backends within a ncclGroupStart|End semantic.
# This ensures that different NCCL communicators' abort calls won't
# deadlock each other.
# For details, please see: https://github.com/pytorch/pytorch/issues/119797
backend._group_start()
if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
backend._group_start()
for pg_to_abort in sorted(
_world.pg_names, key=lambda x: _world.pg_names[x], reverse=True
):
_abort_backend(pg_to_abort)
backend._group_end()
pg_to_abort.abort()
if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
backend._group_end()
_update_default_pg(None)
_world.pg_map.clear()
@ -2279,7 +2244,7 @@ def _abort_process_group(group: Optional[ProcessGroup] = None):
# process group is in good state, we aren't dealing with failures.
_world.group_count = 0
else:
_abort_backend(pg)
pg.abort()
del _world.pg_map[pg]
del _world.pg_names[pg]
del _world.pg_group_ranks[pg]