Fix bug in FSDP wrapped module with zero argument (#147771)

Fixes https://github.com/pytorch/pytorch/issues/147531

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147771
Approved by: https://github.com/awgu
This commit is contained in:
mori360 2025-02-26 01:40:53 +00:00 committed by PyTorch MergeBot
parent 8de6fe8c0b
commit 12112fd198
2 changed files with 15 additions and 2 deletions

View File

@ -395,6 +395,19 @@ class TestFSDPWrap(FSDPTest):
loss.backward()
optim.step()
@skip_if_lt_x_gpu(1)
def test_zero_argument(self):
class ZeroArguModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.a = torch.tensor([1.0])
def forward(self):
return self.a
model = FSDP(ZeroArguModel())
self.assertEqual(model(), torch.tensor([1.0]))
class TestAutoWrap(TestCase):
def setUp(self) -> None:

View File

@ -596,8 +596,8 @@ def _root_pre_forward(
args_tuple, kwargs_tuple = _to_kwargs(
args, kwargs, state.compute_device, False
)
args = args_tuple[0]
kwargs = kwargs_tuple[0]
args = args_tuple[0] if args_tuple else tuple()
kwargs = kwargs_tuple[0] if kwargs_tuple else {}
return _root_cast_forward_input(state, module, args, kwargs)