mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
To refactor multi tensor Rprop to functional API
ghstack-source-id: 0be0fb02a0
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60963
This commit is contained in:
parent
8098c3822b
commit
360e89fdf9
|
|
@ -354,3 +354,42 @@ def sgd(params: List[Tensor],
|
|||
# foreach APIs dont support sparse
|
||||
for i in range(len(params)):
|
||||
params[i].add_(grads[i], alpha=-lr)
|
||||
|
||||
|
||||
def rprop(params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
states: List[Tensor],
|
||||
step_sizes: List[int],
|
||||
*,
|
||||
step_size_max: float,
|
||||
step_size_min: float,
|
||||
etaminus: float,
|
||||
etaplus: float):
|
||||
r"""Functional API that performs Rprop algorithm computation.
|
||||
See :class:`~torch.optim.Rprop` for details.
|
||||
"""
|
||||
|
||||
signs = torch._foreach_mul(grads, [s['prev'] for s in states]) # type: ignore[misc, index]
|
||||
signs = [s.sign() for s in signs]
|
||||
for sign in signs:
|
||||
sign[sign.gt(0)] = etaplus
|
||||
sign[sign.lt(0)] = etaminus
|
||||
sign[sign.eq(0)] = 1
|
||||
|
||||
# update stepsizes with step size updates
|
||||
torch._foreach_mul_(step_sizes, signs) # type: ignore[arg-type]
|
||||
for step_size in step_sizes:
|
||||
step_size.clamp_(step_size_min, step_size_max) # type: ignore[attr-defined]
|
||||
|
||||
# for dir<0, dfdx=0
|
||||
# for dir>=0 dfdx=dfdx
|
||||
for i in range(len(grads)):
|
||||
grads[i] = grads[i].clone(memory_format=torch.preserve_format)
|
||||
grads[i][signs[i].eq(etaminus)] = 0
|
||||
|
||||
# update parameters
|
||||
grad_signs = [grad.sign() for grad in grads]
|
||||
torch._foreach_addcmul_(params, grad_signs, step_sizes, value=-1) # type: ignore[name-defined, arg-type]
|
||||
|
||||
for i in range(len(states)):
|
||||
states[i]['prev'].copy_(grads[i]) # type: ignore[index]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
from . import _functional as F
|
||||
from ..optimizer import Optimizer
|
||||
from collections import defaultdict
|
||||
|
||||
|
|
@ -38,16 +39,15 @@ class Rprop(Optimizer):
|
|||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
grads = []
|
||||
states = []
|
||||
params_with_grad = []
|
||||
step_sizes = []
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
etaminus, etaplus = group['etas']
|
||||
step_size_min, step_size_max = group['step_sizes']
|
||||
grads = []
|
||||
states = []
|
||||
params_with_grad = []
|
||||
step_sizes = []
|
||||
etaminus, etaplus = group['etas']
|
||||
step_size_min, step_size_max = group['step_sizes']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is not None:
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError('RMSprop does not support sparse gradients')
|
||||
|
|
@ -67,30 +67,14 @@ class Rprop(Optimizer):
|
|||
states.append(state)
|
||||
step_sizes.append(state['step_size'])
|
||||
|
||||
signs = torch._foreach_mul(grads, [s['prev'] for s in states])
|
||||
signs = [s.sign() for s in signs]
|
||||
for sign in signs:
|
||||
sign[sign.gt(0)] = etaplus
|
||||
sign[sign.lt(0)] = etaminus
|
||||
sign[sign.eq(0)] = 1
|
||||
|
||||
# update stepsizes with step size updates
|
||||
torch._foreach_mul_(step_sizes, signs)
|
||||
for step_size in step_sizes:
|
||||
step_size.clamp_(step_size_min, step_size_max)
|
||||
|
||||
# for dir<0, dfdx=0
|
||||
# for dir>=0 dfdx=dfdx
|
||||
for i in range(len(grads)):
|
||||
grads[i] = grads[i].clone(memory_format=torch.preserve_format)
|
||||
grads[i][signs[i].eq(etaminus)] = 0
|
||||
|
||||
# update parameters
|
||||
grad_signs = [grad.sign() for grad in grads]
|
||||
torch._foreach_addcmul_(params_with_grad, grad_signs, step_sizes, value=-1)
|
||||
|
||||
for i in range(len(states)):
|
||||
states[i]['prev'].copy_(grads[i])
|
||||
F.rprop(params_with_grad,
|
||||
grads,
|
||||
states,
|
||||
step_sizes,
|
||||
step_size_max=step_size_max,
|
||||
step_size_min=step_size_min,
|
||||
etaminus=etaminus,
|
||||
etaplus=etaplus)
|
||||
|
||||
return loss
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user