pytorch/test/optim/test_optim.py
emmettbicker 92d8965082 Adding support for differentiable lr, weight_decay, and betas in Adam/AdamW (#143726)
Third PR in a series of PRs to broaden differentiable optimizer support w/ @janeyx99 (sorry for pinging over the holidays! I just wanted to put this one out but I am definitely not asking for review or anything like that rn)

This is also going to probably be my last PR before the holidays!

Note: This is a branch of #143710 -- I've never worked on a branch of a branch before so I wasn't sure about the protocol so I thought I'd just made the PR and wait until that one gets merged.

This is adding support for differentiable lr, weight_decay, and betas to Adam and AdamW (but after refactoring AdamW into an Adam subclass, it's really just changing code in torch/optim/adam.py)

I had one main thing I was wondering about, which is that adam already has a differentiable flag built in, so I have code like this
```py
if differentiable and isinstance(beta2, Tensor):
    if beta2.requires_grad:
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2))
    else:
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
else:
    exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
```
That I could definitely simplify to just
```py
if differentiable and isinstance(beta2, Tensor):
    exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2))
else:
    exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
```

It would definitely be a little slower in the case that it's differentiable but doesn't need a grad for beta2, but the code would also be a lot more clear and I'm debating speed vs future code usability.

Also the line in the above example:
```py
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2))
```
was concerning to me because it is considerably more expensive than `value=1 - beta2`, but I couldn't think of a better way to do it.

Further work on #141832

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143726
Approved by: https://github.com/janeyx99
2024-12-30 01:11:57 +00:00

763 lines
27 KiB
Python

# Owner(s): ["module: optimizer"]
from __future__ import annotations
from typing import Any
import torch
from torch import nn, Tensor
from torch.optim import (
Adadelta,
Adagrad,
Adam,
Adamax,
AdamW,
ASGD,
NAdam,
Optimizer,
RAdam,
RMSprop,
Rprop,
SGD,
)
from torch.testing._internal.common_utils import (
gradcheck,
load_tests,
skipIfTorchDynamo,
TestCase,
)
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
# Ignored is the list of values in `opt_differentiable_state`, we do this
# for `gradcheck` to correctly track the state tensors as function inputs
# because otherwise it can't unpack the values in the `opt_differentiable_state`
# dict
p = p.clone()
p.grad = grad
opt_differentiable_state = {
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in opt_differentiable_state.items()
}
opt = opt_class([p], **kwargs)
opt.state[p].update(opt_differentiable_state)
opt.step()
return (p,) + tuple(
v
for v in opt.state[p].values()
if isinstance(v, torch.Tensor) and v.requires_grad
)
def _multistep_backprop_diff_hyperparams_fn(
params: Tensor,
grad: Tensor,
opt_differentiable_state: dict[str, Any],
opt_class: type[Optimizer],
kwargs: dict[str, Any],
*ignored: Any,
) -> tuple[Tensor, ...]:
assert (
kwargs["differentiable"] is True
), "Only call this test function when differentiable=True"
params = params.clone()
params.grad = grad
opt_differentiable_state = {
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in opt_differentiable_state.items()
}
# This copy is necessary so the update on line 78 doesn't overwrite the original kwargs values
kwargs = kwargs.copy()
# Have to pass in beta1 and beta2 separately
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck
if "beta1" in kwargs or "beta2" in kwargs:
# Prevent just one beta kwarg from being passed in
assert (
"beta1" in kwargs and "beta2" in kwargs
), "Both betas should be defined in kwargs"
kwargs.update({"betas": (kwargs.pop("beta1"), kwargs.pop("beta2"))})
kwargs.update(
{k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
)
differentiable_kwargs = [
v for v in kwargs.values() if isinstance(v, torch.Tensor) and v.requires_grad
] + (list(kwargs["betas"]) if "betas" in kwargs else [])
criterion = nn.MSELoss()
optimizer = opt_class([params], **kwargs)
optimizer.state[params].update(opt_differentiable_state)
# Simple x, y pair
x = torch.tensor([1.0], dtype=torch.float64)
y = torch.tensor([2.0], dtype=torch.float64)
for _ in range(2):
loss = criterion(x * torch.sum(params), y)
loss.backward(
inputs=(params,),
create_graph=True,
)
optimizer.step()
optimizer.zero_grad()
meta_loss = loss
meta_loss.backward(inputs=(*differentiable_kwargs,), create_graph=True)
# Extra check to make sure the test properly computed a gradient for all kwargs
for kwarg in differentiable_kwargs:
assert kwarg.grad is not None
return (
(meta_loss,)
+ tuple(
v
for v in optimizer.state[params].values()
if isinstance(v, torch.Tensor) and v.requires_grad
)
+ tuple(differentiable_kwargs)
)
@skipIfTorchDynamo("Differentiable optimizers not supported")
class TestDifferentiableOptimizer(TestCase):
def test_sgd(self):
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64)
state = {"momentum_buffer": mbuff}
gradcheck(
_diff_fn,
(
p,
grad,
state,
SGD,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
def test_adam(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
Adam,
{"lr": 0.9, "differentiable": True, "amsgrad": True},
*state.values(),
),
)
def test_rmsprop(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["step"] = torch.zeros((), dtype=torch.float64)
state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["momentum_buffer"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
# This can cause issues with large values and nan due to sqrt ops
state["grad_avg"] = 1e-2 * torch.rand(
10, requires_grad=True, dtype=torch.float64
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
RMSprop,
{
"lr": 0.9,
"maximize": True,
"momentum": 0.9,
"differentiable": True,
"centered": True,
"weight_decay": 0.1,
},
*state.values(),
),
)
def test_adadelta(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["acc_delta"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
Adadelta,
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
*state.values(),
),
)
def test_adagrad(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["sum"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
Adagrad,
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
*state.values(),
),
)
def test_adamax(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_inf"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
Adamax,
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
*state.values(),
),
)
@skipIfTorchDynamo(
"The inplace mu update fails with dynamo, "
"since this is only happening when differentiable is enabled, skipping for now"
)
def test_asgd(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` `eta` & `mu` are not continuous variables (even though we define them as floats)
# and so they shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["eta"] = torch.tensor(0.9, requires_grad=False, dtype=torch.float64)
state["mu"] = torch.tensor(1.0, requires_grad=False, dtype=torch.float64)
state["ax"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
ASGD,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
def test_rprop(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["prev"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["step_size"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
Rprop,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
def test_adamw(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
AdamW,
{"lr": 0.9, "differentiable": True, "amsgrad": True},
*state.values(),
),
)
def test_nadam(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["mu_product"] = torch.tensor(1.0, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
NAdam,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
NAdam,
{"lr": 0.9, "decoupled_weight_decay": True, "differentiable": True},
*state.values(),
),
)
def test_radam(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
RAdam,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
RAdam,
{
"lr": 0.9,
"weight_decay": 0.1,
"decoupled_weight_decay": True,
"differentiable": True,
},
*state.values(),
),
)
def test_adam_differentiable_lr(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
kwargs: dict[str, Any] = {"lr": lr, "differentiable": True}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
Adam,
kwargs, # includes lr
*state.values(),
*kwargs.values(),
),
)
def test_adam_differentiable_weight_decay(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
kwargs: dict[str, Any] = {"weight_decay": weight_decay, "differentiable": True}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
Adam,
kwargs, # includes weight_decay
*state.values(),
*kwargs.values(),
),
)
def test_adam_differentiable_betas(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
lr = torch.tensor([0.001], requires_grad=True, dtype=torch.float64)
betas = (
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
# Have to pass in beta1 and beta2 separately
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
kwargs: dict[str, Any] = {
"beta1": betas[0],
"beta2": betas[1],
"lr": lr,
"differentiable": True,
}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
Adam,
kwargs, # includes betas
*state.values(),
*kwargs.values(),
),
)
def test_adam_differentiable_all_hyperparams(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)
betas = (
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
# Have to pass in beta1 and beta2 separately
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
kwargs: dict[str, Any] = {
"lr": lr,
"weight_decay": weight_decay,
"beta1": betas[0],
"beta2": betas[1],
"differentiable": True,
}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
Adam,
kwargs, # includes betas
*state.values(),
*kwargs.values(),
),
)
def test_adamw_differentiable_lr(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
kwargs: dict[str, Any] = {"lr": lr, "differentiable": True}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
AdamW,
kwargs, # includes lr
*state.values(),
*kwargs.values(),
),
)
def test_adamw_differentiable_weight_decay(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
kwargs: dict[str, Any] = {"weight_decay": weight_decay, "differentiable": True}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
AdamW,
kwargs, # includes weight_decay
*state.values(),
*kwargs.values(),
),
)
def test_adamw_differentiable_betas(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
betas = (
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
# Have to pass in beta1 and beta2 separately
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
kwargs: dict[str, Any] = {
"beta1": betas[0],
"beta2": betas[1],
"differentiable": True,
}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
AdamW,
kwargs, # includes betas
*state.values(),
*kwargs.values(),
),
)
def test_adamw_differentiable_all_hyperparams(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)
betas = (
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
# Have to pass in beta1 and beta2 separately
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
kwargs: dict[str, Any] = {
"lr": lr,
"weight_decay": weight_decay,
"beta1": betas[0],
"beta2": betas[1],
"differentiable": True,
}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
AdamW,
kwargs, # includes betas
*state.values(),
*kwargs.values(),
),
)
def test_differentiable_lr(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
mbuff = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
state = {"momentum_buffer": mbuff}
kwargs: dict[str, Any] = {"lr": lr, "differentiable": True}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
SGD,
kwargs, # includes lr
*state.values(),
*kwargs.values(),
),
)
def test_differentiable_weight_decay(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
weight_decay = torch.tensor(0.9, requires_grad=True, dtype=torch.float64)
mbuff = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
state = {"momentum_buffer": mbuff}
kwargs: dict[str, Any] = {"weight_decay": weight_decay, "differentiable": True}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
SGD,
kwargs, # includes weight_decay
*state.values(),
*kwargs.values(),
),
)
def test_differentiable_weight_decay_and_lr(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
weight_decay = torch.tensor(0.9, requires_grad=True, dtype=torch.float64)
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
mbuff = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
state = {"momentum_buffer": mbuff}
kwargs: dict[str, Any] = {
"weight_decay": weight_decay,
"lr": lr,
"differentiable": True,
}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
SGD,
kwargs, # includes lr & weight_decay
*state.values(),
*kwargs.values(),
),
)
if __name__ == "__main__":
print("These tests should be run through test/test_optim.py instead")