pytorch/test/optim
emmettbicker 92d8965082 Adding support for differentiable lr, weight_decay, and betas in Adam/AdamW (#143726)
Third PR in a series of PRs to broaden differentiable optimizer support w/ @janeyx99 (sorry for pinging over the holidays! I just wanted to put this one out but I am definitely not asking for review or anything like that rn)

This is also going to probably be my last PR before the holidays!

Note: This is a branch of #143710 -- I've never worked on a branch of a branch before so I wasn't sure about the protocol so I thought I'd just made the PR and wait until that one gets merged.

This is adding support for differentiable lr, weight_decay, and betas to Adam and AdamW (but after refactoring AdamW into an Adam subclass, it's really just changing code in torch/optim/adam.py)

I had one main thing I was wondering about, which is that adam already has a differentiable flag built in, so I have code like this
```py
if differentiable and isinstance(beta2, Tensor):
    if beta2.requires_grad:
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2))
    else:
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
else:
    exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
```
That I could definitely simplify to just
```py
if differentiable and isinstance(beta2, Tensor):
    exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2))
else:
    exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
```

It would definitely be a little slower in the case that it's differentiable but doesn't need a grad for beta2, but the code would also be a lot more clear and I'm debating speed vs future code usability.

Also the line in the above example:
```py
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2))
```
was concerning to me because it is considerably more expensive than `value=1 - beta2`, but I couldn't think of a better way to do it.

Further work on #141832

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143726
Approved by: https://github.com/janeyx99
2024-12-30 01:11:57 +00:00
..
test_lrscheduler.py Fix unused Python variables in test/[e-z]* (#136964) 2024-12-18 23:02:30 +00:00
test_optim.py Adding support for differentiable lr, weight_decay, and betas in Adam/AdamW (#143726) 2024-12-30 01:11:57 +00:00
test_swa_utils.py Fix unused Python variables in test/[e-z]* (#136964) 2024-12-18 23:02:30 +00:00