mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix bug in FSDP wrapped module with zero argument (#147771)
Fixes https://github.com/pytorch/pytorch/issues/147531 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147771 Approved by: https://github.com/awgu
This commit is contained in:
parent
8de6fe8c0b
commit
12112fd198
|
|
@ -395,6 +395,19 @@ class TestFSDPWrap(FSDPTest):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optim.step()
|
optim.step()
|
||||||
|
|
||||||
|
@skip_if_lt_x_gpu(1)
|
||||||
|
def test_zero_argument(self):
|
||||||
|
class ZeroArguModel(nn.Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.a = torch.tensor([1.0])
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self.a
|
||||||
|
|
||||||
|
model = FSDP(ZeroArguModel())
|
||||||
|
self.assertEqual(model(), torch.tensor([1.0]))
|
||||||
|
|
||||||
|
|
||||||
class TestAutoWrap(TestCase):
|
class TestAutoWrap(TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
|
|
|
||||||
|
|
@ -596,8 +596,8 @@ def _root_pre_forward(
|
||||||
args_tuple, kwargs_tuple = _to_kwargs(
|
args_tuple, kwargs_tuple = _to_kwargs(
|
||||||
args, kwargs, state.compute_device, False
|
args, kwargs, state.compute_device, False
|
||||||
)
|
)
|
||||||
args = args_tuple[0]
|
args = args_tuple[0] if args_tuple else tuple()
|
||||||
kwargs = kwargs_tuple[0]
|
kwargs = kwargs_tuple[0] if kwargs_tuple else {}
|
||||||
|
|
||||||
return _root_cast_forward_input(state, module, args, kwargs)
|
return _root_cast_forward_input(state, module, args, kwargs)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user