[FSDP2] Added set_reshard_after_backward (#124319)

This PR adds a `set_reshard_after_backward` method to allow disabling resharding after backward. `reshard_after_backward=False` can be used with `reshard_after_forward=False` to implement "ZeRO-1", where there is only all-gather on the first microbatch forward and reduce-scatter on the last microbatch backward.

```
for microbatch_idx, microbatch in dataloader:
    is_last_microbatch = microbatch_idx == num_microbatches - 1
    model.set_requires_gradient_sync(is_last_microbatch)
    model.set_reshard_after_backward(is_last_microbatch)
    model.set_is_last_backward(is_last_microbatch)
    microbatch_fwd_bwd(model, microbatch, microbatch_idx)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124319
Approved by: https://github.com/weifengpy
This commit is contained in:
Andrew Gu 2024-04-19 08:06:40 -07:00 committed by PyTorch MergeBot
parent 10b9d4d19c
commit 1900f79b72
3 changed files with 100 additions and 12 deletions

View File

@ -616,11 +616,10 @@ class TestFullyShardGradientAccumulation(FSDPTest):
return min(2, torch.cuda.device_count()) return min(2, torch.cuda.device_count())
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
def test_set_requires_gradient_sync(self): def test_gradient_accumulation(self):
""" """
Tests the ``set_requires_gradient_sync`` API to exercise gradient Tests gradient accumulation with/without gradient reduction and
accumulation without gradient reduction. This test includes mixing with with/without resharding after backward.
gradient accumulation *with* gradient reduction.
""" """
self.run_subtests( self.run_subtests(
{ {
@ -629,15 +628,22 @@ class TestFullyShardGradientAccumulation(FSDPTest):
# "root_only": disable reduce-scatter for root's linear only # "root_only": disable reduce-scatter for root's linear only
# "some_mlps": disable reduce-scatter for some MLPs # "some_mlps": disable reduce-scatter for some MLPs
"mode": ["all", "root_only", "some_mlps"], "mode": ["all", "root_only", "some_mlps"],
"reshard_after_backward": [False, True],
}, },
self._test_set_requires_gradient_sync, self._test_gradient_accumulation,
) )
def _test_set_requires_gradient_sync( def _test_gradient_accumulation(
self, self,
reshard_after_forward: Union[bool, int], reshard_after_forward: Union[bool, int],
mode: str, mode: str,
reshard_after_backward: bool,
): ):
if not reshard_after_backward and (
reshard_after_forward is not False or mode == "some_mlps"
):
return # skip since not common
torch.manual_seed(42) torch.manual_seed(42)
local_batch_size, lin_dim, num_mlps, num_microbatches = (2, 32, 3, 3) local_batch_size, lin_dim, num_mlps, num_microbatches = (2, 32, 3, 3)
global_batch_size = local_batch_size * self.world_size global_batch_size = local_batch_size * self.world_size
@ -658,9 +664,17 @@ class TestFullyShardGradientAccumulation(FSDPTest):
fully_shard_fn(model) # root gets the 1st linear fully_shard_fn(model) # root gets the 1st linear
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
optim = torch.optim.Adam(model.parameters(), lr=1e-2) optim = torch.optim.Adam(model.parameters(), lr=1e-2)
orig_all_gather = dist.all_gather_into_tensor
all_gather_count = 0
orig_reduce_scatter = dist.reduce_scatter_tensor orig_reduce_scatter = dist.reduce_scatter_tensor
reduce_scatter_count = 0 reduce_scatter_count = 0
def all_gather_with_count(*args, **kwargs):
nonlocal all_gather_count
all_gather_count += 1
return orig_all_gather(*args, **kwargs)
def reduce_scatter_with_count(*args, **kwargs): def reduce_scatter_with_count(*args, **kwargs):
nonlocal reduce_scatter_count nonlocal reduce_scatter_count
reduce_scatter_count += 1 reduce_scatter_count += 1
@ -668,18 +682,29 @@ class TestFullyShardGradientAccumulation(FSDPTest):
torch.manual_seed(1) # same on all ranks torch.manual_seed(1) # same on all ranks
for iter_idx in range(5): for iter_idx in range(5):
with patch_reduce_scatter(reduce_scatter_with_count): with patch_all_gather(all_gather_with_count), patch_reduce_scatter(
reduce_scatter_with_count
):
for microbatch_idx in range(num_microbatches): for microbatch_idx in range(num_microbatches):
is_last_microbatch = microbatch_idx == num_microbatches - 1 is_last_microbatch = microbatch_idx == num_microbatches - 1
if mode == "all": if mode == "all":
model.set_requires_gradient_sync(is_last_microbatch) model.set_requires_gradient_sync(is_last_microbatch)
if not reshard_after_backward:
model.set_reshard_after_backward(is_last_microbatch)
elif mode == "some_mlps": elif mode == "some_mlps":
for mlp in model[1 : 1 + num_mlps_to_disable_reduce_scatter]: for mlp in model[1 : 1 + num_mlps_to_disable_reduce_scatter]:
mlp.set_requires_gradient_sync(is_last_microbatch) mlp.set_requires_gradient_sync(is_last_microbatch)
if not reshard_after_backward:
mlp.set_reshard_after_backward(is_last_microbatch)
elif mode == "root_only": elif mode == "root_only":
model.set_requires_gradient_sync( model.set_requires_gradient_sync(
is_last_microbatch, recurse=False is_last_microbatch, recurse=False
) )
if not reshard_after_backward:
model.set_reshard_after_backward(
is_last_microbatch, recurse=False
)
global_inp = torch.rand((global_batch_size, lin_dim), device="cuda") global_inp = torch.rand((global_batch_size, lin_dim), device="cuda")
local_inp = global_inp[ local_inp = global_inp[
self.rank self.rank
@ -695,6 +720,7 @@ class TestFullyShardGradientAccumulation(FSDPTest):
losses[-1].backward() losses[-1].backward()
dist.all_reduce(losses[1]) # partial -> replicated dist.all_reduce(losses[1]) # partial -> replicated
self.assertEqual(losses[0], losses[1]) self.assertEqual(losses[0], losses[1])
# Expect one reduce-scatter per MLP plus one for the root's linear # Expect one reduce-scatter per MLP plus one for the root's linear
# on the last microbatch # on the last microbatch
expected_reduce_scatter_count = num_mlps + 1 expected_reduce_scatter_count = num_mlps + 1
@ -709,6 +735,32 @@ class TestFullyShardGradientAccumulation(FSDPTest):
expected_reduce_scatter_count += (num_mlps) * (num_microbatches - 1) expected_reduce_scatter_count += (num_mlps) * (num_microbatches - 1)
self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count) self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count)
reduce_scatter_count = 0 reduce_scatter_count = 0
# Expect one all-gather per MLP plus one for the root's linear in
# the first microbatch's forward
expected_all_gather_count = num_mlps + 1
if reshard_after_forward is not False: # `True` or `2`
# Add the number of MLPs without the +1 for the backward
# all-gathers since the root does not reshard after forward
expected_all_gather_count += num_mlps
# Multiply by the number of microbatches since these
# all-gathers run every microbatch
expected_all_gather_count *= num_microbatches
elif reshard_after_backward: # `reshard_after_forward=False`
expected_all_gather_count *= num_microbatches
elif mode == "all": # `reshard_after_forward/backward=False`
# Only reshard parameters after the last microbatch's backward,
# so there should not be any more all-gathers
pass
elif mode == "root_only": # `reshard_after_forward/backward=False`
# The MLPs should still contribute all-gathers in each
# microbatch forward
expected_all_gather_count += num_mlps * (num_microbatches - 1)
self.assertEqual(all_gather_count, expected_all_gather_count)
all_gather_count = 0
# Average the ref model's gradients over the world size to match
# data parallel semantics
for param in ref_model.parameters(): for param in ref_model.parameters():
if param.grad is not None: if param.grad is not None:
param.grad.div_(self.world_size) param.grad.div_(self.world_size)
@ -722,11 +774,16 @@ class TestFullyShardGradientAccumulation(FSDPTest):
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
def test_1f1b_microbatching(self): def test_1f1b_microbatching(self):
self.run_subtests( self.run_subtests(
{"use_explicit_unshard": [False, True]}, {
"use_explicit_unshard": [False, True],
"reshard_after_backward": [False, True],
},
self._test_1f1b_microbatching, self._test_1f1b_microbatching,
) )
def _test_1f1b_microbatching(self, use_explicit_unshard: bool): def _test_1f1b_microbatching(
self, use_explicit_unshard: bool, reshard_after_backward: bool
):
torch.manual_seed(42) torch.manual_seed(42)
model_args = ModelArgs(dropout_p=0.0) model_args = ModelArgs(dropout_p=0.0)
model = Transformer(model_args) model = Transformer(model_args)
@ -764,6 +821,8 @@ class TestFullyShardGradientAccumulation(FSDPTest):
is_last_microbatch = inp_idx == num_microbatches - 1 is_last_microbatch = inp_idx == num_microbatches - 1
model.set_requires_gradient_sync(is_last_microbatch) model.set_requires_gradient_sync(is_last_microbatch)
model.set_is_last_backward(is_last_microbatch) model.set_is_last_backward(is_last_microbatch)
if not reshard_after_backward:
model.set_reshard_after_backward(is_last_microbatch)
losses.append(model(inp).sum()) losses.append(model(inp).sum())
losses[-1].backward() losses[-1].backward()
ref_losses.append(ref_model(inp).sum()) ref_losses.append(ref_model(inp).sum())

View File

@ -128,6 +128,9 @@ class FSDPParamGroup:
# `self.reduce_grads` is true, in which case setting this to false # `self.reduce_grads` is true, in which case setting this to false
# means reduce-scatter but no all-reduce # means reduce-scatter but no all-reduce
self.all_reduce_grads: bool = True self.all_reduce_grads: bool = True
# Whether to reshard parameters after backward (only useful for
# gradient accumulation)
self.reshard_after_backward: bool = True
# - CUDA events for stream synchronization # - CUDA events for stream synchronization
# Holds the all-gather output buffer, sync objects, and metadata # Holds the all-gather output buffer, sync objects, and metadata
@ -309,7 +312,8 @@ class FSDPParamGroup:
self._training_state = TrainingState.POST_BACKWARD self._training_state = TrainingState.POST_BACKWARD
with torch.profiler.record_function("FSDP::post_backward_reshard"): with torch.profiler.record_function("FSDP::post_backward_reshard"):
if not self.reduce_grads: if not self.reduce_grads:
self.reshard() if self.reshard_after_backward:
self.reshard()
return return
# Save the autograd-computed gradients before resharding to only # Save the autograd-computed gradients before resharding to only
# access the unsharded parameters when their data is present # access the unsharded parameters when their data is present
@ -320,7 +324,8 @@ class FSDPParamGroup:
fsdp_params_with_grad.append(fsdp_param) fsdp_params_with_grad.append(fsdp_param)
unsharded_grads.append(fsdp_param.unsharded_grad_data) unsharded_grads.append(fsdp_param.unsharded_grad_data)
fsdp_param.unsharded_param.grad = None fsdp_param.unsharded_param.grad = None
self.reshard() if self.reshard_after_backward:
self.reshard()
if len(fsdp_params_with_grad) == 0: if len(fsdp_params_with_grad) == 0:
return return
with torch.profiler.record_function("FSDP::post_backward_reduce"): with torch.profiler.record_function("FSDP::post_backward_reduce"):

View File

@ -220,7 +220,9 @@ class FSDP:
fsdp_param_group.reduce_grads = requires_gradient_sync fsdp_param_group.reduce_grads = requires_gradient_sync
fsdp_param_group.all_reduce_grads = requires_gradient_sync fsdp_param_group.all_reduce_grads = requires_gradient_sync
def set_requires_all_reduce(self, requires_all_reduce: bool, recurse: bool = True): def set_requires_all_reduce(
self, requires_all_reduce: bool, recurse: bool = True
) -> None:
""" """
Sets if the module should all-reduce gradients. This can be used to Sets if the module should all-reduce gradients. This can be used to
implement gradient accumulation with only reduce-scatter but not implement gradient accumulation with only reduce-scatter but not
@ -237,6 +239,28 @@ class FSDP:
if fsdp_param_group := state._fsdp_param_group: if fsdp_param_group := state._fsdp_param_group:
fsdp_param_group.all_reduce_grads = requires_all_reduce fsdp_param_group.all_reduce_grads = requires_all_reduce
def set_reshard_after_backward(
self, reshard_after_backward: bool, recurse: bool = True
) -> None:
"""
Sets if the module should reshard parameters after backward. This can
be used during gradient accumulation to trade off higher memory for
reduced communication.
Args:
reshard_after_backward (bool): Whether to reshard parameters after
backward.
recurse (bool): Whether to set for all submodules or just the
passed-in module.
"""
self_module = cast(nn.Module, self)
modules = list(self_module.modules()) if recurse else [self_module]
for module in modules:
if isinstance(module, FSDP):
state = module._get_fsdp_state()
if fsdp_param_group := state._fsdp_param_group:
fsdp_param_group.reshard_after_backward = reshard_after_backward
def _get_fsdp_state(self) -> FSDPState: def _get_fsdp_state(self) -> FSDPState:
if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None: if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None:
raise AssertionError(f"No FSDP state found on {self}") raise AssertionError(f"No FSDP state found on {self}")