mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132376 Approved by: https://github.com/jamesjwu ghstack dependencies: #132335, #132351, #132352
427 lines
15 KiB
Python
427 lines
15 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import unittest
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.distributed.optim import _NamedOptimizer
|
|
|
|
|
|
def _run_model_training(model_optim_lists):
|
|
for _ in range(2):
|
|
x = torch.rand(5, 8)
|
|
for model_optim_list in model_optim_lists:
|
|
model = model_optim_list[0]
|
|
optim_list = model_optim_list[1]
|
|
y = model(x)
|
|
y.sum().backward()
|
|
for optim in optim_list:
|
|
optim.step()
|
|
|
|
|
|
class TestDummyModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
torch.manual_seed(0)
|
|
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
|
|
self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
|
|
self.net3 = nn.Linear(32, 64)
|
|
self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))
|
|
|
|
def forward(self, x):
|
|
return self.net4(self.net3(self.net2(self.net1(x))))
|
|
|
|
|
|
class NamedOptimizerTest(unittest.TestCase):
|
|
def _compare_state_dict_group(self, group, named_group, assert_equal=True):
|
|
for key, val in group.items():
|
|
if key != "params":
|
|
self.assertTrue(
|
|
key in named_group, f"{key} not in named optimizer state dict"
|
|
)
|
|
err_msg = (
|
|
f"{key} state not equal" if assert_equal else f"{key} state equal"
|
|
)
|
|
if isinstance(val, torch.Tensor):
|
|
fn = self.assertTrue if assert_equal else self.assertFalse
|
|
fn(torch.allclose(val, named_group[key]), err_msg)
|
|
else:
|
|
fn = self.assertEqual if assert_equal else self.assertNotEqual
|
|
fn(val, named_group[key], err_msg)
|
|
|
|
def _compare_param_groups(self, param_groups_1, param_groups_2):
|
|
self.assertTrue(isinstance(param_groups_1, list))
|
|
self.assertTrue(isinstance(param_groups_2, list))
|
|
for groups in zip(param_groups_1, param_groups_2):
|
|
self._compare_param_group(groups[0], groups[1])
|
|
|
|
def _compare_param_group(self, group_1, group_2):
|
|
self.assertTrue(isinstance(group_1, dict))
|
|
self.assertTrue(isinstance(group_2, dict))
|
|
for key, val in group_1.items():
|
|
self.assertTrue(key in group_2)
|
|
if key != "params":
|
|
self.assertEqual(val, group_2[key])
|
|
else:
|
|
for tensors in zip(val, group_2[key]):
|
|
self.assertTrue(torch.allclose(tensors[0], tensors[1]))
|
|
|
|
def test_state_dict(self):
|
|
"""Check that NamedOptimizer exposes the expected state dict
|
|
interface."""
|
|
m = TestDummyModel()
|
|
m_dup = TestDummyModel()
|
|
optim = torch.optim.SGD(
|
|
m.parameters(),
|
|
lr=1e-2,
|
|
momentum=0.9,
|
|
)
|
|
|
|
named_optim = _NamedOptimizer(
|
|
m_dup.named_parameters(),
|
|
torch.optim.SGD,
|
|
lr=1e-2,
|
|
momentum=0.9,
|
|
)
|
|
self._compare_param_groups(optim.param_groups, named_optim.param_groups)
|
|
|
|
_run_model_training([(m, [optim]), (m_dup, [named_optim])])
|
|
self._compare_param_groups(optim.param_groups, named_optim.param_groups)
|
|
|
|
sd = optim.state_dict()
|
|
named_sd = named_optim.state_dict()
|
|
|
|
# Compare "state" in optim state dict
|
|
self._compare_state_dict_group(
|
|
sd["state"][0],
|
|
named_sd["state"]["net1.0.weight"],
|
|
assert_equal=True,
|
|
)
|
|
self._compare_state_dict_group(
|
|
sd["state"][3],
|
|
named_sd["state"]["net2.0.bias"],
|
|
assert_equal=True,
|
|
)
|
|
self._compare_state_dict_group(
|
|
sd["state"][4],
|
|
named_sd["state"]["net3.weight"],
|
|
assert_equal=True,
|
|
)
|
|
self._compare_state_dict_group(
|
|
sd["state"][7],
|
|
named_sd["state"]["net4.1.bias"],
|
|
assert_equal=True,
|
|
)
|
|
|
|
def test_state_dict_multi_param_group(self):
|
|
"""Check that NamedOptimizer exposes the expected state dict
|
|
interface when multiple param groups are specified."""
|
|
m = TestDummyModel()
|
|
m_dup = TestDummyModel()
|
|
optim_1 = torch.optim.SGD(
|
|
[
|
|
{"params": m.net1.parameters()},
|
|
{"params": m.net3.parameters(), "lr": 1e-3},
|
|
],
|
|
lr=1e-2,
|
|
momentum=0.9,
|
|
)
|
|
|
|
optim_2 = torch.optim.Adam(
|
|
[
|
|
{"params": m.net2.parameters()},
|
|
{"params": m.net4.parameters(), "lr": 1e-5},
|
|
]
|
|
)
|
|
|
|
named_optim_1 = _NamedOptimizer(
|
|
m_dup.named_parameters(),
|
|
torch.optim.SGD,
|
|
[
|
|
{"params": m_dup.net1.parameters()},
|
|
{"params": m_dup.net3.parameters(), "lr": 1e-3},
|
|
],
|
|
lr=1e-2,
|
|
momentum=0.9,
|
|
)
|
|
|
|
named_optim_2 = _NamedOptimizer(
|
|
m_dup.named_parameters(),
|
|
torch.optim.Adam,
|
|
[
|
|
{"params": m_dup.net2.parameters()},
|
|
{"params": m_dup.net4.parameters(), "lr": 1e-5},
|
|
],
|
|
)
|
|
self._compare_param_groups(optim_1.param_groups, named_optim_1.param_groups)
|
|
self._compare_param_groups(optim_2.param_groups, named_optim_2.param_groups)
|
|
|
|
_run_model_training(
|
|
[(m, [optim_1, optim_2]), (m_dup, [named_optim_1, named_optim_2])]
|
|
)
|
|
self._compare_param_groups(optim_1.param_groups, named_optim_1.param_groups)
|
|
self._compare_param_groups(optim_2.param_groups, named_optim_2.param_groups)
|
|
sd_1 = optim_1.state_dict()
|
|
sd_2 = optim_2.state_dict()
|
|
named_sd_1 = named_optim_1.state_dict()
|
|
named_sd_2 = named_optim_2.state_dict()
|
|
|
|
# Compare "state" in optim state dict
|
|
self._compare_state_dict_group(
|
|
sd_1["state"][0],
|
|
named_sd_1["state"]["net1.0.weight"],
|
|
assert_equal=True,
|
|
)
|
|
self._compare_state_dict_group(
|
|
sd_2["state"][1],
|
|
named_sd_2["state"]["net2.0.bias"],
|
|
assert_equal=True,
|
|
)
|
|
self._compare_state_dict_group(
|
|
sd_1["state"][2],
|
|
named_sd_1["state"]["net3.weight"],
|
|
assert_equal=True,
|
|
)
|
|
self._compare_state_dict_group(
|
|
sd_2["state"][3],
|
|
named_sd_2["state"]["net4.1.bias"],
|
|
assert_equal=True,
|
|
)
|
|
|
|
# Compare "param_groups" in optim state dict
|
|
self._compare_state_dict_group(
|
|
sd_1["param_groups"][0],
|
|
named_sd_1["param_groups"][0],
|
|
assert_equal=True,
|
|
)
|
|
self._compare_state_dict_group(
|
|
sd_2["param_groups"][1], named_sd_2["param_groups"][1], assert_equal=True
|
|
)
|
|
|
|
def test_load_state_dict(self):
|
|
"""Check that NamedOptimizer's load_state_dict works as expected."""
|
|
m = TestDummyModel()
|
|
named_optim_1 = _NamedOptimizer(
|
|
m.named_parameters(),
|
|
torch.optim.SGD,
|
|
lr=1e-2,
|
|
momentum=0.9,
|
|
)
|
|
|
|
_run_model_training([(m, [named_optim_1])])
|
|
state_dict_to_load = named_optim_1.state_dict()
|
|
|
|
named_optim_2 = _NamedOptimizer(
|
|
m.named_parameters(),
|
|
torch.optim.SGD,
|
|
lr=1e-2,
|
|
momentum=0.6,
|
|
)
|
|
|
|
_run_model_training([(m, [named_optim_2])])
|
|
state_dict_before_load = named_optim_2.state_dict()
|
|
|
|
# Compare "state" in optim state dict
|
|
self._compare_state_dict_group(
|
|
state_dict_to_load["state"]["net1.0.weight"],
|
|
state_dict_before_load["state"]["net1.0.weight"],
|
|
assert_equal=False,
|
|
)
|
|
self._compare_state_dict_group(
|
|
state_dict_to_load["state"]["net2.0.bias"],
|
|
state_dict_before_load["state"]["net2.0.bias"],
|
|
assert_equal=False,
|
|
)
|
|
self._compare_state_dict_group(
|
|
state_dict_to_load["state"]["net3.weight"],
|
|
state_dict_before_load["state"]["net3.weight"],
|
|
assert_equal=False,
|
|
)
|
|
self._compare_state_dict_group(
|
|
state_dict_to_load["state"]["net4.1.bias"],
|
|
state_dict_before_load["state"]["net4.1.bias"],
|
|
assert_equal=False,
|
|
)
|
|
|
|
named_optim_2.load_state_dict(state_dict_to_load)
|
|
state_dict_after_load = named_optim_2.state_dict()
|
|
|
|
# Compare "state" in optim state dict
|
|
self._compare_state_dict_group(
|
|
state_dict_to_load["state"]["net1.0.weight"],
|
|
state_dict_after_load["state"]["net1.0.weight"],
|
|
assert_equal=True,
|
|
)
|
|
self._compare_state_dict_group(
|
|
state_dict_to_load["state"]["net2.0.bias"],
|
|
state_dict_after_load["state"]["net2.0.bias"],
|
|
assert_equal=True,
|
|
)
|
|
self._compare_state_dict_group(
|
|
state_dict_to_load["state"]["net3.weight"],
|
|
state_dict_after_load["state"]["net3.weight"],
|
|
assert_equal=True,
|
|
)
|
|
self._compare_state_dict_group(
|
|
state_dict_to_load["state"]["net4.1.bias"],
|
|
state_dict_after_load["state"]["net4.1.bias"],
|
|
assert_equal=True,
|
|
)
|
|
|
|
def test_load_state_dict_conditional_training(self):
|
|
"""Check that NamedOptimizer load_state_dict works under conditional training case."""
|
|
m = TestDummyModel()
|
|
named_optim_1 = _NamedOptimizer(
|
|
m.named_parameters(),
|
|
torch.optim.SGD,
|
|
[
|
|
{"params": m.net1.parameters()},
|
|
{"params": m.net3.parameters(), "lr": 1e-3},
|
|
],
|
|
lr=1e-2,
|
|
momentum=0.9,
|
|
)
|
|
|
|
_run_model_training([(m, [named_optim_1])])
|
|
state_dict_to_load = named_optim_1.state_dict()
|
|
|
|
named_optim_2 = _NamedOptimizer(
|
|
m.named_parameters(),
|
|
torch.optim.SGD,
|
|
lr=1e-2,
|
|
momentum=0.6,
|
|
)
|
|
|
|
_run_model_training([(m, [named_optim_2])])
|
|
named_optim_2.load_state_dict(state_dict_to_load)
|
|
state_dict_after_load = named_optim_2.state_dict()
|
|
|
|
# Compare "state" in optim state dict
|
|
self._compare_state_dict_group(
|
|
state_dict_to_load["state"]["net1.0.weight"],
|
|
state_dict_after_load["state"]["net1.0.weight"],
|
|
assert_equal=True,
|
|
)
|
|
self._compare_state_dict_group(
|
|
state_dict_to_load["state"]["net3.weight"],
|
|
state_dict_after_load["state"]["net3.weight"],
|
|
assert_equal=True,
|
|
)
|
|
|
|
def test_load_state_dict_error(self):
|
|
m = TestDummyModel()
|
|
named_optim_1 = _NamedOptimizer(
|
|
m.named_parameters(),
|
|
torch.optim.SGD,
|
|
lr=1e-2,
|
|
momentum=0.9,
|
|
)
|
|
|
|
_run_model_training([(m, [named_optim_1])])
|
|
state_dict_to_load = named_optim_1.state_dict()
|
|
|
|
named_optim_2 = _NamedOptimizer(
|
|
m.named_parameters(),
|
|
torch.optim.SGD,
|
|
lr=1e-2,
|
|
momentum=0.6,
|
|
)
|
|
|
|
err_msg = (
|
|
"Expects the optim to be initialized before load but found not initialized"
|
|
)
|
|
with self.assertRaisesRegex(ValueError, err_msg):
|
|
named_optim_2.load_state_dict(state_dict_to_load)
|
|
|
|
def test_add_param_group(self):
|
|
m = TestDummyModel()
|
|
m_dup = TestDummyModel()
|
|
optim = torch.optim.SGD(
|
|
[
|
|
{"params": m.net1.parameters()},
|
|
{"params": m.net3.parameters(), "lr": 1e-3},
|
|
],
|
|
lr=1e-2,
|
|
momentum=0.9,
|
|
)
|
|
named_optim = _NamedOptimizer(
|
|
m_dup.named_parameters(),
|
|
torch.optim.SGD,
|
|
[
|
|
{"params": m_dup.net1.parameters()},
|
|
{"params": m_dup.net3.parameters(), "lr": 1e-3},
|
|
],
|
|
lr=1e-2,
|
|
momentum=0.9,
|
|
)
|
|
|
|
_run_model_training([(m, [optim]), (m_dup, [named_optim])])
|
|
self._compare_param_groups(optim.param_groups, named_optim.param_groups)
|
|
|
|
optim.add_param_group({"params": m.net2.parameters(), "lr": 1e-5})
|
|
named_optim.add_param_group({"params": m_dup.net2.parameters(), "lr": 1e-5})
|
|
_run_model_training([(m, [optim]), (m_dup, [named_optim])])
|
|
self._compare_param_groups(optim.param_groups, named_optim.param_groups)
|
|
|
|
optim.add_param_group({"params": m.net4[1].weight, "lr": 1e-3})
|
|
named_optim.add_param_group({"params": m_dup.net4[1].weight, "lr": 1e-3})
|
|
_run_model_training([(m, [optim]), (m_dup, [named_optim])])
|
|
self._compare_param_groups(optim.param_groups, named_optim.param_groups)
|
|
|
|
def test_add_param_group_error(self):
|
|
m = TestDummyModel()
|
|
named_optim = _NamedOptimizer(
|
|
m.named_parameters(),
|
|
torch.optim.SGD,
|
|
[
|
|
{"params": m.net1.parameters()},
|
|
{"params": m.net3.parameters(), "lr": 1e-3},
|
|
],
|
|
lr=1e-2,
|
|
momentum=0.9,
|
|
)
|
|
|
|
err_msg = "some parameters are not in the module"
|
|
with self.assertRaisesRegex(ValueError, err_msg):
|
|
named_optim.add_param_group({"params": [torch.ones(8, 1)], "lr": 1e-5})
|
|
|
|
def test_init_state(self):
|
|
m = TestDummyModel()
|
|
named_optim = _NamedOptimizer(
|
|
m.named_parameters(),
|
|
torch.optim.SGD,
|
|
[
|
|
{"params": m.net1.parameters()},
|
|
{"params": m.net3.parameters(), "lr": 1e-3},
|
|
],
|
|
lr=1e-2,
|
|
momentum=0.9,
|
|
)
|
|
named_sd = named_optim.state_dict()
|
|
self.assertTrue(m.net1[0].weight.grad is None)
|
|
self.assertTrue(len(named_sd["state"]) == 0)
|
|
named_optim.init_state()
|
|
named_sd = named_optim.state_dict()
|
|
self.assertTrue(m.net1[0].weight.grad is not None)
|
|
self.assertTrue("momentum_buffer" in named_sd["state"]["net1.0.weight"])
|
|
self.assertFalse(
|
|
torch.all(named_sd["state"]["net1.0.weight"]["momentum_buffer"]).item()
|
|
)
|
|
self.assertFalse(
|
|
torch.all(named_sd["state"]["net1.0.bias"]["momentum_buffer"]).item()
|
|
)
|
|
self.assertTrue(m.net3.bias.grad is not None)
|
|
self.assertTrue("momentum_buffer" in named_sd["state"]["net3.bias"])
|
|
self.assertFalse(
|
|
torch.all(named_sd["state"]["net3.bias"]["momentum_buffer"]).item()
|
|
)
|
|
self.assertFalse(
|
|
torch.all(named_sd["state"]["net3.weight"]["momentum_buffer"]).item()
|
|
)
|