[c10d] Cleanup split_group logic using the newly built splitGroup (#158488)

with https://github.com/pytorch/pytorch/pull/157716 merged we want to further clean up the code on the python side for `split_group` API. We do need to keep some old global book keeping for bc. The rest of logic is now all in cpp. Regarding the change brought in https://github.com/pytorch/pytorch/pull/152175, we did clean up in https://github.com/pytorch/pytorch/pull/158790 (including internal changes) so that we can safely remove it.

Differential Revision: [D78777152](https://our.internmc.facebook.com/intern/diff/D78777152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158488
Approved by: https://github.com/d4l3k
This commit is contained in:
fduwjj 2025-07-28 11:35:36 -07:00 committed by PyTorch MergeBot
parent 775788f93b
commit 67e68e0785
11 changed files with 68 additions and 99 deletions

View File

@ -350,11 +350,13 @@ class ProcessGroup:
) -> None: ...
def rank(self) -> int: ...
def size(self) -> int: ...
def get_group_store(self) -> Store: ...
def split_group(
self,
new_ranks: list[int],
timeout: Optional[timedelta] = None,
pg_options: Optional[Backend.Options] = None,
opts: Optional[Backend.Options] = None,
group_name: Optional[str] = None,
group_desc: Optional[str] = None,
) -> Optional[ProcessGroup]: ...
def merge_remote_group(

View File

@ -391,6 +391,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
}
virtual c10::intrusive_ptr<Backend> split(
const c10::intrusive_ptr<Store>& store,
const std::vector<int>& ranks,
const c10::intrusive_ptr<Options>& opts) {
TORCH_CHECK(

View File

@ -160,15 +160,16 @@ void ProcessGroup::release_resources() {
c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
const std::vector<int>& ranks,
const std::optional<std::chrono::milliseconds> timeout,
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
const std::optional<std::chrono::milliseconds>& timeout,
const std::optional<c10::intrusive_ptr<Backend::Options>>& opts,
const std::optional<std::string>& name,
const std::optional<std::string>& desc) {
TORCH_CHECK(
ranks.size() > 0,
"Split ranks cannot be empty. Please provide a non-empty list of ranks to split the group.");
TORCH_CHECK(
ranks.size() < static_cast<size_t>(size_),
"the split group's size should be less than the world_size set by init_process_group");
ranks.size() <= static_cast<size_t>(size_),
"the split group's size should be no larger than the world_size set by init_process_group");
std::set<int> ranks_set(ranks.begin(), ranks.end());
TORCH_CHECK(
ranks_set.size() == ranks.size(),
@ -176,9 +177,12 @@ c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
std::vector<int> sorted_ranks = ranks;
std::sort(sorted_ranks.begin(), sorted_ranks.end());
c10::intrusive_ptr<ProcessGroup> newGroup;
// TODO: Figure out a better way for split group name.
std::string groupName =
c10::str(getGroupName(), ":split:", fmt::format("{}", sorted_ranks));
std::string groupName = name.has_value()
? name.value()
: c10::str(getGroupName(), ":split:", fmt::format("{}", sorted_ranks));
c10::intrusive_ptr<Store> store = c10::static_intrusive_pointer_cast<Store>(
c10::make_intrusive<PrefixStore>(
fmt::format("{}/", groupName), store_->clone()));
for (const auto& pair : deviceTypeToBackendType_) {
c10::DeviceType deviceType = pair.first;
BackendType backendType = pair.second;
@ -189,7 +193,7 @@ c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
backendOpts->group_name = groupName;
backendOpts->timeout =
timeout.has_value() ? timeout.value() : backendOpts->timeout;
auto splitBackend = parentBackend->split(sorted_ranks, backendOpts);
auto splitBackend = parentBackend->split(store, sorted_ranks, backendOpts);
if (splitBackend == nullptr) {
continue;
}
@ -204,7 +208,7 @@ c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
if (!newGroup) {
newGroup = c10::make_intrusive<ProcessGroup>(
store_->clone(), splitBackend->getRank(), splitBackend->getSize());
store, splitBackend->getRank(), splitBackend->getSize());
newGroup->setDefaultBackend(backendType_);
newGroup->setGroupName(groupName);
newGroup->setGroupDesc(groupDesc);

View File

@ -967,6 +967,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
return bound_device_id_;
}
c10::intrusive_ptr<c10d::Store> getStore() const {
return store_;
}
void setBoundDeviceId(std::optional<at::Device> device) {
if (device) {
TORCH_CHECK(device->has_index(), "setBoundDeviceId must have an index");
@ -978,8 +982,9 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
// The current rank must be included in the list of new_ranks.
virtual c10::intrusive_ptr<ProcessGroup> splitGroup(
const std::vector<int>& ranks,
const std::optional<std::chrono::milliseconds> timeout,
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
const std::optional<std::chrono::milliseconds>& timeout,
const std::optional<c10::intrusive_ptr<Backend::Options>>& opts,
const std::optional<std::string>& name,
const std::optional<std::string>& groupDesc);
// This creates a new subgroup using the specified ranks.

View File

@ -698,6 +698,7 @@ const std::vector<uint64_t>& ProcessGroupGloo::groupRanks() const {
}
c10::intrusive_ptr<Backend> ProcessGroupGloo::split(
const c10::intrusive_ptr<Store>& store,
const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options>& opts) {
auto it = std::find(ranks.begin(), ranks.end(), rank_);
@ -717,12 +718,8 @@ c10::intrusive_ptr<Backend> ProcessGroupGloo::split(
globalRanksInGroup.emplace_back(groupRanks()[rank]);
}
glooOpts->global_ranks_in_group = std::move(globalRanksInGroup);
auto store = std::dynamic_pointer_cast<GlooStore>(store_);
TORCH_CHECK(
store != nullptr,
"store inside ProcessGroupGloo not a ProcessGroupGloo::GlooStore.");
auto pg = c10::make_intrusive<ProcessGroupGloo>(
store->_getStore()->clone(), groupRank, ranks.size(), glooOpts);
store->clone(), groupRank, ranks.size(), glooOpts);
return c10::static_intrusive_pointer_cast<Backend>(pg);
}

View File

@ -264,6 +264,10 @@ class TORCH_API ProcessGroupGloo : public Backend {
return std::string(GLOO_BACKEND_NAME);
}
bool supportsSplitting() const override {
return true;
}
// Helper functions to create a new device object.
// They are static functions on this class to keep them logically
// separate from the rest of the code base (e.g. torch/csrc/distributed).
@ -309,6 +313,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
}
c10::intrusive_ptr<Backend> split(
const c10::intrusive_ptr<Store>& store,
const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options>& opts) override;

View File

@ -1255,6 +1255,7 @@ void ProcessGroupNCCL::enableCollectivesTiming() {
}
c10::intrusive_ptr<Backend> ProcessGroupNCCL::split(
const c10::intrusive_ptr<Store>& store,
const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options>& opts) {
auto deviceIdx = guessDeviceId();
@ -1288,7 +1289,7 @@ c10::intrusive_ptr<Backend> ProcessGroupNCCL::split(
auto color = genNcclSplitColor(ranks);
ncclOpts->split_color = color;
auto pg = c10::make_intrusive<ProcessGroupNCCL>(
store_->clone(), groupRank, ranks.size(), ncclOpts);
store->clone(), groupRank, ranks.size(), ncclOpts);
pg->eagerConnectSingleDevice(device);
return c10::static_intrusive_pointer_cast<Backend>(pg);
}

View File

@ -960,6 +960,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
void enableCollectivesTiming() override;
c10::intrusive_ptr<Backend> split(
const c10::intrusive_ptr<Store>& store,
const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options>& opts) override;

View File

@ -153,8 +153,9 @@ class PyProcessGroup : public ProcessGroup {
c10::intrusive_ptr<ProcessGroup> splitGroup(
const std::vector<int>& ranks,
const std::optional<std::chrono::milliseconds> timeout,
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
const std::optional<std::chrono::milliseconds>& timeout,
const std::optional<c10::intrusive_ptr<Backend::Options>>& opts,
const std::optional<std::string>& group_name,
const std::optional<std::string>& group_desc) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<ProcessGroup>, /* Return type */
@ -163,6 +164,7 @@ class PyProcessGroup : public ProcessGroup {
ranks,
timeout,
opts,
group_name,
group_desc);
}

View File

@ -2063,13 +2063,15 @@ communication mechanism.
.def("rank", &::c10d::ProcessGroup::getRank, R"(Get the rank of this process group.)")
.def("size", &::c10d::ProcessGroup::getSize, R"(Get the size of this process group.)")
.def("name", &::c10d::ProcessGroup::getBackendName, R"(Get the name of this process group.)")
.def("get_group_store", &::c10d::ProcessGroup::getStore, R"(Get the store of this process group.)")
.def(
"split_group",
&::c10d::ProcessGroup::splitGroup,
py::arg("ranks"),
py::arg("timeout") = std::nullopt,
py::arg("opts") = std::nullopt,
py::arg("groupDesc") = std::nullopt,
py::arg("group_name") = std::nullopt,
py::arg("group_desc") = std::nullopt,
py::call_guard<py::gil_scoped_release>())
.def(
"merge_remote_group",

View File

@ -5013,8 +5013,8 @@ def split_group(
"""
# check inputs
if split_ranks is None:
raise ValueError("split_ranks cannot be None")
if split_ranks is None or len(split_ranks) == 0:
raise ValueError("split_ranks cannot be None or empty")
global _world
default_pg = _get_default_group()
@ -5023,7 +5023,6 @@ def split_group(
raise RuntimeError(
"No device associated with the default pg, not safe to split any process groups"
)
_default_backend, default_store = _world.pg_map[default_pg]
global_rank = default_pg.rank()
global_world_size = default_pg.size()
@ -5054,11 +5053,8 @@ def split_group(
)
# set the group_desc before the color or no_cloor split
group_desc = (
f"{parent_pg.group_desc}:split:{parent_backend.comm_split_count()}" # type: ignore[attr-defined]
if group_desc is None
else group_desc
)
if hasattr(parent_backend, "comm_split_count") and group_desc is None:
group_desc = f"{parent_pg.group_desc}:split:{parent_backend.comm_split_count()}" # type: ignore[attr-defined]
parent_backend_str, _ = _world.pg_map[parent_pg]
# same type of backend as the parent process group
@ -5076,8 +5072,9 @@ def split_group(
_check_valid_timeout(timeout)
# find my group of ranks and my group local rank in split_ranks
my_group = None
group_rank = -1
# for ranks which are not in any split PGs, we just pass in this the first split group
# and None will be returned.
my_group = split_ranks[0]
for split_group in split_ranks:
if len(split_group) == 0:
@ -5091,88 +5088,40 @@ def split_group(
split_group = sorted(split_group)
if parent_group_rank in split_group:
my_group = split_group
group_rank = split_group.index(parent_group_rank)
break
# if my rank does not belong to any sub group,
# no_color split should be called
if my_group is None or group_rank == -1:
parent_backend.perform_nocolor_split(device_id) # type: ignore[attr-defined]
return None
group_name = _process_group_name(my_group, use_hashed_name=False)
global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group]
prefix_store = PrefixStore(f"{group_name}/", default_store)
# We register the backend after initializing and timeout is set in pg_options.
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
len(my_group),
split_pg = parent_pg.split_group(
my_group,
timeout=timeout,
opts=pg_options,
group_name=group_name,
group_desc=group_desc,
)
pg.bound_device_id = device_id # type: ignore[union-attr]
pg_options._timeout = timeout # type: ignore[union-attr]
pg_options.split_from = parent_backend # type: ignore[union-attr]
pg_options.split_color = _process_group_color(my_group) # type: ignore[union-attr]
pg_options.global_ranks_in_group = global_ranks_in_my_group # type: ignore[union-attr]
pg_options.group_name = group_name # type: ignore[union-attr]
if split_pg is None:
return None
if parent_backend_str == Backend.NCCL:
backend_type = ProcessGroup.BackendType.NCCL
if not isinstance(pg_options, ProcessGroupNCCL.Options):
raise RuntimeError(
"Expected pg_options argument to be of type ProcessGroupNCCL.Options"
)
backend_class = ProcessGroupNCCL(
prefix_store, group_rank, len(my_group), pg_options
)
else:
assert parent_backend_str.upper() in Backend._plugins, (
f"Unknown c10d backend type {parent_backend_str.upper()}"
)
backend_plugin = Backend._plugins[parent_backend_str.upper()]
creator_fn = backend_plugin.creator_fn
extended_api = backend_plugin.extended_api
backend_type = ProcessGroup.BackendType.CUSTOM
if not extended_api:
backend_class = creator_fn(prefix_store, group_rank, len(my_group), timeout)
else:
dist_backend_opts = _DistributedBackendOptions()
dist_backend_opts.store = prefix_store
dist_backend_opts.group_rank = group_rank
dist_backend_opts.group_size = len(my_group)
backend_class = creator_fn(dist_backend_opts, pg_options)
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(torch.device("cuda"), backend_type, backend_class)
# set group_name and group_desc to backend
assert group_name is not None
assert group_desc is not None
pg._set_group_name(group_name)
pg._set_group_desc(group_desc)
# always eagerly initialize the backend in split_group
eager_backend = pg._get_backend(device_id)
eager_backend.eager_connect_single_device(device_id)
global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group]
split_pg.bound_device_id = device_id # type: ignore[union-attr]
split_backend_class = split_pg._get_backend(torch.device("cuda"))
split_backend_class._set_sequence_number_for_group()
# update global state
_world.pg_map[pg] = (backend, prefix_store)
_world.pg_names[pg] = group_name
_register_process_group(group_name, pg)
_world.pg_backend_config[pg] = str(backend_config)
_world.pg_map[split_pg] = (backend, split_pg.get_group_store())
_world.pg_names[split_pg] = group_name
_register_process_group(group_name, split_pg)
_world.pg_backend_config[split_pg] = str(backend_config)
pg_tag = f"ptd:{group_name}"
_world.tags_to_pg.setdefault(pg_tag, []).append(pg)
_world.pg_to_tag[pg] = pg_tag
_world.tags_to_pg.setdefault(pg_tag, []).append(split_pg)
_world.pg_to_tag[split_pg] = pg_tag
# Create the global rank to group rank mapping
_world.pg_group_ranks[pg] = {
_world.pg_group_ranks[split_pg] = {
global_rank: group_rank
for group_rank, global_rank in enumerate(global_ranks_in_my_group)
}
return pg
return split_pg
@_time_logger