mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
dist2: add support for passing custom configs directly to PG (#158147)
This is intended to make it easier to have backend specific "hints" that can be provided by the user to hint about certain options. ```py import torch.distributed._dist2 as dist2 pg = dist2.new_group(backend="my_custom_backend", device=..., timeout=..., foo=1234, bar="1234") pg.allreduce(...) ``` Test plan: ``` pytest test/distributed/test_dist2.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158147 Approved by: https://github.com/fduwjj
This commit is contained in:
parent
7cf31b4a42
commit
b7def5ff1c
|
|
@ -28,10 +28,14 @@ class ProcessGroupTest(TestCase):
|
|||
os.environ["MASTER_PORT"] = "29500"
|
||||
|
||||
pg1 = dist2.new_group(
|
||||
backend="gloo", timeout=timedelta(seconds=60), device="cpu", pg_options=None
|
||||
backend="gloo",
|
||||
timeout=timedelta(seconds=60),
|
||||
device="cpu",
|
||||
)
|
||||
pg2 = dist2.new_group(
|
||||
backend="gloo", timeout=timedelta(seconds=60), device="cpu", pg_options=None
|
||||
backend="gloo",
|
||||
timeout=timedelta(seconds=60),
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
self.assertIsNone(dist2.current_process_group())
|
||||
|
|
@ -227,7 +231,6 @@ class ProcessGroupGlooTest(Dist2MultiProcessTestCase):
|
|||
backend="gloo",
|
||||
timeout=timedelta(seconds=60),
|
||||
device=self.device,
|
||||
pg_options=None,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -242,15 +245,10 @@ class ProcessGroupNCCLTest(Dist2MultiProcessTestCase):
|
|||
|
||||
self.device = torch.device("cuda", self.rank)
|
||||
|
||||
from torch.distributed import ProcessGroupNCCL
|
||||
|
||||
opts = ProcessGroupNCCL.Options()
|
||||
|
||||
return dist2.new_group(
|
||||
backend="nccl",
|
||||
timeout=timedelta(seconds=60),
|
||||
device=self.device,
|
||||
pg_options=opts,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ import torch
|
|||
from torch._C._distributed_c10d import (
|
||||
_current_process_group,
|
||||
_set_process_group,
|
||||
Backend,
|
||||
ProcessGroup,
|
||||
ReduceOp,
|
||||
Store,
|
||||
|
|
@ -47,7 +46,7 @@ class ProcessGroupFactory(Protocol):
|
|||
world_size: int,
|
||||
timeout: timedelta,
|
||||
device: torch.device,
|
||||
pg_options: Backend.Options,
|
||||
**kwargs: object,
|
||||
) -> ProcessGroup: ...
|
||||
|
||||
|
||||
|
|
@ -71,11 +70,11 @@ def _gloo_factory(
|
|||
world_size: int,
|
||||
timeout: timedelta,
|
||||
device: torch.device,
|
||||
pg_options: Backend.Options,
|
||||
**kwargs: object,
|
||||
) -> ProcessGroup:
|
||||
from torch.distributed import ProcessGroupGloo
|
||||
|
||||
assert pg_options is None, "Gloo backend does not support options"
|
||||
assert len(kwargs) == 0, "Gloo backend received unexpected kwargs"
|
||||
|
||||
backend_class = ProcessGroupGloo(store, rank, world_size, timeout)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
|
@ -101,15 +100,18 @@ def _nccl_factory(
|
|||
world_size: int,
|
||||
timeout: timedelta,
|
||||
device: torch.device,
|
||||
pg_options: Backend.Options,
|
||||
**kwargs: object,
|
||||
) -> ProcessGroup:
|
||||
from torch.distributed import ProcessGroupNCCL
|
||||
|
||||
assert isinstance(pg_options, ProcessGroupNCCL.Options)
|
||||
opts = ProcessGroupNCCL.Options()
|
||||
opts._timeout = timeout
|
||||
for k, v in kwargs.items():
|
||||
if not hasattr(opts, k):
|
||||
raise KeyError(f"Unknown option {k}")
|
||||
setattr(opts, k, v)
|
||||
|
||||
pg_options._timeout = timeout
|
||||
|
||||
backend_class = ProcessGroupNCCL(store, rank, world_size, pg_options)
|
||||
backend_class = ProcessGroupNCCL(store, rank, world_size, opts)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
backend_class.eager_connect_single_device(device)
|
||||
|
||||
|
|
@ -128,7 +130,7 @@ def new_group(
|
|||
backend: str,
|
||||
timeout: timedelta,
|
||||
device: Union[str, torch.device],
|
||||
pg_options: Backend.Options,
|
||||
**kwargs: object,
|
||||
) -> ProcessGroup:
|
||||
"""
|
||||
Create a new process group with the given backend and options. This group is
|
||||
|
|
@ -139,7 +141,8 @@ def new_group(
|
|||
backend: The backend to use for the process group.
|
||||
timeout: The timeout for collective operations.
|
||||
device: The device to use for the process group.
|
||||
pg_options: The options to use for the process group.
|
||||
**kwargs: All remaining arguments are passed to the backend constructor.
|
||||
See the backend specific documentation for details.
|
||||
|
||||
Returns:
|
||||
A new process group.
|
||||
|
|
@ -152,7 +155,7 @@ def new_group(
|
|||
store, rank, world_size = next(iter(rendezvous("env://")))
|
||||
store.set_timeout(timeout)
|
||||
|
||||
return _BACKENDS[backend](store, rank, world_size, timeout, device, pg_options)
|
||||
return _BACKENDS[backend](store, rank, world_size, timeout, device, **kwargs)
|
||||
|
||||
|
||||
def current_process_group() -> ProcessGroup:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user