mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
To refactor multi tensor Rprop to functional API
ghstack-source-id: 0be0fb02a0
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60963
This commit is contained in:
parent
8098c3822b
commit
360e89fdf9
|
|
@ -354,3 +354,42 @@ def sgd(params: List[Tensor],
|
||||||
# foreach APIs dont support sparse
|
# foreach APIs dont support sparse
|
||||||
for i in range(len(params)):
|
for i in range(len(params)):
|
||||||
params[i].add_(grads[i], alpha=-lr)
|
params[i].add_(grads[i], alpha=-lr)
|
||||||
|
|
||||||
|
|
||||||
|
def rprop(params: List[Tensor],
|
||||||
|
grads: List[Tensor],
|
||||||
|
states: List[Tensor],
|
||||||
|
step_sizes: List[int],
|
||||||
|
*,
|
||||||
|
step_size_max: float,
|
||||||
|
step_size_min: float,
|
||||||
|
etaminus: float,
|
||||||
|
etaplus: float):
|
||||||
|
r"""Functional API that performs Rprop algorithm computation.
|
||||||
|
See :class:`~torch.optim.Rprop` for details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
signs = torch._foreach_mul(grads, [s['prev'] for s in states]) # type: ignore[misc, index]
|
||||||
|
signs = [s.sign() for s in signs]
|
||||||
|
for sign in signs:
|
||||||
|
sign[sign.gt(0)] = etaplus
|
||||||
|
sign[sign.lt(0)] = etaminus
|
||||||
|
sign[sign.eq(0)] = 1
|
||||||
|
|
||||||
|
# update stepsizes with step size updates
|
||||||
|
torch._foreach_mul_(step_sizes, signs) # type: ignore[arg-type]
|
||||||
|
for step_size in step_sizes:
|
||||||
|
step_size.clamp_(step_size_min, step_size_max) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# for dir<0, dfdx=0
|
||||||
|
# for dir>=0 dfdx=dfdx
|
||||||
|
for i in range(len(grads)):
|
||||||
|
grads[i] = grads[i].clone(memory_format=torch.preserve_format)
|
||||||
|
grads[i][signs[i].eq(etaminus)] = 0
|
||||||
|
|
||||||
|
# update parameters
|
||||||
|
grad_signs = [grad.sign() for grad in grads]
|
||||||
|
torch._foreach_addcmul_(params, grad_signs, step_sizes, value=-1) # type: ignore[name-defined, arg-type]
|
||||||
|
|
||||||
|
for i in range(len(states)):
|
||||||
|
states[i]['prev'].copy_(grads[i]) # type: ignore[index]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
from . import _functional as F
|
||||||
from ..optimizer import Optimizer
|
from ..optimizer import Optimizer
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
@ -38,16 +39,15 @@ class Rprop(Optimizer):
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
loss = closure()
|
loss = closure()
|
||||||
|
|
||||||
grads = []
|
|
||||||
states = []
|
|
||||||
params_with_grad = []
|
|
||||||
step_sizes = []
|
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group['params']:
|
grads = []
|
||||||
etaminus, etaplus = group['etas']
|
states = []
|
||||||
step_size_min, step_size_max = group['step_sizes']
|
params_with_grad = []
|
||||||
|
step_sizes = []
|
||||||
|
etaminus, etaplus = group['etas']
|
||||||
|
step_size_min, step_size_max = group['step_sizes']
|
||||||
|
|
||||||
|
for p in group['params']:
|
||||||
if p.grad is not None:
|
if p.grad is not None:
|
||||||
if p.grad.is_sparse:
|
if p.grad.is_sparse:
|
||||||
raise RuntimeError('RMSprop does not support sparse gradients')
|
raise RuntimeError('RMSprop does not support sparse gradients')
|
||||||
|
|
@ -67,30 +67,14 @@ class Rprop(Optimizer):
|
||||||
states.append(state)
|
states.append(state)
|
||||||
step_sizes.append(state['step_size'])
|
step_sizes.append(state['step_size'])
|
||||||
|
|
||||||
signs = torch._foreach_mul(grads, [s['prev'] for s in states])
|
F.rprop(params_with_grad,
|
||||||
signs = [s.sign() for s in signs]
|
grads,
|
||||||
for sign in signs:
|
states,
|
||||||
sign[sign.gt(0)] = etaplus
|
step_sizes,
|
||||||
sign[sign.lt(0)] = etaminus
|
step_size_max=step_size_max,
|
||||||
sign[sign.eq(0)] = 1
|
step_size_min=step_size_min,
|
||||||
|
etaminus=etaminus,
|
||||||
# update stepsizes with step size updates
|
etaplus=etaplus)
|
||||||
torch._foreach_mul_(step_sizes, signs)
|
|
||||||
for step_size in step_sizes:
|
|
||||||
step_size.clamp_(step_size_min, step_size_max)
|
|
||||||
|
|
||||||
# for dir<0, dfdx=0
|
|
||||||
# for dir>=0 dfdx=dfdx
|
|
||||||
for i in range(len(grads)):
|
|
||||||
grads[i] = grads[i].clone(memory_format=torch.preserve_format)
|
|
||||||
grads[i][signs[i].eq(etaminus)] = 0
|
|
||||||
|
|
||||||
# update parameters
|
|
||||||
grad_signs = [grad.sign() for grad in grads]
|
|
||||||
torch._foreach_addcmul_(params_with_grad, grad_signs, step_sizes, value=-1)
|
|
||||||
|
|
||||||
for i in range(len(states)):
|
|
||||||
states[i]['prev'].copy_(grads[i])
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user