Enable NCCL zero-copy (user buffer registration) for FSDP2 (#150564)

In recent versions NCCL introduced support for "user buffer registration", i.e., allowing user-owned memory (such as regular PyTorch tensors) to be "registered" (pinned, page-locked, etc.) with all the various hardware (NVLink, InfiniBand, ...) in order to support zero-copy transfers and thus accelerate communication and reduce resource footprint of NCCL's kernels (which reduces contention).

This was already exposed in PyTorch through a custom allocator provided by the NCCL process group. DDP already uses this, via a memory pool to allow caching and reusing.

FSDP2 is also particularly suited to leverage user buffer registration because the buffers it passes to NCCL are allocated by FSDP2 itself, since it anyways needs to (de)interleave the parameters to/from these private buffers.

This PR adds an extra flag to FSDP2 that tells it to use the ProcessGroup allocator for these private buffers, thus allowing it to leverage NCCL zero-copy (when supported).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150564
Approved by: https://github.com/kwen2501, https://github.com/weifengpy, https://github.com/syed-ahmed
This commit is contained in:
Luca Wehrstedt 2025-06-17 08:36:13 +00:00 committed by PyTorch MergeBot
parent 11bb1ece50
commit 0a0023d984
8 changed files with 162 additions and 8 deletions

View File

@ -325,6 +325,8 @@ test_python_smoke() {
test_h100_distributed() {
# Distributed tests at H100
time python test/run_test.py --include distributed/_composable/test_composability/test_pp_composability.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
# This test requires multicast support
time python test/run_test.py --include distributed/_composable/fsdp/test_fully_shard_comm.py -k TestFullyShardAllocFromPG $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
assert_git_not_dirty
}

View File

@ -3,6 +3,8 @@
import copy
import functools
import itertools
import os
import tempfile
from typing import Callable, Optional, Union
import torch
@ -34,7 +36,10 @@ from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.experimental import implicit_replication
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_distributed import (
requires_multicast_support,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_fsdp import (
check_sharded_parity,
DoubleLinear,
@ -1283,5 +1288,60 @@ class TestFullyShardUnshardMultiThread(FSDPTestMultiThread):
self.assertEqual(ref_param, param)
class TestFullyShardAllocFromPG(FSDPTest):
# The messages might change when we move to a different NCCL version.
# Please update this test if it starts failing.
MEMORY_REGISTER_RE = (
"NCCL INFO register comm 0x[0-9a-f]+ buffer 0x[0-9a-f]+ size [0-9]+"
)
@classmethod
def _run(cls, *args, **kwargs):
cls.nccl_log_dir = tempfile.TemporaryDirectory()
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["NCCL_DEBUG_SUBSYS"] = "INIT,ENV,REG"
os.environ["NCCL_DEBUG_FILE"] = cls.nccl_log_dir.name + "/nccl_log"
super()._run(*args, **kwargs)
@skip_if_lt_x_gpu(2)
# The NCCL PG refuses to allocate tensors if multicast is unavailable, see
# https://github.com/pytorch/pytorch/blob/503362d019b3782581492af7767945dbd75ca1c9/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L5634
@requires_multicast_support()
def test_fully_shard_alloc_from_pg(self):
torch.manual_seed(42)
model_args = ModelArgs()
model = Transformer(model_args)
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module)
fully_shard(model)
torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
loss = model(inp)
loss.sum().backward()
torch.distributed.barrier()
torch.cuda.synchronize()
with open(self.nccl_log_dir.name + "/nccl_log") as f:
self.assertNotRegex(f.read(), self.MEMORY_REGISTER_RE)
for module in model.modules():
if isinstance(module, TransformerBlock):
module.set_allocate_memory_from_process_group_for_comm(True)
model.set_allocate_memory_from_process_group_for_comm(True)
loss = model(inp)
loss.sum().backward()
torch.distributed.barrier()
torch.cuda.synchronize()
with open(self.nccl_log_dir.name + "/nccl_log") as f:
self.assertRegex(f.read(), self.MEMORY_REGISTER_RE)
if __name__ == "__main__":
run_tests()

View File

@ -319,6 +319,14 @@ class Backend:
def _set_sequence_number_for_group(self) -> None: ...
def _set_default_timeout(self, timeout: timedelta) -> None: ...
def get_error(self) -> ErrorType: ...
def supports_tensor_alloc(self, device: torch.device) -> bool: ...
def allocate_tensor(
self,
size: int,
*,
dtype: torch.dtype,
device: torch.device,
) -> Tensor: ...
@property
def mem_allocator(self) -> Any: ...

View File

@ -824,6 +824,8 @@ def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None:
KeywordArg("rank"),
KeywordArg("dtype"),
KeywordArg("device"),
KeywordArg("group_name_inner"),
KeywordArg("allocate_memory_from_process_group"),
),
KeywordArg("item_idx"),
),
@ -862,6 +864,8 @@ def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None:
kwargs["rank"],
kwargs["dtype"],
kwargs["device"],
kwargs["group_name_inner"],
kwargs["allocate_memory_from_process_group"],
kwargs["group_size"],
kwargs["group_name"],
],

View File

@ -2891,6 +2891,27 @@ Arguments:
"_end_coalescing",
&::c10d::Backend::endCoalescing,
py::call_guard<py::gil_scoped_release>())
.def(
"supports_tensor_alloc",
[](::c10d::Backend& self, c10::Device device) {
return self.supportsTensorAlloc(device.index());
},
py::arg("device"),
py::call_guard<py::gil_scoped_release>())
.def(
"allocate_tensor",
[](::c10d::Backend& self,
long size,
c10::ScalarType dtype,
c10::Device device) {
return self.allocateTensor(
size, at::TensorOptions().dtype(dtype).device(device));
},
py::arg("size"),
py::kw_only(),
py::arg("dtype"),
py::arg("device"),
py::call_guard<py::gil_scoped_release>())
.def_property_readonly(
"mem_allocator", &::c10d::Backend::getMemAllocator);

View File

@ -4,7 +4,7 @@ from typing import Callable, cast, NamedTuple, Optional, Union
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import _get_device_handle
from torch.distributed.distributed_c10d import ReduceOp
from torch.distributed.distributed_c10d import _resolve_process_group, ReduceOp
from torch.distributed.tensor import DTensor
from ._fsdp_common import (
@ -29,6 +29,20 @@ class AllGatherResult(NamedTuple):
all_gather_input_split_sizes: list[int]
def allocate_memory(
size: int,
dtype: torch.dtype,
device: torch.device,
group: dist.ProcessGroup,
from_process_group: bool,
) -> torch.Tensor:
if from_process_group:
backend = group._get_backend(device)
if backend.supports_tensor_alloc(device):
return backend.allocate_tensor(size, dtype=dtype, device=device)
return torch.empty((size,), dtype=dtype, device=device)
lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901
lib.define(
@ -40,7 +54,9 @@ lib.define(
SymInt world_size,
SymInt rank,
ScalarType dtype,
Device device
Device device,
str group_name,
bool allocate_memory_from_process_group
) -> (Tensor, Tensor)
"""
)
@ -55,6 +71,8 @@ def all_gather_copy_in_meta(
rank: int,
dtype: torch.dtype,
device: torch.device,
group_name: str,
allocate_memory_from_process_group: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
all_gather_output = torch.empty(
(all_gather_input_numel * world_size,), dtype=dtype, device="meta"
@ -79,9 +97,15 @@ def all_gather_copy_in_cuda(
rank: int,
dtype: torch.dtype,
device: torch.device,
group_name: str,
allocate_memory_from_process_group: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
all_gather_output = torch.empty(
(all_gather_input_numel * world_size,), dtype=dtype, device=device
all_gather_output = allocate_memory(
all_gather_input_numel * world_size,
dtype=dtype,
device=device,
group=_resolve_process_group(group_name),
from_process_group=allocate_memory_from_process_group,
)
all_gather_input = all_gather_output.narrow(
0, all_gather_input_numel * rank, all_gather_input_numel
@ -144,6 +168,7 @@ def foreach_all_gather(
all_gather_copy_in_stream: torch.Stream,
all_gather_stream: torch.Stream,
device: torch.device,
allocate_memory_from_process_group: bool = False,
) -> Optional[AllGatherResult]:
world_size, rank = group.size(), group.rank()
device_handle = _get_device_handle(device.type)
@ -170,6 +195,8 @@ def foreach_all_gather(
rank,
dtype,
device,
group.group_name,
allocate_memory_from_process_group,
)
del param_all_gather_inputs
all_gather_stream.wait_stream(all_gather_copy_in_stream)
@ -360,6 +387,7 @@ def foreach_reduce(
all_reduce_grads: bool,
partial_reduce_output: Optional[torch.Tensor], # only used for HSDP
all_reduce_hook: Optional[Callable[[torch.Tensor], None]],
allocate_memory_from_process_group: bool = False,
) -> tuple[
torch.Tensor,
torch.Event,
@ -398,8 +426,12 @@ def foreach_reduce(
)
reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes)
reduce_scatter_output_numel = reduce_scatter_input_numel // world_size
reduce_scatter_input = torch.empty(
(reduce_scatter_input_numel,), dtype=reduce_dtype, device=device
reduce_scatter_input = allocate_memory(
reduce_scatter_input_numel,
dtype=reduce_dtype,
device=device,
group=reduce_scatter_group,
from_process_group=allocate_memory_from_process_group,
)
device_handle = _get_device_handle(device.type)
foreach_reduce_scatter_copy_in(unsharded_grads, reduce_scatter_input, world_size)
@ -410,7 +442,13 @@ def foreach_reduce(
all_reduce_input = None
all_reduce_event = None
with device_handle.stream(reduce_scatter_stream):
reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,))
reduce_output = allocate_memory(
reduce_scatter_output_numel,
dtype=reduce_dtype,
device=device,
group=reduce_scatter_group,
from_process_group=allocate_memory_from_process_group,
)
_div_if_needed(reduce_scatter_input, predivide_factor)
if reduce_scatter_reduce_op is None:
if predivide_factor is None:

View File

@ -187,6 +187,9 @@ class FSDPParamGroup:
# Whether to unshard in backward: can be overridden by the user if the
# parameters in this group are not needed for backward (e.g. embedding)
self.unshard_in_backward: bool = True
# Whether to (try to) use the ProcessGroup's allocate_tensor method for
# the staging buffers for collective comms.
self.allocate_memory_from_process_group = False
# - CUDA events for stream synchronization
# Holds the all-gather output buffer, sync objects, and metadata
@ -276,6 +279,7 @@ class FSDPParamGroup:
async_op,
*self.comm_ctx.get_all_gather_streams(async_op, self._training_state),
self.device,
self.allocate_memory_from_process_group,
)
def wait_for_unshard(self):
@ -461,6 +465,7 @@ class FSDPParamGroup:
self.all_reduce_grads,
self._partial_reduce_output,
self._all_reduce_hook,
self.allocate_memory_from_process_group,
)
self.comm_ctx.reduce_scatter_state = ReduceScatterState(
reduce_scatter_input, reduce_scatter_event

View File

@ -524,6 +524,22 @@ class FSDPModule:
if (fsdp_param_group := state._fsdp_param_group) is not None:
fsdp_param_group.unshard_in_backward = unshard_in_backward
def set_allocate_memory_from_process_group_for_comm(self, enable: bool) -> None:
"""
Sets whether the temporary staging buffers used to send and receive data
over collective communications should be allocated using the custom
optimized allocator provided by the ProcessGroup itself (if any). This
might allow the ProcessGroup to be more efficient. For example, when
using NCCL, this enables it to leverage zero-copy transfers over SHARP
(for NVLink and/or InfiniBand).
Args:
enable (bool): Whether to turn on ProcessGroup allocation.
"""
state = self._get_fsdp_state()
if (fsdp_param_group := state._fsdp_param_group) is not None:
fsdp_param_group.allocate_memory_from_process_group = enable
def _set_unshard_async_op(self, async_op: bool):
"""
Sets whether to use ``async_op=True`` or ``False`` for the pre-forward