mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[NamedOptimizer][2/N] Prepare the enablement of state_dict for FSDP (#91147)
1. Add param_group check logic and unit test 2. Remove unnecessary check for conditional param update 3. Return the param_group from the inner optimizer so that when param_group is None or not all params are specified, we still return the expected result. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91147 Approved by: https://github.com/fegin
This commit is contained in:
parent
c248f2f379
commit
c7e7ea92e2
|
|
@ -44,11 +44,83 @@ class NamedOptimizerTest(unittest.TestCase):
|
|||
fn = self.assertEqual if assert_equal else self.assertNotEqual
|
||||
fn(val, named_group[key], err_msg)
|
||||
|
||||
def _compare_param_groups(self, param_groups_1, param_groups_2):
|
||||
self.assertTrue(isinstance(param_groups_1, list))
|
||||
self.assertTrue(isinstance(param_groups_2, list))
|
||||
for groups in zip(param_groups_1, param_groups_2):
|
||||
self._compare_param_group(groups[0], groups[1])
|
||||
|
||||
def _compare_param_group(self, group_1, group_2):
|
||||
self.assertTrue(isinstance(group_1, dict))
|
||||
self.assertTrue(isinstance(group_2, dict))
|
||||
for key, val in group_1.items():
|
||||
self.assertTrue(key in group_2)
|
||||
if key != "params":
|
||||
self.assertEqual(val, group_2[key])
|
||||
else:
|
||||
for tensors in zip(val, group_2[key]):
|
||||
self.assertTrue(torch.allclose(tensors[0], tensors[1]))
|
||||
|
||||
def test_state_dict(self):
|
||||
"""Check that NamedOptimizer exposes the expected state dict
|
||||
interface."""
|
||||
m = TestDummyModel()
|
||||
m_dup = TestDummyModel()
|
||||
optim = torch.optim.SGD(
|
||||
m.parameters(),
|
||||
lr=1e-2,
|
||||
momentum=0.9,
|
||||
)
|
||||
|
||||
named_optim = _NamedOptimizer(
|
||||
m_dup.named_parameters(),
|
||||
torch.optim.SGD,
|
||||
lr=1e-2,
|
||||
momentum=0.9,
|
||||
)
|
||||
self._compare_param_groups(optim.param_groups, named_optim.param_groups)
|
||||
|
||||
for _ in range(2):
|
||||
x = torch.rand(5, 8)
|
||||
y = m(x)
|
||||
y.sum().backward()
|
||||
optim.step()
|
||||
|
||||
y = m_dup(x)
|
||||
y.sum().backward()
|
||||
named_optim.step()
|
||||
|
||||
self._compare_param_groups(optim.param_groups, named_optim.param_groups)
|
||||
sd = optim.state_dict()
|
||||
named_sd = named_optim.state_dict()
|
||||
|
||||
# Compare "state" in optim state dict
|
||||
self._compare_state_dict_group(
|
||||
sd["state"][0],
|
||||
named_sd["state"]["net1.0.weight"],
|
||||
assert_equal=True,
|
||||
)
|
||||
self._compare_state_dict_group(
|
||||
sd["state"][3],
|
||||
named_sd["state"]["net2.0.bias"],
|
||||
assert_equal=True,
|
||||
)
|
||||
self._compare_state_dict_group(
|
||||
sd["state"][4],
|
||||
named_sd["state"]["net3.weight"],
|
||||
assert_equal=True,
|
||||
)
|
||||
self._compare_state_dict_group(
|
||||
sd["state"][7],
|
||||
named_sd["state"]["net4.1.bias"],
|
||||
assert_equal=True,
|
||||
)
|
||||
|
||||
def test_state_dict_multi_param_group(self):
|
||||
"""Check that NamedOptimizer exposes the expected state dict
|
||||
interface when multiple param groups are specified."""
|
||||
m = TestDummyModel()
|
||||
m_dup = TestDummyModel()
|
||||
optim_1 = torch.optim.SGD(
|
||||
[
|
||||
{"params": m.net1.parameters()},
|
||||
|
|
@ -84,7 +156,10 @@ class NamedOptimizerTest(unittest.TestCase):
|
|||
{"params": m_dup.net4.parameters(), "lr": 1e-5},
|
||||
],
|
||||
)
|
||||
for i in range(2):
|
||||
self._compare_param_groups(optim_1.param_groups, named_optim_1.param_groups)
|
||||
self._compare_param_groups(optim_2.param_groups, named_optim_2.param_groups)
|
||||
|
||||
for _ in range(2):
|
||||
x = torch.rand(5, 8)
|
||||
y = m(x)
|
||||
y.sum().backward()
|
||||
|
|
@ -96,12 +171,15 @@ class NamedOptimizerTest(unittest.TestCase):
|
|||
named_optim_1.step()
|
||||
named_optim_2.step()
|
||||
|
||||
self._compare_param_groups(optim_1.param_groups, named_optim_1.param_groups)
|
||||
self._compare_param_groups(optim_2.param_groups, named_optim_2.param_groups)
|
||||
sd_1 = optim_1.state_dict()
|
||||
sd_2 = optim_2.state_dict()
|
||||
named_sd_1 = named_optim_1.state_dict()
|
||||
named_sd_2 = named_optim_2.state_dict()
|
||||
|
||||
# Compare "state" in optim state dict
|
||||
print(list(named_sd_1["state"].keys()))
|
||||
self._compare_state_dict_group(
|
||||
sd_1["state"][0],
|
||||
named_sd_1["state"]["net1.0.weight"],
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ class _NamedOptimizer(optim.Optimizer):
|
|||
) -> None:
|
||||
torch._C._log_api_usage_once("torch.distributed.optim._NamedOptimizer")
|
||||
self.param_groups: Collection[Mapping[str, Any]] = param_groups # type: ignore[assignment]
|
||||
self._param_groups_check()
|
||||
self.named_parameters = dict(named_parameters)
|
||||
params_for_optimizer = (
|
||||
self.named_parameters.values() if param_groups is None else param_groups
|
||||
|
|
@ -74,7 +75,6 @@ class _NamedOptimizer(optim.Optimizer):
|
|||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
# TODO: Add param_groups validations and unit tests.
|
||||
if param_groups is None:
|
||||
self.ordered_param_keys = list(self.named_parameters.keys())
|
||||
else:
|
||||
|
|
@ -92,6 +92,25 @@ class _NamedOptimizer(optim.Optimizer):
|
|||
)
|
||||
ordered_param_keys.append(param_to_key[param])
|
||||
self.ordered_param_keys = ordered_param_keys
|
||||
# Update param_groups from optimizer.
|
||||
self.param_groups = self._optimizer.param_groups
|
||||
|
||||
def _param_groups_check(self):
|
||||
if self.param_groups is not None:
|
||||
for param_group in self.param_groups:
|
||||
assert isinstance(param_group, dict), "param group must be a dict"
|
||||
assert "params" in param_group, "param group must contain key params"
|
||||
params = param_group["params"]
|
||||
if isinstance(params, torch.Tensor):
|
||||
params = [params]
|
||||
params = list(params)
|
||||
for param in params:
|
||||
if not isinstance(param, torch.Tensor):
|
||||
raise TypeError(
|
||||
"optimizer can only optimize Tensors, "
|
||||
"but one of the params is " + torch.typename(param)
|
||||
)
|
||||
param_group["params"] = params
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
|
|
@ -170,10 +189,9 @@ class _NamedOptimizer(optim.Optimizer):
|
|||
f"Expects equal length as {len(new_state)} in `state_dict` state length but found {len(state)}."
|
||||
)
|
||||
for idx, param_key in enumerate(self.ordered_param_keys):
|
||||
# When the conditional training is performed, not all parameters are updated in the optim.
|
||||
if param_key not in state.keys():
|
||||
raise ValueError(
|
||||
f"Expect {param_key} as a parameter in `state_dict` state but not found."
|
||||
)
|
||||
continue
|
||||
if len(state[param_key]) != len(new_state[idx]):
|
||||
raise ValueError(
|
||||
f"Expects equal length as {len(new_state[idx])} for parameter {param_key} but found: {len(state[param_key])}"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user