mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix bug in FSDP wrapped module with zero argument (#147771)
Fixes https://github.com/pytorch/pytorch/issues/147531 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147771 Approved by: https://github.com/awgu
This commit is contained in:
parent
8de6fe8c0b
commit
12112fd198
|
|
@ -395,6 +395,19 @@ class TestFSDPWrap(FSDPTest):
|
|||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_zero_argument(self):
|
||||
class ZeroArguModel(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.a = torch.tensor([1.0])
|
||||
|
||||
def forward(self):
|
||||
return self.a
|
||||
|
||||
model = FSDP(ZeroArguModel())
|
||||
self.assertEqual(model(), torch.tensor([1.0]))
|
||||
|
||||
|
||||
class TestAutoWrap(TestCase):
|
||||
def setUp(self) -> None:
|
||||
|
|
|
|||
|
|
@ -596,8 +596,8 @@ def _root_pre_forward(
|
|||
args_tuple, kwargs_tuple = _to_kwargs(
|
||||
args, kwargs, state.compute_device, False
|
||||
)
|
||||
args = args_tuple[0]
|
||||
kwargs = kwargs_tuple[0]
|
||||
args = args_tuple[0] if args_tuple else tuple()
|
||||
kwargs = kwargs_tuple[0] if kwargs_tuple else {}
|
||||
|
||||
return _root_cast_forward_input(state, module, args, kwargs)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user