From b31031c56200533b99e25c4e21ec745f3f925352 Mon Sep 17 00:00:00 2001 From: ilqar ramazanli Date: Wed, 30 Jun 2021 16:10:31 +0000 Subject: [PATCH] To refactor multi tensor RMSprop to functional API ghstack-source-id: 4e483b98f34f8101ec9aa68026f8670243580885 Pull Request resolved: https://github.com/pytorch/pytorch/pull/60964 --- torch/optim/_multi_tensor/_functional.py | 42 ++++++++++++++++++++++++ torch/optim/_multi_tensor/rmsprop.py | 35 +++++++------------- 2 files changed, 53 insertions(+), 24 deletions(-) diff --git a/torch/optim/_multi_tensor/_functional.py b/torch/optim/_multi_tensor/_functional.py index 144bcd84122..82312278693 100644 --- a/torch/optim/_multi_tensor/_functional.py +++ b/torch/optim/_multi_tensor/_functional.py @@ -393,3 +393,45 @@ def rprop(params: List[Tensor], for i in range(len(states)): states[i]['prev'].copy_(grads[i]) # type: ignore[index] + + +def rmsprop(params: List[Tensor], + grads: List[Tensor], + states: List[Tensor], + square_avg: List[Tensor], + *, + lr: float, + alpha: float, + eps: float, + weight_decay: float, + momentum: float, + centered: float): + r"""Functional API that performs RMSprop algorithm computation. + + See :class:`~torch.optim.RMSprop` for details. + """ + + if weight_decay != 0: + torch._foreach_add_(grads, params, alpha=weight_decay) # type: ignore[name-defined] + + torch._foreach_mul_(square_avg, alpha) + torch._foreach_addcmul_(square_avg, grads, grads, value=1 - alpha) + + if centered: + grad_avgs = [s['grad_avg'] for s in states] # type: ignore[index] + torch._foreach_mul_(grad_avgs, alpha) + torch._foreach_add_(grad_avgs, grads, alpha=1 - alpha) + avg = torch._foreach_addcmul(square_avg, grad_avgs, grad_avgs, value=-1) + torch._foreach_sqrt_(avg) + torch._foreach_add_(avg, eps) + else: + avg = torch._foreach_sqrt(square_avg) + torch._foreach_add_(avg, eps) + + if momentum > 0: + buf = [s['momentum_buffer'] for s in states] # type: ignore[index] + torch._foreach_mul_(buf, momentum) + torch._foreach_addcdiv_(buf, grads, avg) + torch._foreach_add_(params, buf, alpha=-lr) # type: ignore[name-defined] + else: + torch._foreach_addcdiv_(params, grads, avg, value=-lr) # type: ignore[name-defined] diff --git a/torch/optim/_multi_tensor/rmsprop.py b/torch/optim/_multi_tensor/rmsprop.py index ac918307e7c..7cd62133579 100644 --- a/torch/optim/_multi_tensor/rmsprop.py +++ b/torch/optim/_multi_tensor/rmsprop.py @@ -1,4 +1,5 @@ import torch +from . import _functional as F from ..optimizer import Optimizer from collections import defaultdict @@ -95,30 +96,16 @@ class RMSprop(Optimizer): states.append(state) square_avg.append(state['square_avg']) - if group['weight_decay'] != 0: - torch._foreach_add_(grads, params_with_grad, alpha=group['weight_decay']) - - torch._foreach_mul_(square_avg, alpha) - torch._foreach_addcmul_(square_avg, grads, grads, value=1 - alpha) - - if group['centered']: - grad_avgs = [s['grad_avg'] for s in states] - torch._foreach_mul_(grad_avgs, alpha) - torch._foreach_add_(grad_avgs, grads, alpha=1 - alpha) - avg = torch._foreach_addcmul(square_avg, grad_avgs, grad_avgs, value=-1) - torch._foreach_sqrt_(avg) - torch._foreach_add_(avg, group['eps']) - else: - avg = torch._foreach_sqrt(square_avg) - torch._foreach_add_(avg, group['eps']) - - if group['momentum'] > 0: - buf = [s['momentum_buffer'] for s in states] - torch._foreach_mul_(buf, group['momentum']) - torch._foreach_addcdiv_(buf, grads, avg) - torch._foreach_add_(params_with_grad, buf, alpha=-group['lr']) - else: - torch._foreach_addcdiv_(params_with_grad, grads, avg, value=-group['lr']) + F.rmsprop(params_with_grad, + grads, + states, + square_avg, + lr=group['lr'], + alpha=group['alpha'], + eps=group['eps'], + weight_decay=group['weight_decay'], + momentum=group['momentum'], + centered=group['centered']) return loss