[FSDP2] Update ignored_params docstring and add unit test (#149074)

Fixes https://github.com/pytorch/pytorch/issues/148242

ignored_params won't be moved to devices in full_shard(), update docstring.
Add unit test `test_move_states_to_device_ignored_param_device` to show that ignored_params won't be moved during full_shard(), but would be after `model.cuda()`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149074
Approved by: https://github.com/awgu
This commit is contained in:
yifanmao 2025-03-15 00:23:06 +00:00 committed by PyTorch MergeBot
parent 09f7f62cfe
commit 7537b19c73
2 changed files with 16 additions and 2 deletions

View File

@ -64,6 +64,19 @@ class TestFullyShardDeviceTensor(FSDPTestMultiThread):
for tensor in itertools.chain(model.parameters(), model.buffers()):
self.assertEqual(tensor.device, cuda_device)
@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_move_states_to_device_ignored_param_device(self):
cpu_device = torch.device("cpu")
model = MLP(8, cpu_device, with_buffer=True)
ignored_params = [model.out_proj.weight, model.out_proj.bias]
fully_shard(model, ignored_params=set(ignored_params))
for tensor in ignored_params:
self.assertEqual(tensor.device, cpu_device)
cuda_device = torch.device("cuda", torch.cuda.current_device())
model.to(torch.device("cuda"))
for tensor in ignored_params:
self.assertEqual(tensor.device, cuda_device)
class TestFullyShardDeviceDTensor(FSDPTestMultiThread):
"""Tests that DTensor parameters are moved to the expected device."""

View File

@ -176,8 +176,9 @@ def fully_shard(
offload_policy (OffloadPolicy): This controls the offloading policy,
which offers parameter/gradient/optimizer state offloading. See
:class:`OffloadPolicy` and its subclasses for details.
ignored_params: Optional(Set[nn.Parameter]): The set of parameters that we
don't want to shard with FSDP.
ignored_params: Optional(Set[nn.Parameter]): The set of parameters to be
ignored by FSDP. They will not be sharded, nor moved to the device
during init, nor have their gradients reduced in backward.
Returns:
FSDPModule: The module with FSDP applied (in-place).