[NamedOptimizer][2/N] Prepare the enablement of state_dict for FSDP (#91147)

1. Add param_group check logic and unit test
2. Remove unnecessary check for conditional param update
3. Return the param_group from the inner optimizer so that when param_group is None or not all params are specified, we still return the expected result.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91147
Approved by: https://github.com/fegin
This commit is contained in:
fduwjj 2022-12-20 19:56:29 +00:00 committed by PyTorch MergeBot
parent c248f2f379
commit c7e7ea92e2
2 changed files with 101 additions and 5 deletions

View File

@ -44,11 +44,83 @@ class NamedOptimizerTest(unittest.TestCase):
fn = self.assertEqual if assert_equal else self.assertNotEqual
fn(val, named_group[key], err_msg)
def _compare_param_groups(self, param_groups_1, param_groups_2):
self.assertTrue(isinstance(param_groups_1, list))
self.assertTrue(isinstance(param_groups_2, list))
for groups in zip(param_groups_1, param_groups_2):
self._compare_param_group(groups[0], groups[1])
def _compare_param_group(self, group_1, group_2):
self.assertTrue(isinstance(group_1, dict))
self.assertTrue(isinstance(group_2, dict))
for key, val in group_1.items():
self.assertTrue(key in group_2)
if key != "params":
self.assertEqual(val, group_2[key])
else:
for tensors in zip(val, group_2[key]):
self.assertTrue(torch.allclose(tensors[0], tensors[1]))
def test_state_dict(self):
"""Check that NamedOptimizer exposes the expected state dict
interface."""
m = TestDummyModel()
m_dup = TestDummyModel()
optim = torch.optim.SGD(
m.parameters(),
lr=1e-2,
momentum=0.9,
)
named_optim = _NamedOptimizer(
m_dup.named_parameters(),
torch.optim.SGD,
lr=1e-2,
momentum=0.9,
)
self._compare_param_groups(optim.param_groups, named_optim.param_groups)
for _ in range(2):
x = torch.rand(5, 8)
y = m(x)
y.sum().backward()
optim.step()
y = m_dup(x)
y.sum().backward()
named_optim.step()
self._compare_param_groups(optim.param_groups, named_optim.param_groups)
sd = optim.state_dict()
named_sd = named_optim.state_dict()
# Compare "state" in optim state dict
self._compare_state_dict_group(
sd["state"][0],
named_sd["state"]["net1.0.weight"],
assert_equal=True,
)
self._compare_state_dict_group(
sd["state"][3],
named_sd["state"]["net2.0.bias"],
assert_equal=True,
)
self._compare_state_dict_group(
sd["state"][4],
named_sd["state"]["net3.weight"],
assert_equal=True,
)
self._compare_state_dict_group(
sd["state"][7],
named_sd["state"]["net4.1.bias"],
assert_equal=True,
)
def test_state_dict_multi_param_group(self):
"""Check that NamedOptimizer exposes the expected state dict
interface when multiple param groups are specified."""
m = TestDummyModel()
m_dup = TestDummyModel()
optim_1 = torch.optim.SGD(
[
{"params": m.net1.parameters()},
@ -84,7 +156,10 @@ class NamedOptimizerTest(unittest.TestCase):
{"params": m_dup.net4.parameters(), "lr": 1e-5},
],
)
for i in range(2):
self._compare_param_groups(optim_1.param_groups, named_optim_1.param_groups)
self._compare_param_groups(optim_2.param_groups, named_optim_2.param_groups)
for _ in range(2):
x = torch.rand(5, 8)
y = m(x)
y.sum().backward()
@ -96,12 +171,15 @@ class NamedOptimizerTest(unittest.TestCase):
named_optim_1.step()
named_optim_2.step()
self._compare_param_groups(optim_1.param_groups, named_optim_1.param_groups)
self._compare_param_groups(optim_2.param_groups, named_optim_2.param_groups)
sd_1 = optim_1.state_dict()
sd_2 = optim_2.state_dict()
named_sd_1 = named_optim_1.state_dict()
named_sd_2 = named_optim_2.state_dict()
# Compare "state" in optim state dict
print(list(named_sd_1["state"].keys()))
self._compare_state_dict_group(
sd_1["state"][0],
named_sd_1["state"]["net1.0.weight"],

View File

@ -65,6 +65,7 @@ class _NamedOptimizer(optim.Optimizer):
) -> None:
torch._C._log_api_usage_once("torch.distributed.optim._NamedOptimizer")
self.param_groups: Collection[Mapping[str, Any]] = param_groups # type: ignore[assignment]
self._param_groups_check()
self.named_parameters = dict(named_parameters)
params_for_optimizer = (
self.named_parameters.values() if param_groups is None else param_groups
@ -74,7 +75,6 @@ class _NamedOptimizer(optim.Optimizer):
*args,
**kwargs,
)
# TODO: Add param_groups validations and unit tests.
if param_groups is None:
self.ordered_param_keys = list(self.named_parameters.keys())
else:
@ -92,6 +92,25 @@ class _NamedOptimizer(optim.Optimizer):
)
ordered_param_keys.append(param_to_key[param])
self.ordered_param_keys = ordered_param_keys
# Update param_groups from optimizer.
self.param_groups = self._optimizer.param_groups
def _param_groups_check(self):
if self.param_groups is not None:
for param_group in self.param_groups:
assert isinstance(param_group, dict), "param group must be a dict"
assert "params" in param_group, "param group must contain key params"
params = param_group["params"]
if isinstance(params, torch.Tensor):
params = [params]
params = list(params)
for param in params:
if not isinstance(param, torch.Tensor):
raise TypeError(
"optimizer can only optimize Tensors, "
"but one of the params is " + torch.typename(param)
)
param_group["params"] = params
def state_dict(self) -> Dict[str, Any]:
"""
@ -170,10 +189,9 @@ class _NamedOptimizer(optim.Optimizer):
f"Expects equal length as {len(new_state)} in `state_dict` state length but found {len(state)}."
)
for idx, param_key in enumerate(self.ordered_param_keys):
# When the conditional training is performed, not all parameters are updated in the optim.
if param_key not in state.keys():
raise ValueError(
f"Expect {param_key} as a parameter in `state_dict` state but not found."
)
continue
if len(state[param_key]) != len(new_state[idx]):
raise ValueError(
f"Expects equal length as {len(new_state[idx])} for parameter {param_key} but found: {len(state[param_key])}"