mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66587 Made some changes in the step function of the non-vectorized Adadelta optimizer to handle complex numbers as two real numbers as per 65711 on github ghstack-source-id: 141484731 Test Plan: buck test mode/dev caffe2/test:optim -- 'test_adadelta_complex' https://pxl.cl/1R7kJ Reviewed By: albanD Differential Revision: D31630069 fbshipit-source-id: 2741177b837960538ce39772897af36bbce7b7d8
526 lines
17 KiB
Python
526 lines
17 KiB
Python
r"""Functional interface"""
|
|
import math
|
|
import torch
|
|
from torch import Tensor
|
|
from typing import List, Optional
|
|
|
|
# TODO: use foreach API in optim._functional to do all the computation
|
|
|
|
def _make_sparse(grad, grad_indices, values):
|
|
size = grad.size()
|
|
if grad_indices.numel() == 0 or values.numel() == 0:
|
|
return torch.empty_like(grad)
|
|
return torch.sparse_coo_tensor(grad_indices, values, size)
|
|
|
|
|
|
def adagrad(params: List[Tensor],
|
|
grads: List[Tensor],
|
|
state_sums: List[Tensor],
|
|
state_steps: List[int],
|
|
*,
|
|
lr: float,
|
|
weight_decay: float,
|
|
lr_decay: float,
|
|
eps: float):
|
|
r"""Functional API that performs Adagrad algorithm computation.
|
|
|
|
See :class:`~torch.optim.Adagrad` for details.
|
|
"""
|
|
|
|
for (param, grad, state_sum, step) in zip(params, grads, state_sums, state_steps):
|
|
if weight_decay != 0:
|
|
if grad.is_sparse:
|
|
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
|
|
grad = grad.add(param, alpha=weight_decay)
|
|
|
|
clr = lr / (1 + (step - 1) * lr_decay)
|
|
|
|
if grad.is_sparse:
|
|
grad = grad.coalesce() # the update is non-linear so indices must be unique
|
|
grad_indices = grad._indices()
|
|
grad_values = grad._values()
|
|
size = grad.size()
|
|
|
|
state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
|
|
std = state_sum.sparse_mask(grad)
|
|
std_values = std._values().sqrt_().add_(eps)
|
|
param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
|
|
else:
|
|
is_complex = torch.is_complex(param)
|
|
if is_complex:
|
|
grad = torch.view_as_real(grad)
|
|
state_sum = torch.view_as_real(state_sum)
|
|
param = torch.view_as_real(param)
|
|
state_sum.addcmul_(grad, grad, value=1)
|
|
std = state_sum.sqrt().add_(eps)
|
|
param.addcdiv_(grad, std, value=-clr)
|
|
if is_complex:
|
|
param = torch.view_as_complex(param)
|
|
state_sum = torch.view_as_complex(state_sum)
|
|
|
|
|
|
|
|
|
|
def adam(params: List[Tensor],
|
|
grads: List[Tensor],
|
|
exp_avgs: List[Tensor],
|
|
exp_avg_sqs: List[Tensor],
|
|
max_exp_avg_sqs: List[Tensor],
|
|
state_steps: List[int],
|
|
*,
|
|
amsgrad: bool,
|
|
beta1: float,
|
|
beta2: float,
|
|
lr: float,
|
|
weight_decay: float,
|
|
eps: float):
|
|
r"""Functional API that performs Adam algorithm computation.
|
|
|
|
See :class:`~torch.optim.Adam` for details.
|
|
"""
|
|
|
|
for i, param in enumerate(params):
|
|
|
|
grad = grads[i]
|
|
exp_avg = exp_avgs[i]
|
|
exp_avg_sq = exp_avg_sqs[i]
|
|
step = state_steps[i]
|
|
|
|
bias_correction1 = 1 - beta1 ** step
|
|
bias_correction2 = 1 - beta2 ** step
|
|
|
|
if weight_decay != 0:
|
|
grad = grad.add(param, alpha=weight_decay)
|
|
|
|
# Decay the first and second moment running average coefficient
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
|
if amsgrad:
|
|
# Maintains the maximum of all 2nd moment running avg. till now
|
|
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
|
|
# Use the max. for normalizing running avg. of gradient
|
|
denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
|
else:
|
|
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
|
|
|
step_size = lr / bias_correction1
|
|
|
|
param.addcdiv_(exp_avg, denom, value=-step_size)
|
|
|
|
|
|
def adamw(params: List[Tensor],
|
|
grads: List[Tensor],
|
|
exp_avgs: List[Tensor],
|
|
exp_avg_sqs: List[Tensor],
|
|
max_exp_avg_sqs: List[Tensor],
|
|
state_steps: List[int],
|
|
*,
|
|
amsgrad: bool,
|
|
beta1: float,
|
|
beta2: float,
|
|
lr: float,
|
|
weight_decay: float,
|
|
eps: float):
|
|
r"""Functional API that performs AdamW algorithm computation.
|
|
|
|
See :class:`~torch.optim.AdamW` for details.
|
|
"""
|
|
for i, param in enumerate(params):
|
|
grad = grads[i]
|
|
exp_avg = exp_avgs[i]
|
|
exp_avg_sq = exp_avg_sqs[i]
|
|
step = state_steps[i]
|
|
|
|
# Perform stepweight decay
|
|
param.mul_(1 - lr * weight_decay)
|
|
|
|
bias_correction1 = 1 - beta1 ** step
|
|
bias_correction2 = 1 - beta2 ** step
|
|
|
|
# Decay the first and second moment running average coefficient
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
if amsgrad:
|
|
# Maintains the maximum of all 2nd moment running avg. till now
|
|
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
|
|
# Use the max. for normalizing running avg. of gradient
|
|
denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
|
else:
|
|
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
|
|
|
step_size = lr / bias_correction1
|
|
|
|
param.addcdiv_(exp_avg, denom, value=-step_size)
|
|
|
|
|
|
def sgd(params: List[Tensor],
|
|
d_p_list: List[Tensor],
|
|
momentum_buffer_list: List[Optional[Tensor]],
|
|
*,
|
|
weight_decay: float,
|
|
momentum: float,
|
|
lr: float,
|
|
dampening: float,
|
|
nesterov: bool):
|
|
r"""Functional API that performs SGD algorithm computation.
|
|
|
|
See :class:`~torch.optim.SGD` for details.
|
|
"""
|
|
|
|
for i, param in enumerate(params):
|
|
|
|
d_p = d_p_list[i]
|
|
if weight_decay != 0:
|
|
d_p = d_p.add(param, alpha=weight_decay)
|
|
|
|
if momentum != 0:
|
|
buf = momentum_buffer_list[i]
|
|
|
|
if buf is None:
|
|
buf = torch.clone(d_p).detach()
|
|
momentum_buffer_list[i] = buf
|
|
else:
|
|
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
|
|
|
|
if nesterov:
|
|
d_p = d_p.add(buf, alpha=momentum)
|
|
else:
|
|
d_p = buf
|
|
|
|
param.add_(d_p, alpha=-lr)
|
|
|
|
|
|
def adadelta(params: List[Tensor],
|
|
grads: List[Tensor],
|
|
square_avgs: List[Tensor],
|
|
acc_deltas: List[Tensor],
|
|
*,
|
|
lr: float,
|
|
rho: float,
|
|
eps: float,
|
|
weight_decay: float):
|
|
r"""Functional API that performs Adadelta algorithm computation.
|
|
|
|
See :class:`~torch.optim.Adadelta` for details.
|
|
"""
|
|
|
|
for (param, grad, square_avg, acc_delta) in zip(params, grads, square_avgs, acc_deltas):
|
|
if weight_decay != 0:
|
|
grad = grad.add(param, alpha=weight_decay)
|
|
|
|
if torch.is_complex(param):
|
|
square_avg = torch.view_as_real(square_avg)
|
|
acc_delta = torch.view_as_real(acc_delta)
|
|
grad = torch.view_as_real(grad)
|
|
|
|
square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho)
|
|
std = square_avg.add(eps).sqrt_()
|
|
delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad)
|
|
acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho)
|
|
if torch.is_complex(param):
|
|
delta = torch.view_as_complex(delta)
|
|
param.add_(delta, alpha=-lr)
|
|
|
|
def rmsprop(params: List[Tensor],
|
|
grads: List[Tensor],
|
|
square_avgs: List[Tensor],
|
|
grad_avgs: List[Tensor],
|
|
momentum_buffer_list: List[Tensor],
|
|
*,
|
|
lr: float,
|
|
alpha: float,
|
|
eps: float,
|
|
weight_decay: float,
|
|
momentum: float,
|
|
centered: bool):
|
|
r"""Functional API that performs rmsprop algorithm computation.
|
|
|
|
See :class:`~torch.optim.RMSProp` for details.
|
|
"""
|
|
|
|
for i, param in enumerate(params):
|
|
grad = grads[i]
|
|
square_avg = square_avgs[i]
|
|
|
|
if weight_decay != 0:
|
|
grad = grad.add(param, alpha=weight_decay)
|
|
|
|
square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
|
|
|
|
if centered:
|
|
grad_avg = grad_avgs[i]
|
|
grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha)
|
|
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(eps)
|
|
else:
|
|
avg = square_avg.sqrt().add_(eps)
|
|
|
|
if momentum > 0:
|
|
buf = momentum_buffer_list[i]
|
|
buf.mul_(momentum).addcdiv_(grad, avg)
|
|
param.add_(buf, alpha=-lr)
|
|
else:
|
|
param.addcdiv_(grad, avg, value=-lr)
|
|
|
|
|
|
def rprop(params: List[Tensor],
|
|
grads: List[Tensor],
|
|
prevs: List[Tensor],
|
|
step_sizes: List[Tensor],
|
|
*,
|
|
step_size_min: float,
|
|
step_size_max: float,
|
|
etaminus: float,
|
|
etaplus: float):
|
|
r"""Functional API that performs rprop algorithm computation.
|
|
|
|
See :class:`~torch.optim.Rprop` for details.
|
|
"""
|
|
|
|
for i, param in enumerate(params):
|
|
grad = grads[i]
|
|
prev = prevs[i]
|
|
step_size = step_sizes[i]
|
|
|
|
sign = grad.mul(prev).sign()
|
|
sign[sign.gt(0)] = etaplus
|
|
sign[sign.lt(0)] = etaminus
|
|
sign[sign.eq(0)] = 1
|
|
|
|
# update stepsizes with step size updates
|
|
step_size.mul_(sign).clamp_(step_size_min, step_size_max)
|
|
|
|
# for dir<0, dfdx=0
|
|
# for dir>=0 dfdx=dfdx
|
|
grad = grad.clone(memory_format=torch.preserve_format)
|
|
grad[sign.eq(etaminus)] = 0
|
|
|
|
# update parameters
|
|
param.addcmul_(grad.sign(), step_size, value=-1)
|
|
|
|
prev.copy_(grad)
|
|
|
|
|
|
def adamax(params: List[Tensor],
|
|
grads: List[Tensor],
|
|
exp_avgs: List[Tensor],
|
|
exp_infs: List[Tensor],
|
|
state_steps: List[int],
|
|
*,
|
|
eps: float,
|
|
beta1: float,
|
|
beta2: float,
|
|
lr: float,
|
|
weight_decay: float):
|
|
r"""Functional API that performs adamax algorithm computation.
|
|
|
|
See :class:`~torch.optim.Adamax` for details.
|
|
"""
|
|
|
|
for i, param in enumerate(params):
|
|
grad = grads[i]
|
|
exp_avg = exp_avgs[i]
|
|
exp_inf = exp_infs[i]
|
|
step = state_steps[i]
|
|
|
|
if weight_decay != 0:
|
|
grad = grad.add(param, alpha=weight_decay)
|
|
|
|
# Update biased first moment estimate.
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
# Update the exponentially weighted infinity norm.
|
|
norm_buf = torch.cat([
|
|
exp_inf.mul_(beta2).unsqueeze(0),
|
|
grad.abs().add_(eps).unsqueeze_(0)
|
|
], 0)
|
|
torch.amax(norm_buf, 0, keepdim=False, out=exp_inf)
|
|
|
|
bias_correction = 1 - beta1 ** step
|
|
clr = lr / bias_correction
|
|
|
|
param.addcdiv_(exp_avg, exp_inf, value=-clr)
|
|
|
|
|
|
def asgd(params: List[Tensor],
|
|
grads: List[Tensor],
|
|
axs: List[Tensor],
|
|
mus: List[float],
|
|
etas: List[float],
|
|
*,
|
|
weight_decay: float,
|
|
lambd: float):
|
|
r"""Functional API that performs asgd algorithm computation.
|
|
|
|
See :class:`~torch.optim.ASGD` for details.
|
|
"""
|
|
|
|
for i, param in enumerate(params):
|
|
grad = grads[i]
|
|
mu = mus[i]
|
|
ax = axs[i]
|
|
eta = etas[i]
|
|
|
|
if weight_decay != 0:
|
|
grad = grad.add(param, alpha=weight_decay)
|
|
|
|
# decay term
|
|
param.mul_(1 - lambd * eta)
|
|
|
|
# update parameter
|
|
param.add_(grad, alpha=-eta)
|
|
|
|
# averaging
|
|
if mu != 1:
|
|
ax.add_(param.sub(ax).mul(mu))
|
|
else:
|
|
ax.copy_(param)
|
|
|
|
|
|
def nadam(params: List[Tensor],
|
|
grads: List[Tensor],
|
|
exp_avgs: List[Tensor],
|
|
exp_avg_sqs: List[Tensor],
|
|
mu_products: List[float],
|
|
state_steps: List[int],
|
|
*,
|
|
beta1: float,
|
|
beta2: float,
|
|
lr: float,
|
|
weight_decay: float,
|
|
momentum_decay: float,
|
|
eps: float):
|
|
r"""Functional API that performs NAdam algorithm computation.
|
|
|
|
See :class:`~torch.optim.NAdam` for details.
|
|
"""
|
|
|
|
for i, param in enumerate(params):
|
|
grad = grads[i]
|
|
exp_avg = exp_avgs[i]
|
|
exp_avg_sq = exp_avg_sqs[i]
|
|
mu_product = mu_products[i]
|
|
step = state_steps[i]
|
|
|
|
bias_correction2 = 1 - beta2 ** step
|
|
|
|
if weight_decay != 0:
|
|
grad = grad.add(param, alpha=weight_decay)
|
|
|
|
# calculate the momentum cache \mu^{t} and \mu^{t+1}
|
|
mu = beta1 * (1. - 0.5 * (0.96 ** (step * momentum_decay)))
|
|
mu_next = beta1 * (1. - 0.5 * (0.96 ** ((step + 1) * momentum_decay)))
|
|
mu_product = mu_product * mu
|
|
mu_product_next = mu_product * mu * mu_next
|
|
|
|
# decay the first and second moment running average coefficient
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
|
|
denom = exp_avg_sq.div(bias_correction2).sqrt().add_(eps)
|
|
param.addcdiv_(grad, denom, value=-lr * (1. - mu) / (1. - mu_product))
|
|
param.addcdiv_(exp_avg, denom, value=-lr * mu_next / (1. - mu_product_next))
|
|
|
|
|
|
def radam(params: List[Tensor],
|
|
grads: List[Tensor],
|
|
exp_avgs: List[Tensor],
|
|
exp_avg_sqs: List[Tensor],
|
|
state_steps: List[int],
|
|
*,
|
|
beta1: float,
|
|
beta2: float,
|
|
lr: float,
|
|
weight_decay: float,
|
|
eps: float):
|
|
r"""Functional API that performs RAdam algorithm computation.
|
|
|
|
See :class:`~torch.optim.RAdam` for details.
|
|
"""
|
|
|
|
for i, param in enumerate(params):
|
|
grad = grads[i]
|
|
exp_avg = exp_avgs[i]
|
|
exp_avg_sq = exp_avg_sqs[i]
|
|
step = state_steps[i]
|
|
|
|
bias_correction1 = 1 - beta1 ** step
|
|
bias_correction2 = 1 - beta2 ** step
|
|
|
|
if weight_decay != 0:
|
|
grad = grad.add(param, alpha=weight_decay)
|
|
|
|
# Decay the first and second moment running average coefficient
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
|
|
# correcting bias for the first moving moment
|
|
bias_corrected_exp_avg = exp_avg / bias_correction1
|
|
|
|
# maximum length of the approximated SMA
|
|
rho_inf = 2 / (1 - beta2) - 1
|
|
# compute the length of the approximated SMA
|
|
rho_t = rho_inf - 2 * step * (beta2 ** step) / bias_correction2
|
|
|
|
if rho_t > 5.:
|
|
# Compute the variance rectification term and update parameters accordingly
|
|
rect = math.sqrt((rho_t - 4) * (rho_t - 2) * rho_inf / ((rho_inf - 4) * (rho_inf - 2) * rho_t))
|
|
adaptive_lr = math.sqrt(bias_correction2) / exp_avg_sq.sqrt().add_(eps)
|
|
|
|
param.add_(bias_corrected_exp_avg * lr * adaptive_lr * rect, alpha=-1.0)
|
|
else:
|
|
param.add_(bias_corrected_exp_avg * lr, alpha=-1.0)
|
|
|
|
|
|
def sparse_adam(params: List[Tensor],
|
|
grads: List[Tensor],
|
|
exp_avgs: List[Tensor],
|
|
exp_avg_sqs: List[Tensor],
|
|
state_steps: List[int],
|
|
*,
|
|
eps: float,
|
|
beta1: float,
|
|
beta2: float,
|
|
lr: float):
|
|
r"""Functional API that performs Sparse Adam algorithm computation.
|
|
|
|
See :class:`~torch.optim.SparseAdam` for details.
|
|
"""
|
|
for i, param in enumerate(params):
|
|
grad = grads[i]
|
|
grad = grad.coalesce() # the update is non-linear so indices must be unique
|
|
grad_indices = grad._indices()
|
|
grad_values = grad._values()
|
|
size = grad.size()
|
|
|
|
exp_avg = exp_avgs[i]
|
|
exp_avg_sq = exp_avg_sqs[i]
|
|
step = state_steps[i]
|
|
|
|
|
|
def make_sparse(values):
|
|
constructor = grad.new
|
|
if grad_indices.dim() == 0 or values.dim() == 0:
|
|
return constructor().resize_as_(grad)
|
|
return constructor(grad_indices, values, size)
|
|
|
|
# Decay the first and second moment running average coefficient
|
|
# old <- b * old + (1 - b) * new
|
|
# <==> old += (1 - b) * (new - old)
|
|
old_exp_avg_values = exp_avg.sparse_mask(grad)._values()
|
|
exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1)
|
|
exp_avg.add_(make_sparse(exp_avg_update_values))
|
|
old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values()
|
|
exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2)
|
|
exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values))
|
|
|
|
# Dense addition again is intended, avoiding another sparse_mask
|
|
numer = exp_avg_update_values.add_(old_exp_avg_values)
|
|
exp_avg_sq_update_values.add_(old_exp_avg_sq_values)
|
|
denom = exp_avg_sq_update_values.sqrt_().add_(eps)
|
|
del exp_avg_update_values, exp_avg_sq_update_values
|
|
|
|
bias_correction1 = 1 - beta1 ** step
|
|
bias_correction2 = 1 - beta2 ** step
|
|
step_size = lr * math.sqrt(bias_correction2) / bias_correction1
|
|
|
|
param.add_(make_sparse(-step_size * numer.div_(denom)))
|