mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45221 This PR introduces a distributed functional optimizer, so that distributed optimizer can reuse the functional optimizer APIs and maintain their own states. This could enable the torchscript compatible functional optimizer when using distributed optimizer, helps getting rid of GIL and improve overall performance of training, especially distributed model parallel training Test Plan: Imported from OSS Reviewed By: ailzhang Differential Revision: D23935256 Pulled By: wanchaol fbshipit-source-id: 59b6d77ff4693ab24a6e1cbb6740bcf614cc624a
99 lines
3.3 KiB
Python
99 lines
3.3 KiB
Python
r"""Functional interface"""
|
|
import math
|
|
import torch
|
|
from torch import Tensor
|
|
from typing import List
|
|
|
|
# TODO: use foreach API in optim.functional to do all the computation
|
|
|
|
def _make_sparse(grad, grad_indices, values):
|
|
size = grad.size()
|
|
if grad_indices.numel() == 0 or values.numel() == 0:
|
|
return torch.empty_like(grad)
|
|
return torch.sparse_coo_tensor(grad_indices, values, size)
|
|
|
|
|
|
def adagrad(params: List[Tensor],
|
|
grads: List[Tensor],
|
|
state_sums: List[Tensor],
|
|
state_steps: List[int],
|
|
lr: float,
|
|
weight_decay: float,
|
|
lr_decay: float,
|
|
eps: float):
|
|
r"""Functional API that performs Adagrad algorithm computation.
|
|
|
|
See :class:`~torch.optim.Adagrad` for details.
|
|
"""
|
|
|
|
for (param, grad, state_sum, step) in zip(params, grads, state_sums, state_steps):
|
|
if weight_decay != 0:
|
|
if grad.is_sparse:
|
|
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
|
|
grad = grad.add(param, alpha=weight_decay)
|
|
|
|
clr = lr / (1 + (step - 1) * lr_decay)
|
|
|
|
if grad.is_sparse:
|
|
grad = grad.coalesce() # the update is non-linear so indices must be unique
|
|
grad_indices = grad._indices()
|
|
grad_values = grad._values()
|
|
size = grad.size()
|
|
|
|
state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
|
|
std = state_sum.sparse_mask(grad)
|
|
std_values = std._values().sqrt_().add_(eps)
|
|
param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
|
|
else:
|
|
state_sum.addcmul_(grad, grad, value=1)
|
|
std = state_sum.sqrt().add_(eps)
|
|
param.addcdiv_(grad, std, value=-clr)
|
|
|
|
|
|
def adam(params: List[Tensor],
|
|
grads: List[Tensor],
|
|
exp_avgs: List[Tensor],
|
|
exp_avg_sqs: List[Tensor],
|
|
max_exp_avg_sqs: List[Tensor],
|
|
state_steps: List[int],
|
|
amsgrad: bool,
|
|
beta1: float,
|
|
beta2: float,
|
|
lr: float,
|
|
weight_decay: float,
|
|
eps: float):
|
|
r"""Functional API that performs Adam algorithm computation.
|
|
|
|
See :class:`~torch.optim.Adam` for details.
|
|
"""
|
|
|
|
for i, param in enumerate(params):
|
|
|
|
grad = grads[i]
|
|
exp_avg = exp_avgs[i]
|
|
exp_avg_sq = exp_avg_sqs[i]
|
|
step = state_steps[i]
|
|
if amsgrad:
|
|
max_exp_avg_sq = max_exp_avg_sqs[i]
|
|
|
|
bias_correction1 = 1 - beta1 ** step
|
|
bias_correction2 = 1 - beta2 ** step
|
|
|
|
if weight_decay != 0:
|
|
grad = grad.add(param, alpha=weight_decay)
|
|
|
|
# Decay the first and second moment running average coefficient
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
if amsgrad:
|
|
# Maintains the maximum of all 2nd moment running avg. till now
|
|
torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
|
# Use the max. for normalizing running avg. of gradient
|
|
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
|
else:
|
|
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
|
|
|
step_size = lr / bias_correction1
|
|
|
|
param.addcdiv_(exp_avg, denom, value=-step_size)
|