mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Cast tensors when loading optimizer state dicts (#3658)
This commit is contained in:
parent
51ca3a1a48
commit
af9fd35d82
|
|
@ -133,7 +133,8 @@ class TestOptim(TestCase):
|
|||
|
||||
def fn_base(optimizer, weight, bias):
|
||||
optimizer.zero_grad()
|
||||
loss = (weight.mv(input) + bias).pow(2).sum()
|
||||
i = input_cuda if weight.is_cuda else input
|
||||
loss = (weight.mv(i) + bias).pow(2).sum()
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
|
|
@ -161,6 +162,29 @@ class TestOptim(TestCase):
|
|||
# Make sure state dict wasn't modified
|
||||
self.assertEqual(state_dict, state_dict_c)
|
||||
|
||||
# Check that state dict can be loaded even when we cast parameters
|
||||
# to a different type and move to a different device.
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
input_cuda = Variable(input.data.float().cuda())
|
||||
weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True)
|
||||
bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True)
|
||||
optimizer_cuda = constructor(weight_cuda, bias_cuda)
|
||||
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda)
|
||||
|
||||
state_dict = deepcopy(optimizer.state_dict())
|
||||
state_dict_c = deepcopy(optimizer.state_dict())
|
||||
optimizer_cuda.load_state_dict(state_dict_c)
|
||||
# Make sure state dict wasn't modified
|
||||
self.assertEqual(state_dict, state_dict_c)
|
||||
|
||||
for i in range(20):
|
||||
optimizer.step(fn)
|
||||
optimizer_cuda.step(fn_cuda)
|
||||
self.assertEqual(weight, weight_cuda)
|
||||
self.assertEqual(bias, bias_cuda)
|
||||
|
||||
def _test_basic_cases(self, constructor, ignore_multidevice=False):
|
||||
self._test_state_dict(
|
||||
torch.randn(10, 5),
|
||||
|
|
|
|||
|
|
@ -93,7 +93,9 @@ class LBFGS(Optimizer):
|
|||
line_search_fn = group['line_search_fn']
|
||||
history_size = group['history_size']
|
||||
|
||||
state = self.state['global_state']
|
||||
# NOTE: LBFGS has only global state, but we register it as state for
|
||||
# the first param, because this helps with casting in load_state_dict
|
||||
state = self.state[self._params[0]]
|
||||
state.setdefault('func_evals', 0)
|
||||
state.setdefault('n_iter', 0)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from collections import defaultdict
|
||||
from collections import defaultdict, Iterable
|
||||
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
|
|
@ -96,8 +96,33 @@ class Optimizer(object):
|
|||
id_map = {old_id: p for old_id, p in
|
||||
zip(chain(*(g['params'] for g in saved_groups)),
|
||||
chain(*(g['params'] for g in groups)))}
|
||||
state = defaultdict(
|
||||
dict, {id_map.get(k, k): v for k, v in state_dict['state'].items()})
|
||||
|
||||
def cast(param, value):
|
||||
"""Make a deep copy of value, casting all tensors to device of param."""
|
||||
if torch.is_tensor(value):
|
||||
# Floating-point types are a bit special here. They are the only ones
|
||||
# that are assumed to always match the type of params.
|
||||
if any(tp in type(param.data).__name__ for tp in {'Half', 'Float', 'Double'}):
|
||||
value = value.type_as(param.data)
|
||||
value = value.cuda(param.get_device()) if param.is_cuda else value.cpu()
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return {k: cast(param, v) for k, v in value.items()}
|
||||
elif isinstance(value, Iterable):
|
||||
return type(value)(cast(param, v) for v in value)
|
||||
else:
|
||||
return value
|
||||
|
||||
# Copy state assigned to params (and cast tensors to appropriate types).
|
||||
# State that is not assigned to params is copied as is (needed for
|
||||
# backward compatibility).
|
||||
state = defaultdict(dict)
|
||||
for k, v in state_dict['state'].items():
|
||||
if k in id_map:
|
||||
param = id_map[k]
|
||||
state[param] = cast(param, v)
|
||||
else:
|
||||
state[k] = v
|
||||
|
||||
# Update parameter groups, setting their 'params' value
|
||||
def update_group(group, new_group):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user