mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
To refactor multi tensor RMSprop to functional API
ghstack-source-id: 4e483b98f3
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60964
This commit is contained in:
parent
360e89fdf9
commit
b31031c562
|
|
@ -393,3 +393,45 @@ def rprop(params: List[Tensor],
|
|||
|
||||
for i in range(len(states)):
|
||||
states[i]['prev'].copy_(grads[i]) # type: ignore[index]
|
||||
|
||||
|
||||
def rmsprop(params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
states: List[Tensor],
|
||||
square_avg: List[Tensor],
|
||||
*,
|
||||
lr: float,
|
||||
alpha: float,
|
||||
eps: float,
|
||||
weight_decay: float,
|
||||
momentum: float,
|
||||
centered: float):
|
||||
r"""Functional API that performs RMSprop algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.RMSprop` for details.
|
||||
"""
|
||||
|
||||
if weight_decay != 0:
|
||||
torch._foreach_add_(grads, params, alpha=weight_decay) # type: ignore[name-defined]
|
||||
|
||||
torch._foreach_mul_(square_avg, alpha)
|
||||
torch._foreach_addcmul_(square_avg, grads, grads, value=1 - alpha)
|
||||
|
||||
if centered:
|
||||
grad_avgs = [s['grad_avg'] for s in states] # type: ignore[index]
|
||||
torch._foreach_mul_(grad_avgs, alpha)
|
||||
torch._foreach_add_(grad_avgs, grads, alpha=1 - alpha)
|
||||
avg = torch._foreach_addcmul(square_avg, grad_avgs, grad_avgs, value=-1)
|
||||
torch._foreach_sqrt_(avg)
|
||||
torch._foreach_add_(avg, eps)
|
||||
else:
|
||||
avg = torch._foreach_sqrt(square_avg)
|
||||
torch._foreach_add_(avg, eps)
|
||||
|
||||
if momentum > 0:
|
||||
buf = [s['momentum_buffer'] for s in states] # type: ignore[index]
|
||||
torch._foreach_mul_(buf, momentum)
|
||||
torch._foreach_addcdiv_(buf, grads, avg)
|
||||
torch._foreach_add_(params, buf, alpha=-lr) # type: ignore[name-defined]
|
||||
else:
|
||||
torch._foreach_addcdiv_(params, grads, avg, value=-lr) # type: ignore[name-defined]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
from . import _functional as F
|
||||
from ..optimizer import Optimizer
|
||||
from collections import defaultdict
|
||||
|
||||
|
|
@ -95,30 +96,16 @@ class RMSprop(Optimizer):
|
|||
states.append(state)
|
||||
square_avg.append(state['square_avg'])
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
torch._foreach_add_(grads, params_with_grad, alpha=group['weight_decay'])
|
||||
|
||||
torch._foreach_mul_(square_avg, alpha)
|
||||
torch._foreach_addcmul_(square_avg, grads, grads, value=1 - alpha)
|
||||
|
||||
if group['centered']:
|
||||
grad_avgs = [s['grad_avg'] for s in states]
|
||||
torch._foreach_mul_(grad_avgs, alpha)
|
||||
torch._foreach_add_(grad_avgs, grads, alpha=1 - alpha)
|
||||
avg = torch._foreach_addcmul(square_avg, grad_avgs, grad_avgs, value=-1)
|
||||
torch._foreach_sqrt_(avg)
|
||||
torch._foreach_add_(avg, group['eps'])
|
||||
else:
|
||||
avg = torch._foreach_sqrt(square_avg)
|
||||
torch._foreach_add_(avg, group['eps'])
|
||||
|
||||
if group['momentum'] > 0:
|
||||
buf = [s['momentum_buffer'] for s in states]
|
||||
torch._foreach_mul_(buf, group['momentum'])
|
||||
torch._foreach_addcdiv_(buf, grads, avg)
|
||||
torch._foreach_add_(params_with_grad, buf, alpha=-group['lr'])
|
||||
else:
|
||||
torch._foreach_addcdiv_(params_with_grad, grads, avg, value=-group['lr'])
|
||||
F.rmsprop(params_with_grad,
|
||||
grads,
|
||||
states,
|
||||
square_avg,
|
||||
lr=group['lr'],
|
||||
alpha=group['alpha'],
|
||||
eps=group['eps'],
|
||||
weight_decay=group['weight_decay'],
|
||||
momentum=group['momentum'],
|
||||
centered=group['centered'])
|
||||
|
||||
return loss
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user