diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index d29ccebf864..00dcf4308dd 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -616,11 +616,10 @@ class TestFullyShardGradientAccumulation(FSDPTest): return min(2, torch.cuda.device_count()) @skip_if_lt_x_gpu(2) - def test_set_requires_gradient_sync(self): + def test_gradient_accumulation(self): """ - Tests the ``set_requires_gradient_sync`` API to exercise gradient - accumulation without gradient reduction. This test includes mixing with - gradient accumulation *with* gradient reduction. + Tests gradient accumulation with/without gradient reduction and + with/without resharding after backward. """ self.run_subtests( { @@ -629,15 +628,22 @@ class TestFullyShardGradientAccumulation(FSDPTest): # "root_only": disable reduce-scatter for root's linear only # "some_mlps": disable reduce-scatter for some MLPs "mode": ["all", "root_only", "some_mlps"], + "reshard_after_backward": [False, True], }, - self._test_set_requires_gradient_sync, + self._test_gradient_accumulation, ) - def _test_set_requires_gradient_sync( + def _test_gradient_accumulation( self, reshard_after_forward: Union[bool, int], mode: str, + reshard_after_backward: bool, ): + if not reshard_after_backward and ( + reshard_after_forward is not False or mode == "some_mlps" + ): + return # skip since not common + torch.manual_seed(42) local_batch_size, lin_dim, num_mlps, num_microbatches = (2, 32, 3, 3) global_batch_size = local_batch_size * self.world_size @@ -658,9 +664,17 @@ class TestFullyShardGradientAccumulation(FSDPTest): fully_shard_fn(model) # root gets the 1st linear ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) optim = torch.optim.Adam(model.parameters(), lr=1e-2) + + orig_all_gather = dist.all_gather_into_tensor + all_gather_count = 0 orig_reduce_scatter = dist.reduce_scatter_tensor reduce_scatter_count = 0 + def all_gather_with_count(*args, **kwargs): + nonlocal all_gather_count + all_gather_count += 1 + return orig_all_gather(*args, **kwargs) + def reduce_scatter_with_count(*args, **kwargs): nonlocal reduce_scatter_count reduce_scatter_count += 1 @@ -668,18 +682,29 @@ class TestFullyShardGradientAccumulation(FSDPTest): torch.manual_seed(1) # same on all ranks for iter_idx in range(5): - with patch_reduce_scatter(reduce_scatter_with_count): + with patch_all_gather(all_gather_with_count), patch_reduce_scatter( + reduce_scatter_with_count + ): for microbatch_idx in range(num_microbatches): is_last_microbatch = microbatch_idx == num_microbatches - 1 if mode == "all": model.set_requires_gradient_sync(is_last_microbatch) + if not reshard_after_backward: + model.set_reshard_after_backward(is_last_microbatch) elif mode == "some_mlps": for mlp in model[1 : 1 + num_mlps_to_disable_reduce_scatter]: mlp.set_requires_gradient_sync(is_last_microbatch) + if not reshard_after_backward: + mlp.set_reshard_after_backward(is_last_microbatch) elif mode == "root_only": model.set_requires_gradient_sync( is_last_microbatch, recurse=False ) + if not reshard_after_backward: + model.set_reshard_after_backward( + is_last_microbatch, recurse=False + ) + global_inp = torch.rand((global_batch_size, lin_dim), device="cuda") local_inp = global_inp[ self.rank @@ -695,6 +720,7 @@ class TestFullyShardGradientAccumulation(FSDPTest): losses[-1].backward() dist.all_reduce(losses[1]) # partial -> replicated self.assertEqual(losses[0], losses[1]) + # Expect one reduce-scatter per MLP plus one for the root's linear # on the last microbatch expected_reduce_scatter_count = num_mlps + 1 @@ -709,6 +735,32 @@ class TestFullyShardGradientAccumulation(FSDPTest): expected_reduce_scatter_count += (num_mlps) * (num_microbatches - 1) self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count) reduce_scatter_count = 0 + + # Expect one all-gather per MLP plus one for the root's linear in + # the first microbatch's forward + expected_all_gather_count = num_mlps + 1 + if reshard_after_forward is not False: # `True` or `2` + # Add the number of MLPs without the +1 for the backward + # all-gathers since the root does not reshard after forward + expected_all_gather_count += num_mlps + # Multiply by the number of microbatches since these + # all-gathers run every microbatch + expected_all_gather_count *= num_microbatches + elif reshard_after_backward: # `reshard_after_forward=False` + expected_all_gather_count *= num_microbatches + elif mode == "all": # `reshard_after_forward/backward=False` + # Only reshard parameters after the last microbatch's backward, + # so there should not be any more all-gathers + pass + elif mode == "root_only": # `reshard_after_forward/backward=False` + # The MLPs should still contribute all-gathers in each + # microbatch forward + expected_all_gather_count += num_mlps * (num_microbatches - 1) + self.assertEqual(all_gather_count, expected_all_gather_count) + all_gather_count = 0 + + # Average the ref model's gradients over the world size to match + # data parallel semantics for param in ref_model.parameters(): if param.grad is not None: param.grad.div_(self.world_size) @@ -722,11 +774,16 @@ class TestFullyShardGradientAccumulation(FSDPTest): @skip_if_lt_x_gpu(2) def test_1f1b_microbatching(self): self.run_subtests( - {"use_explicit_unshard": [False, True]}, + { + "use_explicit_unshard": [False, True], + "reshard_after_backward": [False, True], + }, self._test_1f1b_microbatching, ) - def _test_1f1b_microbatching(self, use_explicit_unshard: bool): + def _test_1f1b_microbatching( + self, use_explicit_unshard: bool, reshard_after_backward: bool + ): torch.manual_seed(42) model_args = ModelArgs(dropout_p=0.0) model = Transformer(model_args) @@ -764,6 +821,8 @@ class TestFullyShardGradientAccumulation(FSDPTest): is_last_microbatch = inp_idx == num_microbatches - 1 model.set_requires_gradient_sync(is_last_microbatch) model.set_is_last_backward(is_last_microbatch) + if not reshard_after_backward: + model.set_reshard_after_backward(is_last_microbatch) losses.append(model(inp).sum()) losses[-1].backward() ref_losses.append(ref_model(inp).sum()) diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index 7a7addb40a1..ab8dfe1aa1b 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -128,6 +128,9 @@ class FSDPParamGroup: # `self.reduce_grads` is true, in which case setting this to false # means reduce-scatter but no all-reduce self.all_reduce_grads: bool = True + # Whether to reshard parameters after backward (only useful for + # gradient accumulation) + self.reshard_after_backward: bool = True # - CUDA events for stream synchronization # Holds the all-gather output buffer, sync objects, and metadata @@ -309,7 +312,8 @@ class FSDPParamGroup: self._training_state = TrainingState.POST_BACKWARD with torch.profiler.record_function("FSDP::post_backward_reshard"): if not self.reduce_grads: - self.reshard() + if self.reshard_after_backward: + self.reshard() return # Save the autograd-computed gradients before resharding to only # access the unsharded parameters when their data is present @@ -320,7 +324,8 @@ class FSDPParamGroup: fsdp_params_with_grad.append(fsdp_param) unsharded_grads.append(fsdp_param.unsharded_grad_data) fsdp_param.unsharded_param.grad = None - self.reshard() + if self.reshard_after_backward: + self.reshard() if len(fsdp_params_with_grad) == 0: return with torch.profiler.record_function("FSDP::post_backward_reduce"): diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index a4909bbff75..06af8c2f90b 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -220,7 +220,9 @@ class FSDP: fsdp_param_group.reduce_grads = requires_gradient_sync fsdp_param_group.all_reduce_grads = requires_gradient_sync - def set_requires_all_reduce(self, requires_all_reduce: bool, recurse: bool = True): + def set_requires_all_reduce( + self, requires_all_reduce: bool, recurse: bool = True + ) -> None: """ Sets if the module should all-reduce gradients. This can be used to implement gradient accumulation with only reduce-scatter but not @@ -237,6 +239,28 @@ class FSDP: if fsdp_param_group := state._fsdp_param_group: fsdp_param_group.all_reduce_grads = requires_all_reduce + def set_reshard_after_backward( + self, reshard_after_backward: bool, recurse: bool = True + ) -> None: + """ + Sets if the module should reshard parameters after backward. This can + be used during gradient accumulation to trade off higher memory for + reduced communication. + + Args: + reshard_after_backward (bool): Whether to reshard parameters after + backward. + recurse (bool): Whether to set for all submodules or just the + passed-in module. + """ + self_module = cast(nn.Module, self) + modules = list(self_module.modules()) if recurse else [self_module] + for module in modules: + if isinstance(module, FSDP): + state = module._get_fsdp_state() + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.reshard_after_backward = reshard_after_backward + def _get_fsdp_state(self) -> FSDPState: if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None: raise AssertionError(f"No FSDP state found on {self}")