mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PTD] Introduce tracing friendly collectives. (#93990)
This change adds torch.distributed.traceable_collectives. This experimental API enables collectives to be fully traced by dynamo and FX. See #93173 for the RFC Pull Request resolved: https://github.com/pytorch/pytorch/pull/93990 Approved by: https://github.com/wconstab, https://github.com/wanchaol, https://github.com/H-Huang
This commit is contained in:
parent
d0fbed76c6
commit
e22d791287
29
aten/src/ATen/native/Collectives.cpp
Normal file
29
aten/src/ATen/native/Collectives.cpp
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Parallel.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
// Dummy impl required by codegen infra, not used
|
||||
at::Tensor all_reduce(at::Tensor const& self, const c10::string_view reduceOp, const c10::string_view tag, c10::ArrayRef<int64_t> ranks, int64_t group_size) {
|
||||
// This should never get called
|
||||
// Defer to python impls in torch/distributed/_functional_collectives.py and _meta_registrations.py
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
}
|
||||
|
||||
at::Tensor wait_tensor(at::Tensor const& self) {
|
||||
// This should never get called
|
||||
// Defer to python impls in torch/distributed/_functional_collectives.py and _meta_registrations.py
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
@ -14670,3 +14670,18 @@
|
|||
dispatch:
|
||||
CUDA: _fused_adamw_kernel_cuda_
|
||||
autogen: _fused_adamw, _fused_adamw.out
|
||||
|
||||
# Collectives
|
||||
- func: all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor
|
||||
# This should be changed to distributed but it requires changes all over the place to work
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: all_reduce
|
||||
variants: function
|
||||
|
||||
- func: wait_tensor(Tensor self) -> Tensor
|
||||
# This should be changed to distributed but it requires changes all over the place to work
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: wait_tensor
|
||||
variants: function
|
||||
|
|
|
|||
|
|
@ -1231,6 +1231,7 @@ aten_native_source_non_codegen_list = [
|
|||
"aten/src/ATen/native/Bucketization.cpp",
|
||||
"aten/src/ATen/native/CPUBlas.cpp",
|
||||
"aten/src/ATen/native/ChanelShuffle.cpp",
|
||||
"aten/src/ATen/native/Collectives.cpp",
|
||||
"aten/src/ATen/native/Col2Im.cpp",
|
||||
"aten/src/ATen/native/PadNd.cpp",
|
||||
"aten/src/ATen/native/Convolution.cpp",
|
||||
|
|
|
|||
269
test/distributed/test_functional_api.py
Normal file
269
test/distributed/test_functional_api.py
Normal file
|
|
@ -0,0 +1,269 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._functional_collectives as ft_c
|
||||
import torch.distributed.distributed_c10d as c10d
|
||||
import torch.distributed._tensor as dt
|
||||
|
||||
from functorch import make_fx
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiThreadedTestCase,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
TestCase
|
||||
)
|
||||
|
||||
def new_subgroups(group_size: int, pg_tag=None):
|
||||
world_size = dist.get_world_size()
|
||||
subgroups = []
|
||||
cur_subgroup = None
|
||||
|
||||
for subgroup_id in range(world_size // group_size):
|
||||
start_rank = subgroup_id * group_size
|
||||
end_rank = start_rank + group_size
|
||||
ranks_in_subgroup = list(range(start_rank, end_rank))
|
||||
subgroup = c10d._new_group_with_tag(
|
||||
ranks=ranks_in_subgroup,
|
||||
pg_tag=pg_tag,
|
||||
)
|
||||
subgroups.append(subgroup)
|
||||
|
||||
rank = dist.get_rank()
|
||||
if rank in ranks_in_subgroup:
|
||||
cur_subgroup = subgroup
|
||||
|
||||
return cur_subgroup, subgroups
|
||||
|
||||
|
||||
class TestExpand(MultiThreadedTestCase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return 4
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_threads()
|
||||
|
||||
def test_expand_1d_rank_list(self):
|
||||
tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3])
|
||||
self.assertEqual("", tag)
|
||||
self.assertEqual([0, 1, 2, 3], rankset)
|
||||
self.assertEqual(4, group_size)
|
||||
|
||||
tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3], "bla")
|
||||
self.assertEqual("bla", tag)
|
||||
|
||||
def test_expand_2d_rank_list(self):
|
||||
tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]])
|
||||
self.assertEqual("", tag)
|
||||
self.assertEqual([0, 1, 2, 3], rankset)
|
||||
self.assertEqual(2, group_size)
|
||||
|
||||
tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]], "blu")
|
||||
self.assertEqual("blu", tag)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "group sizes must be identical"):
|
||||
ft_c._expand_group([[0], [1, 2, 3]])
|
||||
|
||||
def test_expand_process_group(self):
|
||||
tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD)
|
||||
self.assertEqual(c10d._get_group_tag(dist.group.WORLD), tag)
|
||||
self.assertEqual([0, 1, 2, 3], rankset)
|
||||
self.assertEqual(4, group_size)
|
||||
|
||||
tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD, "bla")
|
||||
self.assertEqual("bla", tag)
|
||||
|
||||
my_pg, others = new_subgroups(group_size=2)
|
||||
tag, rankset, group_size = ft_c._expand_group(my_pg)
|
||||
self.assertEqual(c10d._get_group_tag(my_pg), tag)
|
||||
self.assertEqual(dist.get_process_group_ranks(my_pg), rankset)
|
||||
self.assertEqual(2, group_size)
|
||||
|
||||
my_pg = None
|
||||
for i in range(dist.get_world_size()):
|
||||
group = c10d._new_group_with_tag([i], pg_tag="my_pg")
|
||||
if i == dist.get_rank():
|
||||
my_pg = group
|
||||
tag, rankset, group_size = ft_c._expand_group(my_pg)
|
||||
self.assertEqual("my_pg", tag)
|
||||
self.assertEqual([dist.get_rank()], rankset)
|
||||
self.assertEqual(1, group_size)
|
||||
|
||||
tag, rankset, group_size = ft_c._expand_group(my_pg, "bla")
|
||||
self.assertEqual("bla", tag)
|
||||
|
||||
def test_expand_device_mesh(self):
|
||||
mesh = dt.DeviceMesh("cpu", torch.arange(4))
|
||||
tag, rankset, group_size = ft_c._expand_group(mesh)
|
||||
self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
|
||||
self.assertEqual([0, 1, 2, 3], rankset)
|
||||
self.assertEqual(4, group_size)
|
||||
|
||||
mesh = dt.DeviceMesh("cpu", torch.arange(4))
|
||||
tag, rankset, group_size = ft_c._expand_group(mesh)
|
||||
self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
|
||||
self.assertEqual([0, 1, 2, 3], rankset)
|
||||
self.assertEqual(4, group_size)
|
||||
|
||||
def test_expand_device_mesh_tuple(self):
|
||||
mesh = dt.DeviceMesh("cpu", torch.arange(4).view(2, 2))
|
||||
tag, rankset, group_size = ft_c._expand_group(mesh)
|
||||
self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
|
||||
self.assertEqual([0, 2, 1, 3], rankset)
|
||||
self.assertEqual(2, group_size)
|
||||
|
||||
tag, rankset, group_size = ft_c._expand_group((mesh, 0))
|
||||
self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
|
||||
self.assertEqual([0, 2, 1, 3], rankset)
|
||||
self.assertEqual(2, group_size)
|
||||
|
||||
tag, rankset, group_size = ft_c._expand_group((mesh, 1))
|
||||
self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[1]), tag)
|
||||
self.assertEqual([0, 1, 2, 3], rankset)
|
||||
self.assertEqual(2, group_size)
|
||||
|
||||
class TestPgTag(MultiThreadedTestCase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return 4
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_threads()
|
||||
|
||||
"""
|
||||
The behavior we want is as follow:
|
||||
|
||||
- rankset+tag will always result in the same PG.
|
||||
Do we enforce this by failing creation of new PGs or returning existing ones?
|
||||
Return existing one.
|
||||
|
||||
- default tag gives existing behavior.
|
||||
This means we should create duplicates.
|
||||
- _expand_group on _default-tagged pg should always resolve to it
|
||||
This mean we can't depend on empty tag + rankset.
|
||||
"""
|
||||
def test_pg_creation_with_tag(self):
|
||||
my_group, _ = new_subgroups(group_size=2, pg_tag="blu")
|
||||
my_group2, _ = new_subgroups(group_size=2, pg_tag="blu")
|
||||
self.assertEqual(my_group, my_group2)
|
||||
|
||||
my_group3, _ = new_subgroups(group_size=2, pg_tag="blu2")
|
||||
self.assertNotEqual(my_group, my_group3)
|
||||
|
||||
my_group4, _ = new_subgroups(group_size=2)
|
||||
self.assertNotEqual(my_group, my_group4)
|
||||
|
||||
my_group5, _ = new_subgroups(group_size=2)
|
||||
self.assertNotEqual(my_group4, my_group5)
|
||||
|
||||
def test_pg_lookup_roundtrip(self):
|
||||
pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
|
||||
pg_tag1, _ = new_subgroups(group_size=2, pg_tag="blu2")
|
||||
pg_notag0, _ = new_subgroups(group_size=2)
|
||||
pg_notag1, _ = new_subgroups(group_size=2)
|
||||
|
||||
def roundtrip(pg):
|
||||
tag, rankset, _ = ft_c._expand_group(pg)
|
||||
return c10d._find_pg_by_ranks_and_tag(tag, rankset)
|
||||
|
||||
self.assertEqual(pg_tag0, roundtrip(pg_tag0))
|
||||
self.assertEqual(pg_tag1, roundtrip(pg_tag1))
|
||||
self.assertEqual(pg_notag0, roundtrip(pg_notag0))
|
||||
self.assertEqual(pg_notag1, roundtrip(pg_notag1))
|
||||
|
||||
def test_pg_lookup_with_tag(self):
|
||||
pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
|
||||
pg_tag1, _ = new_subgroups(group_size=2, pg_tag="bla")
|
||||
pg_notag0, _ = new_subgroups(group_size=2)
|
||||
|
||||
def roundtrip(pg, pg_tag):
|
||||
tag, rankset, _ = ft_c._expand_group(pg, pg_tag)
|
||||
return c10d._find_pg_by_ranks_and_tag(tag, rankset)
|
||||
|
||||
self.assertEqual(pg_tag0, roundtrip(pg_tag1, "blu"))
|
||||
self.assertEqual(pg_tag0, roundtrip(pg_notag0, "blu"))
|
||||
# Cannot erase the tag of a PG
|
||||
self.assertEqual(pg_tag0, roundtrip(pg_tag0, ""))
|
||||
|
||||
def test_find_or_create_pg(self):
|
||||
pg = c10d._find_or_create_pg_by_ranks_and_tag("blu", [0, 1, 2, 3], 2)
|
||||
pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
|
||||
self.assertEqual(pg, pg_tag0)
|
||||
|
||||
def test_find_root_pg(self):
|
||||
pg = c10d._find_pg_by_ranks_and_tag("", [0, 1, 2, 3])
|
||||
self.assertEqual(dist.group.WORLD, pg)
|
||||
|
||||
class TestTraceableCollectives(MultiThreadedTestCase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return 4
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_threads()
|
||||
|
||||
def test_all_reduce_eager(self):
|
||||
tensor = torch.ones([4])
|
||||
mesh = dt.DeviceMesh("cpu", torch.arange(4))
|
||||
|
||||
res = ft_c.all_reduce(tensor, "sum", mesh)
|
||||
self.assertEqual(res, torch.tensor([4, 4, 4, 4], dtype=torch.float))
|
||||
|
||||
mesh = dt.DeviceMesh("cpu", torch.arange(4).view(2, 2))
|
||||
res2 = ft_c.all_reduce(tensor, "sum", (mesh, 1))
|
||||
self.assertEqual(res2, torch.tensor([2, 2, 2, 2], dtype=torch.float))
|
||||
|
||||
class TestMetaCollectives(TestCase):
|
||||
def test_all_reduce(self):
|
||||
x = torch.rand((2, 3, 4), device="meta")
|
||||
out = ft_c.all_reduce(x, "sum", [1])
|
||||
self.assertEqual(x.size(), out.size())
|
||||
|
||||
class TestGradCollectives(MultiThreadedTestCase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return 2
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_threads()
|
||||
|
||||
def test_all_reduce(self):
|
||||
x = torch.rand([4], requires_grad=True)
|
||||
y = torch.rand([4], requires_grad=True)
|
||||
out = ft_c.all_reduce(x, "sum", [0, 1])
|
||||
(out + y).sum().backward()
|
||||
self.assertIsNone(x.grad)
|
||||
|
||||
class TestMakeFx(MultiThreadedTestCase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return 2
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_threads()
|
||||
|
||||
def test_all_reduce_tracing(self):
|
||||
def allred(input):
|
||||
return ft_c.all_reduce(input, "sum", group=[0, 1]) + 1
|
||||
|
||||
graph = make_fx(allred)(torch.rand(4))
|
||||
nodes = list(graph.graph.nodes)
|
||||
|
||||
self.assertEqual("aten::all_reduce", nodes[1].target.name())
|
||||
self.assertEqual("aten::wait_tensor", nodes[2].target.name())
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -575,6 +575,7 @@ aten::affine_grid_generator
|
|||
aten::affine_grid_generator.out
|
||||
aten::alias_copy
|
||||
aten::alias_copy.out
|
||||
aten::all_reduce
|
||||
aten::allclose
|
||||
aten::aminmax
|
||||
aten::aminmax.out
|
||||
|
|
@ -1339,6 +1340,7 @@ aten::view_copy
|
|||
aten::view_copy.dtype
|
||||
aten::view_copy.dtype_out
|
||||
aten::view_copy.out
|
||||
aten::wait_tensor
|
||||
aten::zeros.names
|
||||
aten::zeros.names_out
|
||||
aten::zeros.out
|
||||
|
|
|
|||
|
|
@ -2702,4 +2702,14 @@ def activate_meta():
|
|||
_meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)
|
||||
|
||||
|
||||
@register_meta(aten.all_reduce)
|
||||
def all_reduce_meta(self, reduceOp, tag, rankset, stride):
|
||||
return torch.empty_like(self)
|
||||
|
||||
|
||||
@register_meta(aten.wait_tensor)
|
||||
def wait_tensor_meta(self):
|
||||
return torch.empty_like(self)
|
||||
|
||||
|
||||
activate_meta()
|
||||
|
|
|
|||
237
torch/distributed/_functional_collectives.py
Normal file
237
torch/distributed/_functional_collectives.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
from typing import Any, Tuple, Union, List, cast
|
||||
|
||||
import weakref
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from torch._C import _disabled_torch_function_impl
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
import torch.distributed.distributed_c10d as c10d
|
||||
"""
|
||||
New traceable, functional collectives.
|
||||
RFC: https://github.com/pytorch/pytorch/issues/93173
|
||||
|
||||
compiler: trace these ops with plain-old-data schemas, then choose how to lower them.
|
||||
eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses,
|
||||
automatically calling .wait() on underlying/hidden async 'work' obj only when fed to
|
||||
a downstream op.
|
||||
|
||||
Issues:
|
||||
* Where should these ops live? Couldn't `import torch` if putting these ops in existing torch.distributed files
|
||||
* Proper support for eager requires inplace ops. We should explore having it as an option for the API.
|
||||
"""
|
||||
|
||||
"""
|
||||
Functional collectives are asynchronous only and we perform implicit stream synchronization
|
||||
on behalf of the user.
|
||||
|
||||
We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness
|
||||
first usage of the tensor and insert cross stream sync at the right place.
|
||||
|
||||
The above are the easy bits, the hard one is how we match the Work object returned by
|
||||
c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective
|
||||
op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the
|
||||
dispatcher which might call other implementations that are allowed to change the returned
|
||||
tensor - even return a tensor with a different shape (see ``torch.vmap``).
|
||||
|
||||
This means the caller of our ops receives a Tensor that is not guaranteed to be the same
|
||||
allocated by our implementations and that makes pairing The AsyncTensor to the original
|
||||
tensor a lot harder. This pairing is needed so we can lookup the Work object to use.
|
||||
|
||||
Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's
|
||||
identity is not stable across dispatch, the op caller would end up with a different Tensor
|
||||
instance that would not match any in the dictionary.
|
||||
|
||||
With Tensor identity out of the question, we decided use the tensor data pointer, which
|
||||
should be stable across all the Tensor changes done during dispatch.
|
||||
|
||||
We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d.
|
||||
|
||||
We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait()
|
||||
|
||||
Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we
|
||||
can clean up stale entries in the dictionary.
|
||||
|
||||
To eliminate the possiblity of races we have a global version counter that is used by the finalizer.
|
||||
|
||||
As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo)
|
||||
|
||||
"""
|
||||
data_ptr_to_work = dict()
|
||||
work_version = 0
|
||||
|
||||
def _register_tensor_work(tensor, work):
|
||||
global data_ptr_to_work
|
||||
global work_version
|
||||
data_ptr_to_work[tensor.data_ptr()] = (work_version, work)
|
||||
work_version += 1
|
||||
|
||||
def _clear_tensor(data_ptr, version):
|
||||
global data_ptr_to_work
|
||||
version_and_work = data_ptr_to_work.get(data_ptr)
|
||||
|
||||
if version_and_work is not None and version_and_work[0] == version:
|
||||
del data_ptr_to_work[data_ptr]
|
||||
|
||||
def _register_wrapper_tensor(tensor_wrapper, tensor):
|
||||
global data_ptr_to_work
|
||||
version, _ = data_ptr_to_work.get(tensor.data_ptr(), (None, None))
|
||||
if version is None:
|
||||
warnings.warn("Trying to register finalizers to AsyncCollectiveTensor but the inner tensor is already gone")
|
||||
else:
|
||||
weakref.finalize(tensor_wrapper, _clear_tensor, tensor.data_ptr(), version)
|
||||
|
||||
def _wait_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
global data_ptr_to_work
|
||||
data_ptr = tensor.data_ptr()
|
||||
version_and_work = data_ptr_to_work.get(data_ptr)
|
||||
if version_and_work is not None:
|
||||
version_and_work[1].wait()
|
||||
_clear_tensor(data_ptr, version_and_work[0])
|
||||
return tensor
|
||||
|
||||
|
||||
class AsyncCollectiveTensor(torch.Tensor):
|
||||
r"""
|
||||
A Tensor subclass that is only used in eager mode, to hold a 'work' object
|
||||
and then wait on it before invoking a real op.
|
||||
|
||||
Usage, from inside functional collective:
|
||||
def functional_collective(input):
|
||||
input = input.clone()
|
||||
mutated_input, work = c10d.{inplace_collective}(input)
|
||||
return AsyncCollectiveTensor(mutated_input, work)
|
||||
"""
|
||||
_tensor: torch.Tensor
|
||||
|
||||
__torch_function__ = _disabled_torch_function_impl
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, tensor: torch.Tensor):
|
||||
t = tensor
|
||||
r = torch.Tensor._make_subclass(cls, t, require_grad=t.requires_grad)
|
||||
r._tensor = tensor # type: ignore[attr-defined]
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
return f"AsyncCollectiveTensor({self._tensor})"
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
def unwrap(e: Any):
|
||||
if isinstance(e, AsyncCollectiveTensor):
|
||||
return wait_tensor(e._tensor)
|
||||
return e
|
||||
|
||||
unwrapped_args = tree_map(unwrap, args)
|
||||
unwrapped_kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
out = func(*unwrapped_args, **unwrapped_kwargs)
|
||||
return out
|
||||
|
||||
def _str_to_reduce_op(reduceOp: str) -> dist.ReduceOp:
|
||||
reduceOp = reduceOp.upper()
|
||||
op = dist.ReduceOp.RedOpType.__members__.get(reduceOp)
|
||||
if op is None:
|
||||
raise ValueError(f"Invalid reduce operation {reduceOp}")
|
||||
return cast(dist.ReduceOp, op)
|
||||
|
||||
# TODO assert if ranks has duplicated entries
|
||||
def _all_reduce(self, reduceOp, tag, ranks, group_size):
|
||||
op = _str_to_reduce_op(reduceOp)
|
||||
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
|
||||
assert group is not None
|
||||
|
||||
inplace_tensor = self.clone()
|
||||
work = dist.all_reduce(inplace_tensor, op=op, group=group, async_op=True)
|
||||
_register_tensor_work(inplace_tensor, work)
|
||||
|
||||
return inplace_tensor
|
||||
|
||||
c10_lib_cpu = torch.library.Library("aten", "IMPL", "CPU")
|
||||
c10_lib_cuda = torch.library.Library("aten", "IMPL", "CUDA")
|
||||
|
||||
c10_lib_cpu.impl("all_reduce", _all_reduce)
|
||||
c10_lib_cuda.impl("all_reduce", _all_reduce)
|
||||
|
||||
c10_lib_cpu.impl("wait_tensor", _wait_tensor)
|
||||
c10_lib_cuda.impl("wait_tensor", _wait_tensor)
|
||||
|
||||
|
||||
RANK_TYPES = Union[List[int], List[List[int]], dist.ProcessGroup, "dist._tensor.DeviceMesh", Tuple["dist._tensor.DeviceMesh", int]]
|
||||
|
||||
def _expand_group(group: RANK_TYPES, tag: str = "") -> Tuple[str, List[int], int]:
|
||||
# Cannot import on the top level to avoid circular imports
|
||||
import torch.distributed._tensor as dt
|
||||
rankset: List[int]
|
||||
if isinstance(group, list):
|
||||
if isinstance(group[0], list):
|
||||
nested_list = cast(List[List[int]], group)
|
||||
rankset = []
|
||||
group_size = -1
|
||||
for rs in nested_list:
|
||||
rankset.extend(rs)
|
||||
if group_size != -1 and group_size != len(rs):
|
||||
raise ValueError(f"group sizes must be identical found {group_size} and {len(rs)}")
|
||||
group_size = len(rs)
|
||||
else:
|
||||
rankset = cast(List[int], group)
|
||||
group_size = len(rankset)
|
||||
elif isinstance(group, dist.ProcessGroup):
|
||||
rankset = dist.get_process_group_ranks(group)
|
||||
group_size = len(rankset)
|
||||
tag = tag or c10d._get_group_tag(group)
|
||||
elif isinstance(group, dt.DeviceMesh):
|
||||
rankset = group.mesh.flatten().tolist()
|
||||
group_size = group.mesh.size(0)
|
||||
rankset = group.mesh.swapdims(-1, 0).reshape(-1, group_size).flatten().tolist()
|
||||
tag = tag or c10d._get_group_tag(group.get_dim_groups()[0])
|
||||
elif isinstance(group, tuple):
|
||||
if len(group) == 2 and isinstance(group[0], dt.DeviceMesh) and isinstance(group[1], int):
|
||||
dmesh = group[0]
|
||||
dim = group[1]
|
||||
group_size = dmesh.mesh.size(dim)
|
||||
rankset = dmesh.mesh.swapdims(-1, dim).reshape(-1, group_size).flatten().tolist()
|
||||
tag = tag or c10d._get_group_tag(dmesh.get_dim_groups()[dim])
|
||||
else:
|
||||
raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
|
||||
else:
|
||||
raise ValueError("Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int).")
|
||||
|
||||
return (tag, rankset, group_size)
|
||||
|
||||
|
||||
def wait_tensor(tensor):
|
||||
"""
|
||||
Wait on a tensor returned by the collectives ops.
|
||||
|
||||
Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
|
||||
"""
|
||||
return torch._C._nn.wait_tensor(tensor) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""):
|
||||
"""
|
||||
Reduces the tensor data across all machines in such a way that all get
|
||||
the final result.
|
||||
|
||||
The input tensor is left unmodified.
|
||||
|
||||
Group can be one of:
|
||||
List[int]: ranks participating in the collective.
|
||||
List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
|
||||
ProcessGroup: Will perform a collective using the ranks and tag of the PG.
|
||||
DeviceMesh: Do a SPMD collective over all ranks of the mesh
|
||||
(DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
|
||||
|
||||
:: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
|
||||
that information and perform collective algebraic optimization. Use other forms of input for that.
|
||||
"""
|
||||
tag, rankset, group_size = _expand_group(group, tag)
|
||||
tensor = torch._C._nn.all_reduce(self, reduceOp, tag, rankset, group_size) # type: ignore[attr-defined]
|
||||
res = AsyncCollectiveTensor(tensor)
|
||||
_register_wrapper_tensor(res, tensor)
|
||||
return res
|
||||
|
|
@ -10,7 +10,7 @@ import time
|
|||
import warnings
|
||||
from collections import namedtuple
|
||||
from datetime import timedelta
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union, List
|
||||
|
||||
import torch
|
||||
from torch._C._distributed_c10d import (
|
||||
|
|
@ -298,6 +298,8 @@ _pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}
|
|||
# For a pg, it is a map from ProcessGroup to BackendConfig
|
||||
_pg_backend_config: Dict[ProcessGroup, str] = {}
|
||||
_group_count = 0
|
||||
_tags_to_pg: Dict[str, List[ProcessGroup]] = {}
|
||||
_pg_to_tag: Dict[ProcessGroup, str] = {}
|
||||
|
||||
class _World:
|
||||
"""
|
||||
|
|
@ -380,6 +382,15 @@ class _World:
|
|||
global _group_count
|
||||
_group_count = value
|
||||
|
||||
@property
|
||||
def tags_to_pg(self) -> Dict[str, List[ProcessGroup]]:
|
||||
global _tags_to_pg
|
||||
return _tags_to_pg
|
||||
|
||||
@property
|
||||
def pg_to_tag(self) -> Dict[ProcessGroup, str]:
|
||||
global _pg_to_tag
|
||||
return _pg_to_tag
|
||||
|
||||
_world = _World()
|
||||
"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
|
||||
|
|
@ -900,7 +911,7 @@ def init_process_group(
|
|||
store,
|
||||
pg_options=pg_options,
|
||||
group_name=group_name,
|
||||
timeout=timeout,
|
||||
timeout=timeout
|
||||
)
|
||||
_update_default_pg(default_pg)
|
||||
|
||||
|
|
@ -929,6 +940,7 @@ def _new_process_group_helper(
|
|||
pg_options=None,
|
||||
group_name=None,
|
||||
timeout=default_pg_timeout,
|
||||
pg_tag=None
|
||||
):
|
||||
"""
|
||||
Create a new distributed process group.
|
||||
|
|
@ -956,6 +968,12 @@ def _new_process_group_helper(
|
|||
"Expected timeout argument to be of type" "datetime.timedelta"
|
||||
)
|
||||
|
||||
if pg_tag not in [None, ""]:
|
||||
# creating with the same tag and rank set results in the same underlying PG
|
||||
existing_group = _find_pg_by_ranks_and_tag(pg_tag, global_ranks_in_group)
|
||||
if existing_group:
|
||||
return existing_group
|
||||
|
||||
# The list of group ranks is empty if we're creating the default group.
|
||||
is_default_group = len(global_ranks_in_group) == 0
|
||||
|
||||
|
|
@ -1084,8 +1102,16 @@ def _new_process_group_helper(
|
|||
_world.pg_map[pg] = (backend, prefix_store)
|
||||
_world.pg_names[pg] = group_name
|
||||
_world.pg_backend_config[pg] = str(backend_config)
|
||||
return pg
|
||||
# "" is the default tag for user PGs
|
||||
if pg_tag in [None, ""]:
|
||||
pg_tag = f"ptd:{group_name}"
|
||||
_world.tags_to_pg.setdefault("", []).append(pg)
|
||||
else:
|
||||
pg_tag = f"user:{pg_tag}"
|
||||
|
||||
_world.tags_to_pg.setdefault(pg_tag, []).append(pg)
|
||||
_world.pg_to_tag[pg] = pg_tag
|
||||
return pg
|
||||
|
||||
def destroy_process_group(group: Optional[ProcessGroup] = None):
|
||||
"""
|
||||
|
|
@ -3460,7 +3486,15 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=N
|
|||
Returns:
|
||||
A handle of distributed group that can be given to collective calls.
|
||||
"""
|
||||
return _new_group_with_tag(ranks, timeout, backend, pg_options)
|
||||
|
||||
def _new_group_with_tag(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=None, pg_tag=None):
|
||||
"""
|
||||
This is a variant of ``new_group`` that exposes tag creation.
|
||||
|
||||
:: N.B. The mechanism is experimental and tied to the functional collectives effort, see
|
||||
``torch.distributed._functional_collectives`` for reference on how to use it.
|
||||
"""
|
||||
global _world
|
||||
|
||||
default_pg = _get_default_group()
|
||||
|
|
@ -3510,6 +3544,7 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=N
|
|||
default_store,
|
||||
pg_options=pg_options,
|
||||
timeout=timeout,
|
||||
pg_tag=pg_tag
|
||||
)
|
||||
|
||||
# Create the global rank to group rank mapping
|
||||
|
|
@ -3767,3 +3802,53 @@ def new_subgroups_by_enumeration(
|
|||
logger.info("Rank {} is assigned to subgroup {}".format(rank, ranks))
|
||||
|
||||
return cur_subgroup, subgroups
|
||||
|
||||
|
||||
def _find_pg_by_ranks_and_tag(tag: str, ranks: List[int]) -> ProcessGroup:
|
||||
if len(tag) > 0 and not tag.startswith("ptd:") and not tag.startswith("user:"):
|
||||
tag = f"user:{tag}"
|
||||
|
||||
for group in _world.tags_to_pg.get(tag, []):
|
||||
if group.size() != len(ranks):
|
||||
continue
|
||||
|
||||
group_ranks = get_process_group_ranks(group)
|
||||
good = all(r in group_ranks for r in ranks)
|
||||
if good:
|
||||
return group
|
||||
return None
|
||||
|
||||
def _find_or_create_pg_by_ranks_and_tag(tag: str, ranks: List[int], stride: int) -> ProcessGroup:
|
||||
assert len(ranks) % stride == 0, f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
|
||||
|
||||
my_rank = get_rank()
|
||||
my_ranks = None
|
||||
|
||||
if stride == len(ranks):
|
||||
my_ranks = ranks.copy()
|
||||
assert my_rank in my_ranks, "rankset doesn't include the current node"
|
||||
else:
|
||||
for i in range(0, len(ranks), stride):
|
||||
rank_set = ranks[i : i + stride]
|
||||
if my_rank in rank_set:
|
||||
my_ranks = rank_set
|
||||
assert my_ranks is not None, "rankset doesn't include the current node"
|
||||
|
||||
my_ranks.sort()
|
||||
|
||||
pg = _find_pg_by_ranks_and_tag(tag, my_ranks)
|
||||
if pg is not None:
|
||||
return pg
|
||||
if tag == "":
|
||||
raise ValueError("Cannot automatically create PG with empty tag")
|
||||
# TODO copy settings and timeout from default PG
|
||||
return _new_group_with_tag(my_ranks, pg_tag=tag)
|
||||
|
||||
def _get_group_tag(pg: ProcessGroup) -> str:
|
||||
"""
|
||||
Returns the tag associated with ``pg``.
|
||||
"""
|
||||
tag = _world.pg_to_tag[pg]
|
||||
if tag.startswith("user:"):
|
||||
tag = tag[5:]
|
||||
return tag
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import sys
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -297,14 +297,15 @@ class WorldData:
|
|||
pg_group_ranks: Dict[dist.ProcessGroup, Dict[int, int]]
|
||||
pg_backend_config: Dict[dist.ProcessGroup, str]
|
||||
group_count: int
|
||||
|
||||
tags_to_pg: Dict[str, List[dist.ProcessGroup]]
|
||||
pg_to_tag: Dict[dist.ProcessGroup, str]
|
||||
|
||||
class ThreadLocalWorld:
|
||||
_world = threading.local()
|
||||
|
||||
def _get_world(self) -> WorldData:
|
||||
if not hasattr(ThreadLocalWorld._world, "world"):
|
||||
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0)
|
||||
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {})
|
||||
return ThreadLocalWorld._world.world
|
||||
|
||||
@property
|
||||
|
|
@ -339,6 +340,14 @@ class ThreadLocalWorld:
|
|||
def group_count(self, value):
|
||||
self._get_world().group_count = value
|
||||
|
||||
@property
|
||||
def tags_to_pg(self):
|
||||
return self._get_world().tags_to_pg
|
||||
|
||||
@property
|
||||
def pg_to_tag(self):
|
||||
return self._get_world().pg_to_tag
|
||||
|
||||
|
||||
_old_pg_world = None
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user