mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Third PR in a series of PRs to broaden differentiable optimizer support w/ @janeyx99 (sorry for pinging over the holidays! I just wanted to put this one out but I am definitely not asking for review or anything like that rn) This is also going to probably be my last PR before the holidays! Note: This is a branch of #143710 -- I've never worked on a branch of a branch before so I wasn't sure about the protocol so I thought I'd just made the PR and wait until that one gets merged. This is adding support for differentiable lr, weight_decay, and betas to Adam and AdamW (but after refactoring AdamW into an Adam subclass, it's really just changing code in torch/optim/adam.py) I had one main thing I was wondering about, which is that adam already has a differentiable flag built in, so I have code like this ```py if differentiable and isinstance(beta2, Tensor): if beta2.requires_grad: exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2)) else: exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) else: exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) ``` That I could definitely simplify to just ```py if differentiable and isinstance(beta2, Tensor): exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2)) else: exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) ``` It would definitely be a little slower in the case that it's differentiable but doesn't need a grad for beta2, but the code would also be a lot more clear and I'm debating speed vs future code usability. Also the line in the above example: ```py exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2)) ``` was concerning to me because it is considerably more expensive than `value=1 - beta2`, but I couldn't think of a better way to do it. Further work on #141832 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143726 Approved by: https://github.com/janeyx99 |
||
|---|---|---|
| .. | ||
| test_lrscheduler.py | ||
| test_optim.py | ||
| test_swa_utils.py | ||