pytorch/torch/optim/adam.py
2016-11-07 22:50:56 +01:00

52 lines
1.9 KiB
Python

import math
from .optimizer import Optimizer
class Adam(Optimizer):
def __init__(self, params, lr=1e-2, betas=(0.9, 0.999), epsilon=1e-8,
weight_decay=0):
defaults = dict(lr=lr, betas=betas, epsilon=epsilon,
weight_decay=weight_decay)
super(Adam, self).__init__(params, defaults)
def step(self, forward_closure=None):
loss = None
if forward_closure is not None:
loss = self._forward_backward(forward_closure)
for group in self.param_groups:
for p in group['params']:
grad = p.grad
state = self.state[id(p)]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = grad.new().resize_as_(grad).zero_()
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], p.data)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
denom = exp_avg_sq.sqrt().add_(group['epsilon'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
return loss