[c10d] Cleanup split_group logic using the newly built splitGroup (#158488)

with https://github.com/pytorch/pytorch/pull/157716 merged we want to further clean up the code on the python side for `split_group` API. We do need to keep some old global book keeping for bc. The rest of logic is now all in cpp. Regarding the change brought in https://github.com/pytorch/pytorch/pull/152175, we did clean up in https://github.com/pytorch/pytorch/pull/158790 (including internal changes) so that we can safely remove it.

Differential Revision: [D78777152](https://our.internmc.facebook.com/intern/diff/D78777152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158488
Approved by: https://github.com/d4l3k
This commit is contained in:
fduwjj 2025-07-28 11:35:36 -07:00 committed by PyTorch MergeBot
parent 775788f93b
commit 67e68e0785
11 changed files with 68 additions and 99 deletions

View File

@ -350,11 +350,13 @@ class ProcessGroup:
) -> None: ... ) -> None: ...
def rank(self) -> int: ... def rank(self) -> int: ...
def size(self) -> int: ... def size(self) -> int: ...
def get_group_store(self) -> Store: ...
def split_group( def split_group(
self, self,
new_ranks: list[int], new_ranks: list[int],
timeout: Optional[timedelta] = None, timeout: Optional[timedelta] = None,
pg_options: Optional[Backend.Options] = None, opts: Optional[Backend.Options] = None,
group_name: Optional[str] = None,
group_desc: Optional[str] = None, group_desc: Optional[str] = None,
) -> Optional[ProcessGroup]: ... ) -> Optional[ProcessGroup]: ...
def merge_remote_group( def merge_remote_group(

View File

@ -391,6 +391,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
} }
virtual c10::intrusive_ptr<Backend> split( virtual c10::intrusive_ptr<Backend> split(
const c10::intrusive_ptr<Store>& store,
const std::vector<int>& ranks, const std::vector<int>& ranks,
const c10::intrusive_ptr<Options>& opts) { const c10::intrusive_ptr<Options>& opts) {
TORCH_CHECK( TORCH_CHECK(

View File

@ -160,15 +160,16 @@ void ProcessGroup::release_resources() {
c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup( c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
const std::vector<int>& ranks, const std::vector<int>& ranks,
const std::optional<std::chrono::milliseconds> timeout, const std::optional<std::chrono::milliseconds>& timeout,
const std::optional<c10::intrusive_ptr<Backend::Options>> opts, const std::optional<c10::intrusive_ptr<Backend::Options>>& opts,
const std::optional<std::string>& name,
const std::optional<std::string>& desc) { const std::optional<std::string>& desc) {
TORCH_CHECK( TORCH_CHECK(
ranks.size() > 0, ranks.size() > 0,
"Split ranks cannot be empty. Please provide a non-empty list of ranks to split the group."); "Split ranks cannot be empty. Please provide a non-empty list of ranks to split the group.");
TORCH_CHECK( TORCH_CHECK(
ranks.size() < static_cast<size_t>(size_), ranks.size() <= static_cast<size_t>(size_),
"the split group's size should be less than the world_size set by init_process_group"); "the split group's size should be no larger than the world_size set by init_process_group");
std::set<int> ranks_set(ranks.begin(), ranks.end()); std::set<int> ranks_set(ranks.begin(), ranks.end());
TORCH_CHECK( TORCH_CHECK(
ranks_set.size() == ranks.size(), ranks_set.size() == ranks.size(),
@ -176,9 +177,12 @@ c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
std::vector<int> sorted_ranks = ranks; std::vector<int> sorted_ranks = ranks;
std::sort(sorted_ranks.begin(), sorted_ranks.end()); std::sort(sorted_ranks.begin(), sorted_ranks.end());
c10::intrusive_ptr<ProcessGroup> newGroup; c10::intrusive_ptr<ProcessGroup> newGroup;
// TODO: Figure out a better way for split group name. std::string groupName = name.has_value()
std::string groupName = ? name.value()
c10::str(getGroupName(), ":split:", fmt::format("{}", sorted_ranks)); : c10::str(getGroupName(), ":split:", fmt::format("{}", sorted_ranks));
c10::intrusive_ptr<Store> store = c10::static_intrusive_pointer_cast<Store>(
c10::make_intrusive<PrefixStore>(
fmt::format("{}/", groupName), store_->clone()));
for (const auto& pair : deviceTypeToBackendType_) { for (const auto& pair : deviceTypeToBackendType_) {
c10::DeviceType deviceType = pair.first; c10::DeviceType deviceType = pair.first;
BackendType backendType = pair.second; BackendType backendType = pair.second;
@ -189,7 +193,7 @@ c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
backendOpts->group_name = groupName; backendOpts->group_name = groupName;
backendOpts->timeout = backendOpts->timeout =
timeout.has_value() ? timeout.value() : backendOpts->timeout; timeout.has_value() ? timeout.value() : backendOpts->timeout;
auto splitBackend = parentBackend->split(sorted_ranks, backendOpts); auto splitBackend = parentBackend->split(store, sorted_ranks, backendOpts);
if (splitBackend == nullptr) { if (splitBackend == nullptr) {
continue; continue;
} }
@ -204,7 +208,7 @@ c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
if (!newGroup) { if (!newGroup) {
newGroup = c10::make_intrusive<ProcessGroup>( newGroup = c10::make_intrusive<ProcessGroup>(
store_->clone(), splitBackend->getRank(), splitBackend->getSize()); store, splitBackend->getRank(), splitBackend->getSize());
newGroup->setDefaultBackend(backendType_); newGroup->setDefaultBackend(backendType_);
newGroup->setGroupName(groupName); newGroup->setGroupName(groupName);
newGroup->setGroupDesc(groupDesc); newGroup->setGroupDesc(groupDesc);

View File

@ -967,6 +967,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
return bound_device_id_; return bound_device_id_;
} }
c10::intrusive_ptr<c10d::Store> getStore() const {
return store_;
}
void setBoundDeviceId(std::optional<at::Device> device) { void setBoundDeviceId(std::optional<at::Device> device) {
if (device) { if (device) {
TORCH_CHECK(device->has_index(), "setBoundDeviceId must have an index"); TORCH_CHECK(device->has_index(), "setBoundDeviceId must have an index");
@ -978,8 +982,9 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
// The current rank must be included in the list of new_ranks. // The current rank must be included in the list of new_ranks.
virtual c10::intrusive_ptr<ProcessGroup> splitGroup( virtual c10::intrusive_ptr<ProcessGroup> splitGroup(
const std::vector<int>& ranks, const std::vector<int>& ranks,
const std::optional<std::chrono::milliseconds> timeout, const std::optional<std::chrono::milliseconds>& timeout,
const std::optional<c10::intrusive_ptr<Backend::Options>> opts, const std::optional<c10::intrusive_ptr<Backend::Options>>& opts,
const std::optional<std::string>& name,
const std::optional<std::string>& groupDesc); const std::optional<std::string>& groupDesc);
// This creates a new subgroup using the specified ranks. // This creates a new subgroup using the specified ranks.

View File

@ -698,6 +698,7 @@ const std::vector<uint64_t>& ProcessGroupGloo::groupRanks() const {
} }
c10::intrusive_ptr<Backend> ProcessGroupGloo::split( c10::intrusive_ptr<Backend> ProcessGroupGloo::split(
const c10::intrusive_ptr<Store>& store,
const std::vector<int>& ranks, const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options>& opts) { const c10::intrusive_ptr<Backend::Options>& opts) {
auto it = std::find(ranks.begin(), ranks.end(), rank_); auto it = std::find(ranks.begin(), ranks.end(), rank_);
@ -717,12 +718,8 @@ c10::intrusive_ptr<Backend> ProcessGroupGloo::split(
globalRanksInGroup.emplace_back(groupRanks()[rank]); globalRanksInGroup.emplace_back(groupRanks()[rank]);
} }
glooOpts->global_ranks_in_group = std::move(globalRanksInGroup); glooOpts->global_ranks_in_group = std::move(globalRanksInGroup);
auto store = std::dynamic_pointer_cast<GlooStore>(store_);
TORCH_CHECK(
store != nullptr,
"store inside ProcessGroupGloo not a ProcessGroupGloo::GlooStore.");
auto pg = c10::make_intrusive<ProcessGroupGloo>( auto pg = c10::make_intrusive<ProcessGroupGloo>(
store->_getStore()->clone(), groupRank, ranks.size(), glooOpts); store->clone(), groupRank, ranks.size(), glooOpts);
return c10::static_intrusive_pointer_cast<Backend>(pg); return c10::static_intrusive_pointer_cast<Backend>(pg);
} }

View File

@ -264,6 +264,10 @@ class TORCH_API ProcessGroupGloo : public Backend {
return std::string(GLOO_BACKEND_NAME); return std::string(GLOO_BACKEND_NAME);
} }
bool supportsSplitting() const override {
return true;
}
// Helper functions to create a new device object. // Helper functions to create a new device object.
// They are static functions on this class to keep them logically // They are static functions on this class to keep them logically
// separate from the rest of the code base (e.g. torch/csrc/distributed). // separate from the rest of the code base (e.g. torch/csrc/distributed).
@ -309,6 +313,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
} }
c10::intrusive_ptr<Backend> split( c10::intrusive_ptr<Backend> split(
const c10::intrusive_ptr<Store>& store,
const std::vector<int>& ranks, const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options>& opts) override; const c10::intrusive_ptr<Backend::Options>& opts) override;

View File

@ -1255,6 +1255,7 @@ void ProcessGroupNCCL::enableCollectivesTiming() {
} }
c10::intrusive_ptr<Backend> ProcessGroupNCCL::split( c10::intrusive_ptr<Backend> ProcessGroupNCCL::split(
const c10::intrusive_ptr<Store>& store,
const std::vector<int>& ranks, const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options>& opts) { const c10::intrusive_ptr<Backend::Options>& opts) {
auto deviceIdx = guessDeviceId(); auto deviceIdx = guessDeviceId();
@ -1288,7 +1289,7 @@ c10::intrusive_ptr<Backend> ProcessGroupNCCL::split(
auto color = genNcclSplitColor(ranks); auto color = genNcclSplitColor(ranks);
ncclOpts->split_color = color; ncclOpts->split_color = color;
auto pg = c10::make_intrusive<ProcessGroupNCCL>( auto pg = c10::make_intrusive<ProcessGroupNCCL>(
store_->clone(), groupRank, ranks.size(), ncclOpts); store->clone(), groupRank, ranks.size(), ncclOpts);
pg->eagerConnectSingleDevice(device); pg->eagerConnectSingleDevice(device);
return c10::static_intrusive_pointer_cast<Backend>(pg); return c10::static_intrusive_pointer_cast<Backend>(pg);
} }

View File

@ -960,6 +960,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
void enableCollectivesTiming() override; void enableCollectivesTiming() override;
c10::intrusive_ptr<Backend> split( c10::intrusive_ptr<Backend> split(
const c10::intrusive_ptr<Store>& store,
const std::vector<int>& ranks, const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options>& opts) override; const c10::intrusive_ptr<Backend::Options>& opts) override;

View File

@ -153,8 +153,9 @@ class PyProcessGroup : public ProcessGroup {
c10::intrusive_ptr<ProcessGroup> splitGroup( c10::intrusive_ptr<ProcessGroup> splitGroup(
const std::vector<int>& ranks, const std::vector<int>& ranks,
const std::optional<std::chrono::milliseconds> timeout, const std::optional<std::chrono::milliseconds>& timeout,
const std::optional<c10::intrusive_ptr<Backend::Options>> opts, const std::optional<c10::intrusive_ptr<Backend::Options>>& opts,
const std::optional<std::string>& group_name,
const std::optional<std::string>& group_desc) override { const std::optional<std::string>& group_desc) override {
PYBIND11_OVERRIDE( PYBIND11_OVERRIDE(
c10::intrusive_ptr<ProcessGroup>, /* Return type */ c10::intrusive_ptr<ProcessGroup>, /* Return type */
@ -163,6 +164,7 @@ class PyProcessGroup : public ProcessGroup {
ranks, ranks,
timeout, timeout,
opts, opts,
group_name,
group_desc); group_desc);
} }

View File

@ -2063,13 +2063,15 @@ communication mechanism.
.def("rank", &::c10d::ProcessGroup::getRank, R"(Get the rank of this process group.)") .def("rank", &::c10d::ProcessGroup::getRank, R"(Get the rank of this process group.)")
.def("size", &::c10d::ProcessGroup::getSize, R"(Get the size of this process group.)") .def("size", &::c10d::ProcessGroup::getSize, R"(Get the size of this process group.)")
.def("name", &::c10d::ProcessGroup::getBackendName, R"(Get the name of this process group.)") .def("name", &::c10d::ProcessGroup::getBackendName, R"(Get the name of this process group.)")
.def("get_group_store", &::c10d::ProcessGroup::getStore, R"(Get the store of this process group.)")
.def( .def(
"split_group", "split_group",
&::c10d::ProcessGroup::splitGroup, &::c10d::ProcessGroup::splitGroup,
py::arg("ranks"), py::arg("ranks"),
py::arg("timeout") = std::nullopt, py::arg("timeout") = std::nullopt,
py::arg("opts") = std::nullopt, py::arg("opts") = std::nullopt,
py::arg("groupDesc") = std::nullopt, py::arg("group_name") = std::nullopt,
py::arg("group_desc") = std::nullopt,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"merge_remote_group", "merge_remote_group",

View File

@ -5013,8 +5013,8 @@ def split_group(
""" """
# check inputs # check inputs
if split_ranks is None: if split_ranks is None or len(split_ranks) == 0:
raise ValueError("split_ranks cannot be None") raise ValueError("split_ranks cannot be None or empty")
global _world global _world
default_pg = _get_default_group() default_pg = _get_default_group()
@ -5023,7 +5023,6 @@ def split_group(
raise RuntimeError( raise RuntimeError(
"No device associated with the default pg, not safe to split any process groups" "No device associated with the default pg, not safe to split any process groups"
) )
_default_backend, default_store = _world.pg_map[default_pg]
global_rank = default_pg.rank() global_rank = default_pg.rank()
global_world_size = default_pg.size() global_world_size = default_pg.size()
@ -5054,11 +5053,8 @@ def split_group(
) )
# set the group_desc before the color or no_cloor split # set the group_desc before the color or no_cloor split
group_desc = ( if hasattr(parent_backend, "comm_split_count") and group_desc is None:
f"{parent_pg.group_desc}:split:{parent_backend.comm_split_count()}" # type: ignore[attr-defined] group_desc = f"{parent_pg.group_desc}:split:{parent_backend.comm_split_count()}" # type: ignore[attr-defined]
if group_desc is None
else group_desc
)
parent_backend_str, _ = _world.pg_map[parent_pg] parent_backend_str, _ = _world.pg_map[parent_pg]
# same type of backend as the parent process group # same type of backend as the parent process group
@ -5076,8 +5072,9 @@ def split_group(
_check_valid_timeout(timeout) _check_valid_timeout(timeout)
# find my group of ranks and my group local rank in split_ranks # find my group of ranks and my group local rank in split_ranks
my_group = None # for ranks which are not in any split PGs, we just pass in this the first split group
group_rank = -1 # and None will be returned.
my_group = split_ranks[0]
for split_group in split_ranks: for split_group in split_ranks:
if len(split_group) == 0: if len(split_group) == 0:
@ -5091,88 +5088,40 @@ def split_group(
split_group = sorted(split_group) split_group = sorted(split_group)
if parent_group_rank in split_group: if parent_group_rank in split_group:
my_group = split_group my_group = split_group
group_rank = split_group.index(parent_group_rank)
break break
# if my rank does not belong to any sub group,
# no_color split should be called
if my_group is None or group_rank == -1:
parent_backend.perform_nocolor_split(device_id) # type: ignore[attr-defined]
return None
group_name = _process_group_name(my_group, use_hashed_name=False) group_name = _process_group_name(my_group, use_hashed_name=False)
global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group] split_pg = parent_pg.split_group(
my_group,
prefix_store = PrefixStore(f"{group_name}/", default_store) timeout=timeout,
# We register the backend after initializing and timeout is set in pg_options. opts=pg_options,
pg: ProcessGroup = ProcessGroup( group_name=group_name,
prefix_store, group_desc=group_desc,
group_rank,
len(my_group),
) )
pg.bound_device_id = device_id # type: ignore[union-attr] if split_pg is None:
pg_options._timeout = timeout # type: ignore[union-attr] return None
pg_options.split_from = parent_backend # type: ignore[union-attr]
pg_options.split_color = _process_group_color(my_group) # type: ignore[union-attr]
pg_options.global_ranks_in_group = global_ranks_in_my_group # type: ignore[union-attr]
pg_options.group_name = group_name # type: ignore[union-attr]
if parent_backend_str == Backend.NCCL: global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group]
backend_type = ProcessGroup.BackendType.NCCL split_pg.bound_device_id = device_id # type: ignore[union-attr]
if not isinstance(pg_options, ProcessGroupNCCL.Options): split_backend_class = split_pg._get_backend(torch.device("cuda"))
raise RuntimeError( split_backend_class._set_sequence_number_for_group()
"Expected pg_options argument to be of type ProcessGroupNCCL.Options"
)
backend_class = ProcessGroupNCCL(
prefix_store, group_rank, len(my_group), pg_options
)
else:
assert parent_backend_str.upper() in Backend._plugins, (
f"Unknown c10d backend type {parent_backend_str.upper()}"
)
backend_plugin = Backend._plugins[parent_backend_str.upper()]
creator_fn = backend_plugin.creator_fn
extended_api = backend_plugin.extended_api
backend_type = ProcessGroup.BackendType.CUSTOM
if not extended_api:
backend_class = creator_fn(prefix_store, group_rank, len(my_group), timeout)
else:
dist_backend_opts = _DistributedBackendOptions()
dist_backend_opts.store = prefix_store
dist_backend_opts.group_rank = group_rank
dist_backend_opts.group_size = len(my_group)
backend_class = creator_fn(dist_backend_opts, pg_options)
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(torch.device("cuda"), backend_type, backend_class)
# set group_name and group_desc to backend
assert group_name is not None
assert group_desc is not None
pg._set_group_name(group_name)
pg._set_group_desc(group_desc)
# always eagerly initialize the backend in split_group
eager_backend = pg._get_backend(device_id)
eager_backend.eager_connect_single_device(device_id)
# update global state # update global state
_world.pg_map[pg] = (backend, prefix_store) _world.pg_map[split_pg] = (backend, split_pg.get_group_store())
_world.pg_names[pg] = group_name _world.pg_names[split_pg] = group_name
_register_process_group(group_name, pg) _register_process_group(group_name, split_pg)
_world.pg_backend_config[pg] = str(backend_config) _world.pg_backend_config[split_pg] = str(backend_config)
pg_tag = f"ptd:{group_name}" pg_tag = f"ptd:{group_name}"
_world.tags_to_pg.setdefault(pg_tag, []).append(pg) _world.tags_to_pg.setdefault(pg_tag, []).append(split_pg)
_world.pg_to_tag[pg] = pg_tag _world.pg_to_tag[split_pg] = pg_tag
# Create the global rank to group rank mapping # Create the global rank to group rank mapping
_world.pg_group_ranks[pg] = { _world.pg_group_ranks[split_pg] = {
global_rank: group_rank global_rank: group_rank
for group_rank, global_rank in enumerate(global_ranks_in_my_group) for group_rank, global_rank in enumerate(global_ranks_in_my_group)
} }
return pg return split_pg
@_time_logger @_time_logger