[FSDP][Collectives] skipping allgather when world size is 1 (#160135)

**Summary:** In its current state, FSDP collectives uses cuda synchronizations and communication ops regardless of what the world size is. However, now that replicate will use FSDP, there will be instances where group size = 1 and these synchronizations and ops will be used needlessly. I have updated fsdp_params group to skip the foreach_all_gather and foreach_all_gather_copy_out APIs when world_size ‎ = 1. I have created a test that uses CommDebugMode to verify that the all gather comm has been removed. I also edited an affected test which used 1-way FSDP by verifying and changing its assert statements for CommDebugMode. Below, I have included the link to the profile trace verifying these two APIs were skipped and two test commands.

https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/anshulsi_f846ac3b-9467-4060-8e36-8cc3bc4449c3_devgpu263.prn2.facebook.com_652183.1753822140871934814.pt.trace.json

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160135
Approved by: https://github.com/weifengpy
This commit is contained in:
Anshul Sinha 2025-08-12 10:06:12 -07:00 committed by PyTorch MergeBot
parent b4596895b9
commit c24ca7f4bf
4 changed files with 159 additions and 24 deletions

View File

@ -299,12 +299,20 @@ val.shape: {[node.meta["val"].shape for node in aliased_graph_inputs]},
def _reinplace_all_gather_with_optional_checks(self, fwd_fullgraph):
def _run_with_checks(graph, orig_fn):
if self.world_size > 1:
self.assertGreater(
_count_op_in_graph(
graph, torch.ops._c10d_functional.all_gather_into_tensor.default
),
0,
)
elif self.world_size == 1:
self.assertEqual(
_count_op_in_graph(
graph, torch.ops._c10d_functional.all_gather_into_tensor.default
),
0,
)
orig_fn(graph)
@ -315,9 +323,19 @@ val.shape: {[node.meta["val"].shape for node in aliased_graph_inputs]},
0,
)
if self.world_size > 1:
self.assertGreater(
_count_op_in_graph(
graph, torch.ops._c10d_functional.all_gather_into_tensor_out.default
graph,
torch.ops._c10d_functional.all_gather_into_tensor_out.default,
),
0,
)
else:
self.assertEqual(
_count_op_in_graph(
graph,
torch.ops._c10d_functional.all_gather_into_tensor_out.default,
),
0,
)

View File

@ -1467,5 +1467,70 @@ class TestFullyShardCustomForwardMethod(FSDPTest):
check_sharded_parity(self, ref_model, model)
class TestFullyShardWorldSize1(FSDPTest):
@property
def world_size(self) -> int:
return 1
@skip_if_lt_x_gpu(1)
def test_train_parity_single_worldsize1(self):
"""
Tests train parity with DDP for a single FSDP group when sharding
parameters on dim-0.
"""
self.run_subtests(
{
"lin_shapes": [
[(16, 15), (15, 8)],
[(7, 15), (15, 3)],
[(16, 17), (17, 8)],
],
"use_shard_placement_fn": [False],
},
self._test_train_parity_single_group,
)
def _test_train_parity_single_group(
self, lin_shapes: list[tuple[int, int]], use_shard_placement_fn: bool
):
torch.manual_seed(42)
model = nn.Sequential(
nn.Linear(*lin_shapes[0]), nn.ReLU(), nn.Linear(*lin_shapes[1])
)
ref_model = copy.deepcopy(model).to(device_type)
replicate(ref_model, device_ids=[self.rank])
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
return Shard(param.shape.index(max(param.shape)))
shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None
fully_shard(model, shard_placement_fn=shard_placement_fn)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
torch.manual_seed(42 + self.rank + 1)
inp = (torch.randn((4, lin_shapes[0][0]), device=device_type.type),)
for iter_idx in range(10):
losses: list[torch.Tensor] = []
ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(ref_model(*inp).sum())
losses[-1].backward()
ref_optim.step()
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
comm_mode = CommDebugMode()
with comm_mode:
losses.append(model(*inp).sum())
losses[-1].backward()
# Before there was 1 all-gather and 1 reduce-scatter
# Now therre is 1 reduce-scatter
self.assertEqual(comm_mode.get_total_counts(), 1)
optim.step()
self.assertEqual(losses[0], losses[1])
if __name__ == "__main__":
run_tests()

View File

@ -277,19 +277,19 @@ class TestFullyShard2DTraining(FSDPTest):
loss = model(inp).sum()
fwd_comm_counts = fwd_comm_mode.get_comm_counts()
self.assertEqual(len(fwd_comm_counts), 2)
self.assertEqual(len(fwd_comm_counts), 1)
self.assertEqual(fwd_comm_counts[funcol.all_reduce], num_mlps)
self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_mlps)
self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], 0)
ref_loss = ref_model(inp).sum()
self.assertEqual(loss, ref_loss)
with CommDebugMode() as bwd_comm_mode:
loss.backward()
bwd_comm_counts = bwd_comm_mode.get_comm_counts()
self.assertEqual(len(bwd_comm_counts), 3)
self.assertEqual(len(bwd_comm_counts), 2)
# First MLP's input gradient does not need to be all-reduced
self.assertEqual(bwd_comm_counts[funcol.all_reduce], num_mlps - 1)
self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_mlps)
self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], 0)
self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_mlps)
ref_loss.backward()

View File

@ -32,7 +32,7 @@ from ._fsdp_common import (
HSDPMeshInfo,
TrainingState,
)
from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState
from ._fsdp_param import alloc_storage, FSDPParam, ParamModuleInfo, ShardedState
logger = logging.getLogger("torch.distributed.fsdp.fully_shard")
@ -166,6 +166,7 @@ class FSDPParamGroup:
self._module_to_pre_load_state_dict_hook_handle: _ModuleToHandleDict = {}
self._all_reduce_hook: Optional[Callable[[torch.Tensor], None]] = None
self._all_gather_comm: AllGather = DefaultAllGather()
self._all_gather_output = torch.empty(0, device=self.device)
self._reduce_scatter_comm: ReduceScatter = DefaultReduceScatter()
# Optional stream to run the user-defined all-reduce hook in
# Saved here and not in the comm. context because we allow the user to
@ -310,6 +311,22 @@ class FSDPParamGroup:
# used in the all-gather streams
self._wait_all_gather_streams_on_event(self._reshard_after_forward_event)
self._reshard_after_forward_event = None
world_size = self._all_gather_process_group.size()
if world_size == 1:
# can't skip due to early return in wait_for_unshard if
# no self._all_gather_result
self._all_gather_result = AllGatherResult(
all_gather_output=self._all_gather_output,
all_gather_event=self.device_handle.Event().record(),
all_gather_work=None,
param_all_gather_input_dtypes=[],
param_all_gather_input_numels=[],
all_gather_input_split_sizes=[],
)
return
with record_function(self._with_fqn("FSDP::all_gather")):
self._all_gather_result = foreach_all_gather(
self.fsdp_params,
@ -336,18 +353,52 @@ class FSDPParamGroup:
if prev_all_gather_state := self.comm_ctx.all_gather_state:
self._wait_all_gather_streams_on_event(prev_all_gather_state.event)
self.comm_ctx.all_gather_state = None # free the all-gather result
world_size = self._all_gather_process_group.size()
if world_size == 1:
# directly initialize unsharded parameters from sharded parameters
for fsdp_param in self.fsdp_params:
# Use all_gather_inputs which already handles conversion to param_dtype
# This is consistent with the world_size > 1 path
all_gather_input = fsdp_param.all_gather_inputs[0]
# Make sure the all_gather_outputs has proper storage size before using it
# First ensure we have at least one tensor in all_gather_outputs
fsdp_param.init_all_gather_outputs(
[all_gather_input.numel()],
[all_gather_input.dtype],
world_size,
self.device,
force_recreate=False,
)
tensor = fsdp_param.all_gather_outputs[0]
alloc_storage(tensor)
# find alternative way to check if tensor.is_inference
with torch.autograd._unsafe_preserve_version_counter(tensor):
tensor.copy_(all_gather_input)
else:
with record_function(self._with_fqn("FSDP::all_gather_copy_out")):
foreach_all_gather_copy_out(
self._all_gather_result,
self.fsdp_params,
self._all_gather_process_group,
)
for fsdp_param in self.fsdp_params:
fsdp_param.init_unsharded_param()
self._to_unsharded()
all_gather_copy_out_event = self.device_handle.Event()
all_gather_copy_out_event.record()
if not async_op and self._training_state == TrainingState.FORWARD:
if (
not async_op
and self._training_state == TrainingState.FORWARD
and world_size > 1
):
# Defer free to allow for overlap of this copy-out with next
# all-gather collective
self.comm_ctx.all_gather_state = AllGatherState(
@ -355,6 +406,7 @@ class FSDPParamGroup:
)
else:
self._wait_all_gather_streams_on_event(all_gather_copy_out_event)
self._all_gather_result = None # free unless saved in `all_gather_state`
def _wait_all_gather_streams_on_event(self, event: Optional[torch.Event]):