mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
dist2: add support for passing custom configs directly to PG (#158147)
This is intended to make it easier to have backend specific "hints" that can be provided by the user to hint about certain options. ```py import torch.distributed._dist2 as dist2 pg = dist2.new_group(backend="my_custom_backend", device=..., timeout=..., foo=1234, bar="1234") pg.allreduce(...) ``` Test plan: ``` pytest test/distributed/test_dist2.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158147 Approved by: https://github.com/fduwjj
This commit is contained in:
parent
7cf31b4a42
commit
b7def5ff1c
|
|
@ -28,10 +28,14 @@ class ProcessGroupTest(TestCase):
|
||||||
os.environ["MASTER_PORT"] = "29500"
|
os.environ["MASTER_PORT"] = "29500"
|
||||||
|
|
||||||
pg1 = dist2.new_group(
|
pg1 = dist2.new_group(
|
||||||
backend="gloo", timeout=timedelta(seconds=60), device="cpu", pg_options=None
|
backend="gloo",
|
||||||
|
timeout=timedelta(seconds=60),
|
||||||
|
device="cpu",
|
||||||
)
|
)
|
||||||
pg2 = dist2.new_group(
|
pg2 = dist2.new_group(
|
||||||
backend="gloo", timeout=timedelta(seconds=60), device="cpu", pg_options=None
|
backend="gloo",
|
||||||
|
timeout=timedelta(seconds=60),
|
||||||
|
device="cpu",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNone(dist2.current_process_group())
|
self.assertIsNone(dist2.current_process_group())
|
||||||
|
|
@ -227,7 +231,6 @@ class ProcessGroupGlooTest(Dist2MultiProcessTestCase):
|
||||||
backend="gloo",
|
backend="gloo",
|
||||||
timeout=timedelta(seconds=60),
|
timeout=timedelta(seconds=60),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
pg_options=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -242,15 +245,10 @@ class ProcessGroupNCCLTest(Dist2MultiProcessTestCase):
|
||||||
|
|
||||||
self.device = torch.device("cuda", self.rank)
|
self.device = torch.device("cuda", self.rank)
|
||||||
|
|
||||||
from torch.distributed import ProcessGroupNCCL
|
|
||||||
|
|
||||||
opts = ProcessGroupNCCL.Options()
|
|
||||||
|
|
||||||
return dist2.new_group(
|
return dist2.new_group(
|
||||||
backend="nccl",
|
backend="nccl",
|
||||||
timeout=timedelta(seconds=60),
|
timeout=timedelta(seconds=60),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
pg_options=opts,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ import torch
|
||||||
from torch._C._distributed_c10d import (
|
from torch._C._distributed_c10d import (
|
||||||
_current_process_group,
|
_current_process_group,
|
||||||
_set_process_group,
|
_set_process_group,
|
||||||
Backend,
|
|
||||||
ProcessGroup,
|
ProcessGroup,
|
||||||
ReduceOp,
|
ReduceOp,
|
||||||
Store,
|
Store,
|
||||||
|
|
@ -47,7 +46,7 @@ class ProcessGroupFactory(Protocol):
|
||||||
world_size: int,
|
world_size: int,
|
||||||
timeout: timedelta,
|
timeout: timedelta,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
pg_options: Backend.Options,
|
**kwargs: object,
|
||||||
) -> ProcessGroup: ...
|
) -> ProcessGroup: ...
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -71,11 +70,11 @@ def _gloo_factory(
|
||||||
world_size: int,
|
world_size: int,
|
||||||
timeout: timedelta,
|
timeout: timedelta,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
pg_options: Backend.Options,
|
**kwargs: object,
|
||||||
) -> ProcessGroup:
|
) -> ProcessGroup:
|
||||||
from torch.distributed import ProcessGroupGloo
|
from torch.distributed import ProcessGroupGloo
|
||||||
|
|
||||||
assert pg_options is None, "Gloo backend does not support options"
|
assert len(kwargs) == 0, "Gloo backend received unexpected kwargs"
|
||||||
|
|
||||||
backend_class = ProcessGroupGloo(store, rank, world_size, timeout)
|
backend_class = ProcessGroupGloo(store, rank, world_size, timeout)
|
||||||
backend_class._set_sequence_number_for_group()
|
backend_class._set_sequence_number_for_group()
|
||||||
|
|
@ -101,15 +100,18 @@ def _nccl_factory(
|
||||||
world_size: int,
|
world_size: int,
|
||||||
timeout: timedelta,
|
timeout: timedelta,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
pg_options: Backend.Options,
|
**kwargs: object,
|
||||||
) -> ProcessGroup:
|
) -> ProcessGroup:
|
||||||
from torch.distributed import ProcessGroupNCCL
|
from torch.distributed import ProcessGroupNCCL
|
||||||
|
|
||||||
assert isinstance(pg_options, ProcessGroupNCCL.Options)
|
opts = ProcessGroupNCCL.Options()
|
||||||
|
opts._timeout = timeout
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if not hasattr(opts, k):
|
||||||
|
raise KeyError(f"Unknown option {k}")
|
||||||
|
setattr(opts, k, v)
|
||||||
|
|
||||||
pg_options._timeout = timeout
|
backend_class = ProcessGroupNCCL(store, rank, world_size, opts)
|
||||||
|
|
||||||
backend_class = ProcessGroupNCCL(store, rank, world_size, pg_options)
|
|
||||||
backend_class._set_sequence_number_for_group()
|
backend_class._set_sequence_number_for_group()
|
||||||
backend_class.eager_connect_single_device(device)
|
backend_class.eager_connect_single_device(device)
|
||||||
|
|
||||||
|
|
@ -128,7 +130,7 @@ def new_group(
|
||||||
backend: str,
|
backend: str,
|
||||||
timeout: timedelta,
|
timeout: timedelta,
|
||||||
device: Union[str, torch.device],
|
device: Union[str, torch.device],
|
||||||
pg_options: Backend.Options,
|
**kwargs: object,
|
||||||
) -> ProcessGroup:
|
) -> ProcessGroup:
|
||||||
"""
|
"""
|
||||||
Create a new process group with the given backend and options. This group is
|
Create a new process group with the given backend and options. This group is
|
||||||
|
|
@ -139,7 +141,8 @@ def new_group(
|
||||||
backend: The backend to use for the process group.
|
backend: The backend to use for the process group.
|
||||||
timeout: The timeout for collective operations.
|
timeout: The timeout for collective operations.
|
||||||
device: The device to use for the process group.
|
device: The device to use for the process group.
|
||||||
pg_options: The options to use for the process group.
|
**kwargs: All remaining arguments are passed to the backend constructor.
|
||||||
|
See the backend specific documentation for details.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A new process group.
|
A new process group.
|
||||||
|
|
@ -152,7 +155,7 @@ def new_group(
|
||||||
store, rank, world_size = next(iter(rendezvous("env://")))
|
store, rank, world_size = next(iter(rendezvous("env://")))
|
||||||
store.set_timeout(timeout)
|
store.set_timeout(timeout)
|
||||||
|
|
||||||
return _BACKENDS[backend](store, rank, world_size, timeout, device, pg_options)
|
return _BACKENDS[backend](store, rank, world_size, timeout, device, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def current_process_group() -> ProcessGroup:
|
def current_process_group() -> ProcessGroup:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user