mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Only make a shallow copy when loading optimizer state_dict (#106082)
The thing we do still deep copy is the param_groups, which is much lighter weight. This should also save memory when loading from a checkpoint. The deepcopy was introduced inecfcf39f30, but module.py had only a shallow copy at that point so it did not actually bring parity. Incorporates an XLA fix, which is why I'm updating the pin toca5eab87a7Pull Request resolved: https://github.com/pytorch/pytorch/pull/106082 Approved by: https://github.com/albanD, https://github.com/Skylion007
This commit is contained in:
parent
ceea08a986
commit
59d0dea90f
|
|
@ -36,3 +36,5 @@ dd3a77bc965adf9fe8ba582ee13bb7f14c9661b0
|
||||||
f70844bec783bfce43c950ccf180dc494e86f2bf
|
f70844bec783bfce43c950ccf180dc494e86f2bf
|
||||||
# 2023-07-28 Apply UFMT to all non test/torch files
|
# 2023-07-28 Apply UFMT to all non test/torch files
|
||||||
e6ec0efaf87703c5f889cfc20b29be455885d58d
|
e6ec0efaf87703c5f889cfc20b29be455885d58d
|
||||||
|
# 2023-07-31 [optim][BE] split test file into logical parts: SWA, LR, optim
|
||||||
|
a53cda1ddc15336dc1ff0ce1eff2a49cdc5f882e
|
||||||
|
|
|
||||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
|
|
@ -1 +1 @@
|
||||||
f5edcb2088195db71bcd36d0f8f1b6a5e663afd8
|
ca5eab87a71f80cd3168630511d02549cc7d2516
|
||||||
|
|
|
||||||
|
|
@ -238,8 +238,6 @@ class TestOptim(TestCase):
|
||||||
optimizer_c.step(fn_c)
|
optimizer_c.step(fn_c)
|
||||||
self.assertEqual(weight, weight_c)
|
self.assertEqual(weight, weight_c)
|
||||||
self.assertEqual(bias, bias_c)
|
self.assertEqual(bias, bias_c)
|
||||||
# Make sure state dict wasn't modified
|
|
||||||
self.assertEqual(state_dict, state_dict_c)
|
|
||||||
# Make sure state dict is deterministic with equal but not identical parameters
|
# Make sure state dict is deterministic with equal but not identical parameters
|
||||||
self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
|
self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
|
||||||
# Make sure repeated parameters have identical representation in state dict
|
# Make sure repeated parameters have identical representation in state dict
|
||||||
|
|
@ -301,7 +299,7 @@ class TestOptim(TestCase):
|
||||||
state_dict_c = deepcopy(optimizer.state_dict())
|
state_dict_c = deepcopy(optimizer.state_dict())
|
||||||
optimizer_cuda.load_state_dict(state_dict_c)
|
optimizer_cuda.load_state_dict(state_dict_c)
|
||||||
|
|
||||||
# Make sure state dict wasn't modified
|
# Make sure state_dict_c isn't modified by merely calling load_state_dict
|
||||||
self.assertEqual(state_dict, state_dict_c)
|
self.assertEqual(state_dict, state_dict_c)
|
||||||
|
|
||||||
# Make sure that device of state['step'] is still CPU
|
# Make sure that device of state['step'] is still CPU
|
||||||
|
|
@ -312,7 +310,7 @@ class TestOptim(TestCase):
|
||||||
for state in new_state_dict["state"].values():
|
for state in new_state_dict["state"].values():
|
||||||
self.assertEqual(state["step"].device.type, "cpu")
|
self.assertEqual(state["step"].device.type, "cpu")
|
||||||
|
|
||||||
for _i in range(20):
|
for _ in range(20):
|
||||||
optimizer.step(fn)
|
optimizer.step(fn)
|
||||||
optimizer_cuda.step(fn_cuda)
|
optimizer_cuda.step(fn_cuda)
|
||||||
self.assertEqual(weight, weight_cuda)
|
self.assertEqual(weight, weight_cuda)
|
||||||
|
|
|
||||||
|
|
@ -712,8 +712,8 @@ class Optimizer:
|
||||||
state_dict (dict): optimizer state. Should be an object returned
|
state_dict (dict): optimizer state. Should be an object returned
|
||||||
from a call to :meth:`state_dict`.
|
from a call to :meth:`state_dict`.
|
||||||
"""
|
"""
|
||||||
# deepcopy, to be consistent with module API
|
# shallow copy, to be consistent with module API
|
||||||
state_dict = deepcopy(state_dict)
|
state_dict = state_dict.copy()
|
||||||
|
|
||||||
for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
|
for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
|
||||||
hook_result = pre_hook(self, state_dict)
|
hook_result = pre_hook(self, state_dict)
|
||||||
|
|
@ -722,7 +722,9 @@ class Optimizer:
|
||||||
|
|
||||||
# Validate the state_dict
|
# Validate the state_dict
|
||||||
groups = self.param_groups
|
groups = self.param_groups
|
||||||
saved_groups = state_dict['param_groups']
|
|
||||||
|
# Deepcopy as we write into saved_groups later to update state
|
||||||
|
saved_groups = deepcopy(state_dict['param_groups'])
|
||||||
|
|
||||||
if len(groups) != len(saved_groups):
|
if len(groups) != len(saved_groups):
|
||||||
raise ValueError("loaded state dict has a different number of "
|
raise ValueError("loaded state dict has a different number of "
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user