mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
[FSDP][Collectives] skipping allgather when world size is 1 (#160135)
**Summary:** In its current state, FSDP collectives uses cuda synchronizations and communication ops regardless of what the world size is. However, now that replicate will use FSDP, there will be instances where group size = 1 and these synchronizations and ops will be used needlessly. I have updated fsdp_params group to skip the foreach_all_gather and foreach_all_gather_copy_out APIs when world_size = 1. I have created a test that uses CommDebugMode to verify that the all gather comm has been removed. I also edited an affected test which used 1-way FSDP by verifying and changing its assert statements for CommDebugMode. Below, I have included the link to the profile trace verifying these two APIs were skipped and two test commands. https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/anshulsi_f846ac3b-9467-4060-8e36-8cc3bc4449c3_devgpu263.prn2.facebook.com_652183.1753822140871934814.pt.trace.json Pull Request resolved: https://github.com/pytorch/pytorch/pull/160135 Approved by: https://github.com/weifengpy
This commit is contained in:
parent
b4596895b9
commit
c24ca7f4bf
|
|
@ -299,12 +299,20 @@ val.shape: {[node.meta["val"].shape for node in aliased_graph_inputs]},
|
|||
|
||||
def _reinplace_all_gather_with_optional_checks(self, fwd_fullgraph):
|
||||
def _run_with_checks(graph, orig_fn):
|
||||
if self.world_size > 1:
|
||||
self.assertGreater(
|
||||
_count_op_in_graph(
|
||||
graph, torch.ops._c10d_functional.all_gather_into_tensor.default
|
||||
),
|
||||
0,
|
||||
)
|
||||
elif self.world_size == 1:
|
||||
self.assertEqual(
|
||||
_count_op_in_graph(
|
||||
graph, torch.ops._c10d_functional.all_gather_into_tensor.default
|
||||
),
|
||||
0,
|
||||
)
|
||||
|
||||
orig_fn(graph)
|
||||
|
||||
|
|
@ -315,9 +323,19 @@ val.shape: {[node.meta["val"].shape for node in aliased_graph_inputs]},
|
|||
0,
|
||||
)
|
||||
|
||||
if self.world_size > 1:
|
||||
self.assertGreater(
|
||||
_count_op_in_graph(
|
||||
graph, torch.ops._c10d_functional.all_gather_into_tensor_out.default
|
||||
graph,
|
||||
torch.ops._c10d_functional.all_gather_into_tensor_out.default,
|
||||
),
|
||||
0,
|
||||
)
|
||||
else:
|
||||
self.assertEqual(
|
||||
_count_op_in_graph(
|
||||
graph,
|
||||
torch.ops._c10d_functional.all_gather_into_tensor_out.default,
|
||||
),
|
||||
0,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1467,5 +1467,70 @@ class TestFullyShardCustomForwardMethod(FSDPTest):
|
|||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
|
||||
class TestFullyShardWorldSize1(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 1
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_train_parity_single_worldsize1(self):
|
||||
"""
|
||||
Tests train parity with DDP for a single FSDP group when sharding
|
||||
parameters on dim-0.
|
||||
"""
|
||||
self.run_subtests(
|
||||
{
|
||||
"lin_shapes": [
|
||||
[(16, 15), (15, 8)],
|
||||
[(7, 15), (15, 3)],
|
||||
[(16, 17), (17, 8)],
|
||||
],
|
||||
"use_shard_placement_fn": [False],
|
||||
},
|
||||
self._test_train_parity_single_group,
|
||||
)
|
||||
|
||||
def _test_train_parity_single_group(
|
||||
self, lin_shapes: list[tuple[int, int]], use_shard_placement_fn: bool
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(
|
||||
nn.Linear(*lin_shapes[0]), nn.ReLU(), nn.Linear(*lin_shapes[1])
|
||||
)
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
replicate(ref_model, device_ids=[self.rank])
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
|
||||
def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
|
||||
return Shard(param.shape.index(max(param.shape)))
|
||||
|
||||
shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None
|
||||
fully_shard(model, shard_placement_fn=shard_placement_fn)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = (torch.randn((4, lin_shapes[0][0]), device=device_type.type),)
|
||||
|
||||
for iter_idx in range(10):
|
||||
losses: list[torch.Tensor] = []
|
||||
|
||||
ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
losses.append(ref_model(*inp).sum())
|
||||
losses[-1].backward()
|
||||
ref_optim.step()
|
||||
|
||||
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
comm_mode = CommDebugMode()
|
||||
with comm_mode:
|
||||
losses.append(model(*inp).sum())
|
||||
losses[-1].backward()
|
||||
|
||||
# Before there was 1 all-gather and 1 reduce-scatter
|
||||
# Now therre is 1 reduce-scatter
|
||||
self.assertEqual(comm_mode.get_total_counts(), 1)
|
||||
optim.step()
|
||||
|
||||
self.assertEqual(losses[0], losses[1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -277,19 +277,19 @@ class TestFullyShard2DTraining(FSDPTest):
|
|||
loss = model(inp).sum()
|
||||
|
||||
fwd_comm_counts = fwd_comm_mode.get_comm_counts()
|
||||
self.assertEqual(len(fwd_comm_counts), 2)
|
||||
self.assertEqual(len(fwd_comm_counts), 1)
|
||||
self.assertEqual(fwd_comm_counts[funcol.all_reduce], num_mlps)
|
||||
self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_mlps)
|
||||
self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], 0)
|
||||
ref_loss = ref_model(inp).sum()
|
||||
self.assertEqual(loss, ref_loss)
|
||||
|
||||
with CommDebugMode() as bwd_comm_mode:
|
||||
loss.backward()
|
||||
bwd_comm_counts = bwd_comm_mode.get_comm_counts()
|
||||
self.assertEqual(len(bwd_comm_counts), 3)
|
||||
self.assertEqual(len(bwd_comm_counts), 2)
|
||||
# First MLP's input gradient does not need to be all-reduced
|
||||
self.assertEqual(bwd_comm_counts[funcol.all_reduce], num_mlps - 1)
|
||||
self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_mlps)
|
||||
self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], 0)
|
||||
self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_mlps)
|
||||
ref_loss.backward()
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from ._fsdp_common import (
|
|||
HSDPMeshInfo,
|
||||
TrainingState,
|
||||
)
|
||||
from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState
|
||||
from ._fsdp_param import alloc_storage, FSDPParam, ParamModuleInfo, ShardedState
|
||||
|
||||
|
||||
logger = logging.getLogger("torch.distributed.fsdp.fully_shard")
|
||||
|
|
@ -166,6 +166,7 @@ class FSDPParamGroup:
|
|||
self._module_to_pre_load_state_dict_hook_handle: _ModuleToHandleDict = {}
|
||||
self._all_reduce_hook: Optional[Callable[[torch.Tensor], None]] = None
|
||||
self._all_gather_comm: AllGather = DefaultAllGather()
|
||||
self._all_gather_output = torch.empty(0, device=self.device)
|
||||
self._reduce_scatter_comm: ReduceScatter = DefaultReduceScatter()
|
||||
# Optional stream to run the user-defined all-reduce hook in
|
||||
# Saved here and not in the comm. context because we allow the user to
|
||||
|
|
@ -310,6 +311,22 @@ class FSDPParamGroup:
|
|||
# used in the all-gather streams
|
||||
self._wait_all_gather_streams_on_event(self._reshard_after_forward_event)
|
||||
self._reshard_after_forward_event = None
|
||||
|
||||
world_size = self._all_gather_process_group.size()
|
||||
if world_size == 1:
|
||||
# can't skip due to early return in wait_for_unshard if
|
||||
# no self._all_gather_result
|
||||
self._all_gather_result = AllGatherResult(
|
||||
all_gather_output=self._all_gather_output,
|
||||
all_gather_event=self.device_handle.Event().record(),
|
||||
all_gather_work=None,
|
||||
param_all_gather_input_dtypes=[],
|
||||
param_all_gather_input_numels=[],
|
||||
all_gather_input_split_sizes=[],
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
with record_function(self._with_fqn("FSDP::all_gather")):
|
||||
self._all_gather_result = foreach_all_gather(
|
||||
self.fsdp_params,
|
||||
|
|
@ -336,18 +353,52 @@ class FSDPParamGroup:
|
|||
if prev_all_gather_state := self.comm_ctx.all_gather_state:
|
||||
self._wait_all_gather_streams_on_event(prev_all_gather_state.event)
|
||||
self.comm_ctx.all_gather_state = None # free the all-gather result
|
||||
world_size = self._all_gather_process_group.size()
|
||||
if world_size == 1:
|
||||
# directly initialize unsharded parameters from sharded parameters
|
||||
|
||||
for fsdp_param in self.fsdp_params:
|
||||
# Use all_gather_inputs which already handles conversion to param_dtype
|
||||
# This is consistent with the world_size > 1 path
|
||||
all_gather_input = fsdp_param.all_gather_inputs[0]
|
||||
|
||||
# Make sure the all_gather_outputs has proper storage size before using it
|
||||
# First ensure we have at least one tensor in all_gather_outputs
|
||||
fsdp_param.init_all_gather_outputs(
|
||||
[all_gather_input.numel()],
|
||||
[all_gather_input.dtype],
|
||||
world_size,
|
||||
self.device,
|
||||
force_recreate=False,
|
||||
)
|
||||
|
||||
tensor = fsdp_param.all_gather_outputs[0]
|
||||
alloc_storage(tensor)
|
||||
|
||||
# find alternative way to check if tensor.is_inference
|
||||
with torch.autograd._unsafe_preserve_version_counter(tensor):
|
||||
tensor.copy_(all_gather_input)
|
||||
|
||||
else:
|
||||
with record_function(self._with_fqn("FSDP::all_gather_copy_out")):
|
||||
foreach_all_gather_copy_out(
|
||||
self._all_gather_result,
|
||||
self.fsdp_params,
|
||||
self._all_gather_process_group,
|
||||
)
|
||||
|
||||
for fsdp_param in self.fsdp_params:
|
||||
fsdp_param.init_unsharded_param()
|
||||
|
||||
self._to_unsharded()
|
||||
all_gather_copy_out_event = self.device_handle.Event()
|
||||
all_gather_copy_out_event.record()
|
||||
if not async_op and self._training_state == TrainingState.FORWARD:
|
||||
|
||||
if (
|
||||
not async_op
|
||||
and self._training_state == TrainingState.FORWARD
|
||||
and world_size > 1
|
||||
):
|
||||
# Defer free to allow for overlap of this copy-out with next
|
||||
# all-gather collective
|
||||
self.comm_ctx.all_gather_state = AllGatherState(
|
||||
|
|
@ -355,6 +406,7 @@ class FSDPParamGroup:
|
|||
)
|
||||
else:
|
||||
self._wait_all_gather_streams_on_event(all_gather_copy_out_event)
|
||||
|
||||
self._all_gather_result = None # free unless saved in `all_gather_state`
|
||||
|
||||
def _wait_all_gather_streams_on_event(self, event: Optional[torch.Event]):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user