Solving pickle error when saving CyclicLR state_dict (#110931)

## How to reproduce:
```py
import os
import tempfile

import torch
from torch import nn
from torch.optim import SGD
from torch.optim.lr_scheduler import CyclicLR

model = nn.Linear(100, 100)
opt = SGD(model.parameters(), lr=1.)
scheduler = CyclicLR(opt, base_lr=0.1, max_lr=0.2, scale_fn=lambda x: 0.99)

tmp = tempfile.NamedTemporaryFile(delete=False)
try:
    torch.save(scheduler.state_dict(), tmp.name)
    scheduler.load_state_dict(torch.load(tmp.name))
finally:
    tmp.close()
    os.unlink(tmp.name)
```
Error:
```
_pickle.PicklingError: Can't pickle <function <lambda> at 0x000001A51DF67600>: attribute lookup <lambda> on __main__ failed
```
## Fix:
Saving `scale_fn` to the state dict only if it is a callable object and not if it is a function or lambda.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110931
Approved by: https://github.com/janeyx99
This commit is contained in:
ancestor-mithril 2023-11-22 11:38:31 +00:00 committed by PyTorch MergeBot
parent a0e3321f0c
commit 2b72543f36
2 changed files with 42 additions and 3 deletions

View File

@ -1530,10 +1530,39 @@ class TestLRScheduler(TestCase):
def test_cycle_lr_state_dict_picklable(self):
adam_opt = Adam(self.net.parameters())
# Case 1: Built-in mode
scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False)
self.assertIsInstance(scheduler._scale_fn_ref, types.FunctionType)
state = scheduler.state_dict()
self.assertNotIn("_scale_fn_ref", state)
self.assertIs(state["_scale_fn_custom"], None)
pickle.dumps(state)
# Case 2: Custom `scale_fn`, a function object
def scale_fn(_):
return 0.5
scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn)
state = scheduler.state_dict()
self.assertNotIn("_scale_fn_ref", state)
self.assertIs(state["_scale_fn_custom"], None)
pickle.dumps(state)
# Case 3: Custom `scale_fn`, a callable class
class ScaleFn:
def __init__(self):
self.x = 0.5
def __call__(self, _):
return self.x
scale_fn = ScaleFn()
scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn)
state = scheduler.state_dict()
self.assertNotIn("_scale_fn_ref", state)
self.assertEqual(state["_scale_fn_custom"], scale_fn.__dict__)
pickle.dumps(state)
def test_cycle_lr_scale_fn_restored_from_state_dict(self):

View File

@ -1365,16 +1365,26 @@ class CyclicLR(LRScheduler):
def state_dict(self):
state = super().state_dict()
# We are dropping the `_scale_fn_ref` attribute because it is a `weakref.WeakMethod` and can't be pickled
state.pop("_scale_fn_ref")
# We are dropping the `_scale_fn_ref` attribute because it is a
# `weakref.WeakMethod` and can't be pickled.
state.pop('_scale_fn_ref')
fn = state.pop('_scale_fn_custom')
state['_scale_fn_custom'] = None
if fn is not None and not isinstance(fn, types.FunctionType):
# The _scale_fn_custom will only be saved if it is a callable object
# and not if it is a function or lambda.
state['_scale_fn_custom'] = fn.__dict__.copy()
return state
def load_state_dict(self, state_dict):
fn = state_dict.pop('_scale_fn_custom')
super().load_state_dict(state_dict)
if fn is not None:
self._scale_fn_custom.__dict__.update(fn)
self._init_scale_fn()
class CosineAnnealingWarmRestarts(LRScheduler):
r"""Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`