mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[FSDP2] Update ignored_params docstring and add unit test (#149074)
Fixes https://github.com/pytorch/pytorch/issues/148242 ignored_params won't be moved to devices in full_shard(), update docstring. Add unit test `test_move_states_to_device_ignored_param_device` to show that ignored_params won't be moved during full_shard(), but would be after `model.cuda()` Pull Request resolved: https://github.com/pytorch/pytorch/pull/149074 Approved by: https://github.com/awgu
This commit is contained in:
parent
09f7f62cfe
commit
7537b19c73
|
|
@ -64,6 +64,19 @@ class TestFullyShardDeviceTensor(FSDPTestMultiThread):
|
|||
for tensor in itertools.chain(model.parameters(), model.buffers()):
|
||||
self.assertEqual(tensor.device, cuda_device)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
def test_move_states_to_device_ignored_param_device(self):
|
||||
cpu_device = torch.device("cpu")
|
||||
model = MLP(8, cpu_device, with_buffer=True)
|
||||
ignored_params = [model.out_proj.weight, model.out_proj.bias]
|
||||
fully_shard(model, ignored_params=set(ignored_params))
|
||||
for tensor in ignored_params:
|
||||
self.assertEqual(tensor.device, cpu_device)
|
||||
cuda_device = torch.device("cuda", torch.cuda.current_device())
|
||||
model.to(torch.device("cuda"))
|
||||
for tensor in ignored_params:
|
||||
self.assertEqual(tensor.device, cuda_device)
|
||||
|
||||
|
||||
class TestFullyShardDeviceDTensor(FSDPTestMultiThread):
|
||||
"""Tests that DTensor parameters are moved to the expected device."""
|
||||
|
|
|
|||
|
|
@ -176,8 +176,9 @@ def fully_shard(
|
|||
offload_policy (OffloadPolicy): This controls the offloading policy,
|
||||
which offers parameter/gradient/optimizer state offloading. See
|
||||
:class:`OffloadPolicy` and its subclasses for details.
|
||||
ignored_params: Optional(Set[nn.Parameter]): The set of parameters that we
|
||||
don't want to shard with FSDP.
|
||||
ignored_params: Optional(Set[nn.Parameter]): The set of parameters to be
|
||||
ignored by FSDP. They will not be sharded, nor moved to the device
|
||||
during init, nor have their gradients reduced in backward.
|
||||
|
||||
Returns:
|
||||
FSDPModule: The module with FSDP applied (in-place).
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user