To refactor multi tensor Rprop to functional API

ghstack-source-id: 0be0fb02a0
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60963
This commit is contained in:
ilqar ramazanli 2021-06-30 16:08:32 +00:00
parent 8098c3822b
commit 360e89fdf9
2 changed files with 55 additions and 32 deletions

View File

@ -354,3 +354,42 @@ def sgd(params: List[Tensor],
# foreach APIs dont support sparse
for i in range(len(params)):
params[i].add_(grads[i], alpha=-lr)
def rprop(params: List[Tensor],
grads: List[Tensor],
states: List[Tensor],
step_sizes: List[int],
*,
step_size_max: float,
step_size_min: float,
etaminus: float,
etaplus: float):
r"""Functional API that performs Rprop algorithm computation.
See :class:`~torch.optim.Rprop` for details.
"""
signs = torch._foreach_mul(grads, [s['prev'] for s in states]) # type: ignore[misc, index]
signs = [s.sign() for s in signs]
for sign in signs:
sign[sign.gt(0)] = etaplus
sign[sign.lt(0)] = etaminus
sign[sign.eq(0)] = 1
# update stepsizes with step size updates
torch._foreach_mul_(step_sizes, signs) # type: ignore[arg-type]
for step_size in step_sizes:
step_size.clamp_(step_size_min, step_size_max) # type: ignore[attr-defined]
# for dir<0, dfdx=0
# for dir>=0 dfdx=dfdx
for i in range(len(grads)):
grads[i] = grads[i].clone(memory_format=torch.preserve_format)
grads[i][signs[i].eq(etaminus)] = 0
# update parameters
grad_signs = [grad.sign() for grad in grads]
torch._foreach_addcmul_(params, grad_signs, step_sizes, value=-1) # type: ignore[name-defined, arg-type]
for i in range(len(states)):
states[i]['prev'].copy_(grads[i]) # type: ignore[index]

View File

@ -1,4 +1,5 @@
import torch
from . import _functional as F
from ..optimizer import Optimizer
from collections import defaultdict
@ -38,16 +39,15 @@ class Rprop(Optimizer):
with torch.enable_grad():
loss = closure()
grads = []
states = []
params_with_grad = []
step_sizes = []
for group in self.param_groups:
for p in group['params']:
etaminus, etaplus = group['etas']
step_size_min, step_size_max = group['step_sizes']
grads = []
states = []
params_with_grad = []
step_sizes = []
etaminus, etaplus = group['etas']
step_size_min, step_size_max = group['step_sizes']
for p in group['params']:
if p.grad is not None:
if p.grad.is_sparse:
raise RuntimeError('RMSprop does not support sparse gradients')
@ -67,30 +67,14 @@ class Rprop(Optimizer):
states.append(state)
step_sizes.append(state['step_size'])
signs = torch._foreach_mul(grads, [s['prev'] for s in states])
signs = [s.sign() for s in signs]
for sign in signs:
sign[sign.gt(0)] = etaplus
sign[sign.lt(0)] = etaminus
sign[sign.eq(0)] = 1
# update stepsizes with step size updates
torch._foreach_mul_(step_sizes, signs)
for step_size in step_sizes:
step_size.clamp_(step_size_min, step_size_max)
# for dir<0, dfdx=0
# for dir>=0 dfdx=dfdx
for i in range(len(grads)):
grads[i] = grads[i].clone(memory_format=torch.preserve_format)
grads[i][signs[i].eq(etaminus)] = 0
# update parameters
grad_signs = [grad.sign() for grad in grads]
torch._foreach_addcmul_(params_with_grad, grad_signs, step_sizes, value=-1)
for i in range(len(states)):
states[i]['prev'].copy_(grads[i])
F.rprop(params_with_grad,
grads,
states,
step_sizes,
step_size_max=step_size_max,
step_size_min=step_size_min,
etaminus=etaminus,
etaplus=etaplus)
return loss