To refactor multi tensor RMSprop to functional API

ghstack-source-id: 4e483b98f3
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60964
This commit is contained in:
ilqar ramazanli 2021-06-30 16:10:31 +00:00
parent 360e89fdf9
commit b31031c562
2 changed files with 53 additions and 24 deletions

View File

@ -393,3 +393,45 @@ def rprop(params: List[Tensor],
for i in range(len(states)):
states[i]['prev'].copy_(grads[i]) # type: ignore[index]
def rmsprop(params: List[Tensor],
grads: List[Tensor],
states: List[Tensor],
square_avg: List[Tensor],
*,
lr: float,
alpha: float,
eps: float,
weight_decay: float,
momentum: float,
centered: float):
r"""Functional API that performs RMSprop algorithm computation.
See :class:`~torch.optim.RMSprop` for details.
"""
if weight_decay != 0:
torch._foreach_add_(grads, params, alpha=weight_decay) # type: ignore[name-defined]
torch._foreach_mul_(square_avg, alpha)
torch._foreach_addcmul_(square_avg, grads, grads, value=1 - alpha)
if centered:
grad_avgs = [s['grad_avg'] for s in states] # type: ignore[index]
torch._foreach_mul_(grad_avgs, alpha)
torch._foreach_add_(grad_avgs, grads, alpha=1 - alpha)
avg = torch._foreach_addcmul(square_avg, grad_avgs, grad_avgs, value=-1)
torch._foreach_sqrt_(avg)
torch._foreach_add_(avg, eps)
else:
avg = torch._foreach_sqrt(square_avg)
torch._foreach_add_(avg, eps)
if momentum > 0:
buf = [s['momentum_buffer'] for s in states] # type: ignore[index]
torch._foreach_mul_(buf, momentum)
torch._foreach_addcdiv_(buf, grads, avg)
torch._foreach_add_(params, buf, alpha=-lr) # type: ignore[name-defined]
else:
torch._foreach_addcdiv_(params, grads, avg, value=-lr) # type: ignore[name-defined]

View File

@ -1,4 +1,5 @@
import torch
from . import _functional as F
from ..optimizer import Optimizer
from collections import defaultdict
@ -95,30 +96,16 @@ class RMSprop(Optimizer):
states.append(state)
square_avg.append(state['square_avg'])
if group['weight_decay'] != 0:
torch._foreach_add_(grads, params_with_grad, alpha=group['weight_decay'])
torch._foreach_mul_(square_avg, alpha)
torch._foreach_addcmul_(square_avg, grads, grads, value=1 - alpha)
if group['centered']:
grad_avgs = [s['grad_avg'] for s in states]
torch._foreach_mul_(grad_avgs, alpha)
torch._foreach_add_(grad_avgs, grads, alpha=1 - alpha)
avg = torch._foreach_addcmul(square_avg, grad_avgs, grad_avgs, value=-1)
torch._foreach_sqrt_(avg)
torch._foreach_add_(avg, group['eps'])
else:
avg = torch._foreach_sqrt(square_avg)
torch._foreach_add_(avg, group['eps'])
if group['momentum'] > 0:
buf = [s['momentum_buffer'] for s in states]
torch._foreach_mul_(buf, group['momentum'])
torch._foreach_addcdiv_(buf, grads, avg)
torch._foreach_add_(params_with_grad, buf, alpha=-group['lr'])
else:
torch._foreach_addcdiv_(params_with_grad, grads, avg, value=-group['lr'])
F.rmsprop(params_with_grad,
grads,
states,
square_avg,
lr=group['lr'],
alpha=group['alpha'],
eps=group['eps'],
weight_decay=group['weight_decay'],
momentum=group['momentum'],
centered=group['centered'])
return loss