mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[optim] rmsprop: handle complex params as independent real params (#83860)
Ref: #65711 Pull Request resolved: https://github.com/pytorch/pytorch/pull/83860 Approved by: https://github.com/albanD
This commit is contained in:
parent
62d9f1559e
commit
09331c947c
|
|
@ -331,7 +331,7 @@ class TestOptim(TestCase):
|
|||
optim1 = optimizer_constructor([a1])
|
||||
optim2 = optimizer_constructor([a1_real, a1_imag])
|
||||
|
||||
for i in range(10):
|
||||
for _ in range(10):
|
||||
optim1.zero_grad()
|
||||
optim2.zero_grad()
|
||||
a2 = torch.complex(a1_real, a1_imag)
|
||||
|
|
@ -871,6 +871,14 @@ class TestOptim(TestCase):
|
|||
lr=1e-2, momentum=0.1, weight_decay=1, maximize=maximize),
|
||||
constructor_accepts_maximize=True
|
||||
)
|
||||
self._test_complex_2d(optimizer)
|
||||
self._test_complex_2d(lambda param: optimizer(param, centered=True))
|
||||
self._test_complex_2d(lambda param: optimizer(param, momentum=0.1))
|
||||
self._test_complex_2d(lambda param: optimizer(param, maximize=True))
|
||||
self._test_complex_optimizer(lambda param: optimizer([param]))
|
||||
self._test_complex_optimizer(lambda param: optimizer([param], centered=True))
|
||||
self._test_complex_optimizer(lambda param: optimizer([param], momentum=0.1))
|
||||
self._test_complex_optimizer(lambda param: optimizer([param], maximize=True))
|
||||
with self.assertRaisesRegex(ValueError, "Invalid momentum value: -1.0"):
|
||||
optimizer(None, lr=1e-2, momentum=-1.0)
|
||||
|
||||
|
|
|
|||
|
|
@ -236,10 +236,18 @@ def _single_tensor_rmsprop(params: List[Tensor],
|
|||
if weight_decay != 0:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
||||
is_complex_param = torch.is_complex(param)
|
||||
if is_complex_param:
|
||||
param = torch.view_as_real(param)
|
||||
grad = torch.view_as_real(grad)
|
||||
square_avg = torch.view_as_real(square_avg)
|
||||
|
||||
square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
|
||||
|
||||
if centered:
|
||||
grad_avg = grad_avgs[i]
|
||||
if is_complex_param:
|
||||
grad_avg = torch.view_as_real(grad_avg)
|
||||
grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha)
|
||||
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_()
|
||||
else:
|
||||
|
|
@ -252,6 +260,8 @@ def _single_tensor_rmsprop(params: List[Tensor],
|
|||
|
||||
if momentum > 0:
|
||||
buf = momentum_buffer_list[i]
|
||||
if is_complex_param:
|
||||
buf = torch.view_as_real(buf)
|
||||
buf.mul_(momentum).addcdiv_(grad, avg)
|
||||
param.add_(buf, alpha=-lr)
|
||||
else:
|
||||
|
|
@ -284,10 +294,18 @@ def _multi_tensor_rmsprop(params: List[Tensor],
|
|||
if weight_decay != 0:
|
||||
torch._foreach_add_(grads, params, alpha=weight_decay)
|
||||
|
||||
def _view_complex_as_real(tensor_list):
|
||||
return [torch.view_as_real(t) if torch.is_complex(t) else t for t in tensor_list]
|
||||
|
||||
grads = _view_complex_as_real(grads)
|
||||
params = _view_complex_as_real(params)
|
||||
square_avgs = _view_complex_as_real(square_avgs)
|
||||
|
||||
torch._foreach_mul_(square_avgs, alpha)
|
||||
torch._foreach_addcmul_(square_avgs, grads, grads, value=1 - alpha)
|
||||
|
||||
if centered:
|
||||
grad_avgs = _view_complex_as_real(grad_avgs)
|
||||
torch._foreach_mul_(grad_avgs, alpha)
|
||||
torch._foreach_add_(grad_avgs, grads, alpha=1 - alpha)
|
||||
avg = torch._foreach_addcmul(square_avgs, grad_avgs, grad_avgs, value=-1)
|
||||
|
|
@ -298,6 +316,7 @@ def _multi_tensor_rmsprop(params: List[Tensor],
|
|||
torch._foreach_add_(avg, eps)
|
||||
|
||||
if momentum > 0:
|
||||
momentum_buffer_list = _view_complex_as_real(momentum_buffer_list)
|
||||
torch._foreach_mul_(momentum_buffer_list, momentum)
|
||||
torch._foreach_addcdiv_(momentum_buffer_list, grads, avg)
|
||||
torch._foreach_add_(params, momentum_buffer_list, alpha=-lr)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user