pytorch/torch/optim/functional.py
Wanchao Liang 0444c372e1 [optimizer] introduce optimizer functional API, refactor Adagrad (#44715)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44715

We have provided a nice and intuitive API in Python. But in the context of large scale distributed training (e.g. Distributed Model Parallel), users often want to use multithreaded training instead of multiprocess training as it provides better resource utilization and efficiency.

This PR introduces functional optimizer concept (that is similar to the concept of `nn.functional`), we split optimizer into two parts: 1. optimizer state management 2. optimizer computation. We expose the computation part as a separate functional API that is available to be used by internal and OSS developers, the caller of the functional API will maintain their own states in order to directly calls the functional API. While maintaining the end user API be the same, the functional API is TorchScript friendly, and could be used by the distributed optimizer to speed up the training without GIL.

Test Plan: Imported from OSS

Reviewed By: ailzhang

Differential Revision: D23935258

Pulled By: wanchaol

fbshipit-source-id: d2a5228439edb3bc64f7771af2bb9e891847136a
2020-09-25 17:10:26 -07:00

47 lines
1.7 KiB
Python

r"""Functional interface"""
from torch import Tensor
from typing import List
# TODO: use foreach API in optim.functional to do all the computation
def adagrad(params: List[Tensor],
grads: List[Tensor],
state_sums: List[Tensor],
state_steps: List[int],
lr: float,
weight_decay: float,
lr_decay: float,
eps: float):
r"""Functional API that performs Adagrad algorithm computation.
See :class:`~torch.optim.Adagrad` for details.
"""
for (param, grad, state_sum, step) in zip(params, grads, state_sums, state_steps):
if weight_decay != 0:
if grad.is_sparse:
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
grad = grad.add(param, alpha=weight_decay)
clr = lr / (1 + (step - 1) * lr_decay)
if grad.is_sparse:
grad = grad.coalesce() # the update is non-linear so indices must be unique
grad_indices = grad._indices()
grad_values = grad._values()
size = grad.size()
def make_sparse(values):
constructor = grad.new
if grad_indices.dim() == 0 or values.dim() == 0:
return constructor().resize_as_(grad)
return constructor(grad_indices, values, size)
state_sum.add_(make_sparse(grad_values.pow(2)))
std = state_sum.sparse_mask(grad)
std_values = std._values().sqrt_().add_(eps)
param.add_(make_sparse(grad_values / std_values), alpha=-clr)
else:
state_sum.addcmul_(grad, grad, value=1)
std = state_sum.sqrt().add_(eps)
param.addcdiv_(grad, std, value=-clr)