mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44715 We have provided a nice and intuitive API in Python. But in the context of large scale distributed training (e.g. Distributed Model Parallel), users often want to use multithreaded training instead of multiprocess training as it provides better resource utilization and efficiency. This PR introduces functional optimizer concept (that is similar to the concept of `nn.functional`), we split optimizer into two parts: 1. optimizer state management 2. optimizer computation. We expose the computation part as a separate functional API that is available to be used by internal and OSS developers, the caller of the functional API will maintain their own states in order to directly calls the functional API. While maintaining the end user API be the same, the functional API is TorchScript friendly, and could be used by the distributed optimizer to speed up the training without GIL. Test Plan: Imported from OSS Reviewed By: ailzhang Differential Revision: D23935258 Pulled By: wanchaol fbshipit-source-id: d2a5228439edb3bc64f7771af2bb9e891847136a
93 lines
3.5 KiB
Python
93 lines
3.5 KiB
Python
import torch
|
|
from . import functional as F
|
|
from .optimizer import Optimizer
|
|
|
|
|
|
class Adagrad(Optimizer):
|
|
"""Implements Adagrad algorithm.
|
|
|
|
It has been proposed in `Adaptive Subgradient Methods for Online Learning
|
|
and Stochastic Optimization`_.
|
|
|
|
Arguments:
|
|
params (iterable): iterable of parameters to optimize or dicts defining
|
|
parameter groups
|
|
lr (float, optional): learning rate (default: 1e-2)
|
|
lr_decay (float, optional): learning rate decay (default: 0)
|
|
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
|
eps (float, optional): term added to the denominator to improve
|
|
numerical stability (default: 1e-10)
|
|
|
|
.. _Adaptive Subgradient Methods for Online Learning and Stochastic
|
|
Optimization: http://jmlr.org/papers/v12/duchi11a.html
|
|
"""
|
|
|
|
def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10):
|
|
if not 0.0 <= lr:
|
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
if not 0.0 <= lr_decay:
|
|
raise ValueError("Invalid lr_decay value: {}".format(lr_decay))
|
|
if not 0.0 <= weight_decay:
|
|
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
|
if not 0.0 <= initial_accumulator_value:
|
|
raise ValueError("Invalid initial_accumulator_value value: {}".format(initial_accumulator_value))
|
|
if not 0.0 <= eps:
|
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
|
|
|
defaults = dict(lr=lr, lr_decay=lr_decay, eps=eps, weight_decay=weight_decay,
|
|
initial_accumulator_value=initial_accumulator_value)
|
|
super(Adagrad, self).__init__(params, defaults)
|
|
|
|
for group in self.param_groups:
|
|
for p in group['params']:
|
|
state = self.state[p]
|
|
state['step'] = 0
|
|
state['sum'] = torch.full_like(p, initial_accumulator_value, memory_format=torch.preserve_format)
|
|
|
|
def share_memory(self):
|
|
for group in self.param_groups:
|
|
for p in group['params']:
|
|
state = self.state[p]
|
|
state['sum'].share_memory_()
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure=None):
|
|
"""Performs a single optimization step.
|
|
|
|
Arguments:
|
|
closure (callable, optional): A closure that reevaluates the model
|
|
and returns the loss.
|
|
"""
|
|
loss = None
|
|
if closure is not None:
|
|
with torch.enable_grad():
|
|
loss = closure()
|
|
|
|
for group in self.param_groups:
|
|
params_with_grad = []
|
|
grads = []
|
|
state_sums = []
|
|
state_steps = []
|
|
|
|
for p in group['params']:
|
|
if p.grad is not None:
|
|
params_with_grad.append(p)
|
|
grads.append(p.grad)
|
|
state = self.state[p]
|
|
state_sums.append(state['sum'])
|
|
# update the steps for each param group update
|
|
state['step'] += 1
|
|
# record the step after step update
|
|
state_steps.append(state['step'])
|
|
|
|
F.adagrad(params_with_grad,
|
|
grads,
|
|
state_sums,
|
|
state_steps,
|
|
group['lr'],
|
|
group['weight_decay'],
|
|
group['lr_decay'],
|
|
group['eps'])
|
|
|
|
return loss
|