[PTD] Introduce tracing friendly collectives. (#93990)

This change adds torch.distributed.traceable_collectives.

This experimental API enables collectives to be fully traced by dynamo and FX.

See #93173 for the RFC

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93990
Approved by: https://github.com/wconstab, https://github.com/wanchaol, https://github.com/H-Huang
This commit is contained in:
Rodrigo Kumpera 2023-02-16 15:35:01 +00:00 committed by PyTorch MergeBot
parent d0fbed76c6
commit e22d791287
9 changed files with 663 additions and 6 deletions

View File

@ -0,0 +1,29 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Parallel.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#endif
namespace at {
namespace native {
// Dummy impl required by codegen infra, not used
at::Tensor all_reduce(at::Tensor const& self, const c10::string_view reduceOp, const c10::string_view tag, c10::ArrayRef<int64_t> ranks, int64_t group_size) {
// This should never get called
// Defer to python impls in torch/distributed/_functional_collectives.py and _meta_registrations.py
TORCH_INTERNAL_ASSERT(false);
}
at::Tensor wait_tensor(at::Tensor const& self) {
// This should never get called
// Defer to python impls in torch/distributed/_functional_collectives.py and _meta_registrations.py
TORCH_INTERNAL_ASSERT(false);
}
} // namespace native
} // namespace at

View File

@ -14670,3 +14670,18 @@
dispatch:
CUDA: _fused_adamw_kernel_cuda_
autogen: _fused_adamw, _fused_adamw.out
# Collectives
- func: all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor
# This should be changed to distributed but it requires changes all over the place to work
python_module: nn
dispatch:
CompositeExplicitAutograd: all_reduce
variants: function
- func: wait_tensor(Tensor self) -> Tensor
# This should be changed to distributed but it requires changes all over the place to work
python_module: nn
dispatch:
CompositeExplicitAutograd: wait_tensor
variants: function

View File

@ -1231,6 +1231,7 @@ aten_native_source_non_codegen_list = [
"aten/src/ATen/native/Bucketization.cpp",
"aten/src/ATen/native/CPUBlas.cpp",
"aten/src/ATen/native/ChanelShuffle.cpp",
"aten/src/ATen/native/Collectives.cpp",
"aten/src/ATen/native/Col2Im.cpp",
"aten/src/ATen/native/PadNd.cpp",
"aten/src/ATen/native/Convolution.cpp",

View File

@ -0,0 +1,269 @@
# Owner(s): ["oncall: distributed"]
import sys
import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as ft_c
import torch.distributed.distributed_c10d as c10d
import torch.distributed._tensor as dt
from functorch import make_fx
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_distributed import (
MultiThreadedTestCase,
)
from torch.testing._internal.common_utils import (
run_tests,
TestCase
)
def new_subgroups(group_size: int, pg_tag=None):
world_size = dist.get_world_size()
subgroups = []
cur_subgroup = None
for subgroup_id in range(world_size // group_size):
start_rank = subgroup_id * group_size
end_rank = start_rank + group_size
ranks_in_subgroup = list(range(start_rank, end_rank))
subgroup = c10d._new_group_with_tag(
ranks=ranks_in_subgroup,
pg_tag=pg_tag,
)
subgroups.append(subgroup)
rank = dist.get_rank()
if rank in ranks_in_subgroup:
cur_subgroup = subgroup
return cur_subgroup, subgroups
class TestExpand(MultiThreadedTestCase):
@property
def world_size(self):
return 4
def setUp(self):
super().setUp()
self._spawn_threads()
def test_expand_1d_rank_list(self):
tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3])
self.assertEqual("", tag)
self.assertEqual([0, 1, 2, 3], rankset)
self.assertEqual(4, group_size)
tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3], "bla")
self.assertEqual("bla", tag)
def test_expand_2d_rank_list(self):
tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]])
self.assertEqual("", tag)
self.assertEqual([0, 1, 2, 3], rankset)
self.assertEqual(2, group_size)
tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]], "blu")
self.assertEqual("blu", tag)
with self.assertRaisesRegex(ValueError, "group sizes must be identical"):
ft_c._expand_group([[0], [1, 2, 3]])
def test_expand_process_group(self):
tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD)
self.assertEqual(c10d._get_group_tag(dist.group.WORLD), tag)
self.assertEqual([0, 1, 2, 3], rankset)
self.assertEqual(4, group_size)
tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD, "bla")
self.assertEqual("bla", tag)
my_pg, others = new_subgroups(group_size=2)
tag, rankset, group_size = ft_c._expand_group(my_pg)
self.assertEqual(c10d._get_group_tag(my_pg), tag)
self.assertEqual(dist.get_process_group_ranks(my_pg), rankset)
self.assertEqual(2, group_size)
my_pg = None
for i in range(dist.get_world_size()):
group = c10d._new_group_with_tag([i], pg_tag="my_pg")
if i == dist.get_rank():
my_pg = group
tag, rankset, group_size = ft_c._expand_group(my_pg)
self.assertEqual("my_pg", tag)
self.assertEqual([dist.get_rank()], rankset)
self.assertEqual(1, group_size)
tag, rankset, group_size = ft_c._expand_group(my_pg, "bla")
self.assertEqual("bla", tag)
def test_expand_device_mesh(self):
mesh = dt.DeviceMesh("cpu", torch.arange(4))
tag, rankset, group_size = ft_c._expand_group(mesh)
self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
self.assertEqual([0, 1, 2, 3], rankset)
self.assertEqual(4, group_size)
mesh = dt.DeviceMesh("cpu", torch.arange(4))
tag, rankset, group_size = ft_c._expand_group(mesh)
self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
self.assertEqual([0, 1, 2, 3], rankset)
self.assertEqual(4, group_size)
def test_expand_device_mesh_tuple(self):
mesh = dt.DeviceMesh("cpu", torch.arange(4).view(2, 2))
tag, rankset, group_size = ft_c._expand_group(mesh)
self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
self.assertEqual([0, 2, 1, 3], rankset)
self.assertEqual(2, group_size)
tag, rankset, group_size = ft_c._expand_group((mesh, 0))
self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
self.assertEqual([0, 2, 1, 3], rankset)
self.assertEqual(2, group_size)
tag, rankset, group_size = ft_c._expand_group((mesh, 1))
self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[1]), tag)
self.assertEqual([0, 1, 2, 3], rankset)
self.assertEqual(2, group_size)
class TestPgTag(MultiThreadedTestCase):
@property
def world_size(self):
return 4
def setUp(self):
super().setUp()
self._spawn_threads()
"""
The behavior we want is as follow:
- rankset+tag will always result in the same PG.
Do we enforce this by failing creation of new PGs or returning existing ones?
Return existing one.
- default tag gives existing behavior.
This means we should create duplicates.
- _expand_group on _default-tagged pg should always resolve to it
This mean we can't depend on empty tag + rankset.
"""
def test_pg_creation_with_tag(self):
my_group, _ = new_subgroups(group_size=2, pg_tag="blu")
my_group2, _ = new_subgroups(group_size=2, pg_tag="blu")
self.assertEqual(my_group, my_group2)
my_group3, _ = new_subgroups(group_size=2, pg_tag="blu2")
self.assertNotEqual(my_group, my_group3)
my_group4, _ = new_subgroups(group_size=2)
self.assertNotEqual(my_group, my_group4)
my_group5, _ = new_subgroups(group_size=2)
self.assertNotEqual(my_group4, my_group5)
def test_pg_lookup_roundtrip(self):
pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
pg_tag1, _ = new_subgroups(group_size=2, pg_tag="blu2")
pg_notag0, _ = new_subgroups(group_size=2)
pg_notag1, _ = new_subgroups(group_size=2)
def roundtrip(pg):
tag, rankset, _ = ft_c._expand_group(pg)
return c10d._find_pg_by_ranks_and_tag(tag, rankset)
self.assertEqual(pg_tag0, roundtrip(pg_tag0))
self.assertEqual(pg_tag1, roundtrip(pg_tag1))
self.assertEqual(pg_notag0, roundtrip(pg_notag0))
self.assertEqual(pg_notag1, roundtrip(pg_notag1))
def test_pg_lookup_with_tag(self):
pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
pg_tag1, _ = new_subgroups(group_size=2, pg_tag="bla")
pg_notag0, _ = new_subgroups(group_size=2)
def roundtrip(pg, pg_tag):
tag, rankset, _ = ft_c._expand_group(pg, pg_tag)
return c10d._find_pg_by_ranks_and_tag(tag, rankset)
self.assertEqual(pg_tag0, roundtrip(pg_tag1, "blu"))
self.assertEqual(pg_tag0, roundtrip(pg_notag0, "blu"))
# Cannot erase the tag of a PG
self.assertEqual(pg_tag0, roundtrip(pg_tag0, ""))
def test_find_or_create_pg(self):
pg = c10d._find_or_create_pg_by_ranks_and_tag("blu", [0, 1, 2, 3], 2)
pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
self.assertEqual(pg, pg_tag0)
def test_find_root_pg(self):
pg = c10d._find_pg_by_ranks_and_tag("", [0, 1, 2, 3])
self.assertEqual(dist.group.WORLD, pg)
class TestTraceableCollectives(MultiThreadedTestCase):
@property
def world_size(self):
return 4
def setUp(self):
super().setUp()
self._spawn_threads()
def test_all_reduce_eager(self):
tensor = torch.ones([4])
mesh = dt.DeviceMesh("cpu", torch.arange(4))
res = ft_c.all_reduce(tensor, "sum", mesh)
self.assertEqual(res, torch.tensor([4, 4, 4, 4], dtype=torch.float))
mesh = dt.DeviceMesh("cpu", torch.arange(4).view(2, 2))
res2 = ft_c.all_reduce(tensor, "sum", (mesh, 1))
self.assertEqual(res2, torch.tensor([2, 2, 2, 2], dtype=torch.float))
class TestMetaCollectives(TestCase):
def test_all_reduce(self):
x = torch.rand((2, 3, 4), device="meta")
out = ft_c.all_reduce(x, "sum", [1])
self.assertEqual(x.size(), out.size())
class TestGradCollectives(MultiThreadedTestCase):
@property
def world_size(self):
return 2
def setUp(self):
super().setUp()
self._spawn_threads()
def test_all_reduce(self):
x = torch.rand([4], requires_grad=True)
y = torch.rand([4], requires_grad=True)
out = ft_c.all_reduce(x, "sum", [0, 1])
(out + y).sum().backward()
self.assertIsNone(x.grad)
class TestMakeFx(MultiThreadedTestCase):
@property
def world_size(self):
return 2
def setUp(self):
super().setUp()
self._spawn_threads()
def test_all_reduce_tracing(self):
def allred(input):
return ft_c.all_reduce(input, "sum", group=[0, 1]) + 1
graph = make_fx(allred)(torch.rand(4))
nodes = list(graph.graph.nodes)
self.assertEqual("aten::all_reduce", nodes[1].target.name())
self.assertEqual("aten::wait_tensor", nodes[2].target.name())
if __name__ == "__main__":
run_tests()

View File

@ -575,6 +575,7 @@ aten::affine_grid_generator
aten::affine_grid_generator.out
aten::alias_copy
aten::alias_copy.out
aten::all_reduce
aten::allclose
aten::aminmax
aten::aminmax.out
@ -1339,6 +1340,7 @@ aten::view_copy
aten::view_copy.dtype
aten::view_copy.dtype_out
aten::view_copy.out
aten::wait_tensor
aten::zeros.names
aten::zeros.names_out
aten::zeros.out

View File

@ -2702,4 +2702,14 @@ def activate_meta():
_meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)
@register_meta(aten.all_reduce)
def all_reduce_meta(self, reduceOp, tag, rankset, stride):
return torch.empty_like(self)
@register_meta(aten.wait_tensor)
def wait_tensor_meta(self):
return torch.empty_like(self)
activate_meta()

View File

@ -0,0 +1,237 @@
from typing import Any, Tuple, Union, List, cast
import weakref
import warnings
import torch
import torch.distributed as dist
from torch._C import _disabled_torch_function_impl
from torch.utils._pytree import tree_map
import torch.distributed.distributed_c10d as c10d
"""
New traceable, functional collectives.
RFC: https://github.com/pytorch/pytorch/issues/93173
compiler: trace these ops with plain-old-data schemas, then choose how to lower them.
eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses,
automatically calling .wait() on underlying/hidden async 'work' obj only when fed to
a downstream op.
Issues:
* Where should these ops live? Couldn't `import torch` if putting these ops in existing torch.distributed files
* Proper support for eager requires inplace ops. We should explore having it as an option for the API.
"""
"""
Functional collectives are asynchronous only and we perform implicit stream synchronization
on behalf of the user.
We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness
first usage of the tensor and insert cross stream sync at the right place.
The above are the easy bits, the hard one is how we match the Work object returned by
c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective
op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the
dispatcher which might call other implementations that are allowed to change the returned
tensor - even return a tensor with a different shape (see ``torch.vmap``).
This means the caller of our ops receives a Tensor that is not guaranteed to be the same
allocated by our implementations and that makes pairing The AsyncTensor to the original
tensor a lot harder. This pairing is needed so we can lookup the Work object to use.
Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's
identity is not stable across dispatch, the op caller would end up with a different Tensor
instance that would not match any in the dictionary.
With Tensor identity out of the question, we decided use the tensor data pointer, which
should be stable across all the Tensor changes done during dispatch.
We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d.
We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait()
Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we
can clean up stale entries in the dictionary.
To eliminate the possiblity of races we have a global version counter that is used by the finalizer.
As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo)
"""
data_ptr_to_work = dict()
work_version = 0
def _register_tensor_work(tensor, work):
global data_ptr_to_work
global work_version
data_ptr_to_work[tensor.data_ptr()] = (work_version, work)
work_version += 1
def _clear_tensor(data_ptr, version):
global data_ptr_to_work
version_and_work = data_ptr_to_work.get(data_ptr)
if version_and_work is not None and version_and_work[0] == version:
del data_ptr_to_work[data_ptr]
def _register_wrapper_tensor(tensor_wrapper, tensor):
global data_ptr_to_work
version, _ = data_ptr_to_work.get(tensor.data_ptr(), (None, None))
if version is None:
warnings.warn("Trying to register finalizers to AsyncCollectiveTensor but the inner tensor is already gone")
else:
weakref.finalize(tensor_wrapper, _clear_tensor, tensor.data_ptr(), version)
def _wait_tensor(tensor: torch.Tensor) -> torch.Tensor:
global data_ptr_to_work
data_ptr = tensor.data_ptr()
version_and_work = data_ptr_to_work.get(data_ptr)
if version_and_work is not None:
version_and_work[1].wait()
_clear_tensor(data_ptr, version_and_work[0])
return tensor
class AsyncCollectiveTensor(torch.Tensor):
r"""
A Tensor subclass that is only used in eager mode, to hold a 'work' object
and then wait on it before invoking a real op.
Usage, from inside functional collective:
def functional_collective(input):
input = input.clone()
mutated_input, work = c10d.{inplace_collective}(input)
return AsyncCollectiveTensor(mutated_input, work)
"""
_tensor: torch.Tensor
__torch_function__ = _disabled_torch_function_impl
@staticmethod
def __new__(cls, tensor: torch.Tensor):
t = tensor
r = torch.Tensor._make_subclass(cls, t, require_grad=t.requires_grad)
r._tensor = tensor # type: ignore[attr-defined]
return r
def __repr__(self):
return f"AsyncCollectiveTensor({self._tensor})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e: Any):
if isinstance(e, AsyncCollectiveTensor):
return wait_tensor(e._tensor)
return e
unwrapped_args = tree_map(unwrap, args)
unwrapped_kwargs = tree_map(unwrap, kwargs)
out = func(*unwrapped_args, **unwrapped_kwargs)
return out
def _str_to_reduce_op(reduceOp: str) -> dist.ReduceOp:
reduceOp = reduceOp.upper()
op = dist.ReduceOp.RedOpType.__members__.get(reduceOp)
if op is None:
raise ValueError(f"Invalid reduce operation {reduceOp}")
return cast(dist.ReduceOp, op)
# TODO assert if ranks has duplicated entries
def _all_reduce(self, reduceOp, tag, ranks, group_size):
op = _str_to_reduce_op(reduceOp)
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
assert group is not None
inplace_tensor = self.clone()
work = dist.all_reduce(inplace_tensor, op=op, group=group, async_op=True)
_register_tensor_work(inplace_tensor, work)
return inplace_tensor
c10_lib_cpu = torch.library.Library("aten", "IMPL", "CPU")
c10_lib_cuda = torch.library.Library("aten", "IMPL", "CUDA")
c10_lib_cpu.impl("all_reduce", _all_reduce)
c10_lib_cuda.impl("all_reduce", _all_reduce)
c10_lib_cpu.impl("wait_tensor", _wait_tensor)
c10_lib_cuda.impl("wait_tensor", _wait_tensor)
RANK_TYPES = Union[List[int], List[List[int]], dist.ProcessGroup, "dist._tensor.DeviceMesh", Tuple["dist._tensor.DeviceMesh", int]]
def _expand_group(group: RANK_TYPES, tag: str = "") -> Tuple[str, List[int], int]:
# Cannot import on the top level to avoid circular imports
import torch.distributed._tensor as dt
rankset: List[int]
if isinstance(group, list):
if isinstance(group[0], list):
nested_list = cast(List[List[int]], group)
rankset = []
group_size = -1
for rs in nested_list:
rankset.extend(rs)
if group_size != -1 and group_size != len(rs):
raise ValueError(f"group sizes must be identical found {group_size} and {len(rs)}")
group_size = len(rs)
else:
rankset = cast(List[int], group)
group_size = len(rankset)
elif isinstance(group, dist.ProcessGroup):
rankset = dist.get_process_group_ranks(group)
group_size = len(rankset)
tag = tag or c10d._get_group_tag(group)
elif isinstance(group, dt.DeviceMesh):
rankset = group.mesh.flatten().tolist()
group_size = group.mesh.size(0)
rankset = group.mesh.swapdims(-1, 0).reshape(-1, group_size).flatten().tolist()
tag = tag or c10d._get_group_tag(group.get_dim_groups()[0])
elif isinstance(group, tuple):
if len(group) == 2 and isinstance(group[0], dt.DeviceMesh) and isinstance(group[1], int):
dmesh = group[0]
dim = group[1]
group_size = dmesh.mesh.size(dim)
rankset = dmesh.mesh.swapdims(-1, dim).reshape(-1, group_size).flatten().tolist()
tag = tag or c10d._get_group_tag(dmesh.get_dim_groups()[dim])
else:
raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
else:
raise ValueError("Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int).")
return (tag, rankset, group_size)
def wait_tensor(tensor):
"""
Wait on a tensor returned by the collectives ops.
Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
"""
return torch._C._nn.wait_tensor(tensor) # type: ignore[attr-defined]
def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""):
"""
Reduces the tensor data across all machines in such a way that all get
the final result.
The input tensor is left unmodified.
Group can be one of:
List[int]: ranks participating in the collective.
List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
ProcessGroup: Will perform a collective using the ranks and tag of the PG.
DeviceMesh: Do a SPMD collective over all ranks of the mesh
(DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
:: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
that information and perform collective algebraic optimization. Use other forms of input for that.
"""
tag, rankset, group_size = _expand_group(group, tag)
tensor = torch._C._nn.all_reduce(self, reduceOp, tag, rankset, group_size) # type: ignore[attr-defined]
res = AsyncCollectiveTensor(tensor)
_register_wrapper_tensor(res, tensor)
return res

View File

@ -10,7 +10,7 @@ import time
import warnings
from collections import namedtuple
from datetime import timedelta
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union, List
import torch
from torch._C._distributed_c10d import (
@ -298,6 +298,8 @@ _pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}
# For a pg, it is a map from ProcessGroup to BackendConfig
_pg_backend_config: Dict[ProcessGroup, str] = {}
_group_count = 0
_tags_to_pg: Dict[str, List[ProcessGroup]] = {}
_pg_to_tag: Dict[ProcessGroup, str] = {}
class _World:
"""
@ -380,6 +382,15 @@ class _World:
global _group_count
_group_count = value
@property
def tags_to_pg(self) -> Dict[str, List[ProcessGroup]]:
global _tags_to_pg
return _tags_to_pg
@property
def pg_to_tag(self) -> Dict[ProcessGroup, str]:
global _pg_to_tag
return _pg_to_tag
_world = _World()
"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
@ -900,7 +911,7 @@ def init_process_group(
store,
pg_options=pg_options,
group_name=group_name,
timeout=timeout,
timeout=timeout
)
_update_default_pg(default_pg)
@ -929,6 +940,7 @@ def _new_process_group_helper(
pg_options=None,
group_name=None,
timeout=default_pg_timeout,
pg_tag=None
):
"""
Create a new distributed process group.
@ -956,6 +968,12 @@ def _new_process_group_helper(
"Expected timeout argument to be of type" "datetime.timedelta"
)
if pg_tag not in [None, ""]:
# creating with the same tag and rank set results in the same underlying PG
existing_group = _find_pg_by_ranks_and_tag(pg_tag, global_ranks_in_group)
if existing_group:
return existing_group
# The list of group ranks is empty if we're creating the default group.
is_default_group = len(global_ranks_in_group) == 0
@ -1084,8 +1102,16 @@ def _new_process_group_helper(
_world.pg_map[pg] = (backend, prefix_store)
_world.pg_names[pg] = group_name
_world.pg_backend_config[pg] = str(backend_config)
return pg
# "" is the default tag for user PGs
if pg_tag in [None, ""]:
pg_tag = f"ptd:{group_name}"
_world.tags_to_pg.setdefault("", []).append(pg)
else:
pg_tag = f"user:{pg_tag}"
_world.tags_to_pg.setdefault(pg_tag, []).append(pg)
_world.pg_to_tag[pg] = pg_tag
return pg
def destroy_process_group(group: Optional[ProcessGroup] = None):
"""
@ -3460,7 +3486,15 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=N
Returns:
A handle of distributed group that can be given to collective calls.
"""
return _new_group_with_tag(ranks, timeout, backend, pg_options)
def _new_group_with_tag(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=None, pg_tag=None):
"""
This is a variant of ``new_group`` that exposes tag creation.
:: N.B. The mechanism is experimental and tied to the functional collectives effort, see
``torch.distributed._functional_collectives`` for reference on how to use it.
"""
global _world
default_pg = _get_default_group()
@ -3510,6 +3544,7 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=N
default_store,
pg_options=pg_options,
timeout=timeout,
pg_tag=pg_tag
)
# Create the global rank to group rank mapping
@ -3767,3 +3802,53 @@ def new_subgroups_by_enumeration(
logger.info("Rank {} is assigned to subgroup {}".format(rank, ranks))
return cur_subgroup, subgroups
def _find_pg_by_ranks_and_tag(tag: str, ranks: List[int]) -> ProcessGroup:
if len(tag) > 0 and not tag.startswith("ptd:") and not tag.startswith("user:"):
tag = f"user:{tag}"
for group in _world.tags_to_pg.get(tag, []):
if group.size() != len(ranks):
continue
group_ranks = get_process_group_ranks(group)
good = all(r in group_ranks for r in ranks)
if good:
return group
return None
def _find_or_create_pg_by_ranks_and_tag(tag: str, ranks: List[int], stride: int) -> ProcessGroup:
assert len(ranks) % stride == 0, f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
my_rank = get_rank()
my_ranks = None
if stride == len(ranks):
my_ranks = ranks.copy()
assert my_rank in my_ranks, "rankset doesn't include the current node"
else:
for i in range(0, len(ranks), stride):
rank_set = ranks[i : i + stride]
if my_rank in rank_set:
my_ranks = rank_set
assert my_ranks is not None, "rankset doesn't include the current node"
my_ranks.sort()
pg = _find_pg_by_ranks_and_tag(tag, my_ranks)
if pg is not None:
return pg
if tag == "":
raise ValueError("Cannot automatically create PG with empty tag")
# TODO copy settings and timeout from default PG
return _new_group_with_tag(my_ranks, pg_tag=tag)
def _get_group_tag(pg: ProcessGroup) -> str:
"""
Returns the tag associated with ``pg``.
"""
tag = _world.pg_to_tag[pg]
if tag.startswith("user:"):
tag = tag[5:]
return tag

View File

@ -1,7 +1,7 @@
import sys
import threading
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
@ -297,14 +297,15 @@ class WorldData:
pg_group_ranks: Dict[dist.ProcessGroup, Dict[int, int]]
pg_backend_config: Dict[dist.ProcessGroup, str]
group_count: int
tags_to_pg: Dict[str, List[dist.ProcessGroup]]
pg_to_tag: Dict[dist.ProcessGroup, str]
class ThreadLocalWorld:
_world = threading.local()
def _get_world(self) -> WorldData:
if not hasattr(ThreadLocalWorld._world, "world"):
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0)
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {})
return ThreadLocalWorld._world.world
@property
@ -339,6 +340,14 @@ class ThreadLocalWorld:
def group_count(self, value):
self._get_world().group_count = value
@property
def tags_to_pg(self):
return self._get_world().tags_to_pg
@property
def pg_to_tag(self):
return self._get_world().pg_to_tag
_old_pg_world = None