mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix casting bug in state_step for optimizers when loading state dict
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75214 Approved by: https://github.com/albanD
This commit is contained in:
parent
22d227fd29
commit
10bb0ffe69
|
|
@ -228,6 +228,12 @@ class TestOptim(TestCase):
|
||||||
# Make sure state dict wasn't modified
|
# Make sure state dict wasn't modified
|
||||||
self.assertEqual(state_dict, state_dict_c)
|
self.assertEqual(state_dict, state_dict_c)
|
||||||
|
|
||||||
|
# Make sure that device of state['step'] is still CPU
|
||||||
|
new_state_dict = optimizer_cuda.state_dict()
|
||||||
|
if 'step' in state_dict['state'][0] and torch.is_tensor(state_dict['state'][0]['step']):
|
||||||
|
for state in new_state_dict['state'].values():
|
||||||
|
self.assertEqual(state['step'].device.type, 'cpu')
|
||||||
|
|
||||||
for _i in range(20):
|
for _i in range(20):
|
||||||
optimizer.step(fn)
|
optimizer.step(fn)
|
||||||
optimizer_cuda.step(fn_cuda)
|
optimizer_cuda.step(fn_cuda)
|
||||||
|
|
|
||||||
|
|
@ -151,17 +151,19 @@ class Optimizer(object):
|
||||||
zip(chain.from_iterable((g['params'] for g in saved_groups)),
|
zip(chain.from_iterable((g['params'] for g in saved_groups)),
|
||||||
chain.from_iterable((g['params'] for g in groups)))}
|
chain.from_iterable((g['params'] for g in groups)))}
|
||||||
|
|
||||||
def cast(param, value):
|
def cast(param, value, key=None):
|
||||||
r"""Make a deep copy of value, casting all tensors to device of param."""
|
r"""Make a deep copy of value, casting all tensors to device of param."""
|
||||||
if isinstance(value, torch.Tensor):
|
if isinstance(value, torch.Tensor):
|
||||||
# Floating-point types are a bit special here. They are the only ones
|
# Floating-point types are a bit special here. They are the only ones
|
||||||
# that are assumed to always match the type of params.
|
# that are assumed to always match the type of params.
|
||||||
|
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
|
||||||
|
if (key != "step"):
|
||||||
if param.is_floating_point():
|
if param.is_floating_point():
|
||||||
value = value.to(param.dtype)
|
value = value.to(param.dtype)
|
||||||
value = value.to(param.device)
|
value = value.to(param.device)
|
||||||
return value
|
return value
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
return {k: cast(param, v) for k, v in value.items()}
|
return {k: cast(param, v, key=k) for k, v in value.items()}
|
||||||
elif isinstance(value, container_abcs.Iterable):
|
elif isinstance(value, container_abcs.Iterable):
|
||||||
return type(value)(cast(param, v) for v in value)
|
return type(value)(cast(param, v) for v in value)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user