Cast tensors when loading optimizer state dicts (#3658)

This commit is contained in:
Adam Paszke 2017-11-28 15:56:39 +01:00 committed by Edward Z. Yang
parent 51ca3a1a48
commit af9fd35d82
3 changed files with 56 additions and 5 deletions

View File

@ -133,7 +133,8 @@ class TestOptim(TestCase):
def fn_base(optimizer, weight, bias):
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
i = input_cuda if weight.is_cuda else input
loss = (weight.mv(i) + bias).pow(2).sum()
loss.backward()
return loss
@ -161,6 +162,29 @@ class TestOptim(TestCase):
# Make sure state dict wasn't modified
self.assertEqual(state_dict, state_dict_c)
# Check that state dict can be loaded even when we cast parameters
# to a different type and move to a different device.
if not torch.cuda.is_available():
return
input_cuda = Variable(input.data.float().cuda())
weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True)
bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True)
optimizer_cuda = constructor(weight_cuda, bias_cuda)
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda)
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_cuda.load_state_dict(state_dict_c)
# Make sure state dict wasn't modified
self.assertEqual(state_dict, state_dict_c)
for i in range(20):
optimizer.step(fn)
optimizer_cuda.step(fn_cuda)
self.assertEqual(weight, weight_cuda)
self.assertEqual(bias, bias_cuda)
def _test_basic_cases(self, constructor, ignore_multidevice=False):
self._test_state_dict(
torch.randn(10, 5),

View File

@ -93,7 +93,9 @@ class LBFGS(Optimizer):
line_search_fn = group['line_search_fn']
history_size = group['history_size']
state = self.state['global_state']
# NOTE: LBFGS has only global state, but we register it as state for
# the first param, because this helps with casting in load_state_dict
state = self.state[self._params[0]]
state.setdefault('func_evals', 0)
state.setdefault('n_iter', 0)

View File

@ -1,4 +1,4 @@
from collections import defaultdict
from collections import defaultdict, Iterable
import torch
from copy import deepcopy
@ -96,8 +96,33 @@ class Optimizer(object):
id_map = {old_id: p for old_id, p in
zip(chain(*(g['params'] for g in saved_groups)),
chain(*(g['params'] for g in groups)))}
state = defaultdict(
dict, {id_map.get(k, k): v for k, v in state_dict['state'].items()})
def cast(param, value):
"""Make a deep copy of value, casting all tensors to device of param."""
if torch.is_tensor(value):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
if any(tp in type(param.data).__name__ for tp in {'Half', 'Float', 'Double'}):
value = value.type_as(param.data)
value = value.cuda(param.get_device()) if param.is_cuda else value.cpu()
return value
elif isinstance(value, dict):
return {k: cast(param, v) for k, v in value.items()}
elif isinstance(value, Iterable):
return type(value)(cast(param, v) for v in value)
else:
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
for k, v in state_dict['state'].items():
if k in id_map:
param = id_map[k]
state[param] = cast(param, v)
else:
state[k] = v
# Update parameter groups, setting their 'params' value
def update_group(group, new_group):