dist2: add support for passing custom configs directly to PG (#158147)

This is intended to make it easier to have backend specific "hints" that can be provided by the user to hint about certain options.

```py
import torch.distributed._dist2 as dist2

pg = dist2.new_group(backend="my_custom_backend", device=..., timeout=..., foo=1234, bar="1234")
pg.allreduce(...)
```

Test plan:

```
pytest test/distributed/test_dist2.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158147
Approved by: https://github.com/fduwjj
This commit is contained in:
Tristan Rice 2025-07-15 00:02:50 +00:00 committed by PyTorch MergeBot
parent 7cf31b4a42
commit b7def5ff1c
2 changed files with 21 additions and 20 deletions

View File

@ -28,10 +28,14 @@ class ProcessGroupTest(TestCase):
os.environ["MASTER_PORT"] = "29500" os.environ["MASTER_PORT"] = "29500"
pg1 = dist2.new_group( pg1 = dist2.new_group(
backend="gloo", timeout=timedelta(seconds=60), device="cpu", pg_options=None backend="gloo",
timeout=timedelta(seconds=60),
device="cpu",
) )
pg2 = dist2.new_group( pg2 = dist2.new_group(
backend="gloo", timeout=timedelta(seconds=60), device="cpu", pg_options=None backend="gloo",
timeout=timedelta(seconds=60),
device="cpu",
) )
self.assertIsNone(dist2.current_process_group()) self.assertIsNone(dist2.current_process_group())
@ -227,7 +231,6 @@ class ProcessGroupGlooTest(Dist2MultiProcessTestCase):
backend="gloo", backend="gloo",
timeout=timedelta(seconds=60), timeout=timedelta(seconds=60),
device=self.device, device=self.device,
pg_options=None,
) )
@ -242,15 +245,10 @@ class ProcessGroupNCCLTest(Dist2MultiProcessTestCase):
self.device = torch.device("cuda", self.rank) self.device = torch.device("cuda", self.rank)
from torch.distributed import ProcessGroupNCCL
opts = ProcessGroupNCCL.Options()
return dist2.new_group( return dist2.new_group(
backend="nccl", backend="nccl",
timeout=timedelta(seconds=60), timeout=timedelta(seconds=60),
device=self.device, device=self.device,
pg_options=opts,
) )

View File

@ -16,7 +16,6 @@ import torch
from torch._C._distributed_c10d import ( from torch._C._distributed_c10d import (
_current_process_group, _current_process_group,
_set_process_group, _set_process_group,
Backend,
ProcessGroup, ProcessGroup,
ReduceOp, ReduceOp,
Store, Store,
@ -47,7 +46,7 @@ class ProcessGroupFactory(Protocol):
world_size: int, world_size: int,
timeout: timedelta, timeout: timedelta,
device: torch.device, device: torch.device,
pg_options: Backend.Options, **kwargs: object,
) -> ProcessGroup: ... ) -> ProcessGroup: ...
@ -71,11 +70,11 @@ def _gloo_factory(
world_size: int, world_size: int,
timeout: timedelta, timeout: timedelta,
device: torch.device, device: torch.device,
pg_options: Backend.Options, **kwargs: object,
) -> ProcessGroup: ) -> ProcessGroup:
from torch.distributed import ProcessGroupGloo from torch.distributed import ProcessGroupGloo
assert pg_options is None, "Gloo backend does not support options" assert len(kwargs) == 0, "Gloo backend received unexpected kwargs"
backend_class = ProcessGroupGloo(store, rank, world_size, timeout) backend_class = ProcessGroupGloo(store, rank, world_size, timeout)
backend_class._set_sequence_number_for_group() backend_class._set_sequence_number_for_group()
@ -101,15 +100,18 @@ def _nccl_factory(
world_size: int, world_size: int,
timeout: timedelta, timeout: timedelta,
device: torch.device, device: torch.device,
pg_options: Backend.Options, **kwargs: object,
) -> ProcessGroup: ) -> ProcessGroup:
from torch.distributed import ProcessGroupNCCL from torch.distributed import ProcessGroupNCCL
assert isinstance(pg_options, ProcessGroupNCCL.Options) opts = ProcessGroupNCCL.Options()
opts._timeout = timeout
for k, v in kwargs.items():
if not hasattr(opts, k):
raise KeyError(f"Unknown option {k}")
setattr(opts, k, v)
pg_options._timeout = timeout backend_class = ProcessGroupNCCL(store, rank, world_size, opts)
backend_class = ProcessGroupNCCL(store, rank, world_size, pg_options)
backend_class._set_sequence_number_for_group() backend_class._set_sequence_number_for_group()
backend_class.eager_connect_single_device(device) backend_class.eager_connect_single_device(device)
@ -128,7 +130,7 @@ def new_group(
backend: str, backend: str,
timeout: timedelta, timeout: timedelta,
device: Union[str, torch.device], device: Union[str, torch.device],
pg_options: Backend.Options, **kwargs: object,
) -> ProcessGroup: ) -> ProcessGroup:
""" """
Create a new process group with the given backend and options. This group is Create a new process group with the given backend and options. This group is
@ -139,7 +141,8 @@ def new_group(
backend: The backend to use for the process group. backend: The backend to use for the process group.
timeout: The timeout for collective operations. timeout: The timeout for collective operations.
device: The device to use for the process group. device: The device to use for the process group.
pg_options: The options to use for the process group. **kwargs: All remaining arguments are passed to the backend constructor.
See the backend specific documentation for details.
Returns: Returns:
A new process group. A new process group.
@ -152,7 +155,7 @@ def new_group(
store, rank, world_size = next(iter(rendezvous("env://"))) store, rank, world_size = next(iter(rendezvous("env://")))
store.set_timeout(timeout) store.set_timeout(timeout)
return _BACKENDS[backend](store, rank, world_size, timeout, device, pg_options) return _BACKENDS[backend](store, rank, world_size, timeout, device, **kwargs)
def current_process_group() -> ProcessGroup: def current_process_group() -> ProcessGroup: