mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Starts addressing https://github.com/pytorch/pytorch/issues/97712 by - Minimizing intermediates usage for foreach Adam - Document the extra memory usage - Add comments within the code for clarity now that we reuse intermediates - Add tests - Did some refactoring Next steps involve doing this for all other foreach implementations. Note that even after this change, foreach mem usage will be higher than forloop due to the fact that we have a minimum budget of 1 intermediate (to not muddle the input values) and the intermediate will be larger. For capturable, the memory usage is higher due to moving more tensors to CUDA. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104780 Approved by: https://github.com/albanD
2175 lines
84 KiB
Python
2175 lines
84 KiB
Python
# Owner(s): ["module: optimizer"]
|
|
|
|
import math
|
|
import unittest
|
|
import functools
|
|
import itertools
|
|
from copy import deepcopy
|
|
|
|
import torch
|
|
import torch.optim as optim
|
|
from torch.nn import Parameter
|
|
from torch.optim import Adam, SGD, Optimizer
|
|
from torch.optim.lr_scheduler import (
|
|
StepLR,
|
|
ConstantLR,
|
|
LinearLR,
|
|
ExponentialLR,
|
|
ReduceLROnPlateau,
|
|
PolynomialLR,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase,
|
|
load_tests,
|
|
gradcheck,
|
|
skipIfRocm,
|
|
skipIfTorchDynamo
|
|
)
|
|
|
|
from torch._dynamo import disable as disable_dynamo
|
|
|
|
from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
|
|
from torch.testing._internal.common_device_type import largeTensorTest
|
|
from typing import Dict, Any, Tuple
|
|
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
|
|
from unittest.mock import patch
|
|
|
|
# load_tests from common_utils is used to automatically filter tests for
|
|
# sharding on sandcastle. This line silences flake warnings
|
|
load_tests = load_tests
|
|
|
|
|
|
def rosenbrock(tensor):
|
|
assert tensor.size() == torch.Size([2]), f"Requires tensor with 2 scalars but got {tensor.size()}"
|
|
x, y = tensor
|
|
return (1 - x) ** 2 + 100 * (y - x**2) ** 2
|
|
|
|
|
|
def drosenbrock(tensor):
|
|
assert tensor.size() == torch.Size([2]), f"Requires tensor with 2 scalars but got {tensor.size()}"
|
|
x, y = tensor
|
|
return torch.tensor((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2)))
|
|
|
|
@skipIfTorchDynamo("This is a TEMPORARY stopgap, see https://github.com/pytorch/pytorch/issues/103322")
|
|
class TestOptim(TestCase):
|
|
exact_dtype = True
|
|
|
|
def _test_rosenbrock_sparse(
|
|
self,
|
|
constructor,
|
|
scheduler_constructors=None,
|
|
sparse_only=False,
|
|
maximize=False,
|
|
):
|
|
if scheduler_constructors is None:
|
|
scheduler_constructors = []
|
|
# For rosenbrock tests, it is mandated that the param is a tensor with 2 numbers
|
|
param_t = torch.tensor([1.5, 1.5])
|
|
|
|
param = Parameter(param_t)
|
|
optimizer = constructor([param])
|
|
schedulers = []
|
|
for scheduler_constructor in scheduler_constructors:
|
|
schedulers.append(scheduler_constructor(optimizer))
|
|
|
|
if not sparse_only:
|
|
param_c = Parameter(param_t.clone())
|
|
optimizer_c = constructor([param_c])
|
|
|
|
solution = torch.tensor([1, 1])
|
|
with torch.no_grad():
|
|
initial_dist = param.dist(solution)
|
|
|
|
def eval(param, sparse_grad, w):
|
|
# Depending on w, provide only the x or y gradient
|
|
optimizer.zero_grad()
|
|
loss = rosenbrock(param)
|
|
loss.backward()
|
|
grad = drosenbrock(param)
|
|
# NB: We torture test the optimizer by returning an
|
|
# uncoalesced sparse tensor
|
|
if w:
|
|
i = torch.LongTensor([[0, 0]])
|
|
x = grad[0]
|
|
v = torch.tensor([x / 4.0, x - x / 4.0])
|
|
else:
|
|
i = torch.LongTensor([[1, 1]])
|
|
y = grad[1]
|
|
v = torch.tensor([y - y / 4.0, y / 4.0])
|
|
x = torch.sparse_coo_tensor(i, v, (2,), dtype=v.dtype)
|
|
with torch.no_grad():
|
|
if sparse_grad:
|
|
param.grad = x
|
|
else:
|
|
param.grad = x.to_dense()
|
|
return loss
|
|
|
|
for i in range(2000):
|
|
# Do cyclic coordinate descent
|
|
w = i % 2
|
|
optimizer.step(functools.partial(eval, param, True, w))
|
|
for scheduler in schedulers:
|
|
if isinstance(scheduler, ReduceLROnPlateau):
|
|
scheduler.step(rosenbrock(param))
|
|
else:
|
|
scheduler.step()
|
|
if not sparse_only:
|
|
optimizer_c.step(functools.partial(eval, param_c, False, w))
|
|
self.assertEqual(param, param_c)
|
|
|
|
if not maximize:
|
|
self.assertLessEqual(param.dist(solution), initial_dist)
|
|
else:
|
|
self.assertGreaterEqual(rosenbrock(param), rosenbrock(param_t))
|
|
|
|
def _test_basic_cases_template(
|
|
self,
|
|
weight_tensor,
|
|
bias_tensor,
|
|
input_tensor,
|
|
constructor,
|
|
scheduler_constructors,
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=False,
|
|
):
|
|
maximize_options = {False, constructor_accepts_maximize}
|
|
foreach_options = {False, constructor_accepts_foreach}
|
|
|
|
four_arg_constructor = constructor
|
|
if constructor_accepts_maximize and constructor_accepts_foreach:
|
|
pass
|
|
elif constructor_accepts_maximize:
|
|
|
|
def four_arg_constructor(weight, bias, maximize, foreach):
|
|
self.assertFalse(foreach)
|
|
return constructor(weight, bias, maximize)
|
|
|
|
elif constructor_accepts_foreach:
|
|
|
|
def four_arg_constructor(weight, bias, maximize, foreach):
|
|
self.assertFalse(maximize)
|
|
return constructor(weight, bias, foreach)
|
|
|
|
else:
|
|
|
|
def four_arg_constructor(weight, bias, maximize, foreach):
|
|
self.assertFalse(maximize or foreach)
|
|
return constructor(weight, bias)
|
|
|
|
for maximize, foreach in itertools.product(maximize_options, foreach_options):
|
|
with torch.no_grad():
|
|
weight = Parameter(weight_tensor.clone().detach())
|
|
bias = Parameter(bias_tensor.clone().detach())
|
|
input = input_tensor.clone().detach().requires_grad_()
|
|
optimizer = four_arg_constructor(weight, bias, maximize, foreach)
|
|
schedulers = []
|
|
for scheduler_constructor in scheduler_constructors:
|
|
schedulers.append(scheduler_constructor(optimizer))
|
|
|
|
# to check if the optimizer can be printed as a string
|
|
optimizer.__repr__()
|
|
|
|
def fn():
|
|
optimizer.zero_grad()
|
|
y = weight.mv(input)
|
|
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
|
|
y = y.cuda(bias.get_device())
|
|
loss = (y + bias).pow(2).sum()
|
|
loss.backward()
|
|
return loss
|
|
|
|
initial_value = fn().item()
|
|
for _ in range(200):
|
|
optimizer.step(fn)
|
|
for scheduler in schedulers:
|
|
if isinstance(scheduler, ReduceLROnPlateau):
|
|
val_loss = fn()
|
|
scheduler.step(val_loss)
|
|
else:
|
|
scheduler.step()
|
|
if maximize:
|
|
self.assertGreater(fn().item(), initial_value)
|
|
else:
|
|
self.assertLess(fn().item(), initial_value)
|
|
|
|
# Note: disable dynamo on this function
|
|
# This allows us to continue running actual logic of the optimizer
|
|
# tests in dynamo without tracing this test code which has a lot of unsupported
|
|
# behavior
|
|
@disable_dynamo(recursive=False)
|
|
def _test_state_dict(self, weight, bias, input, constructor, atol=None, rtol=None):
|
|
weight = Parameter(weight)
|
|
bias = Parameter(bias)
|
|
with torch.no_grad():
|
|
input = input.clone().detach().requires_grad_()
|
|
|
|
# Note: Disable dynamo on this function
|
|
# This avoids a bug where input_cuda is not detected in the environment
|
|
# because it currently is not defined in the local environmet. Unable to repro
|
|
# anywhere else however and this is test code that we don't need to spend
|
|
# time getting dynamo to trace unless the issue repros in real models.
|
|
@disable_dynamo(recursive=False)
|
|
def fn_base(optimizer, weight, bias):
|
|
optimizer.zero_grad()
|
|
i = input_cuda if weight.is_cuda else input
|
|
loss = (weight.mv(i) + bias).pow(2).sum()
|
|
loss.backward()
|
|
return loss
|
|
|
|
optimizer = constructor(weight, bias)
|
|
fn = functools.partial(fn_base, optimizer, weight, bias)
|
|
|
|
# Prime the optimizer
|
|
for _i in range(20):
|
|
optimizer.step(fn)
|
|
# Clone the weights and construct new optimizer for them
|
|
with torch.no_grad():
|
|
weight_c = Parameter(weight.clone().detach())
|
|
bias_c = Parameter(bias.clone().detach())
|
|
optimizer_c = constructor(weight_c, bias_c)
|
|
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
|
|
# Load state dict
|
|
state_dict = deepcopy(optimizer.state_dict())
|
|
state_dict_c = deepcopy(optimizer.state_dict())
|
|
optimizer_c.load_state_dict(state_dict_c)
|
|
# Run both optimizers in parallel
|
|
for _ in range(20):
|
|
optimizer.step(fn)
|
|
optimizer_c.step(fn_c)
|
|
self.assertEqual(weight, weight_c)
|
|
self.assertEqual(bias, bias_c)
|
|
# Make sure state dict wasn't modified
|
|
self.assertEqual(state_dict, state_dict_c)
|
|
# Make sure state dict is deterministic with equal but not identical parameters
|
|
self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
|
|
# Make sure repeated parameters have identical representation in state dict
|
|
optimizer_c.param_groups.extend(optimizer_c.param_groups)
|
|
self.assertEqual(
|
|
optimizer.state_dict()["param_groups"][-1],
|
|
optimizer_c.state_dict()["param_groups"][-1],
|
|
)
|
|
|
|
# Make sure that optimizers that support maximize can load older models
|
|
old_state_dict = deepcopy(optimizer.state_dict())
|
|
state_dict_no_maximize = deepcopy(optimizer.state_dict())
|
|
if "maximize" in state_dict_no_maximize["param_groups"][0]:
|
|
for group in state_dict_no_maximize["param_groups"]:
|
|
del group["maximize"]
|
|
optimizer.load_state_dict(state_dict_no_maximize)
|
|
# Make sure we can still step
|
|
optimizer.step()
|
|
# Undo these changes before proceeding!
|
|
optimizer.load_state_dict(old_state_dict)
|
|
# Make sure that optimizers that support foreach can load older models
|
|
state_dict_no_foreach = deepcopy(optimizer.state_dict())
|
|
if "foreach" in state_dict_no_foreach["param_groups"][0]:
|
|
for group in state_dict_no_foreach["param_groups"]:
|
|
del group["foreach"]
|
|
optimizer.load_state_dict(state_dict_no_foreach)
|
|
# Make sure we can still step
|
|
optimizer.step()
|
|
# Undo these changes before proceeding!
|
|
optimizer.load_state_dict(old_state_dict)
|
|
|
|
# Make sure that loading optimizers with step not wrapped in tensor can work
|
|
state_dict = optimizer.state_dict()
|
|
if "step" in state_dict["state"][0] and torch.is_tensor(
|
|
state_dict["state"][0]["step"]
|
|
):
|
|
for state in state_dict["state"].values():
|
|
state["step"] = state["step"].item()
|
|
optimizer.load_state_dict(state_dict)
|
|
optimizer.step()
|
|
|
|
# Check that state dict can be loaded even when we cast parameters
|
|
# to a different type and move to a different device.
|
|
if not torch.cuda.is_available():
|
|
return
|
|
|
|
with torch.no_grad():
|
|
input_cuda = input.clone().detach().to(dtype=torch.float32, device="cuda")
|
|
weight_cuda = Parameter(
|
|
weight.clone().detach().to(dtype=torch.float32, device="cuda")
|
|
)
|
|
bias_cuda = Parameter(
|
|
bias.clone().detach().to(dtype=torch.float32, device="cuda")
|
|
)
|
|
optimizer_cuda = constructor(weight_cuda, bias_cuda)
|
|
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda)
|
|
|
|
state_dict = deepcopy(optimizer.state_dict())
|
|
state_dict_c = deepcopy(optimizer.state_dict())
|
|
optimizer_cuda.load_state_dict(state_dict_c)
|
|
|
|
# Make sure state dict wasn't modified
|
|
self.assertEqual(state_dict, state_dict_c)
|
|
|
|
# Make sure that device of state['step'] is still CPU
|
|
new_state_dict = optimizer_cuda.state_dict()
|
|
if "step" in state_dict["state"][0] and torch.is_tensor(
|
|
state_dict["state"][0]["step"]
|
|
):
|
|
for state in new_state_dict["state"].values():
|
|
self.assertEqual(state["step"].device.type, "cpu")
|
|
|
|
for _i in range(20):
|
|
optimizer.step(fn)
|
|
optimizer_cuda.step(fn_cuda)
|
|
self.assertEqual(weight, weight_cuda)
|
|
self.assertEqual(bias, bias_cuda, atol=atol, rtol=rtol)
|
|
|
|
# validate deepcopy() copies all public attributes
|
|
def getPublicAttr(obj):
|
|
return {k for k in obj.__dict__ if not k.startswith("_")}
|
|
|
|
self.assertEqual(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer)))
|
|
|
|
def _test_basic_cases(
|
|
self,
|
|
constructor,
|
|
scheduler_constructors=None,
|
|
ignore_multidevice=False,
|
|
constructor_accepts_maximize=False,
|
|
constructor_accepts_foreach=False,
|
|
atol=None,
|
|
rtol=None,
|
|
):
|
|
if scheduler_constructors is None:
|
|
scheduler_constructors = []
|
|
|
|
def make_two_arg_constructor(
|
|
constructor, maximize: bool, foreach: bool
|
|
):
|
|
if constructor_accepts_maximize and constructor_accepts_foreach:
|
|
return lambda weight, bias: constructor(weight, bias, maximize, foreach)
|
|
if constructor_accepts_maximize:
|
|
return lambda weight, bias: constructor(weight, bias, maximize)
|
|
if constructor_accepts_foreach:
|
|
return lambda weight, bias: constructor(weight, bias, foreach)
|
|
return constructor
|
|
|
|
for maximize, foreach in itertools.product(
|
|
{False, constructor_accepts_maximize},
|
|
{False, constructor_accepts_foreach},
|
|
):
|
|
self._test_state_dict(
|
|
torch.randn(10, 5),
|
|
torch.randn(10),
|
|
torch.randn(5),
|
|
make_two_arg_constructor(constructor, maximize, foreach),
|
|
atol=atol,
|
|
rtol=rtol,
|
|
)
|
|
self._test_basic_cases_template(
|
|
torch.randn(10, 5),
|
|
torch.randn(10),
|
|
torch.randn(5),
|
|
constructor,
|
|
scheduler_constructors,
|
|
constructor_accepts_maximize,
|
|
constructor_accepts_foreach,
|
|
)
|
|
# non-contiguous parameters
|
|
self._test_basic_cases_template(
|
|
torch.randn(10, 5, 2)[..., 0],
|
|
torch.randn(10, 2)[..., 0],
|
|
torch.randn(5),
|
|
constructor,
|
|
scheduler_constructors,
|
|
constructor_accepts_maximize,
|
|
constructor_accepts_foreach,
|
|
)
|
|
# CUDA
|
|
if not torch.cuda.is_available():
|
|
return
|
|
self._test_basic_cases_template(
|
|
torch.randn(10, 5).cuda(),
|
|
torch.randn(10).cuda(),
|
|
torch.randn(5).cuda(),
|
|
constructor,
|
|
scheduler_constructors,
|
|
constructor_accepts_maximize,
|
|
constructor_accepts_foreach,
|
|
)
|
|
# Multi-GPU
|
|
if not torch.cuda.device_count() > 1 or ignore_multidevice:
|
|
return
|
|
self._test_basic_cases_template(
|
|
torch.randn(10, 5).cuda(0),
|
|
torch.randn(10).cuda(1),
|
|
torch.randn(5).cuda(0),
|
|
constructor,
|
|
scheduler_constructors,
|
|
constructor_accepts_maximize,
|
|
constructor_accepts_foreach,
|
|
)
|
|
|
|
def _test_complex_optimizer(self, optimizer_constructor):
|
|
complex_param = torch.randn(5, 5, dtype=torch.complex64, requires_grad=True)
|
|
real_param = torch.view_as_real(complex_param).detach().clone().requires_grad_()
|
|
complex_opt = optimizer_constructor(complex_param)
|
|
real_opt = optimizer_constructor(real_param)
|
|
|
|
for _ in range(3):
|
|
complex_param.grad = torch.randn_like(complex_param)
|
|
real_param.grad = torch.view_as_real(complex_param.grad)
|
|
complex_opt.step()
|
|
real_opt.step()
|
|
|
|
self.assertEqual(torch.view_as_real(complex_param), real_param)
|
|
|
|
def _test_complex_2d(self, optimizer_constructor):
|
|
a1 = torch.randn(2, dtype=torch.complex64, requires_grad=True)
|
|
a1_real = a1.real.clone().detach()
|
|
a1_imag = a1.imag.clone().detach()
|
|
a1_real.requires_grad_()
|
|
a1_imag.requires_grad_()
|
|
optim1 = optimizer_constructor([a1])
|
|
optim2 = optimizer_constructor([a1_real, a1_imag])
|
|
|
|
for _ in range(10):
|
|
optim1.zero_grad()
|
|
optim2.zero_grad()
|
|
a2 = torch.complex(a1_real, a1_imag)
|
|
rosenbrock(a1).abs().backward()
|
|
rosenbrock(a2).abs().backward()
|
|
|
|
self.assertEqual(a1.grad.real, a1_real.grad)
|
|
self.assertEqual(a1.grad.imag, a1_imag.grad)
|
|
|
|
optim1.step()
|
|
optim2.step()
|
|
self.assertEqual(a1.real, a1_real)
|
|
self.assertEqual(a1.imag, a1_imag)
|
|
|
|
def _build_params_dict(self, weight, bias, **kwargs):
|
|
return [{"params": [weight]}, dict(params=[bias], **kwargs)]
|
|
|
|
def _build_params_dict_single(self, weight, bias, **kwargs):
|
|
return [dict(params=bias, **kwargs)]
|
|
|
|
def test_sgd(self):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
self._build_params_dict_single(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
self._build_params_dict_single(weight, bias, lr=1e-2),
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
scheduler_constructors=[lambda opt: StepLR(opt, gamma=0.9, step_size=10)],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
scheduler_constructors=[
|
|
lambda opt: LinearLR(
|
|
opt, start_factor=0.4, end_factor=0.8, total_iters=4
|
|
)
|
|
],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
scheduler_constructors=[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
scheduler_constructors=[lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
scheduler_constructors=[
|
|
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
|
lambda opt: LinearLR(
|
|
opt, start_factor=0.4, end_factor=0.6, total_iters=4
|
|
),
|
|
],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
[
|
|
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
|
lambda opt: ReduceLROnPlateau(opt),
|
|
],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
[
|
|
lambda opt: StepLR(opt, gamma=0.99, step_size=10),
|
|
lambda opt: ExponentialLR(opt, gamma=0.99),
|
|
lambda opt: ReduceLROnPlateau(opt),
|
|
],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
momentum=0.5,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
momentum=0.5,
|
|
weight_decay=1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias],
|
|
nesterov=True,
|
|
lr=1e-3,
|
|
momentum=0.5,
|
|
weight_decay=1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "Invalid momentum value: -0.5"):
|
|
optim.SGD(None, lr=1e-2, momentum=-0.5)
|
|
|
|
def test_sgd_sparse(self):
|
|
for foreach in (False, True):
|
|
self._test_rosenbrock_sparse(
|
|
lambda params: optim.SGD(params, lr=4.8e-3, foreach=foreach)
|
|
)
|
|
self._test_rosenbrock_sparse(
|
|
lambda params: optim.SGD(params, lr=0.0048, foreach=foreach),
|
|
scheduler_constructors=[lambda opt: StepLR(opt, gamma=0.99999, step_size=300)],
|
|
)
|
|
|
|
def test_sgd_complex(self):
|
|
for foreach in (False, True):
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.SGD([param], lr=0.001, foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.SGD([param], lr=0.001, momentum=1, foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.SGD(
|
|
[param], lr=0.001, momentum=1, weight_decay=1, foreach=foreach
|
|
)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.SGD(
|
|
[param],
|
|
lr=0.001,
|
|
nesterov=True,
|
|
momentum=1,
|
|
weight_decay=1,
|
|
foreach=foreach,
|
|
)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.SGD(
|
|
[param],
|
|
lr=0.001,
|
|
momentum=1,
|
|
dampening=0.5,
|
|
weight_decay=1,
|
|
foreach=foreach,
|
|
)
|
|
)
|
|
|
|
def _test_derived_optimizers_varying_tensors(self, optimizer_with_kwargs, kwarg):
|
|
if not torch.cuda.is_available():
|
|
return
|
|
assert kwarg in ("foreach", "fused")
|
|
|
|
# Specifically test that inputting params of different dtypes and devices
|
|
# is handled equivalently on the foreach and fused implementations as the
|
|
# single tensor implementations. We need multiple GPUs (vs just a CPU and
|
|
# GPU) because fused adam only works on GPUs. (Thus we only run the tests
|
|
# that call into this helper when TEST_MULTIGPU.)
|
|
params = [
|
|
torch.rand(2, 3, dtype=torch.float64, device='cuda:0', requires_grad=True),
|
|
torch.rand(2, 3, dtype=torch.float32, device='cuda:0', requires_grad=True),
|
|
torch.rand(2, 3, dtype=torch.float16, device='cuda:0', requires_grad=True),
|
|
torch.rand(2, 3, dtype=torch.bfloat16, device='cuda:0', requires_grad=True),
|
|
torch.rand(2, 3, dtype=torch.float64, device='cuda:1', requires_grad=True),
|
|
torch.rand(2, 3, dtype=torch.float32, device='cuda:1', requires_grad=True),
|
|
torch.rand(2, 3, dtype=torch.float16, device='cuda:1', requires_grad=True),
|
|
torch.rand(2, 3, dtype=torch.bfloat16, device='cuda:1', requires_grad=True),
|
|
torch.randint(1024, (2, 3), dtype=torch.int64, device='cuda:1', requires_grad=False),
|
|
]
|
|
|
|
for p in params:
|
|
if p.requires_grad:
|
|
p.grad = torch.rand_like(p, device=p.device, dtype=p.dtype)
|
|
|
|
kIterations = 7 if kwarg == "foreach" else 1
|
|
for optimizer_constructor, kwargs in optimizer_with_kwargs:
|
|
res, state = [], []
|
|
for enabled in (False, True):
|
|
kwargs_clone = deepcopy(kwargs)
|
|
kwargs_clone[kwarg] = enabled
|
|
|
|
params_clone = []
|
|
for p in params:
|
|
p_clone = p.clone().detach()
|
|
if p.requires_grad:
|
|
p_clone.requires_grad = True
|
|
p_clone.grad = p.grad.clone().detach()
|
|
params_clone.append(p_clone)
|
|
|
|
optimizer = optimizer_constructor(params_clone, **kwargs_clone)
|
|
for _ in range(kIterations):
|
|
optimizer.step()
|
|
|
|
state.append(optimizer.state)
|
|
res.append(params_clone)
|
|
|
|
st_state = state[0]
|
|
mt_state = state[1]
|
|
for st_p, mt_p in zip(res[0], res[1]):
|
|
self.assertEqual(st_p, mt_p)
|
|
|
|
# check that optimizer states are the same
|
|
st_p_state = st_state[st_p]
|
|
mt_p_state = mt_state[mt_p]
|
|
|
|
for k in st_p_state:
|
|
actual = mt_p_state[k]
|
|
# If `torch.optim.Adam` is `__init__`ed with either `fused=True` or `capturable=True`,
|
|
# `step` Tensor is 1D while usually it's 0D.
|
|
if (
|
|
k == "step"
|
|
and isinstance(actual, torch.Tensor)
|
|
and actual.ndim == 1
|
|
):
|
|
actual = actual[0]
|
|
self.assertEqual(st_p_state[k], actual)
|
|
|
|
def _test_derived_optimizers(self, optimizer_pairs_with_flags, flag):
|
|
if not torch.cuda.is_available():
|
|
return
|
|
assert flag in ("foreach", "fused")
|
|
|
|
# why 7? iteration 7 is where we start to see differences for RAdam
|
|
# params interacting with the small eps value, because that's right
|
|
# after rho_t becomes greater than 5 in step 6.
|
|
kIterations = 7
|
|
device = "cuda"
|
|
for optimizer_constructor, params in optimizer_pairs_with_flags:
|
|
res, state = [], []
|
|
for flag_value in (False, True):
|
|
input = torch.tensor(
|
|
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=torch.float64, device=device
|
|
).reshape(3, 2)
|
|
|
|
torch.manual_seed(1)
|
|
model = torch.nn.Sequential(
|
|
torch.nn.Linear(2, 3),
|
|
torch.nn.Sigmoid(),
|
|
torch.nn.Linear(3, 1),
|
|
torch.nn.Sigmoid(),
|
|
)
|
|
model.to(dtype=torch.float64, device=device)
|
|
params_with_flags = deepcopy(params)
|
|
params_with_flags[flag] = flag_value
|
|
|
|
# foreach/fused optimizers should be tested with a param_groups['params'] with
|
|
# zero_size tensor as its last param.
|
|
# ref: https://github.com/pytorch/pytorch/issues/100701
|
|
empty_params = [torch.empty((), device=device, dtype=torch.float64)]
|
|
|
|
optimizer = optimizer_constructor(
|
|
list(model.parameters()) + empty_params, **params_with_flags
|
|
)
|
|
|
|
for i in range(kIterations):
|
|
optimizer.zero_grad()
|
|
output = model(input)
|
|
loss = output.sum()
|
|
loss.backward()
|
|
|
|
# Test that step behaves as expected (a no-op) when grads are set to None
|
|
if i == 0:
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
optimizer.step()
|
|
|
|
state.append(optimizer.state)
|
|
res.append(model.parameters())
|
|
|
|
st_state = state[0]
|
|
mt_state = state[1]
|
|
for st_p, mt_p in zip(res[0], res[1]):
|
|
self.assertEqual(st_p, mt_p)
|
|
|
|
# check that optimizer states are the same
|
|
st_p_state = st_state[st_p]
|
|
mt_p_state = mt_state[mt_p]
|
|
|
|
for k in st_p_state:
|
|
self.assertEqual(st_p_state[k], mt_p_state[k])
|
|
|
|
def _test_foreach_memory(self, optimizer_pairs_with_flags):
|
|
if not torch.cuda.is_available():
|
|
return
|
|
|
|
device = "cuda"
|
|
nparams = 10
|
|
for optimizer_constructor, kwargs in optimizer_pairs_with_flags:
|
|
max_mems = []
|
|
for flag_value in (False, True):
|
|
kwargs_with_flags = deepcopy(kwargs)
|
|
kwargs_with_flags['foreach'] = flag_value
|
|
|
|
# The 128 is critical here! Our CUDACachingAllocator allocates in blocks of 512,
|
|
# meaning any tensor that occupies <512 bytes of memory will allocate a whole
|
|
# 512 bytes anyway. We use 128 (since datasize would be 4 bytes) so that param
|
|
# is size 512 exactly, making our later calculations for intermediate_size easy.
|
|
param = torch.rand(128, device=device)
|
|
params = [torch.rand_like(param) for _ in range(nparams)]
|
|
|
|
optimizer = optimizer_constructor(
|
|
params, **kwargs_with_flags
|
|
)
|
|
|
|
for p in params:
|
|
p.grad = torch.rand_like(p)
|
|
|
|
optimizer.step()
|
|
import gc
|
|
gc.collect()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
optimizer.step()
|
|
gc.collect()
|
|
max_mems.append(torch.cuda.max_memory_allocated())
|
|
|
|
st_max_mem, mt_max_mem = max_mems
|
|
intermediate_size = nparams * param.nelement() * param.element_size()
|
|
nintermediates = 1 # we expect a budget of 1 intermediate most of the time
|
|
if 'capturable' in kwargs_with_flags and kwargs_with_flags['capturable']:
|
|
# with capturable in Adam, we have 2 extra intermediates for the bias_corrections
|
|
nintermediates = 3
|
|
|
|
self.assertLessEqual(mt_max_mem, st_max_mem + intermediate_size * nintermediates)
|
|
|
|
@property
|
|
def _multi_tensor_optimizer_configs(self):
|
|
return [
|
|
(optim.Adam, dict(weight_decay=1.0, amsgrad=False)),
|
|
(optim.Adam, dict(weight_decay=0.0, amsgrad=True)),
|
|
(optim.Adam, dict(weight_decay=0.0, amsgrad=False, maximize=True)),
|
|
(optim.Adam, dict(weight_decay=1.0, amsgrad=True, maximize=True)),
|
|
(optim.Adam, dict(weight_decay=0.0, amsgrad=False, capturable=True, maximize=True)),
|
|
(optim.Adam, dict(weight_decay=1.0, amsgrad=True, capturable=True, maximize=True)),
|
|
(optim.AdamW, dict(weight_decay=1.0, amsgrad=False)),
|
|
(optim.AdamW, dict(weight_decay=0.0, amsgrad=True)),
|
|
(optim.AdamW, dict(weight_decay=1.0, amsgrad=True, maximize=True)),
|
|
(optim.AdamW, dict(weight_decay=0.0, amsgrad=False, maximize=True)),
|
|
(optim.AdamW, dict(weight_decay=1.0, amsgrad=True, capturable=True, maximize=True)),
|
|
(optim.AdamW, dict(weight_decay=0.0, amsgrad=False, capturable=True, maximize=True)),
|
|
(optim.NAdam, dict(weight_decay=0.0, momentum_decay=6e-3)),
|
|
(optim.NAdam, dict(weight_decay=1.0, momentum_decay=6e-3)),
|
|
(optim.NAdam, dict(weight_decay=0.0, momentum_decay=4e-3)),
|
|
(optim.NAdam, dict(weight_decay=0.01, momentum_decay=4e-3)),
|
|
(
|
|
optim.SGD,
|
|
dict(lr=0.2, momentum=1, dampening=0, weight_decay=1, nesterov=True),
|
|
),
|
|
(
|
|
optim.SGD,
|
|
dict(lr=0.2, momentum=1, dampening=0.5, weight_decay=1, nesterov=False),
|
|
),
|
|
(optim.RAdam, dict(weight_decay=0, eps=1e-6)),
|
|
(optim.RAdam, dict(weight_decay=0)),
|
|
(optim.RAdam, dict(weight_decay=1, eps=1e-6)),
|
|
(optim.RAdam, dict(weight_decay=1)),
|
|
(optim.RMSprop, dict(weight_decay=1, momentum=1, centered=True)),
|
|
(optim.RMSprop, dict(weight_decay=1, momentum=0, centered=True)),
|
|
(optim.RMSprop, dict(weight_decay=1, momentum=1, centered=False)),
|
|
(optim.RMSprop, dict(weight_decay=0, momentum=1, centered=False)),
|
|
(optim.Rprop, dict(lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50))),
|
|
(optim.ASGD, dict(weight_decay=0)),
|
|
(optim.ASGD, dict(weight_decay=1)),
|
|
(optim.Adamax, dict(weight_decay=0)),
|
|
(optim.Adamax, dict(weight_decay=1)),
|
|
(optim.Adadelta, dict(weight_decay=0)),
|
|
(optim.Adadelta, dict(weight_decay=1)),
|
|
(optim.Adagrad, dict(weight_decay=0)),
|
|
(optim.Adagrad, dict(weight_decay=1)),
|
|
]
|
|
|
|
def test_multi_tensor_optimizers(self):
|
|
self._test_derived_optimizers(self._multi_tensor_optimizer_configs, "foreach")
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
|
def test_multi_tensor_optimizers_with_varying_tensors(self):
|
|
self._test_derived_optimizers_varying_tensors(self._multi_tensor_optimizer_configs, "foreach")
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Requires a GPU")
|
|
@largeTensorTest("72GB", "cuda")
|
|
def test_multi_tensor_optimizers_with_large_tensors(self):
|
|
for optimizer_ctor, optimizer_params in self._multi_tensor_optimizer_configs:
|
|
# note(crcrpar): H100 wasn't sufficient for Adamax, surprisingly
|
|
if optimizer_ctor == optim.Adamax:
|
|
continue
|
|
params = [torch.ones(2 ** 32, device="cuda", dtype=torch.float16)]
|
|
params[0].grad = torch.zeros_like(params[0])
|
|
optimizer = optimizer_ctor(params, foreach=True, **optimizer_params)
|
|
optimizer.step()
|
|
|
|
def test_peak_mem_multi_tensor_optimizers(self):
|
|
configs = [(o, d) for (o, d) in self._multi_tensor_optimizer_configs if o.__name__ == "Adam"]
|
|
self._test_foreach_memory(configs)
|
|
|
|
@property
|
|
def _fused_optimizer_configs(self):
|
|
return tuple(itertools.product(
|
|
(optim.Adam, optim.AdamW),
|
|
(
|
|
dict(weight_decay=1., amsgrad=False, capturable=True, maximize=True),
|
|
dict(weight_decay=1., amsgrad=False, maximize=True),
|
|
dict(weight_decay=1., amsgrad=True),
|
|
dict(weight_decay=0., amsgrad=False),
|
|
dict(weight_decay=0., amsgrad=True, capturable=True, maximize=True),
|
|
dict(weight_decay=0., amsgrad=True, maximize=True),
|
|
),
|
|
))
|
|
|
|
def test_fused_optimizers(self):
|
|
self._test_derived_optimizers(self._fused_optimizer_configs, "fused")
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
|
def test_fused_optimizers_with_varying_tensors(self):
|
|
self._test_derived_optimizers_varying_tensors(self._fused_optimizer_configs, "fused")
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Requires a GPU")
|
|
@largeTensorTest("64GB", "cuda")
|
|
def test_fused_optimizers_with_large_tensors(self):
|
|
for optimizer_ctor, optimizer_params in self._fused_optimizer_configs:
|
|
params = [torch.ones(2 ** 32, device="cuda", dtype=torch.float16)]
|
|
params[0].grad = torch.zeros_like(params[0])
|
|
optimizer = optimizer_ctor(params, fused=True, **optimizer_params)
|
|
optimizer.step()
|
|
|
|
def test_adam(self):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
amsgrad=True,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
weight_decay=0.1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
amsgrad=True,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[lambda opt: ExponentialLR(opt, gamma=0.9)],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
amsgrad=True,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[
|
|
lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
|
|
lambda opt: ExponentialLR(opt, gamma=0.9),
|
|
],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
amsgrad=True,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[
|
|
lambda opt: ExponentialLR(opt, gamma=0.9),
|
|
lambda opt: ReduceLROnPlateau(opt),
|
|
],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
amsgrad=True,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[
|
|
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
|
lambda opt: ReduceLROnPlateau(opt),
|
|
],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[lambda opt: PolynomialLR(opt, total_iters=4, power=0.9)],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_complex_2d(optim.Adam)
|
|
self._test_complex_2d(functools.partial(optim.Adam, foreach=True))
|
|
self._test_complex_2d(functools.partial(optim.Adam, foreach=True, weight_decay=0.2))
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Invalid beta parameter at index 0: 1.0"
|
|
):
|
|
optim.Adam(None, lr=1e-2, betas=(1.0, 0.0))
|
|
|
|
with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"):
|
|
optim.Adam(None, lr=1e-2, weight_decay=-1)
|
|
|
|
def test_adamw(self):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.AdamW(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.AdamW(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.AdamW(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
weight_decay=1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.AdamW(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
weight_decay=1,
|
|
amsgrad=True,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_complex_2d(optim.AdamW)
|
|
self._test_complex_2d(functools.partial(optim.AdamW, foreach=True))
|
|
with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"):
|
|
optim.AdamW(None, lr=1e-2, weight_decay=-1)
|
|
|
|
def test_sparse_adam(self):
|
|
self._test_rosenbrock_sparse(
|
|
lambda params: optim.SparseAdam(params, lr=4e-2), [], True
|
|
)
|
|
self._test_rosenbrock_sparse(
|
|
lambda params: optim.SparseAdam(params, lr=4e-2, maximize=True),
|
|
scheduler_constructors=[],
|
|
sparse_only=True,
|
|
maximize=True,
|
|
)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Invalid beta parameter at index 0: 1.0"
|
|
):
|
|
optim.SparseAdam(None, lr=1e-2, betas=(1.0, 0.0))
|
|
with self.assertRaisesRegex(
|
|
ValueError, "SparseAdam requires dense parameter tensors"
|
|
):
|
|
optim.SparseAdam([torch.zeros(3, layout=torch.sparse_coo)])
|
|
with self.assertRaisesRegex(
|
|
ValueError, "SparseAdam requires dense parameter tensors"
|
|
):
|
|
optim.SparseAdam([{"params": [torch.zeros(3, layout=torch.sparse_coo)]}])
|
|
|
|
# ROCm precision is too low to pass this test
|
|
def test_adadelta(self):
|
|
# Handles https://github.com/pytorch/pytorch/issues/69698
|
|
self.rel_tol = 4e-3
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adadelta(
|
|
[weight, bias], maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adadelta(
|
|
self._build_params_dict(weight, bias, rho=0.95),
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adadelta(
|
|
self._build_params_dict(weight, bias, rho=0.95),
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[
|
|
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
|
lambda opt: ReduceLROnPlateau(opt),
|
|
],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adadelta(
|
|
[weight, bias], weight_decay=1, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "Invalid rho value: 1.1"):
|
|
optim.Adadelta(None, lr=1e-2, rho=1.1)
|
|
|
|
def test_adadelta_complex(self):
|
|
# Handles https://github.com/pytorch/pytorch/issues/69698
|
|
self.rel_tol = 2e-2
|
|
for optimizer in [optim.Adadelta]:
|
|
self._test_complex_optimizer(lambda weight: optimizer([weight]))
|
|
self._test_complex_optimizer(lambda weight: optimizer([weight], rho=0.95))
|
|
self._test_complex_optimizer(
|
|
lambda weight: optimizer([weight], rho=0.95, weight_decay=1)
|
|
)
|
|
|
|
def test_nadam(self):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, foreach: optim.NAdam(
|
|
self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, foreach=foreach
|
|
),
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, foreach: optim.NAdam(
|
|
[weight, bias], lr=1e-3, foreach=foreach
|
|
),
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, foreach: optim.NAdam(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
weight_decay=0.1,
|
|
momentum_decay=6e-3,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, foreach: optim.NAdam(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
weight_decay=0.1,
|
|
momentum_decay=6e-3,
|
|
foreach=foreach,
|
|
),
|
|
[lambda opt: ExponentialLR(opt, gamma=0.9)],
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Invalid beta parameter at index 0: 1.0"
|
|
):
|
|
optim.NAdam(None, lr=1e-2, betas=(1.0, 0.0))
|
|
with self.assertRaisesRegex(ValueError, "Invalid momentum_decay value: -0.2"):
|
|
optim.NAdam(None, lr=1e-2, momentum_decay=-0.2)
|
|
|
|
def test_adagrad(self):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adagrad(
|
|
[weight, bias], lr=1e-1, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adagrad(
|
|
[weight, bias],
|
|
lr=1e-1,
|
|
initial_accumulator_value=0.1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adagrad(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adagrad(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[lambda opt: ReduceLROnPlateau(opt)],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adagrad(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[
|
|
lambda opt: ReduceLROnPlateau(opt),
|
|
lambda opt: ExponentialLR(opt, gamma=0.99),
|
|
],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "Invalid lr_decay value: -0.5"):
|
|
optim.Adagrad(None, lr=1e-2, lr_decay=-0.5)
|
|
|
|
def test_adagrad_sparse(self):
|
|
for foreach in (False, True):
|
|
self._test_rosenbrock_sparse(
|
|
lambda params: optim.Adagrad(params, lr=1e-1, foreach=foreach)
|
|
)
|
|
self._test_rosenbrock_sparse(
|
|
lambda params: optim.Adagrad(params, lr=0.1, foreach=foreach),
|
|
scheduler_constructors=[
|
|
lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500),
|
|
lambda opt: ReduceLROnPlateau(opt, threshold=1e-4),
|
|
],
|
|
)
|
|
|
|
def test_adagrad_complex(self):
|
|
for foreach in (False, True):
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.Adagrad([param], lr=1e-1, foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.Adagrad(
|
|
[param],
|
|
lr=1e-1,
|
|
initial_accumulator_value=0.1,
|
|
foreach=foreach,
|
|
)
|
|
)
|
|
|
|
def test_adamax(self):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adamax(
|
|
[weight, bias], lr=1e-1, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adamax(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adamax(
|
|
[weight, bias],
|
|
lr=1e-1,
|
|
weight_decay=1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_complex_2d(optim.Adamax)
|
|
self._test_complex_2d(functools.partial(optim.Adamax, foreach=True))
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Invalid beta parameter at index 1: 1.0"
|
|
):
|
|
optim.Adamax(None, lr=1e-2, betas=(0.0, 1.0))
|
|
|
|
def test_radam(self):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, foreach: optim.RAdam(
|
|
[weight, bias], lr=1e-3, foreach=foreach
|
|
),
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, foreach: optim.RAdam(
|
|
self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, foreach=foreach
|
|
),
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, foreach: optim.RAdam(
|
|
[weight, bias], lr=1e-3, weight_decay=0.1, foreach=foreach
|
|
),
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, foreach: optim.RAdam(
|
|
[weight, bias], lr=1e-3, foreach=foreach
|
|
),
|
|
[
|
|
lambda opt: ExponentialLR(opt, gamma=0.9),
|
|
lambda opt: ReduceLROnPlateau(opt),
|
|
],
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Invalid beta parameter at index 0: 1.0"
|
|
):
|
|
optim.RAdam(None, lr=1e-2, betas=(1.0, 0.0))
|
|
|
|
with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"):
|
|
optim.RAdam(None, lr=1e-2, weight_decay=-1)
|
|
|
|
def test_rmsprop(self):
|
|
for foreach in (False, True):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.RMSprop(
|
|
[weight, bias], lr=1e-2, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.RMSprop(
|
|
self._build_params_dict(weight, bias, lr=1e-3),
|
|
lr=1e-2,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.RMSprop(
|
|
self._build_params_dict(weight, bias, lr=1e-3),
|
|
lr=1e-2,
|
|
centered=True,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.RMSprop(
|
|
self._build_params_dict(weight, bias, lr=1e-3),
|
|
lr=1e-2,
|
|
centered=True,
|
|
momentum=0.1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.RMSprop(
|
|
self._build_params_dict(weight, bias, lr=1e-3),
|
|
lr=1e-2,
|
|
momentum=0.1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.RMSprop(
|
|
self._build_params_dict(weight, bias, lr=1e-3),
|
|
lr=1e-2,
|
|
momentum=0.1,
|
|
weight_decay=1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_complex_2d(lambda param: optim.RMSprop(param, foreach=foreach))
|
|
self._test_complex_2d(
|
|
lambda param: optim.RMSprop(param, centered=True, foreach=foreach)
|
|
)
|
|
self._test_complex_2d(
|
|
lambda param: optim.RMSprop(param, momentum=0.1, foreach=foreach)
|
|
)
|
|
self._test_complex_2d(
|
|
lambda param: optim.RMSprop(param, maximize=True, foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.RMSprop([param], foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.RMSprop([param], centered=True, foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.RMSprop([param], momentum=0.1, foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.RMSprop([param], maximize=True, foreach=foreach)
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "Invalid momentum value: -1.0"):
|
|
optim.RMSprop(None, lr=1e-2, momentum=-1.0, foreach=foreach)
|
|
|
|
def test_asgd(self):
|
|
for foreach in (False, True):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.ASGD(
|
|
[weight, bias], lr=1e-3, t0=100, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.ASGD(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
t0=100,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.ASGD(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
weight_decay=1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
# Ref: https://github.com/pytorch/pytorch/issues/84560
|
|
# self._test_complex_2d(optimizer)
|
|
self._test_complex_optimizer(
|
|
lambda params: optim.ASGD([params], foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda params: optim.ASGD([params], maximize=True, foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda params: optim.ASGD(
|
|
[params], maximize=True, weight_decay=0.9, foreach=foreach
|
|
)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda params: optim.ASGD(
|
|
[params], maximize=False, weight_decay=0.9, foreach=foreach
|
|
)
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -0.5"):
|
|
optim.ASGD(None, lr=1e-2, weight_decay=-0.5, foreach=foreach)
|
|
|
|
@skipIfRocm
|
|
@skipIfTorchDynamo()
|
|
def test_rprop(self):
|
|
is_cuda_sm86 = torch.cuda.is_available() and torch.cuda.get_device_capability(
|
|
0
|
|
) == (8, 6)
|
|
for foreach in (False, True):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Rprop(
|
|
[weight, bias], lr=2e-4, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Rprop(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=2e-4,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
atol=4e-5 if is_cuda_sm86 else None,
|
|
rtol=3e-5 if is_cuda_sm86 else None,
|
|
)
|
|
self._test_complex_2d(lambda param: optim.Rprop(param, foreach=foreach))
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.Rprop([param], lr=0.001, foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.Rprop(
|
|
[param], lr=0.001, maximize=True, foreach=foreach
|
|
)
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "Invalid eta values: 1.0, 0.5"):
|
|
optim.Rprop(None, lr=1e-2, etas=(1.0, 0.5), foreach=foreach)
|
|
|
|
def test_lbfgs(self):
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.LBFGS([weight, bias]), ignore_multidevice=True
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.LBFGS(
|
|
[weight, bias], line_search_fn="strong_wolfe"
|
|
),
|
|
ignore_multidevice=True,
|
|
)
|
|
|
|
def test_lbfgs_returns_consistent_type(self):
|
|
params = [torch.randn(10, 5), torch.randn(10)]
|
|
opt1 = optim.LBFGS(params, 0.01, tolerance_grad=math.inf)
|
|
opt2 = optim.LBFGS(params, 0.01, tolerance_grad=-math.inf)
|
|
|
|
def closure():
|
|
return torch.tensor([10])
|
|
|
|
res1 = opt1.step(closure)
|
|
res2 = opt2.step(closure)
|
|
self.assertEqual(type(res1), type(res2))
|
|
|
|
def test_invalid_param_type(self):
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
'params argument given to the optimizer should be an iterable of Tensors or dicts',
|
|
lambda: optim.LBFGS(Parameter(torch.randn(5, 5)))
|
|
)
|
|
|
|
def test_duplicate_params_in_one_param_group(self):
|
|
param = Parameter(torch.randn(1))
|
|
with self.assertWarnsOnceRegex(UserWarning, '.*a parameter group with duplicate parameters.*'):
|
|
optim.Adamax([param, param], lr=0.01)
|
|
|
|
def test_duplicate_params_across_param_groups(self):
|
|
param = Parameter(torch.randn(1))
|
|
self.assertRaisesRegex(
|
|
ValueError,
|
|
'some parameters appear in more than one parameter group',
|
|
lambda: optim.Adadelta([{'params': param}, {'params': param}])
|
|
)
|
|
|
|
def test_step_is_noop_when_params_have_no_grad(self):
|
|
params = [torch.randn(2, 3, requires_grad=False) for _ in range(2)]
|
|
old_params = [p.clone().detach() for p in params]
|
|
|
|
def closure():
|
|
return torch.tensor([1])
|
|
|
|
optimizer_list = [
|
|
optim.Adadelta,
|
|
optim.AdamW,
|
|
optim.Adam,
|
|
optim.RAdam,
|
|
optim.NAdam,
|
|
optim.Adagrad,
|
|
optim.Adamax,
|
|
optim.RMSprop,
|
|
optim.SGD,
|
|
optim.SparseAdam,
|
|
optim.ASGD,
|
|
optim.LBFGS
|
|
]
|
|
for optim_ctr in optimizer_list:
|
|
opt = optim_ctr(params, lr=0.1)
|
|
opt.step(closure)
|
|
self.assertEqual(old_params, params)
|
|
|
|
|
|
def test_step_is_noop_for_empty_grads(self):
|
|
optimizers = [
|
|
optim.Adadelta,
|
|
optim.AdamW,
|
|
optim.Adam,
|
|
optim.RAdam,
|
|
optim.NAdam,
|
|
optim.Adagrad,
|
|
optim.Adamax,
|
|
optim.RMSprop,
|
|
optim.SGD,
|
|
optim.SparseAdam,
|
|
optim.ASGD,
|
|
optim.LBFGS
|
|
]
|
|
param = torch.randn(5, 1, requires_grad=True)
|
|
old_param = param.clone().detach()
|
|
|
|
def closure():
|
|
return torch.tensor([1])
|
|
|
|
for optimizer in optimizers:
|
|
opt = optimizer([param], lr=1e-5)
|
|
param.grad = torch.zeros_like(param)
|
|
if optimizer is optim.SparseAdam:
|
|
# Intentionally construct a multidimensional empty v for the sparse grad
|
|
# Single dim v passes the test while multidim correctly repros the issue
|
|
# https://github.com/pytorch/pytorch/issues/82486
|
|
i = torch.empty(1, 0)
|
|
v = torch.empty(0, 1)
|
|
param.grad = torch.sparse_coo_tensor(i, v, (5, 1))
|
|
opt.step(closure)
|
|
self.assertEqual(old_param, param)
|
|
|
|
|
|
def test_fused_optimizer_does_not_step_if_foundinf(self):
|
|
if not torch.cuda.is_available():
|
|
self.skipTest("CUDA is required.")
|
|
|
|
from torch.optim import adam, adamw
|
|
|
|
num_tensors = 5
|
|
for functional_optim, amsgrad, no_grad_scale in itertools.product((adam.adam, adamw.adamw), (False, True), (False, True)):
|
|
params, grads, exp_avgs, exp_avg_sqs = [
|
|
[torch.ones((1,), device="cuda") for _ in range(num_tensors)] for _ in range(4)]
|
|
prev_params = [t.clone().detach() for t in params]
|
|
max_exp_avg_sqs = [torch.ones((1,), device="cuda") for _ in range(num_tensors)] if amsgrad else []
|
|
state_steps = [torch.ones((), dtype=torch.float32, device="cuda") for _ in range(num_tensors)]
|
|
grad_scale = None if no_grad_scale else torch.ones((1,), dtype=torch.float32, device="cuda")
|
|
found_inf = torch.ones((), dtype=torch.float32, device="cuda")
|
|
|
|
functional_optim(
|
|
params,
|
|
grads,
|
|
exp_avgs,
|
|
exp_avg_sqs,
|
|
max_exp_avg_sqs,
|
|
state_steps,
|
|
foreach=False,
|
|
capturable=False,
|
|
fused=True,
|
|
amsgrad=amsgrad,
|
|
beta1=0.9,
|
|
beta2=0.99,
|
|
lr=1e-2,
|
|
weight_decay=0.0,
|
|
eps=1e-8,
|
|
maximize=False,
|
|
grad_scale=grad_scale,
|
|
found_inf=found_inf,
|
|
)
|
|
|
|
self.assertEqual(
|
|
state_steps,
|
|
[
|
|
torch.ones((), dtype=torch.float32, device="cuda")
|
|
for _ in range(num_tensors)
|
|
],
|
|
)
|
|
self.assertEqual(params, prev_params)
|
|
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required.")
|
|
def test_fused_optimizer_load_state_dict(self):
|
|
# NOTE: This SIMULATES a fused/capturable optimizer with state moved to CPU, issue 103256
|
|
# How do we get there? Users typically create CUDA models on fused optimizers and then
|
|
# store checkpoints on CPU as CUDA memory is limited with torch.load(...map_location="cpu").
|
|
# Since this is a unit test, it is more expedient to simulate what the state_dict
|
|
# would look like, which is basically CPU tensors with fused/capturable flag = True.
|
|
for optimC, kwarg in itertools.product((Adam, optim.AdamW), ("fused", "capturable")):
|
|
input = torch.tensor([0.1, 0.2], dtype=torch.float32, device="cpu")
|
|
optimizer = optimC([input])
|
|
optimizer.zero_grad()
|
|
input.grad = torch.rand_like(input)
|
|
optimizer.step()
|
|
optim_state_dict_cpu = deepcopy(optimizer.state_dict())
|
|
optim_state_dict_cpu["param_groups"][0][kwarg] = True
|
|
|
|
# load
|
|
input_cuda = input.clone().detach().to(device="cuda")
|
|
defaults = {kwarg: True}
|
|
optimizer_cuda = optimC([input_cuda], **defaults)
|
|
optimizer_cuda.load_state_dict(optim_state_dict_cpu)
|
|
optimizer_cuda.zero_grad()
|
|
input_cuda.grad = torch.rand_like(input_cuda)
|
|
optimizer_cuda.step()
|
|
|
|
|
|
@skipIfTorchDynamo()
|
|
def test_post_hook(self):
|
|
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
|
nonlocal data
|
|
data += 2
|
|
|
|
params = [torch.Tensor([1, 1])]
|
|
opt = SGD(params, lr=0.001)
|
|
data = 2
|
|
hook_handle = opt.register_step_post_hook(post_hook)
|
|
|
|
opt.step()
|
|
opt.step()
|
|
# check if pre hooks were registered
|
|
self.assertEqual(data, 6)
|
|
|
|
# remove handles, take step and verify that hook is no longer registered
|
|
hook_handle.remove()
|
|
|
|
opt.step()
|
|
self.assertEqual(data, 6)
|
|
|
|
@skipIfTorchDynamo()
|
|
def test_pre_hook(self):
|
|
def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
|
nonlocal data
|
|
data += 2
|
|
|
|
params = [torch.Tensor([1, 1])]
|
|
opt = SGD(params, lr=0.001)
|
|
data = 5
|
|
hook_handle = opt.register_step_pre_hook(pre_hook)
|
|
|
|
opt.step()
|
|
opt.step()
|
|
# check if pre hooks were registered
|
|
self.assertEqual(data, 9)
|
|
|
|
# remove handles, take step and verify that hook is no longer registered
|
|
hook_handle.remove()
|
|
|
|
opt.step()
|
|
self.assertEqual(data, 9)
|
|
|
|
@skipIfTorchDynamo()
|
|
def test_pre_and_post_hook(self):
|
|
def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
|
nonlocal data
|
|
data.append(0)
|
|
|
|
def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
|
nonlocal data
|
|
data.append(5)
|
|
|
|
def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
|
nonlocal data
|
|
data.append(1)
|
|
|
|
def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
|
nonlocal data
|
|
data.append(2)
|
|
|
|
params = [torch.Tensor([1, 1])]
|
|
opt1 = SGD(params, lr=0.001)
|
|
opt2 = Adam(params, lr=0.01)
|
|
data = []
|
|
|
|
# register global hooks to both optimizers
|
|
global_pre_handle = register_optimizer_step_pre_hook(global_pre_hook)
|
|
global_post_handle = register_optimizer_step_post_hook(global_post_hook)
|
|
|
|
# register local hooks
|
|
first_pre_handle = opt1.register_step_pre_hook(local_pre_hook)
|
|
first_post_handle = opt1.register_step_post_hook(local_post_hook)
|
|
second_pre_handle = opt2.register_step_pre_hook(local_pre_hook)
|
|
second_post_handle = opt2.register_step_post_hook(local_post_hook)
|
|
|
|
opt1.step()
|
|
self.assertListEqual(data, [0, 1, 2, 5])
|
|
opt2.step()
|
|
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5])
|
|
opt1.step()
|
|
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
|
|
|
|
# remove all hooks
|
|
global_pre_handle.remove()
|
|
global_post_handle.remove()
|
|
first_pre_handle.remove()
|
|
first_post_handle.remove()
|
|
second_pre_handle.remove()
|
|
second_post_handle.remove()
|
|
|
|
opt1.step()
|
|
opt2.step()
|
|
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
|
|
|
|
def test_fused_optimizer_raises(self):
|
|
if not torch.cuda.is_available():
|
|
self.skipTest("Requires CUDA devices")
|
|
for optimizer_ctor in (torch.optim.Adam, torch.optim.AdamW):
|
|
with self.assertRaisesRegex(RuntimeError, "`fused` and `foreach` cannot be `True` together."):
|
|
optimizer_ctor([torch.empty((), device="cuda")], foreach=True, fused=True)
|
|
with self.assertRaisesRegex(RuntimeError, "`fused` does not support `differentiable`"):
|
|
optimizer_ctor([torch.empty((), device="cuda")], differentiable=True, fused=True)
|
|
|
|
|
|
def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
|
|
# Ignored is the list of values in `opt_differentiable_state`, we do this
|
|
# for `gradcheck` to correctly track the state tensors as function inputs
|
|
# because otherwise it can't unpack the values in the `opt_differentiable_state`
|
|
# dict
|
|
p = p.clone()
|
|
p.grad = grad
|
|
opt_differentiable_state = {
|
|
k: v.clone() if isinstance(v, torch.Tensor) else v
|
|
for k, v in opt_differentiable_state.items()
|
|
}
|
|
opt = opt_class([p], **kwargs)
|
|
opt.state[p].update(opt_differentiable_state)
|
|
opt.step()
|
|
return (p,) + tuple(
|
|
v
|
|
for v in opt.state[p].values()
|
|
if isinstance(v, torch.Tensor) and v.requires_grad
|
|
)
|
|
|
|
|
|
@skipIfTorchDynamo("Differentiable optimizers not supported")
|
|
class TestDifferentiableOptimizer(TestCase):
|
|
|
|
def test_sgd(self):
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state = {"momentum_buffer": mbuff}
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.SGD,
|
|
{"lr": 0.9, "differentiable": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
|
|
def test_adam(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
# `step` is not a continuous variable (even though we define it as a float)
|
|
# and so it shouldn't require gradients.
|
|
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
|
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["max_exp_avg_sq"] = torch.rand(
|
|
10, requires_grad=True, dtype=torch.float64
|
|
)
|
|
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.Adam,
|
|
{"lr": 0.9, "differentiable": True, "amsgrad": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
|
|
def test_rmsprop(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["step"] = 0
|
|
state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["momentum_buffer"] = torch.rand(
|
|
10, requires_grad=True, dtype=torch.float64
|
|
)
|
|
# This can cause issues with large values and nan due to sqrt ops
|
|
state["grad_avg"] = 1e-2 * torch.rand(
|
|
10, requires_grad=True, dtype=torch.float64
|
|
)
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.RMSprop,
|
|
{
|
|
"lr": 0.9,
|
|
"maximize": True,
|
|
"momentum": 0.9,
|
|
"differentiable": True,
|
|
"centered": True,
|
|
"weight_decay": 0.1,
|
|
},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
|
|
def test_adadelta(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
# `step` is not a continuous variable (even though we define it as a float)
|
|
# and so it shouldn't require gradients.
|
|
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
|
state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["acc_delta"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.Adadelta,
|
|
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
|
|
def test_adagrad(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
# `step` is not a continuous variable (even though we define it as a float)
|
|
# and so it shouldn't require gradients.
|
|
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
|
state["sum"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.Adagrad,
|
|
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
|
|
def test_adamax(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
# `step` is not a continuous variable (even though we define it as a float)
|
|
# and so it shouldn't require gradients.
|
|
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
|
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["exp_inf"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.Adamax,
|
|
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
|
|
@skipIfTorchDynamo("The inplace mu update fails with dynamo, "
|
|
"since this is only happening when differentiable is enabled, skipping for now")
|
|
def test_asgd(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
# `step` `eta` & `mu` are not continuous variables (even though we define them as floats)
|
|
# and so they shouldn't require gradients.
|
|
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
|
state["eta"] = torch.tensor(0.9, requires_grad=False, dtype=torch.float64)
|
|
state["mu"] = torch.tensor(1.0, requires_grad=False, dtype=torch.float64)
|
|
state["ax"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.ASGD,
|
|
{"lr": 0.9, "differentiable": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
def test_rprop(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
# `step` is not a continuous variable (even though we define it as a float)
|
|
# and so it shouldn't require gradients.
|
|
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
|
state["prev"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["step_size"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.Rprop,
|
|
{"lr": 0.9, "differentiable": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
def test_adamw(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
# `step` is not a continuous variable (even though we define it as a float)
|
|
# and so it shouldn't require gradients.
|
|
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
|
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["max_exp_avg_sq"] = torch.rand(
|
|
10, requires_grad=True, dtype=torch.float64
|
|
)
|
|
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.AdamW,
|
|
{"lr": 0.9, "differentiable": True, "amsgrad": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
def test_nadam(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
# `step` is not a continuous variable (even though we define it as a float)
|
|
# and so it shouldn't require gradients.
|
|
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
|
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["mu_product"] = torch.tensor(1.0, requires_grad=True, dtype=torch.float64)
|
|
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.NAdam,
|
|
{"lr": 0.9, "differentiable": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
def test_radam(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
# `step` is not a continuous variable (even though we define it as a float)
|
|
# and so it shouldn't require gradients.
|
|
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
|
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.RAdam,
|
|
{"lr": 0.9, "differentiable": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
|
|
def test_defaults_changed_to_foreach(self):
|
|
from torch.optim import (adam, adamw, nadam, sgd, radam, rmsprop, rprop,
|
|
asgd, adamax, adadelta, adagrad)
|
|
multi_optims = ((optim.Adam, adam, "_multi_tensor_adam"),
|
|
(optim.AdamW, adamw, "_multi_tensor_adamw"),
|
|
(optim.NAdam, nadam, "_multi_tensor_nadam"),
|
|
(optim.SGD, sgd, "_multi_tensor_sgd"),
|
|
(optim.RAdam, radam, "_multi_tensor_radam"),
|
|
(optim.RMSprop, rmsprop, "_multi_tensor_rmsprop"),
|
|
(optim.Rprop, rprop, "_multi_tensor_rprop"),
|
|
(optim.ASGD, asgd, "_multi_tensor_asgd"),
|
|
(optim.Adamax, adamax, "_multi_tensor_adamax"),
|
|
(optim.Adadelta, adadelta, "_multi_tensor_adadelta"),
|
|
(optim.Adagrad, adagrad, "_multi_tensor_adagrad"),)
|
|
|
|
model = torch.nn.Linear(5, 5)
|
|
model.to(dtype=torch.float64, device="cuda")
|
|
input = torch.rand(2, 5, dtype=torch.float64, device="cuda")
|
|
|
|
for opt, mod, func in multi_optims:
|
|
defaults = {}
|
|
if opt == optim.SGD:
|
|
defaults["lr"] = 1e-2
|
|
optimizer = opt(model.parameters(), **defaults)
|
|
optimizer.zero_grad()
|
|
output = model(input)
|
|
loss = output.sum()
|
|
loss.backward()
|
|
with patch.object(mod, func) as mocked_foreach_impl:
|
|
optimizer.step()
|
|
self.assertTrue(mocked_foreach_impl.called)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("These tests should be run through test/test_optim.py instead")
|