pytorch/torch/optim/_functional.py
Ilqar Ramazanli 5ed6e4429e To fix variance computation for complex Adam (#62946)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/59998

It has been discussed in the issue that the variance term of Adam optimizer currently doesn't compute correctly for complex domain.  As it has been stated in the Generalization to Complex numbers section  in https://en.wikipedia.org/wiki/Variance variance is computed as E[(X - mu)(X-mu)*] (where mu = E[X] and * stands for conjugate) for complex random variable X.

However, currently the computation method in implementation of Adam is via E[(X - mu)(X-mu)] which doesn't return right variance value, in particular it returns complex number. Variance is defined to be real number even though underlying random variable is complex.

We fix this issue here, and testing that resulting variance is indeed real number.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62946

Reviewed By: albanD

Differential Revision: D30196038

Pulled By: iramazanli

fbshipit-source-id: ab0a6f31658aeb56bdcb211ff86eaa29f3f0d718
2021-08-09 17:54:43 -07:00

510 lines
16 KiB
Python

r"""Functional interface"""
import math
import torch
from torch import Tensor
from typing import List, Optional
# TODO: use foreach API in optim._functional to do all the computation
def _make_sparse(grad, grad_indices, values):
size = grad.size()
if grad_indices.numel() == 0 or values.numel() == 0:
return torch.empty_like(grad)
return torch.sparse_coo_tensor(grad_indices, values, size)
def adagrad(params: List[Tensor],
grads: List[Tensor],
state_sums: List[Tensor],
state_steps: List[int],
*,
lr: float,
weight_decay: float,
lr_decay: float,
eps: float):
r"""Functional API that performs Adagrad algorithm computation.
See :class:`~torch.optim.Adagrad` for details.
"""
for (param, grad, state_sum, step) in zip(params, grads, state_sums, state_steps):
if weight_decay != 0:
if grad.is_sparse:
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
grad = grad.add(param, alpha=weight_decay)
clr = lr / (1 + (step - 1) * lr_decay)
if grad.is_sparse:
grad = grad.coalesce() # the update is non-linear so indices must be unique
grad_indices = grad._indices()
grad_values = grad._values()
size = grad.size()
state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
std = state_sum.sparse_mask(grad)
std_values = std._values().sqrt_().add_(eps)
param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
else:
state_sum.addcmul_(grad, grad, value=1)
std = state_sum.sqrt().add_(eps)
param.addcdiv_(grad, std, value=-clr)
def adam(params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[int],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float):
r"""Functional API that performs Adam algorithm computation.
See :class:`~torch.optim.Adam` for details.
"""
for i, param in enumerate(params):
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps)
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
param.addcdiv_(exp_avg, denom, value=-step_size)
def adamw(params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[int],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float):
r"""Functional API that performs AdamW algorithm computation.
See :class:`~torch.optim.AdamW` for details.
"""
for i, param in enumerate(params):
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
# Perform stepweight decay
param.mul_(1 - lr * weight_decay)
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps)
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
param.addcdiv_(exp_avg, denom, value=-step_size)
def sgd(params: List[Tensor],
d_p_list: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
*,
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool):
r"""Functional API that performs SGD algorithm computation.
See :class:`~torch.optim.SGD` for details.
"""
for i, param in enumerate(params):
d_p = d_p_list[i]
if weight_decay != 0:
d_p = d_p.add(param, alpha=weight_decay)
if momentum != 0:
buf = momentum_buffer_list[i]
if buf is None:
buf = torch.clone(d_p).detach()
momentum_buffer_list[i] = buf
else:
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
d_p = d_p.add(buf, alpha=momentum)
else:
d_p = buf
param.add_(d_p, alpha=-lr)
def adadelta(params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
acc_deltas: List[Tensor],
*,
lr: float,
rho: float,
eps: float,
weight_decay: float):
r"""Functional API that performs Adadelta algorithm computation.
See :class:`~torch.optim.Adadelta` for details.
"""
for (param, grad, square_avg, acc_delta) in zip(params, grads, square_avgs, acc_deltas):
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho)
std = square_avg.add(eps).sqrt_()
delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad)
param.add_(delta, alpha=-lr)
acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho)
def rmsprop(params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
grad_avgs: List[Tensor],
momentum_buffer_list: List[Tensor],
*,
lr: float,
alpha: float,
eps: float,
weight_decay: float,
momentum: float,
centered: bool):
r"""Functional API that performs rmsprop algorithm computation.
See :class:`~torch.optim.RMSProp` for details.
"""
for i, param in enumerate(params):
grad = grads[i]
square_avg = square_avgs[i]
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
if centered:
grad_avg = grad_avgs[i]
grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha)
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(eps)
else:
avg = square_avg.sqrt().add_(eps)
if momentum > 0:
buf = momentum_buffer_list[i]
buf.mul_(momentum).addcdiv_(grad, avg)
param.add_(buf, alpha=-lr)
else:
param.addcdiv_(grad, avg, value=-lr)
def rprop(params: List[Tensor],
grads: List[Tensor],
prevs: List[Tensor],
step_sizes: List[Tensor],
*,
step_size_min: float,
step_size_max: float,
etaminus: float,
etaplus: float):
r"""Functional API that performs rprop algorithm computation.
See :class:`~torch.optim.Rprop` for details.
"""
for i, param in enumerate(params):
grad = grads[i]
prev = prevs[i]
step_size = step_sizes[i]
sign = grad.mul(prev).sign()
sign[sign.gt(0)] = etaplus
sign[sign.lt(0)] = etaminus
sign[sign.eq(0)] = 1
# update stepsizes with step size updates
step_size.mul_(sign).clamp_(step_size_min, step_size_max)
# for dir<0, dfdx=0
# for dir>=0 dfdx=dfdx
grad = grad.clone(memory_format=torch.preserve_format)
grad[sign.eq(etaminus)] = 0
# update parameters
param.addcmul_(grad.sign(), step_size, value=-1)
prev.copy_(grad)
def adamax(params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_infs: List[Tensor],
state_steps: List[int],
*,
eps: float,
beta1: float,
beta2: float,
lr: float,
weight_decay: float):
r"""Functional API that performs adamax algorithm computation.
See :class:`~torch.optim.Adamax` for details.
"""
for i, param in enumerate(params):
grad = grads[i]
exp_avg = exp_avgs[i]
exp_inf = exp_infs[i]
step = state_steps[i]
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
# Update biased first moment estimate.
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
# Update the exponentially weighted infinity norm.
norm_buf = torch.cat([
exp_inf.mul_(beta2).unsqueeze(0),
grad.abs().add_(eps).unsqueeze_(0)
], 0)
torch.amax(norm_buf, 0, keepdim=False, out=exp_inf)
bias_correction = 1 - beta1 ** step
clr = lr / bias_correction
param.addcdiv_(exp_avg, exp_inf, value=-clr)
def asgd(params: List[Tensor],
grads: List[Tensor],
axs: List[Tensor],
mus: List[float],
etas: List[float],
*,
weight_decay: float,
lambd: float):
r"""Functional API that performs asgd algorithm computation.
See :class:`~torch.optim.ASGD` for details.
"""
for i, param in enumerate(params):
grad = grads[i]
mu = mus[i]
ax = axs[i]
eta = etas[i]
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
# decay term
param.mul_(1 - lambd * eta)
# update parameter
param.add_(grad, alpha=-eta)
# averaging
if mu != 1:
ax.add_(param.sub(ax).mul(mu))
else:
ax.copy_(param)
def nadam(params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
mu_products: List[float],
state_steps: List[int],
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
momentum_decay: float,
eps: float):
r"""Functional API that performs NAdam algorithm computation.
See :class:`~torch.optim.NAdam` for details.
"""
for i, param in enumerate(params):
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
mu_product = mu_products[i]
step = state_steps[i]
bias_correction2 = 1 - beta2 ** step
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
# calculate the momentum cache \mu^{t} and \mu^{t+1}
mu = beta1 * (1. - 0.5 * (0.96 ** (step * momentum_decay)))
mu_next = beta1 * (1. - 0.5 * (0.96 ** ((step + 1) * momentum_decay)))
mu_product = mu_product * mu
mu_product_next = mu_product * mu * mu_next
# decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = exp_avg_sq.div(bias_correction2).sqrt().add_(eps)
param.addcdiv_(grad, denom, value=-lr * (1. - mu) / (1. - mu_product))
param.addcdiv_(exp_avg, denom, value=-lr * mu_next / (1. - mu_product_next))
def radam(params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[int],
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float):
r"""Functional API that performs RAdam algorithm computation.
See :class:`~torch.optim.RAdam` for details.
"""
for i, param in enumerate(params):
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# correcting bias for the first moving moment
bias_corrected_exp_avg = exp_avg / bias_correction1
# maximum length of the approximated SMA
rho_inf = 2 / (1 - beta2) - 1
# compute the length of the approximated SMA
rho_t = rho_inf - 2 * step * (beta2 ** step) / bias_correction2
if rho_t > 5.:
# Compute the variance rectification term and update parameters accordingly
rect = math.sqrt((rho_t - 4) * (rho_t - 2) * rho_inf / ((rho_inf - 4) * (rho_inf - 2) * rho_t))
adaptive_lr = math.sqrt(bias_correction2) / exp_avg_sq.sqrt().add_(eps)
param.add_(bias_corrected_exp_avg * lr * adaptive_lr * rect, alpha=-1.0)
else:
param.add_(bias_corrected_exp_avg * lr, alpha=-1.0)
def sparse_adam(params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[int],
*,
eps: float,
beta1: float,
beta2: float,
lr: float):
r"""Functional API that performs Sparse Adam algorithm computation.
See :class:`~torch.optim.SparseAdam` for details.
"""
for i, param in enumerate(params):
grad = grads[i]
grad = grad.coalesce() # the update is non-linear so indices must be unique
grad_indices = grad._indices()
grad_values = grad._values()
size = grad.size()
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
def make_sparse(values):
constructor = grad.new
if grad_indices.dim() == 0 or values.dim() == 0:
return constructor().resize_as_(grad)
return constructor(grad_indices, values, size)
# Decay the first and second moment running average coefficient
# old <- b * old + (1 - b) * new
# <==> old += (1 - b) * (new - old)
old_exp_avg_values = exp_avg.sparse_mask(grad)._values()
exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1)
exp_avg.add_(make_sparse(exp_avg_update_values))
old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values()
exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2)
exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values))
# Dense addition again is intended, avoiding another sparse_mask
numer = exp_avg_update_values.add_(old_exp_avg_values)
exp_avg_sq_update_values.add_(old_exp_avg_sq_values)
denom = exp_avg_sq_update_values.sqrt_().add_(eps)
del exp_avg_update_values, exp_avg_sq_update_values
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
step_size = lr * math.sqrt(bias_correction2) / bias_correction1
param.add_(make_sparse(-step_size * numer.div_(denom)))