mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
To refactor multi tensor RMSprop to functional API
ghstack-source-id: 4e483b98f3
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60964
This commit is contained in:
parent
360e89fdf9
commit
b31031c562
|
|
@ -393,3 +393,45 @@ def rprop(params: List[Tensor],
|
||||||
|
|
||||||
for i in range(len(states)):
|
for i in range(len(states)):
|
||||||
states[i]['prev'].copy_(grads[i]) # type: ignore[index]
|
states[i]['prev'].copy_(grads[i]) # type: ignore[index]
|
||||||
|
|
||||||
|
|
||||||
|
def rmsprop(params: List[Tensor],
|
||||||
|
grads: List[Tensor],
|
||||||
|
states: List[Tensor],
|
||||||
|
square_avg: List[Tensor],
|
||||||
|
*,
|
||||||
|
lr: float,
|
||||||
|
alpha: float,
|
||||||
|
eps: float,
|
||||||
|
weight_decay: float,
|
||||||
|
momentum: float,
|
||||||
|
centered: float):
|
||||||
|
r"""Functional API that performs RMSprop algorithm computation.
|
||||||
|
|
||||||
|
See :class:`~torch.optim.RMSprop` for details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if weight_decay != 0:
|
||||||
|
torch._foreach_add_(grads, params, alpha=weight_decay) # type: ignore[name-defined]
|
||||||
|
|
||||||
|
torch._foreach_mul_(square_avg, alpha)
|
||||||
|
torch._foreach_addcmul_(square_avg, grads, grads, value=1 - alpha)
|
||||||
|
|
||||||
|
if centered:
|
||||||
|
grad_avgs = [s['grad_avg'] for s in states] # type: ignore[index]
|
||||||
|
torch._foreach_mul_(grad_avgs, alpha)
|
||||||
|
torch._foreach_add_(grad_avgs, grads, alpha=1 - alpha)
|
||||||
|
avg = torch._foreach_addcmul(square_avg, grad_avgs, grad_avgs, value=-1)
|
||||||
|
torch._foreach_sqrt_(avg)
|
||||||
|
torch._foreach_add_(avg, eps)
|
||||||
|
else:
|
||||||
|
avg = torch._foreach_sqrt(square_avg)
|
||||||
|
torch._foreach_add_(avg, eps)
|
||||||
|
|
||||||
|
if momentum > 0:
|
||||||
|
buf = [s['momentum_buffer'] for s in states] # type: ignore[index]
|
||||||
|
torch._foreach_mul_(buf, momentum)
|
||||||
|
torch._foreach_addcdiv_(buf, grads, avg)
|
||||||
|
torch._foreach_add_(params, buf, alpha=-lr) # type: ignore[name-defined]
|
||||||
|
else:
|
||||||
|
torch._foreach_addcdiv_(params, grads, avg, value=-lr) # type: ignore[name-defined]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
from . import _functional as F
|
||||||
from ..optimizer import Optimizer
|
from ..optimizer import Optimizer
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
@ -95,30 +96,16 @@ class RMSprop(Optimizer):
|
||||||
states.append(state)
|
states.append(state)
|
||||||
square_avg.append(state['square_avg'])
|
square_avg.append(state['square_avg'])
|
||||||
|
|
||||||
if group['weight_decay'] != 0:
|
F.rmsprop(params_with_grad,
|
||||||
torch._foreach_add_(grads, params_with_grad, alpha=group['weight_decay'])
|
grads,
|
||||||
|
states,
|
||||||
torch._foreach_mul_(square_avg, alpha)
|
square_avg,
|
||||||
torch._foreach_addcmul_(square_avg, grads, grads, value=1 - alpha)
|
lr=group['lr'],
|
||||||
|
alpha=group['alpha'],
|
||||||
if group['centered']:
|
eps=group['eps'],
|
||||||
grad_avgs = [s['grad_avg'] for s in states]
|
weight_decay=group['weight_decay'],
|
||||||
torch._foreach_mul_(grad_avgs, alpha)
|
momentum=group['momentum'],
|
||||||
torch._foreach_add_(grad_avgs, grads, alpha=1 - alpha)
|
centered=group['centered'])
|
||||||
avg = torch._foreach_addcmul(square_avg, grad_avgs, grad_avgs, value=-1)
|
|
||||||
torch._foreach_sqrt_(avg)
|
|
||||||
torch._foreach_add_(avg, group['eps'])
|
|
||||||
else:
|
|
||||||
avg = torch._foreach_sqrt(square_avg)
|
|
||||||
torch._foreach_add_(avg, group['eps'])
|
|
||||||
|
|
||||||
if group['momentum'] > 0:
|
|
||||||
buf = [s['momentum_buffer'] for s in states]
|
|
||||||
torch._foreach_mul_(buf, group['momentum'])
|
|
||||||
torch._foreach_addcdiv_(buf, grads, avg)
|
|
||||||
torch._foreach_add_(params_with_grad, buf, alpha=-group['lr'])
|
|
||||||
else:
|
|
||||||
torch._foreach_addcdiv_(params_with_grad, grads, avg, value=-group['lr'])
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user