dist2: add support for passing custom configs directly to PG (#158147)

This is intended to make it easier to have backend specific "hints" that can be provided by the user to hint about certain options.

```py
import torch.distributed._dist2 as dist2

pg = dist2.new_group(backend="my_custom_backend", device=..., timeout=..., foo=1234, bar="1234")
pg.allreduce(...)
```

Test plan:

```
pytest test/distributed/test_dist2.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158147
Approved by: https://github.com/fduwjj
This commit is contained in:
Tristan Rice 2025-07-15 00:02:50 +00:00 committed by PyTorch MergeBot
parent 7cf31b4a42
commit b7def5ff1c
2 changed files with 21 additions and 20 deletions

View File

@ -28,10 +28,14 @@ class ProcessGroupTest(TestCase):
os.environ["MASTER_PORT"] = "29500"
pg1 = dist2.new_group(
backend="gloo", timeout=timedelta(seconds=60), device="cpu", pg_options=None
backend="gloo",
timeout=timedelta(seconds=60),
device="cpu",
)
pg2 = dist2.new_group(
backend="gloo", timeout=timedelta(seconds=60), device="cpu", pg_options=None
backend="gloo",
timeout=timedelta(seconds=60),
device="cpu",
)
self.assertIsNone(dist2.current_process_group())
@ -227,7 +231,6 @@ class ProcessGroupGlooTest(Dist2MultiProcessTestCase):
backend="gloo",
timeout=timedelta(seconds=60),
device=self.device,
pg_options=None,
)
@ -242,15 +245,10 @@ class ProcessGroupNCCLTest(Dist2MultiProcessTestCase):
self.device = torch.device("cuda", self.rank)
from torch.distributed import ProcessGroupNCCL
opts = ProcessGroupNCCL.Options()
return dist2.new_group(
backend="nccl",
timeout=timedelta(seconds=60),
device=self.device,
pg_options=opts,
)

View File

@ -16,7 +16,6 @@ import torch
from torch._C._distributed_c10d import (
_current_process_group,
_set_process_group,
Backend,
ProcessGroup,
ReduceOp,
Store,
@ -47,7 +46,7 @@ class ProcessGroupFactory(Protocol):
world_size: int,
timeout: timedelta,
device: torch.device,
pg_options: Backend.Options,
**kwargs: object,
) -> ProcessGroup: ...
@ -71,11 +70,11 @@ def _gloo_factory(
world_size: int,
timeout: timedelta,
device: torch.device,
pg_options: Backend.Options,
**kwargs: object,
) -> ProcessGroup:
from torch.distributed import ProcessGroupGloo
assert pg_options is None, "Gloo backend does not support options"
assert len(kwargs) == 0, "Gloo backend received unexpected kwargs"
backend_class = ProcessGroupGloo(store, rank, world_size, timeout)
backend_class._set_sequence_number_for_group()
@ -101,15 +100,18 @@ def _nccl_factory(
world_size: int,
timeout: timedelta,
device: torch.device,
pg_options: Backend.Options,
**kwargs: object,
) -> ProcessGroup:
from torch.distributed import ProcessGroupNCCL
assert isinstance(pg_options, ProcessGroupNCCL.Options)
opts = ProcessGroupNCCL.Options()
opts._timeout = timeout
for k, v in kwargs.items():
if not hasattr(opts, k):
raise KeyError(f"Unknown option {k}")
setattr(opts, k, v)
pg_options._timeout = timeout
backend_class = ProcessGroupNCCL(store, rank, world_size, pg_options)
backend_class = ProcessGroupNCCL(store, rank, world_size, opts)
backend_class._set_sequence_number_for_group()
backend_class.eager_connect_single_device(device)
@ -128,7 +130,7 @@ def new_group(
backend: str,
timeout: timedelta,
device: Union[str, torch.device],
pg_options: Backend.Options,
**kwargs: object,
) -> ProcessGroup:
"""
Create a new process group with the given backend and options. This group is
@ -139,7 +141,8 @@ def new_group(
backend: The backend to use for the process group.
timeout: The timeout for collective operations.
device: The device to use for the process group.
pg_options: The options to use for the process group.
**kwargs: All remaining arguments are passed to the backend constructor.
See the backend specific documentation for details.
Returns:
A new process group.
@ -152,7 +155,7 @@ def new_group(
store, rank, world_size = next(iter(rendezvous("env://")))
store.set_timeout(timeout)
return _BACKENDS[backend](store, rank, world_size, timeout, device, pg_options)
return _BACKENDS[backend](store, rank, world_size, timeout, device, **kwargs)
def current_process_group() -> ProcessGroup: