mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Enable NCCL zero-copy (user buffer registration) for FSDP2 (#150564)
In recent versions NCCL introduced support for "user buffer registration", i.e., allowing user-owned memory (such as regular PyTorch tensors) to be "registered" (pinned, page-locked, etc.) with all the various hardware (NVLink, InfiniBand, ...) in order to support zero-copy transfers and thus accelerate communication and reduce resource footprint of NCCL's kernels (which reduces contention). This was already exposed in PyTorch through a custom allocator provided by the NCCL process group. DDP already uses this, via a memory pool to allow caching and reusing. FSDP2 is also particularly suited to leverage user buffer registration because the buffers it passes to NCCL are allocated by FSDP2 itself, since it anyways needs to (de)interleave the parameters to/from these private buffers. This PR adds an extra flag to FSDP2 that tells it to use the ProcessGroup allocator for these private buffers, thus allowing it to leverage NCCL zero-copy (when supported). Pull Request resolved: https://github.com/pytorch/pytorch/pull/150564 Approved by: https://github.com/kwen2501, https://github.com/weifengpy, https://github.com/syed-ahmed
This commit is contained in:
parent
11bb1ece50
commit
0a0023d984
|
|
@ -325,6 +325,8 @@ test_python_smoke() {
|
|||
test_h100_distributed() {
|
||||
# Distributed tests at H100
|
||||
time python test/run_test.py --include distributed/_composable/test_composability/test_pp_composability.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
# This test requires multicast support
|
||||
time python test/run_test.py --include distributed/_composable/fsdp/test_fully_shard_comm.py -k TestFullyShardAllocFromPG $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
import copy
|
||||
import functools
|
||||
import itertools
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
|
@ -34,7 +36,10 @@ from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
|
|||
from torch.distributed.tensor import DTensor
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.experimental import implicit_replication
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_multicast_support,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
check_sharded_parity,
|
||||
DoubleLinear,
|
||||
|
|
@ -1283,5 +1288,60 @@ class TestFullyShardUnshardMultiThread(FSDPTestMultiThread):
|
|||
self.assertEqual(ref_param, param)
|
||||
|
||||
|
||||
class TestFullyShardAllocFromPG(FSDPTest):
|
||||
# The messages might change when we move to a different NCCL version.
|
||||
# Please update this test if it starts failing.
|
||||
MEMORY_REGISTER_RE = (
|
||||
"NCCL INFO register comm 0x[0-9a-f]+ buffer 0x[0-9a-f]+ size [0-9]+"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _run(cls, *args, **kwargs):
|
||||
cls.nccl_log_dir = tempfile.TemporaryDirectory()
|
||||
os.environ["NCCL_DEBUG"] = "INFO"
|
||||
os.environ["NCCL_DEBUG_SUBSYS"] = "INIT,ENV,REG"
|
||||
os.environ["NCCL_DEBUG_FILE"] = cls.nccl_log_dir.name + "/nccl_log"
|
||||
super()._run(*args, **kwargs)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
# The NCCL PG refuses to allocate tensors if multicast is unavailable, see
|
||||
# https://github.com/pytorch/pytorch/blob/503362d019b3782581492af7767945dbd75ca1c9/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L5634
|
||||
@requires_multicast_support()
|
||||
def test_fully_shard_alloc_from_pg(self):
|
||||
torch.manual_seed(42)
|
||||
model_args = ModelArgs()
|
||||
model = Transformer(model_args)
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
fully_shard(module)
|
||||
fully_shard(model)
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
|
||||
|
||||
loss = model(inp)
|
||||
loss.sum().backward()
|
||||
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with open(self.nccl_log_dir.name + "/nccl_log") as f:
|
||||
self.assertNotRegex(f.read(), self.MEMORY_REGISTER_RE)
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
module.set_allocate_memory_from_process_group_for_comm(True)
|
||||
model.set_allocate_memory_from_process_group_for_comm(True)
|
||||
|
||||
loss = model(inp)
|
||||
loss.sum().backward()
|
||||
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with open(self.nccl_log_dir.name + "/nccl_log") as f:
|
||||
self.assertRegex(f.read(), self.MEMORY_REGISTER_RE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -319,6 +319,14 @@ class Backend:
|
|||
def _set_sequence_number_for_group(self) -> None: ...
|
||||
def _set_default_timeout(self, timeout: timedelta) -> None: ...
|
||||
def get_error(self) -> ErrorType: ...
|
||||
def supports_tensor_alloc(self, device: torch.device) -> bool: ...
|
||||
def allocate_tensor(
|
||||
self,
|
||||
size: int,
|
||||
*,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> Tensor: ...
|
||||
@property
|
||||
def mem_allocator(self) -> Any: ...
|
||||
|
||||
|
|
|
|||
|
|
@ -824,6 +824,8 @@ def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None:
|
|||
KeywordArg("rank"),
|
||||
KeywordArg("dtype"),
|
||||
KeywordArg("device"),
|
||||
KeywordArg("group_name_inner"),
|
||||
KeywordArg("allocate_memory_from_process_group"),
|
||||
),
|
||||
KeywordArg("item_idx"),
|
||||
),
|
||||
|
|
@ -862,6 +864,8 @@ def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None:
|
|||
kwargs["rank"],
|
||||
kwargs["dtype"],
|
||||
kwargs["device"],
|
||||
kwargs["group_name_inner"],
|
||||
kwargs["allocate_memory_from_process_group"],
|
||||
kwargs["group_size"],
|
||||
kwargs["group_name"],
|
||||
],
|
||||
|
|
|
|||
|
|
@ -2891,6 +2891,27 @@ Arguments:
|
|||
"_end_coalescing",
|
||||
&::c10d::Backend::endCoalescing,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"supports_tensor_alloc",
|
||||
[](::c10d::Backend& self, c10::Device device) {
|
||||
return self.supportsTensorAlloc(device.index());
|
||||
},
|
||||
py::arg("device"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"allocate_tensor",
|
||||
[](::c10d::Backend& self,
|
||||
long size,
|
||||
c10::ScalarType dtype,
|
||||
c10::Device device) {
|
||||
return self.allocateTensor(
|
||||
size, at::TensorOptions().dtype(dtype).device(device));
|
||||
},
|
||||
py::arg("size"),
|
||||
py::kw_only(),
|
||||
py::arg("dtype"),
|
||||
py::arg("device"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly(
|
||||
"mem_allocator", &::c10d::Backend::getMemAllocator);
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Callable, cast, NamedTuple, Optional, Union
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.device_mesh import _get_device_handle
|
||||
from torch.distributed.distributed_c10d import ReduceOp
|
||||
from torch.distributed.distributed_c10d import _resolve_process_group, ReduceOp
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
from ._fsdp_common import (
|
||||
|
|
@ -29,6 +29,20 @@ class AllGatherResult(NamedTuple):
|
|||
all_gather_input_split_sizes: list[int]
|
||||
|
||||
|
||||
def allocate_memory(
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
group: dist.ProcessGroup,
|
||||
from_process_group: bool,
|
||||
) -> torch.Tensor:
|
||||
if from_process_group:
|
||||
backend = group._get_backend(device)
|
||||
if backend.supports_tensor_alloc(device):
|
||||
return backend.allocate_tensor(size, dtype=dtype, device=device)
|
||||
return torch.empty((size,), dtype=dtype, device=device)
|
||||
|
||||
|
||||
lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901
|
||||
|
||||
lib.define(
|
||||
|
|
@ -40,7 +54,9 @@ lib.define(
|
|||
SymInt world_size,
|
||||
SymInt rank,
|
||||
ScalarType dtype,
|
||||
Device device
|
||||
Device device,
|
||||
str group_name,
|
||||
bool allocate_memory_from_process_group
|
||||
) -> (Tensor, Tensor)
|
||||
"""
|
||||
)
|
||||
|
|
@ -55,6 +71,8 @@ def all_gather_copy_in_meta(
|
|||
rank: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
group_name: str,
|
||||
allocate_memory_from_process_group: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_gather_output = torch.empty(
|
||||
(all_gather_input_numel * world_size,), dtype=dtype, device="meta"
|
||||
|
|
@ -79,9 +97,15 @@ def all_gather_copy_in_cuda(
|
|||
rank: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
group_name: str,
|
||||
allocate_memory_from_process_group: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_gather_output = torch.empty(
|
||||
(all_gather_input_numel * world_size,), dtype=dtype, device=device
|
||||
all_gather_output = allocate_memory(
|
||||
all_gather_input_numel * world_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
group=_resolve_process_group(group_name),
|
||||
from_process_group=allocate_memory_from_process_group,
|
||||
)
|
||||
all_gather_input = all_gather_output.narrow(
|
||||
0, all_gather_input_numel * rank, all_gather_input_numel
|
||||
|
|
@ -144,6 +168,7 @@ def foreach_all_gather(
|
|||
all_gather_copy_in_stream: torch.Stream,
|
||||
all_gather_stream: torch.Stream,
|
||||
device: torch.device,
|
||||
allocate_memory_from_process_group: bool = False,
|
||||
) -> Optional[AllGatherResult]:
|
||||
world_size, rank = group.size(), group.rank()
|
||||
device_handle = _get_device_handle(device.type)
|
||||
|
|
@ -170,6 +195,8 @@ def foreach_all_gather(
|
|||
rank,
|
||||
dtype,
|
||||
device,
|
||||
group.group_name,
|
||||
allocate_memory_from_process_group,
|
||||
)
|
||||
del param_all_gather_inputs
|
||||
all_gather_stream.wait_stream(all_gather_copy_in_stream)
|
||||
|
|
@ -360,6 +387,7 @@ def foreach_reduce(
|
|||
all_reduce_grads: bool,
|
||||
partial_reduce_output: Optional[torch.Tensor], # only used for HSDP
|
||||
all_reduce_hook: Optional[Callable[[torch.Tensor], None]],
|
||||
allocate_memory_from_process_group: bool = False,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Event,
|
||||
|
|
@ -398,8 +426,12 @@ def foreach_reduce(
|
|||
)
|
||||
reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes)
|
||||
reduce_scatter_output_numel = reduce_scatter_input_numel // world_size
|
||||
reduce_scatter_input = torch.empty(
|
||||
(reduce_scatter_input_numel,), dtype=reduce_dtype, device=device
|
||||
reduce_scatter_input = allocate_memory(
|
||||
reduce_scatter_input_numel,
|
||||
dtype=reduce_dtype,
|
||||
device=device,
|
||||
group=reduce_scatter_group,
|
||||
from_process_group=allocate_memory_from_process_group,
|
||||
)
|
||||
device_handle = _get_device_handle(device.type)
|
||||
foreach_reduce_scatter_copy_in(unsharded_grads, reduce_scatter_input, world_size)
|
||||
|
|
@ -410,7 +442,13 @@ def foreach_reduce(
|
|||
all_reduce_input = None
|
||||
all_reduce_event = None
|
||||
with device_handle.stream(reduce_scatter_stream):
|
||||
reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,))
|
||||
reduce_output = allocate_memory(
|
||||
reduce_scatter_output_numel,
|
||||
dtype=reduce_dtype,
|
||||
device=device,
|
||||
group=reduce_scatter_group,
|
||||
from_process_group=allocate_memory_from_process_group,
|
||||
)
|
||||
_div_if_needed(reduce_scatter_input, predivide_factor)
|
||||
if reduce_scatter_reduce_op is None:
|
||||
if predivide_factor is None:
|
||||
|
|
|
|||
|
|
@ -187,6 +187,9 @@ class FSDPParamGroup:
|
|||
# Whether to unshard in backward: can be overridden by the user if the
|
||||
# parameters in this group are not needed for backward (e.g. embedding)
|
||||
self.unshard_in_backward: bool = True
|
||||
# Whether to (try to) use the ProcessGroup's allocate_tensor method for
|
||||
# the staging buffers for collective comms.
|
||||
self.allocate_memory_from_process_group = False
|
||||
|
||||
# - CUDA events for stream synchronization
|
||||
# Holds the all-gather output buffer, sync objects, and metadata
|
||||
|
|
@ -276,6 +279,7 @@ class FSDPParamGroup:
|
|||
async_op,
|
||||
*self.comm_ctx.get_all_gather_streams(async_op, self._training_state),
|
||||
self.device,
|
||||
self.allocate_memory_from_process_group,
|
||||
)
|
||||
|
||||
def wait_for_unshard(self):
|
||||
|
|
@ -461,6 +465,7 @@ class FSDPParamGroup:
|
|||
self.all_reduce_grads,
|
||||
self._partial_reduce_output,
|
||||
self._all_reduce_hook,
|
||||
self.allocate_memory_from_process_group,
|
||||
)
|
||||
self.comm_ctx.reduce_scatter_state = ReduceScatterState(
|
||||
reduce_scatter_input, reduce_scatter_event
|
||||
|
|
|
|||
|
|
@ -524,6 +524,22 @@ class FSDPModule:
|
|||
if (fsdp_param_group := state._fsdp_param_group) is not None:
|
||||
fsdp_param_group.unshard_in_backward = unshard_in_backward
|
||||
|
||||
def set_allocate_memory_from_process_group_for_comm(self, enable: bool) -> None:
|
||||
"""
|
||||
Sets whether the temporary staging buffers used to send and receive data
|
||||
over collective communications should be allocated using the custom
|
||||
optimized allocator provided by the ProcessGroup itself (if any). This
|
||||
might allow the ProcessGroup to be more efficient. For example, when
|
||||
using NCCL, this enables it to leverage zero-copy transfers over SHARP
|
||||
(for NVLink and/or InfiniBand).
|
||||
|
||||
Args:
|
||||
enable (bool): Whether to turn on ProcessGroup allocation.
|
||||
"""
|
||||
state = self._get_fsdp_state()
|
||||
if (fsdp_param_group := state._fsdp_param_group) is not None:
|
||||
fsdp_param_group.allocate_memory_from_process_group = enable
|
||||
|
||||
def _set_unshard_async_op(self, async_op: bool):
|
||||
"""
|
||||
Sets whether to use ``async_op=True`` or ``False`` for the pre-forward
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user