mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Solving pickle error when saving CyclicLR state_dict (#110931)
## How to reproduce:
```py
import os
import tempfile
import torch
from torch import nn
from torch.optim import SGD
from torch.optim.lr_scheduler import CyclicLR
model = nn.Linear(100, 100)
opt = SGD(model.parameters(), lr=1.)
scheduler = CyclicLR(opt, base_lr=0.1, max_lr=0.2, scale_fn=lambda x: 0.99)
tmp = tempfile.NamedTemporaryFile(delete=False)
try:
torch.save(scheduler.state_dict(), tmp.name)
scheduler.load_state_dict(torch.load(tmp.name))
finally:
tmp.close()
os.unlink(tmp.name)
```
Error:
```
_pickle.PicklingError: Can't pickle <function <lambda> at 0x000001A51DF67600>: attribute lookup <lambda> on __main__ failed
```
## Fix:
Saving `scale_fn` to the state dict only if it is a callable object and not if it is a function or lambda.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110931
Approved by: https://github.com/janeyx99
This commit is contained in:
parent
a0e3321f0c
commit
2b72543f36
|
|
@ -1530,10 +1530,39 @@ class TestLRScheduler(TestCase):
|
|||
|
||||
def test_cycle_lr_state_dict_picklable(self):
|
||||
adam_opt = Adam(self.net.parameters())
|
||||
|
||||
# Case 1: Built-in mode
|
||||
scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False)
|
||||
self.assertIsInstance(scheduler._scale_fn_ref, types.FunctionType)
|
||||
state = scheduler.state_dict()
|
||||
self.assertNotIn("_scale_fn_ref", state)
|
||||
self.assertIs(state["_scale_fn_custom"], None)
|
||||
pickle.dumps(state)
|
||||
|
||||
# Case 2: Custom `scale_fn`, a function object
|
||||
def scale_fn(_):
|
||||
return 0.5
|
||||
|
||||
scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn)
|
||||
state = scheduler.state_dict()
|
||||
self.assertNotIn("_scale_fn_ref", state)
|
||||
self.assertIs(state["_scale_fn_custom"], None)
|
||||
pickle.dumps(state)
|
||||
|
||||
# Case 3: Custom `scale_fn`, a callable class
|
||||
class ScaleFn:
|
||||
def __init__(self):
|
||||
self.x = 0.5
|
||||
|
||||
def __call__(self, _):
|
||||
return self.x
|
||||
|
||||
scale_fn = ScaleFn()
|
||||
|
||||
scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn)
|
||||
state = scheduler.state_dict()
|
||||
self.assertNotIn("_scale_fn_ref", state)
|
||||
self.assertEqual(state["_scale_fn_custom"], scale_fn.__dict__)
|
||||
pickle.dumps(state)
|
||||
|
||||
def test_cycle_lr_scale_fn_restored_from_state_dict(self):
|
||||
|
|
|
|||
|
|
@ -1365,16 +1365,26 @@ class CyclicLR(LRScheduler):
|
|||
|
||||
def state_dict(self):
|
||||
state = super().state_dict()
|
||||
# We are dropping the `_scale_fn_ref` attribute because it is a `weakref.WeakMethod` and can't be pickled
|
||||
state.pop("_scale_fn_ref")
|
||||
# We are dropping the `_scale_fn_ref` attribute because it is a
|
||||
# `weakref.WeakMethod` and can't be pickled.
|
||||
state.pop('_scale_fn_ref')
|
||||
fn = state.pop('_scale_fn_custom')
|
||||
state['_scale_fn_custom'] = None
|
||||
if fn is not None and not isinstance(fn, types.FunctionType):
|
||||
# The _scale_fn_custom will only be saved if it is a callable object
|
||||
# and not if it is a function or lambda.
|
||||
state['_scale_fn_custom'] = fn.__dict__.copy()
|
||||
|
||||
return state
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
fn = state_dict.pop('_scale_fn_custom')
|
||||
super().load_state_dict(state_dict)
|
||||
if fn is not None:
|
||||
self._scale_fn_custom.__dict__.update(fn)
|
||||
self._init_scale_fn()
|
||||
|
||||
|
||||
|
||||
class CosineAnnealingWarmRestarts(LRScheduler):
|
||||
r"""Set the learning rate of each parameter group using a cosine annealing
|
||||
schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user