[FSDP2] Update ignored_params docstring and add unit test (#149074)

Fixes https://github.com/pytorch/pytorch/issues/148242

ignored_params won't be moved to devices in full_shard(), update docstring.
Add unit test `test_move_states_to_device_ignored_param_device` to show that ignored_params won't be moved during full_shard(), but would be after `model.cuda()`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149074
Approved by: https://github.com/awgu
This commit is contained in:
yifanmao 2025-03-15 00:23:06 +00:00 committed by PyTorch MergeBot
parent 09f7f62cfe
commit 7537b19c73
2 changed files with 16 additions and 2 deletions

View File

@ -64,6 +64,19 @@ class TestFullyShardDeviceTensor(FSDPTestMultiThread):
for tensor in itertools.chain(model.parameters(), model.buffers()): for tensor in itertools.chain(model.parameters(), model.buffers()):
self.assertEqual(tensor.device, cuda_device) self.assertEqual(tensor.device, cuda_device)
@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_move_states_to_device_ignored_param_device(self):
cpu_device = torch.device("cpu")
model = MLP(8, cpu_device, with_buffer=True)
ignored_params = [model.out_proj.weight, model.out_proj.bias]
fully_shard(model, ignored_params=set(ignored_params))
for tensor in ignored_params:
self.assertEqual(tensor.device, cpu_device)
cuda_device = torch.device("cuda", torch.cuda.current_device())
model.to(torch.device("cuda"))
for tensor in ignored_params:
self.assertEqual(tensor.device, cuda_device)
class TestFullyShardDeviceDTensor(FSDPTestMultiThread): class TestFullyShardDeviceDTensor(FSDPTestMultiThread):
"""Tests that DTensor parameters are moved to the expected device.""" """Tests that DTensor parameters are moved to the expected device."""

View File

@ -176,8 +176,9 @@ def fully_shard(
offload_policy (OffloadPolicy): This controls the offloading policy, offload_policy (OffloadPolicy): This controls the offloading policy,
which offers parameter/gradient/optimizer state offloading. See which offers parameter/gradient/optimizer state offloading. See
:class:`OffloadPolicy` and its subclasses for details. :class:`OffloadPolicy` and its subclasses for details.
ignored_params: Optional(Set[nn.Parameter]): The set of parameters that we ignored_params: Optional(Set[nn.Parameter]): The set of parameters to be
don't want to shard with FSDP. ignored by FSDP. They will not be sharded, nor moved to the device
during init, nor have their gradients reduced in backward.
Returns: Returns:
FSDPModule: The module with FSDP applied (in-place). FSDPModule: The module with FSDP applied (in-place).