mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[c10d] Cleanup split_group logic using the newly built splitGroup (#158488)
with https://github.com/pytorch/pytorch/pull/157716 merged we want to further clean up the code on the python side for `split_group` API. We do need to keep some old global book keeping for bc. The rest of logic is now all in cpp. Regarding the change brought in https://github.com/pytorch/pytorch/pull/152175, we did clean up in https://github.com/pytorch/pytorch/pull/158790 (including internal changes) so that we can safely remove it. Differential Revision: [D78777152](https://our.internmc.facebook.com/intern/diff/D78777152) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158488 Approved by: https://github.com/d4l3k
This commit is contained in:
parent
775788f93b
commit
67e68e0785
|
|
@ -350,11 +350,13 @@ class ProcessGroup:
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
def rank(self) -> int: ...
|
def rank(self) -> int: ...
|
||||||
def size(self) -> int: ...
|
def size(self) -> int: ...
|
||||||
|
def get_group_store(self) -> Store: ...
|
||||||
def split_group(
|
def split_group(
|
||||||
self,
|
self,
|
||||||
new_ranks: list[int],
|
new_ranks: list[int],
|
||||||
timeout: Optional[timedelta] = None,
|
timeout: Optional[timedelta] = None,
|
||||||
pg_options: Optional[Backend.Options] = None,
|
opts: Optional[Backend.Options] = None,
|
||||||
|
group_name: Optional[str] = None,
|
||||||
group_desc: Optional[str] = None,
|
group_desc: Optional[str] = None,
|
||||||
) -> Optional[ProcessGroup]: ...
|
) -> Optional[ProcessGroup]: ...
|
||||||
def merge_remote_group(
|
def merge_remote_group(
|
||||||
|
|
|
||||||
|
|
@ -391,6 +391,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual c10::intrusive_ptr<Backend> split(
|
virtual c10::intrusive_ptr<Backend> split(
|
||||||
|
const c10::intrusive_ptr<Store>& store,
|
||||||
const std::vector<int>& ranks,
|
const std::vector<int>& ranks,
|
||||||
const c10::intrusive_ptr<Options>& opts) {
|
const c10::intrusive_ptr<Options>& opts) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
|
|
|
||||||
|
|
@ -160,15 +160,16 @@ void ProcessGroup::release_resources() {
|
||||||
|
|
||||||
c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
|
c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
|
||||||
const std::vector<int>& ranks,
|
const std::vector<int>& ranks,
|
||||||
const std::optional<std::chrono::milliseconds> timeout,
|
const std::optional<std::chrono::milliseconds>& timeout,
|
||||||
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
|
const std::optional<c10::intrusive_ptr<Backend::Options>>& opts,
|
||||||
|
const std::optional<std::string>& name,
|
||||||
const std::optional<std::string>& desc) {
|
const std::optional<std::string>& desc) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
ranks.size() > 0,
|
ranks.size() > 0,
|
||||||
"Split ranks cannot be empty. Please provide a non-empty list of ranks to split the group.");
|
"Split ranks cannot be empty. Please provide a non-empty list of ranks to split the group.");
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
ranks.size() < static_cast<size_t>(size_),
|
ranks.size() <= static_cast<size_t>(size_),
|
||||||
"the split group's size should be less than the world_size set by init_process_group");
|
"the split group's size should be no larger than the world_size set by init_process_group");
|
||||||
std::set<int> ranks_set(ranks.begin(), ranks.end());
|
std::set<int> ranks_set(ranks.begin(), ranks.end());
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
ranks_set.size() == ranks.size(),
|
ranks_set.size() == ranks.size(),
|
||||||
|
|
@ -176,9 +177,12 @@ c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
|
||||||
std::vector<int> sorted_ranks = ranks;
|
std::vector<int> sorted_ranks = ranks;
|
||||||
std::sort(sorted_ranks.begin(), sorted_ranks.end());
|
std::sort(sorted_ranks.begin(), sorted_ranks.end());
|
||||||
c10::intrusive_ptr<ProcessGroup> newGroup;
|
c10::intrusive_ptr<ProcessGroup> newGroup;
|
||||||
// TODO: Figure out a better way for split group name.
|
std::string groupName = name.has_value()
|
||||||
std::string groupName =
|
? name.value()
|
||||||
c10::str(getGroupName(), ":split:", fmt::format("{}", sorted_ranks));
|
: c10::str(getGroupName(), ":split:", fmt::format("{}", sorted_ranks));
|
||||||
|
c10::intrusive_ptr<Store> store = c10::static_intrusive_pointer_cast<Store>(
|
||||||
|
c10::make_intrusive<PrefixStore>(
|
||||||
|
fmt::format("{}/", groupName), store_->clone()));
|
||||||
for (const auto& pair : deviceTypeToBackendType_) {
|
for (const auto& pair : deviceTypeToBackendType_) {
|
||||||
c10::DeviceType deviceType = pair.first;
|
c10::DeviceType deviceType = pair.first;
|
||||||
BackendType backendType = pair.second;
|
BackendType backendType = pair.second;
|
||||||
|
|
@ -189,7 +193,7 @@ c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
|
||||||
backendOpts->group_name = groupName;
|
backendOpts->group_name = groupName;
|
||||||
backendOpts->timeout =
|
backendOpts->timeout =
|
||||||
timeout.has_value() ? timeout.value() : backendOpts->timeout;
|
timeout.has_value() ? timeout.value() : backendOpts->timeout;
|
||||||
auto splitBackend = parentBackend->split(sorted_ranks, backendOpts);
|
auto splitBackend = parentBackend->split(store, sorted_ranks, backendOpts);
|
||||||
if (splitBackend == nullptr) {
|
if (splitBackend == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
@ -204,7 +208,7 @@ c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
|
||||||
|
|
||||||
if (!newGroup) {
|
if (!newGroup) {
|
||||||
newGroup = c10::make_intrusive<ProcessGroup>(
|
newGroup = c10::make_intrusive<ProcessGroup>(
|
||||||
store_->clone(), splitBackend->getRank(), splitBackend->getSize());
|
store, splitBackend->getRank(), splitBackend->getSize());
|
||||||
newGroup->setDefaultBackend(backendType_);
|
newGroup->setDefaultBackend(backendType_);
|
||||||
newGroup->setGroupName(groupName);
|
newGroup->setGroupName(groupName);
|
||||||
newGroup->setGroupDesc(groupDesc);
|
newGroup->setGroupDesc(groupDesc);
|
||||||
|
|
|
||||||
|
|
@ -967,6 +967,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||||
return bound_device_id_;
|
return bound_device_id_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<c10d::Store> getStore() const {
|
||||||
|
return store_;
|
||||||
|
}
|
||||||
|
|
||||||
void setBoundDeviceId(std::optional<at::Device> device) {
|
void setBoundDeviceId(std::optional<at::Device> device) {
|
||||||
if (device) {
|
if (device) {
|
||||||
TORCH_CHECK(device->has_index(), "setBoundDeviceId must have an index");
|
TORCH_CHECK(device->has_index(), "setBoundDeviceId must have an index");
|
||||||
|
|
@ -978,8 +982,9 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||||
// The current rank must be included in the list of new_ranks.
|
// The current rank must be included in the list of new_ranks.
|
||||||
virtual c10::intrusive_ptr<ProcessGroup> splitGroup(
|
virtual c10::intrusive_ptr<ProcessGroup> splitGroup(
|
||||||
const std::vector<int>& ranks,
|
const std::vector<int>& ranks,
|
||||||
const std::optional<std::chrono::milliseconds> timeout,
|
const std::optional<std::chrono::milliseconds>& timeout,
|
||||||
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
|
const std::optional<c10::intrusive_ptr<Backend::Options>>& opts,
|
||||||
|
const std::optional<std::string>& name,
|
||||||
const std::optional<std::string>& groupDesc);
|
const std::optional<std::string>& groupDesc);
|
||||||
|
|
||||||
// This creates a new subgroup using the specified ranks.
|
// This creates a new subgroup using the specified ranks.
|
||||||
|
|
|
||||||
|
|
@ -698,6 +698,7 @@ const std::vector<uint64_t>& ProcessGroupGloo::groupRanks() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::intrusive_ptr<Backend> ProcessGroupGloo::split(
|
c10::intrusive_ptr<Backend> ProcessGroupGloo::split(
|
||||||
|
const c10::intrusive_ptr<Store>& store,
|
||||||
const std::vector<int>& ranks,
|
const std::vector<int>& ranks,
|
||||||
const c10::intrusive_ptr<Backend::Options>& opts) {
|
const c10::intrusive_ptr<Backend::Options>& opts) {
|
||||||
auto it = std::find(ranks.begin(), ranks.end(), rank_);
|
auto it = std::find(ranks.begin(), ranks.end(), rank_);
|
||||||
|
|
@ -717,12 +718,8 @@ c10::intrusive_ptr<Backend> ProcessGroupGloo::split(
|
||||||
globalRanksInGroup.emplace_back(groupRanks()[rank]);
|
globalRanksInGroup.emplace_back(groupRanks()[rank]);
|
||||||
}
|
}
|
||||||
glooOpts->global_ranks_in_group = std::move(globalRanksInGroup);
|
glooOpts->global_ranks_in_group = std::move(globalRanksInGroup);
|
||||||
auto store = std::dynamic_pointer_cast<GlooStore>(store_);
|
|
||||||
TORCH_CHECK(
|
|
||||||
store != nullptr,
|
|
||||||
"store inside ProcessGroupGloo not a ProcessGroupGloo::GlooStore.");
|
|
||||||
auto pg = c10::make_intrusive<ProcessGroupGloo>(
|
auto pg = c10::make_intrusive<ProcessGroupGloo>(
|
||||||
store->_getStore()->clone(), groupRank, ranks.size(), glooOpts);
|
store->clone(), groupRank, ranks.size(), glooOpts);
|
||||||
return c10::static_intrusive_pointer_cast<Backend>(pg);
|
return c10::static_intrusive_pointer_cast<Backend>(pg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -264,6 +264,10 @@ class TORCH_API ProcessGroupGloo : public Backend {
|
||||||
return std::string(GLOO_BACKEND_NAME);
|
return std::string(GLOO_BACKEND_NAME);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool supportsSplitting() const override {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// Helper functions to create a new device object.
|
// Helper functions to create a new device object.
|
||||||
// They are static functions on this class to keep them logically
|
// They are static functions on this class to keep them logically
|
||||||
// separate from the rest of the code base (e.g. torch/csrc/distributed).
|
// separate from the rest of the code base (e.g. torch/csrc/distributed).
|
||||||
|
|
@ -309,6 +313,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::intrusive_ptr<Backend> split(
|
c10::intrusive_ptr<Backend> split(
|
||||||
|
const c10::intrusive_ptr<Store>& store,
|
||||||
const std::vector<int>& ranks,
|
const std::vector<int>& ranks,
|
||||||
const c10::intrusive_ptr<Backend::Options>& opts) override;
|
const c10::intrusive_ptr<Backend::Options>& opts) override;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1255,6 +1255,7 @@ void ProcessGroupNCCL::enableCollectivesTiming() {
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::intrusive_ptr<Backend> ProcessGroupNCCL::split(
|
c10::intrusive_ptr<Backend> ProcessGroupNCCL::split(
|
||||||
|
const c10::intrusive_ptr<Store>& store,
|
||||||
const std::vector<int>& ranks,
|
const std::vector<int>& ranks,
|
||||||
const c10::intrusive_ptr<Backend::Options>& opts) {
|
const c10::intrusive_ptr<Backend::Options>& opts) {
|
||||||
auto deviceIdx = guessDeviceId();
|
auto deviceIdx = guessDeviceId();
|
||||||
|
|
@ -1288,7 +1289,7 @@ c10::intrusive_ptr<Backend> ProcessGroupNCCL::split(
|
||||||
auto color = genNcclSplitColor(ranks);
|
auto color = genNcclSplitColor(ranks);
|
||||||
ncclOpts->split_color = color;
|
ncclOpts->split_color = color;
|
||||||
auto pg = c10::make_intrusive<ProcessGroupNCCL>(
|
auto pg = c10::make_intrusive<ProcessGroupNCCL>(
|
||||||
store_->clone(), groupRank, ranks.size(), ncclOpts);
|
store->clone(), groupRank, ranks.size(), ncclOpts);
|
||||||
pg->eagerConnectSingleDevice(device);
|
pg->eagerConnectSingleDevice(device);
|
||||||
return c10::static_intrusive_pointer_cast<Backend>(pg);
|
return c10::static_intrusive_pointer_cast<Backend>(pg);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -960,6 +960,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||||
void enableCollectivesTiming() override;
|
void enableCollectivesTiming() override;
|
||||||
|
|
||||||
c10::intrusive_ptr<Backend> split(
|
c10::intrusive_ptr<Backend> split(
|
||||||
|
const c10::intrusive_ptr<Store>& store,
|
||||||
const std::vector<int>& ranks,
|
const std::vector<int>& ranks,
|
||||||
const c10::intrusive_ptr<Backend::Options>& opts) override;
|
const c10::intrusive_ptr<Backend::Options>& opts) override;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -153,8 +153,9 @@ class PyProcessGroup : public ProcessGroup {
|
||||||
|
|
||||||
c10::intrusive_ptr<ProcessGroup> splitGroup(
|
c10::intrusive_ptr<ProcessGroup> splitGroup(
|
||||||
const std::vector<int>& ranks,
|
const std::vector<int>& ranks,
|
||||||
const std::optional<std::chrono::milliseconds> timeout,
|
const std::optional<std::chrono::milliseconds>& timeout,
|
||||||
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
|
const std::optional<c10::intrusive_ptr<Backend::Options>>& opts,
|
||||||
|
const std::optional<std::string>& group_name,
|
||||||
const std::optional<std::string>& group_desc) override {
|
const std::optional<std::string>& group_desc) override {
|
||||||
PYBIND11_OVERRIDE(
|
PYBIND11_OVERRIDE(
|
||||||
c10::intrusive_ptr<ProcessGroup>, /* Return type */
|
c10::intrusive_ptr<ProcessGroup>, /* Return type */
|
||||||
|
|
@ -163,6 +164,7 @@ class PyProcessGroup : public ProcessGroup {
|
||||||
ranks,
|
ranks,
|
||||||
timeout,
|
timeout,
|
||||||
opts,
|
opts,
|
||||||
|
group_name,
|
||||||
group_desc);
|
group_desc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2063,13 +2063,15 @@ communication mechanism.
|
||||||
.def("rank", &::c10d::ProcessGroup::getRank, R"(Get the rank of this process group.)")
|
.def("rank", &::c10d::ProcessGroup::getRank, R"(Get the rank of this process group.)")
|
||||||
.def("size", &::c10d::ProcessGroup::getSize, R"(Get the size of this process group.)")
|
.def("size", &::c10d::ProcessGroup::getSize, R"(Get the size of this process group.)")
|
||||||
.def("name", &::c10d::ProcessGroup::getBackendName, R"(Get the name of this process group.)")
|
.def("name", &::c10d::ProcessGroup::getBackendName, R"(Get the name of this process group.)")
|
||||||
|
.def("get_group_store", &::c10d::ProcessGroup::getStore, R"(Get the store of this process group.)")
|
||||||
.def(
|
.def(
|
||||||
"split_group",
|
"split_group",
|
||||||
&::c10d::ProcessGroup::splitGroup,
|
&::c10d::ProcessGroup::splitGroup,
|
||||||
py::arg("ranks"),
|
py::arg("ranks"),
|
||||||
py::arg("timeout") = std::nullopt,
|
py::arg("timeout") = std::nullopt,
|
||||||
py::arg("opts") = std::nullopt,
|
py::arg("opts") = std::nullopt,
|
||||||
py::arg("groupDesc") = std::nullopt,
|
py::arg("group_name") = std::nullopt,
|
||||||
|
py::arg("group_desc") = std::nullopt,
|
||||||
py::call_guard<py::gil_scoped_release>())
|
py::call_guard<py::gil_scoped_release>())
|
||||||
.def(
|
.def(
|
||||||
"merge_remote_group",
|
"merge_remote_group",
|
||||||
|
|
|
||||||
|
|
@ -5013,8 +5013,8 @@ def split_group(
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# check inputs
|
# check inputs
|
||||||
if split_ranks is None:
|
if split_ranks is None or len(split_ranks) == 0:
|
||||||
raise ValueError("split_ranks cannot be None")
|
raise ValueError("split_ranks cannot be None or empty")
|
||||||
|
|
||||||
global _world
|
global _world
|
||||||
default_pg = _get_default_group()
|
default_pg = _get_default_group()
|
||||||
|
|
@ -5023,7 +5023,6 @@ def split_group(
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"No device associated with the default pg, not safe to split any process groups"
|
"No device associated with the default pg, not safe to split any process groups"
|
||||||
)
|
)
|
||||||
_default_backend, default_store = _world.pg_map[default_pg]
|
|
||||||
global_rank = default_pg.rank()
|
global_rank = default_pg.rank()
|
||||||
global_world_size = default_pg.size()
|
global_world_size = default_pg.size()
|
||||||
|
|
||||||
|
|
@ -5054,11 +5053,8 @@ def split_group(
|
||||||
)
|
)
|
||||||
|
|
||||||
# set the group_desc before the color or no_cloor split
|
# set the group_desc before the color or no_cloor split
|
||||||
group_desc = (
|
if hasattr(parent_backend, "comm_split_count") and group_desc is None:
|
||||||
f"{parent_pg.group_desc}:split:{parent_backend.comm_split_count()}" # type: ignore[attr-defined]
|
group_desc = f"{parent_pg.group_desc}:split:{parent_backend.comm_split_count()}" # type: ignore[attr-defined]
|
||||||
if group_desc is None
|
|
||||||
else group_desc
|
|
||||||
)
|
|
||||||
|
|
||||||
parent_backend_str, _ = _world.pg_map[parent_pg]
|
parent_backend_str, _ = _world.pg_map[parent_pg]
|
||||||
# same type of backend as the parent process group
|
# same type of backend as the parent process group
|
||||||
|
|
@ -5076,8 +5072,9 @@ def split_group(
|
||||||
_check_valid_timeout(timeout)
|
_check_valid_timeout(timeout)
|
||||||
|
|
||||||
# find my group of ranks and my group local rank in split_ranks
|
# find my group of ranks and my group local rank in split_ranks
|
||||||
my_group = None
|
# for ranks which are not in any split PGs, we just pass in this the first split group
|
||||||
group_rank = -1
|
# and None will be returned.
|
||||||
|
my_group = split_ranks[0]
|
||||||
|
|
||||||
for split_group in split_ranks:
|
for split_group in split_ranks:
|
||||||
if len(split_group) == 0:
|
if len(split_group) == 0:
|
||||||
|
|
@ -5091,88 +5088,40 @@ def split_group(
|
||||||
split_group = sorted(split_group)
|
split_group = sorted(split_group)
|
||||||
if parent_group_rank in split_group:
|
if parent_group_rank in split_group:
|
||||||
my_group = split_group
|
my_group = split_group
|
||||||
group_rank = split_group.index(parent_group_rank)
|
|
||||||
break
|
break
|
||||||
# if my rank does not belong to any sub group,
|
|
||||||
# no_color split should be called
|
|
||||||
if my_group is None or group_rank == -1:
|
|
||||||
parent_backend.perform_nocolor_split(device_id) # type: ignore[attr-defined]
|
|
||||||
return None
|
|
||||||
|
|
||||||
group_name = _process_group_name(my_group, use_hashed_name=False)
|
group_name = _process_group_name(my_group, use_hashed_name=False)
|
||||||
global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group]
|
split_pg = parent_pg.split_group(
|
||||||
|
my_group,
|
||||||
prefix_store = PrefixStore(f"{group_name}/", default_store)
|
timeout=timeout,
|
||||||
# We register the backend after initializing and timeout is set in pg_options.
|
opts=pg_options,
|
||||||
pg: ProcessGroup = ProcessGroup(
|
group_name=group_name,
|
||||||
prefix_store,
|
group_desc=group_desc,
|
||||||
group_rank,
|
|
||||||
len(my_group),
|
|
||||||
)
|
)
|
||||||
pg.bound_device_id = device_id # type: ignore[union-attr]
|
if split_pg is None:
|
||||||
pg_options._timeout = timeout # type: ignore[union-attr]
|
return None
|
||||||
pg_options.split_from = parent_backend # type: ignore[union-attr]
|
|
||||||
pg_options.split_color = _process_group_color(my_group) # type: ignore[union-attr]
|
|
||||||
pg_options.global_ranks_in_group = global_ranks_in_my_group # type: ignore[union-attr]
|
|
||||||
pg_options.group_name = group_name # type: ignore[union-attr]
|
|
||||||
|
|
||||||
if parent_backend_str == Backend.NCCL:
|
global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group]
|
||||||
backend_type = ProcessGroup.BackendType.NCCL
|
split_pg.bound_device_id = device_id # type: ignore[union-attr]
|
||||||
if not isinstance(pg_options, ProcessGroupNCCL.Options):
|
split_backend_class = split_pg._get_backend(torch.device("cuda"))
|
||||||
raise RuntimeError(
|
split_backend_class._set_sequence_number_for_group()
|
||||||
"Expected pg_options argument to be of type ProcessGroupNCCL.Options"
|
|
||||||
)
|
|
||||||
backend_class = ProcessGroupNCCL(
|
|
||||||
prefix_store, group_rank, len(my_group), pg_options
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert parent_backend_str.upper() in Backend._plugins, (
|
|
||||||
f"Unknown c10d backend type {parent_backend_str.upper()}"
|
|
||||||
)
|
|
||||||
backend_plugin = Backend._plugins[parent_backend_str.upper()]
|
|
||||||
creator_fn = backend_plugin.creator_fn
|
|
||||||
extended_api = backend_plugin.extended_api
|
|
||||||
backend_type = ProcessGroup.BackendType.CUSTOM
|
|
||||||
if not extended_api:
|
|
||||||
backend_class = creator_fn(prefix_store, group_rank, len(my_group), timeout)
|
|
||||||
else:
|
|
||||||
dist_backend_opts = _DistributedBackendOptions()
|
|
||||||
dist_backend_opts.store = prefix_store
|
|
||||||
dist_backend_opts.group_rank = group_rank
|
|
||||||
dist_backend_opts.group_size = len(my_group)
|
|
||||||
backend_class = creator_fn(dist_backend_opts, pg_options)
|
|
||||||
|
|
||||||
pg._set_default_backend(backend_type)
|
|
||||||
backend_class._set_sequence_number_for_group()
|
|
||||||
|
|
||||||
pg._register_backend(torch.device("cuda"), backend_type, backend_class)
|
|
||||||
|
|
||||||
# set group_name and group_desc to backend
|
|
||||||
assert group_name is not None
|
|
||||||
assert group_desc is not None
|
|
||||||
pg._set_group_name(group_name)
|
|
||||||
pg._set_group_desc(group_desc)
|
|
||||||
|
|
||||||
# always eagerly initialize the backend in split_group
|
|
||||||
eager_backend = pg._get_backend(device_id)
|
|
||||||
eager_backend.eager_connect_single_device(device_id)
|
|
||||||
|
|
||||||
# update global state
|
# update global state
|
||||||
_world.pg_map[pg] = (backend, prefix_store)
|
_world.pg_map[split_pg] = (backend, split_pg.get_group_store())
|
||||||
_world.pg_names[pg] = group_name
|
_world.pg_names[split_pg] = group_name
|
||||||
_register_process_group(group_name, pg)
|
_register_process_group(group_name, split_pg)
|
||||||
_world.pg_backend_config[pg] = str(backend_config)
|
_world.pg_backend_config[split_pg] = str(backend_config)
|
||||||
pg_tag = f"ptd:{group_name}"
|
pg_tag = f"ptd:{group_name}"
|
||||||
_world.tags_to_pg.setdefault(pg_tag, []).append(pg)
|
_world.tags_to_pg.setdefault(pg_tag, []).append(split_pg)
|
||||||
_world.pg_to_tag[pg] = pg_tag
|
_world.pg_to_tag[split_pg] = pg_tag
|
||||||
|
|
||||||
# Create the global rank to group rank mapping
|
# Create the global rank to group rank mapping
|
||||||
_world.pg_group_ranks[pg] = {
|
_world.pg_group_ranks[split_pg] = {
|
||||||
global_rank: group_rank
|
global_rank: group_rank
|
||||||
for group_rank, global_rank in enumerate(global_ranks_in_my_group)
|
for group_rank, global_rank in enumerate(global_ranks_in_my_group)
|
||||||
}
|
}
|
||||||
|
|
||||||
return pg
|
return split_pg
|
||||||
|
|
||||||
|
|
||||||
@_time_logger
|
@_time_logger
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user