XCCL changes for DDP (#155497)

Add XCCL documentation for DDP

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155497
Approved by: https://github.com/guangyey, https://github.com/AlannaBurke

Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
This commit is contained in:
Tanima Dey 2025-07-03 05:18:08 +00:00 committed by PyTorch MergeBot
parent 382598ef87
commit 4ce6e6ec88

View File

@ -347,20 +347,33 @@ class DistributedDataParallel(Module, Joinable):
To use ``DistributedDataParallel`` on a host with N GPUs, you should spawn
up ``N`` processes, ensuring that each process exclusively works on a single
GPU from 0 to N-1. This can be done by either setting
``CUDA_VISIBLE_DEVICES`` for every process or by calling:
``CUDA_VISIBLE_DEVICES`` for every process or by calling the following API for GPUs,
>>> # xdoctest: +SKIP("undefined variables")
>>> torch.cuda.set_device(i)
or calling the unified API for :ref:`accelerator<accelerators>`,
>>> # xdoctest: +SKIP("undefined variables")
>>> torch.accelerator.set_device_index(i)
where i is from 0 to N-1. In each process, you should refer the following
to construct this module:
>>> # xdoctest: +SKIP("undefined variables")
>>> if torch.accelerator.is_available():
>>> device_type = torch.accelerator.current_accelerator().type
>>> vendor_backend = torch.distributed.get_default_backend_for_device(device_type)
>>>
>>> torch.distributed.init_process_group(
>>> backend='nccl', world_size=N, init_method='...'
>>> backend=vendor_backend, world_size=N, init_method='...'
>>> )
>>> model = DistributedDataParallel(model, device_ids=[i], output_device=i)
Or you can use the latest API for initialization:
>>> torch.distributed.init_process_group(device_id=i)
In order to spawn up multiple processes per node, you can use either
``torch.distributed.launch`` or ``torch.multiprocessing.spawn``.