mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[FSDP2] Added set_reshard_after_backward (#124319)
This PR adds a `set_reshard_after_backward` method to allow disabling resharding after backward. `reshard_after_backward=False` can be used with `reshard_after_forward=False` to implement "ZeRO-1", where there is only all-gather on the first microbatch forward and reduce-scatter on the last microbatch backward.
```
for microbatch_idx, microbatch in dataloader:
is_last_microbatch = microbatch_idx == num_microbatches - 1
model.set_requires_gradient_sync(is_last_microbatch)
model.set_reshard_after_backward(is_last_microbatch)
model.set_is_last_backward(is_last_microbatch)
microbatch_fwd_bwd(model, microbatch, microbatch_idx)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124319
Approved by: https://github.com/weifengpy
This commit is contained in:
parent
10b9d4d19c
commit
1900f79b72
|
|
@ -616,11 +616,10 @@ class TestFullyShardGradientAccumulation(FSDPTest):
|
||||||
return min(2, torch.cuda.device_count())
|
return min(2, torch.cuda.device_count())
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
def test_set_requires_gradient_sync(self):
|
def test_gradient_accumulation(self):
|
||||||
"""
|
"""
|
||||||
Tests the ``set_requires_gradient_sync`` API to exercise gradient
|
Tests gradient accumulation with/without gradient reduction and
|
||||||
accumulation without gradient reduction. This test includes mixing with
|
with/without resharding after backward.
|
||||||
gradient accumulation *with* gradient reduction.
|
|
||||||
"""
|
"""
|
||||||
self.run_subtests(
|
self.run_subtests(
|
||||||
{
|
{
|
||||||
|
|
@ -629,15 +628,22 @@ class TestFullyShardGradientAccumulation(FSDPTest):
|
||||||
# "root_only": disable reduce-scatter for root's linear only
|
# "root_only": disable reduce-scatter for root's linear only
|
||||||
# "some_mlps": disable reduce-scatter for some MLPs
|
# "some_mlps": disable reduce-scatter for some MLPs
|
||||||
"mode": ["all", "root_only", "some_mlps"],
|
"mode": ["all", "root_only", "some_mlps"],
|
||||||
|
"reshard_after_backward": [False, True],
|
||||||
},
|
},
|
||||||
self._test_set_requires_gradient_sync,
|
self._test_gradient_accumulation,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _test_set_requires_gradient_sync(
|
def _test_gradient_accumulation(
|
||||||
self,
|
self,
|
||||||
reshard_after_forward: Union[bool, int],
|
reshard_after_forward: Union[bool, int],
|
||||||
mode: str,
|
mode: str,
|
||||||
|
reshard_after_backward: bool,
|
||||||
):
|
):
|
||||||
|
if not reshard_after_backward and (
|
||||||
|
reshard_after_forward is not False or mode == "some_mlps"
|
||||||
|
):
|
||||||
|
return # skip since not common
|
||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
local_batch_size, lin_dim, num_mlps, num_microbatches = (2, 32, 3, 3)
|
local_batch_size, lin_dim, num_mlps, num_microbatches = (2, 32, 3, 3)
|
||||||
global_batch_size = local_batch_size * self.world_size
|
global_batch_size = local_batch_size * self.world_size
|
||||||
|
|
@ -658,9 +664,17 @@ class TestFullyShardGradientAccumulation(FSDPTest):
|
||||||
fully_shard_fn(model) # root gets the 1st linear
|
fully_shard_fn(model) # root gets the 1st linear
|
||||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||||
|
|
||||||
|
orig_all_gather = dist.all_gather_into_tensor
|
||||||
|
all_gather_count = 0
|
||||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||||
reduce_scatter_count = 0
|
reduce_scatter_count = 0
|
||||||
|
|
||||||
|
def all_gather_with_count(*args, **kwargs):
|
||||||
|
nonlocal all_gather_count
|
||||||
|
all_gather_count += 1
|
||||||
|
return orig_all_gather(*args, **kwargs)
|
||||||
|
|
||||||
def reduce_scatter_with_count(*args, **kwargs):
|
def reduce_scatter_with_count(*args, **kwargs):
|
||||||
nonlocal reduce_scatter_count
|
nonlocal reduce_scatter_count
|
||||||
reduce_scatter_count += 1
|
reduce_scatter_count += 1
|
||||||
|
|
@ -668,18 +682,29 @@ class TestFullyShardGradientAccumulation(FSDPTest):
|
||||||
|
|
||||||
torch.manual_seed(1) # same on all ranks
|
torch.manual_seed(1) # same on all ranks
|
||||||
for iter_idx in range(5):
|
for iter_idx in range(5):
|
||||||
with patch_reduce_scatter(reduce_scatter_with_count):
|
with patch_all_gather(all_gather_with_count), patch_reduce_scatter(
|
||||||
|
reduce_scatter_with_count
|
||||||
|
):
|
||||||
for microbatch_idx in range(num_microbatches):
|
for microbatch_idx in range(num_microbatches):
|
||||||
is_last_microbatch = microbatch_idx == num_microbatches - 1
|
is_last_microbatch = microbatch_idx == num_microbatches - 1
|
||||||
if mode == "all":
|
if mode == "all":
|
||||||
model.set_requires_gradient_sync(is_last_microbatch)
|
model.set_requires_gradient_sync(is_last_microbatch)
|
||||||
|
if not reshard_after_backward:
|
||||||
|
model.set_reshard_after_backward(is_last_microbatch)
|
||||||
elif mode == "some_mlps":
|
elif mode == "some_mlps":
|
||||||
for mlp in model[1 : 1 + num_mlps_to_disable_reduce_scatter]:
|
for mlp in model[1 : 1 + num_mlps_to_disable_reduce_scatter]:
|
||||||
mlp.set_requires_gradient_sync(is_last_microbatch)
|
mlp.set_requires_gradient_sync(is_last_microbatch)
|
||||||
|
if not reshard_after_backward:
|
||||||
|
mlp.set_reshard_after_backward(is_last_microbatch)
|
||||||
elif mode == "root_only":
|
elif mode == "root_only":
|
||||||
model.set_requires_gradient_sync(
|
model.set_requires_gradient_sync(
|
||||||
is_last_microbatch, recurse=False
|
is_last_microbatch, recurse=False
|
||||||
)
|
)
|
||||||
|
if not reshard_after_backward:
|
||||||
|
model.set_reshard_after_backward(
|
||||||
|
is_last_microbatch, recurse=False
|
||||||
|
)
|
||||||
|
|
||||||
global_inp = torch.rand((global_batch_size, lin_dim), device="cuda")
|
global_inp = torch.rand((global_batch_size, lin_dim), device="cuda")
|
||||||
local_inp = global_inp[
|
local_inp = global_inp[
|
||||||
self.rank
|
self.rank
|
||||||
|
|
@ -695,6 +720,7 @@ class TestFullyShardGradientAccumulation(FSDPTest):
|
||||||
losses[-1].backward()
|
losses[-1].backward()
|
||||||
dist.all_reduce(losses[1]) # partial -> replicated
|
dist.all_reduce(losses[1]) # partial -> replicated
|
||||||
self.assertEqual(losses[0], losses[1])
|
self.assertEqual(losses[0], losses[1])
|
||||||
|
|
||||||
# Expect one reduce-scatter per MLP plus one for the root's linear
|
# Expect one reduce-scatter per MLP plus one for the root's linear
|
||||||
# on the last microbatch
|
# on the last microbatch
|
||||||
expected_reduce_scatter_count = num_mlps + 1
|
expected_reduce_scatter_count = num_mlps + 1
|
||||||
|
|
@ -709,6 +735,32 @@ class TestFullyShardGradientAccumulation(FSDPTest):
|
||||||
expected_reduce_scatter_count += (num_mlps) * (num_microbatches - 1)
|
expected_reduce_scatter_count += (num_mlps) * (num_microbatches - 1)
|
||||||
self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count)
|
self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count)
|
||||||
reduce_scatter_count = 0
|
reduce_scatter_count = 0
|
||||||
|
|
||||||
|
# Expect one all-gather per MLP plus one for the root's linear in
|
||||||
|
# the first microbatch's forward
|
||||||
|
expected_all_gather_count = num_mlps + 1
|
||||||
|
if reshard_after_forward is not False: # `True` or `2`
|
||||||
|
# Add the number of MLPs without the +1 for the backward
|
||||||
|
# all-gathers since the root does not reshard after forward
|
||||||
|
expected_all_gather_count += num_mlps
|
||||||
|
# Multiply by the number of microbatches since these
|
||||||
|
# all-gathers run every microbatch
|
||||||
|
expected_all_gather_count *= num_microbatches
|
||||||
|
elif reshard_after_backward: # `reshard_after_forward=False`
|
||||||
|
expected_all_gather_count *= num_microbatches
|
||||||
|
elif mode == "all": # `reshard_after_forward/backward=False`
|
||||||
|
# Only reshard parameters after the last microbatch's backward,
|
||||||
|
# so there should not be any more all-gathers
|
||||||
|
pass
|
||||||
|
elif mode == "root_only": # `reshard_after_forward/backward=False`
|
||||||
|
# The MLPs should still contribute all-gathers in each
|
||||||
|
# microbatch forward
|
||||||
|
expected_all_gather_count += num_mlps * (num_microbatches - 1)
|
||||||
|
self.assertEqual(all_gather_count, expected_all_gather_count)
|
||||||
|
all_gather_count = 0
|
||||||
|
|
||||||
|
# Average the ref model's gradients over the world size to match
|
||||||
|
# data parallel semantics
|
||||||
for param in ref_model.parameters():
|
for param in ref_model.parameters():
|
||||||
if param.grad is not None:
|
if param.grad is not None:
|
||||||
param.grad.div_(self.world_size)
|
param.grad.div_(self.world_size)
|
||||||
|
|
@ -722,11 +774,16 @@ class TestFullyShardGradientAccumulation(FSDPTest):
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
def test_1f1b_microbatching(self):
|
def test_1f1b_microbatching(self):
|
||||||
self.run_subtests(
|
self.run_subtests(
|
||||||
{"use_explicit_unshard": [False, True]},
|
{
|
||||||
|
"use_explicit_unshard": [False, True],
|
||||||
|
"reshard_after_backward": [False, True],
|
||||||
|
},
|
||||||
self._test_1f1b_microbatching,
|
self._test_1f1b_microbatching,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _test_1f1b_microbatching(self, use_explicit_unshard: bool):
|
def _test_1f1b_microbatching(
|
||||||
|
self, use_explicit_unshard: bool, reshard_after_backward: bool
|
||||||
|
):
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
model_args = ModelArgs(dropout_p=0.0)
|
model_args = ModelArgs(dropout_p=0.0)
|
||||||
model = Transformer(model_args)
|
model = Transformer(model_args)
|
||||||
|
|
@ -764,6 +821,8 @@ class TestFullyShardGradientAccumulation(FSDPTest):
|
||||||
is_last_microbatch = inp_idx == num_microbatches - 1
|
is_last_microbatch = inp_idx == num_microbatches - 1
|
||||||
model.set_requires_gradient_sync(is_last_microbatch)
|
model.set_requires_gradient_sync(is_last_microbatch)
|
||||||
model.set_is_last_backward(is_last_microbatch)
|
model.set_is_last_backward(is_last_microbatch)
|
||||||
|
if not reshard_after_backward:
|
||||||
|
model.set_reshard_after_backward(is_last_microbatch)
|
||||||
losses.append(model(inp).sum())
|
losses.append(model(inp).sum())
|
||||||
losses[-1].backward()
|
losses[-1].backward()
|
||||||
ref_losses.append(ref_model(inp).sum())
|
ref_losses.append(ref_model(inp).sum())
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,9 @@ class FSDPParamGroup:
|
||||||
# `self.reduce_grads` is true, in which case setting this to false
|
# `self.reduce_grads` is true, in which case setting this to false
|
||||||
# means reduce-scatter but no all-reduce
|
# means reduce-scatter but no all-reduce
|
||||||
self.all_reduce_grads: bool = True
|
self.all_reduce_grads: bool = True
|
||||||
|
# Whether to reshard parameters after backward (only useful for
|
||||||
|
# gradient accumulation)
|
||||||
|
self.reshard_after_backward: bool = True
|
||||||
|
|
||||||
# - CUDA events for stream synchronization
|
# - CUDA events for stream synchronization
|
||||||
# Holds the all-gather output buffer, sync objects, and metadata
|
# Holds the all-gather output buffer, sync objects, and metadata
|
||||||
|
|
@ -309,7 +312,8 @@ class FSDPParamGroup:
|
||||||
self._training_state = TrainingState.POST_BACKWARD
|
self._training_state = TrainingState.POST_BACKWARD
|
||||||
with torch.profiler.record_function("FSDP::post_backward_reshard"):
|
with torch.profiler.record_function("FSDP::post_backward_reshard"):
|
||||||
if not self.reduce_grads:
|
if not self.reduce_grads:
|
||||||
self.reshard()
|
if self.reshard_after_backward:
|
||||||
|
self.reshard()
|
||||||
return
|
return
|
||||||
# Save the autograd-computed gradients before resharding to only
|
# Save the autograd-computed gradients before resharding to only
|
||||||
# access the unsharded parameters when their data is present
|
# access the unsharded parameters when their data is present
|
||||||
|
|
@ -320,7 +324,8 @@ class FSDPParamGroup:
|
||||||
fsdp_params_with_grad.append(fsdp_param)
|
fsdp_params_with_grad.append(fsdp_param)
|
||||||
unsharded_grads.append(fsdp_param.unsharded_grad_data)
|
unsharded_grads.append(fsdp_param.unsharded_grad_data)
|
||||||
fsdp_param.unsharded_param.grad = None
|
fsdp_param.unsharded_param.grad = None
|
||||||
self.reshard()
|
if self.reshard_after_backward:
|
||||||
|
self.reshard()
|
||||||
if len(fsdp_params_with_grad) == 0:
|
if len(fsdp_params_with_grad) == 0:
|
||||||
return
|
return
|
||||||
with torch.profiler.record_function("FSDP::post_backward_reduce"):
|
with torch.profiler.record_function("FSDP::post_backward_reduce"):
|
||||||
|
|
|
||||||
|
|
@ -220,7 +220,9 @@ class FSDP:
|
||||||
fsdp_param_group.reduce_grads = requires_gradient_sync
|
fsdp_param_group.reduce_grads = requires_gradient_sync
|
||||||
fsdp_param_group.all_reduce_grads = requires_gradient_sync
|
fsdp_param_group.all_reduce_grads = requires_gradient_sync
|
||||||
|
|
||||||
def set_requires_all_reduce(self, requires_all_reduce: bool, recurse: bool = True):
|
def set_requires_all_reduce(
|
||||||
|
self, requires_all_reduce: bool, recurse: bool = True
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Sets if the module should all-reduce gradients. This can be used to
|
Sets if the module should all-reduce gradients. This can be used to
|
||||||
implement gradient accumulation with only reduce-scatter but not
|
implement gradient accumulation with only reduce-scatter but not
|
||||||
|
|
@ -237,6 +239,28 @@ class FSDP:
|
||||||
if fsdp_param_group := state._fsdp_param_group:
|
if fsdp_param_group := state._fsdp_param_group:
|
||||||
fsdp_param_group.all_reduce_grads = requires_all_reduce
|
fsdp_param_group.all_reduce_grads = requires_all_reduce
|
||||||
|
|
||||||
|
def set_reshard_after_backward(
|
||||||
|
self, reshard_after_backward: bool, recurse: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Sets if the module should reshard parameters after backward. This can
|
||||||
|
be used during gradient accumulation to trade off higher memory for
|
||||||
|
reduced communication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reshard_after_backward (bool): Whether to reshard parameters after
|
||||||
|
backward.
|
||||||
|
recurse (bool): Whether to set for all submodules or just the
|
||||||
|
passed-in module.
|
||||||
|
"""
|
||||||
|
self_module = cast(nn.Module, self)
|
||||||
|
modules = list(self_module.modules()) if recurse else [self_module]
|
||||||
|
for module in modules:
|
||||||
|
if isinstance(module, FSDP):
|
||||||
|
state = module._get_fsdp_state()
|
||||||
|
if fsdp_param_group := state._fsdp_param_group:
|
||||||
|
fsdp_param_group.reshard_after_backward = reshard_after_backward
|
||||||
|
|
||||||
def _get_fsdp_state(self) -> FSDPState:
|
def _get_fsdp_state(self) -> FSDPState:
|
||||||
if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None:
|
if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None:
|
||||||
raise AssertionError(f"No FSDP state found on {self}")
|
raise AssertionError(f"No FSDP state found on {self}")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user