Migrate test_complex_optimizer to OptimizerInfo (#118160)

This PR does what it says and more.

1. We increase coverage by a LOT! Previously, complex was not tested for many many configs, including foreach + maximize at the same time. Or the fused impls. Or just random configs people forgot about.
2. I rearranged the maximize conditional and the _view_as_real to preserve list-ness. This is needed for _view_as_real to function properly, I did add a comment in the Files Changed. This new order also just...makes more aesthetic sense.
3. Note that LBFGS and SparseAdam are skipped--they don't support complex and now we know.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118160
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Jane Xu 2024-01-24 10:14:27 -08:00 committed by PyTorch MergeBot
parent 6978c3ddf3
commit 17ecd1e9cd
10 changed files with 84 additions and 190 deletions

View File

@ -270,19 +270,6 @@ class TestOptim(TestCase):
constructor_accepts_foreach,
)
def _test_complex_optimizer(self, optimizer_constructor):
complex_param = torch.randn(5, 5, dtype=torch.complex64, requires_grad=True)
real_param = torch.view_as_real(complex_param).detach().clone().requires_grad_()
complex_opt = optimizer_constructor(complex_param)
real_opt = optimizer_constructor(real_param)
for _ in range(3):
complex_param.grad = torch.randn_like(complex_param)
real_param.grad = torch.view_as_real(complex_param.grad)
complex_opt.step()
real_opt.step()
self.assertEqual(torch.view_as_real(complex_param), real_param)
def _test_complex_2d(self, optimizer_constructor):
a1 = torch.randn(2, dtype=torch.complex64, requires_grad=True)
@ -398,40 +385,6 @@ class TestOptim(TestCase):
multi_tensor=foreach,
)
def test_sgd_complex(self):
for foreach in (False, True):
self._test_complex_optimizer(
lambda param: SGD([param], lr=0.001, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: SGD([param], lr=0.001, momentum=1, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: SGD(
[param], lr=0.001, momentum=1, weight_decay=1, foreach=foreach
)
)
self._test_complex_optimizer(
lambda param: SGD(
[param],
lr=0.001,
nesterov=True,
momentum=1,
weight_decay=1,
foreach=foreach,
)
)
self._test_complex_optimizer(
lambda param: SGD(
[param],
lr=0.001,
momentum=1,
dampening=0.5,
weight_decay=1,
foreach=foreach,
)
)
def test_adam(self):
self._test_basic_cases(
@ -603,15 +556,6 @@ class TestOptim(TestCase):
)
def test_adadelta_complex(self):
# Handles https://github.com/pytorch/pytorch/issues/110606
self.rel_tol = 2e-2
for foreach in (False, True):
self._test_complex_optimizer(lambda weight: Adadelta([weight], foreach=foreach))
self._test_complex_optimizer(lambda weight: Adadelta([weight], rho=0.95, foreach=foreach))
self._test_complex_optimizer(
lambda weight: Adadelta([weight], rho=0.95, weight_decay=1, foreach=foreach)
)
def test_nadam(self):
self._test_basic_cases(
@ -640,28 +584,6 @@ class TestOptim(TestCase):
)
def test_nadam_complex(self):
for foreach in (False, True):
self._test_complex_optimizer(
lambda param: NAdam([param], lr=1e-1, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: NAdam(
[param],
lr=1e-1,
weight_decay=0.01,
foreach=foreach,
)
)
self._test_complex_optimizer(
lambda param: NAdam(
[param],
lr=1e-1,
momentum_decay=0.01,
foreach=foreach,
)
)
def test_adagrad(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adagrad(
@ -705,19 +627,6 @@ class TestOptim(TestCase):
multi_tensor=foreach,
)
def test_adagrad_complex(self):
for foreach in (False, True):
self._test_complex_optimizer(
lambda param: Adagrad([param], lr=1e-1, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: Adagrad(
[param],
lr=1e-1,
initial_accumulator_value=0.1,
foreach=foreach,
)
)
def test_adamax(self):
self._test_complex_2d(Adamax)
@ -748,29 +657,6 @@ class TestOptim(TestCase):
)
def test_radam_complex(self):
for foreach in (False, True):
self._test_complex_optimizer(
lambda param: RAdam([param], lr=1e-1, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: RAdam(
[param],
lr=1e-1,
weight_decay=0.01,
foreach=foreach,
)
)
self._test_complex_optimizer(
lambda param: RAdam(
[param],
lr=1e-1,
weight_decay=0.01,
decoupled_weight_decay=True,
foreach=foreach,
)
)
def test_rmsprop(self):
for foreach in (False, True):
self._test_complex_2d(lambda param: RMSprop(param, foreach=foreach))
@ -783,40 +669,6 @@ class TestOptim(TestCase):
self._test_complex_2d(
lambda param: RMSprop(param, maximize=True, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: RMSprop([param], foreach=foreach)
)
self._test_complex_optimizer(
lambda param: RMSprop([param], centered=True, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: RMSprop([param], momentum=0.1, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: RMSprop([param], maximize=True, foreach=foreach)
)
def test_asgd(self):
for foreach in (False, True):
# Ref: https://github.com/pytorch/pytorch/issues/84560
# self._test_complex_2d(optimizer)
self._test_complex_optimizer(
lambda params: ASGD([params], foreach=foreach)
)
self._test_complex_optimizer(
lambda params: ASGD([params], maximize=True, foreach=foreach)
)
self._test_complex_optimizer(
lambda params: ASGD(
[params], maximize=True, weight_decay=0.1, foreach=foreach
)
)
self._test_complex_optimizer(
lambda params: ASGD(
[params], maximize=False, weight_decay=0.1, foreach=foreach
)
)
@skipIfRocm
@ -824,14 +676,6 @@ class TestOptim(TestCase):
def test_rprop(self):
for foreach in (False, True):
self._test_complex_2d(lambda param: Rprop(param, foreach=foreach))
self._test_complex_optimizer(
lambda param: Rprop([param], lr=0.001, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: Rprop(
[param], lr=0.001, maximize=True, foreach=foreach
)
)
def test_lbfgs_returns_consistent_type(self):

View File

@ -128,6 +128,30 @@ class TestOptimRenewed(TestCase):
self.assertLess(closure().item(), initial_value)
@skipMPS
@optims(optim_db, dtypes=[torch.complex64])
def test_complex(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
for optim_input in all_optim_inputs:
complex_params = [torch.randn(2, 3, device=device, dtype=dtype, requires_grad=True) for _ in range(3)]
real_params = [torch.view_as_real(p).detach().clone().requires_grad_(True) for p in complex_params]
complex_optimizer = optim_cls(complex_params, **optim_input.kwargs)
real_optimizer = optim_cls(real_params, **optim_input.kwargs)
for _ in range(3):
for (c, r) in zip(complex_params, real_params):
c.grad = torch.randn_like(c)
r.grad = torch.view_as_real(c.grad)
complex_optimizer.step()
real_optimizer.step()
for (c, r) in zip(complex_params, real_params):
self.assertEqual(torch.view_as_real(c), r)
def _test_derived_optimizers(self, device, dtype, optim_info, flag, reduced_precision=False, assert_step_dtype=None):
"""
Given a flag 'fused' or 'foreach', test for parity of optimizer state

View File

@ -286,12 +286,12 @@ def _multi_tensor_adadelta(
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, square_avgs, acc_deltas])
for ((device_params, device_grads, device_square_avgs, device_acc_deltas), _) in grouped_tensors.values():
if maximize:
device_grads = torch._foreach_neg(device_grads)
if has_complex:
_view_as_real(device_params, device_grads, device_square_avgs, device_acc_deltas)
if maximize:
device_grads = torch._foreach_neg(device_grads)
if weight_decay != 0:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:

View File

@ -344,13 +344,13 @@ def _multi_tensor_adagrad(
)
continue
if maximize:
device_grads = torch._foreach_neg(device_grads)
# Handle complex parameters
if has_complex:
_view_as_real(device_params, device_grads, device_state_sums)
if maximize:
device_grads = torch._foreach_neg(device_grads)
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just

View File

@ -489,9 +489,6 @@ def _multi_tensor_adam(params: List[Tensor],
device_state_steps,
), _) in grouped_tensors.values():
if maximize:
device_grads = torch._foreach_neg(device_grads)
# Handle complex parameters
if has_complex:
if amsgrad:
@ -499,6 +496,9 @@ def _multi_tensor_adam(params: List[Tensor],
else:
_view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs)
if maximize:
device_grads = torch._foreach_neg(device_grads)
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just

View File

@ -330,12 +330,12 @@ def _multi_tensor_adamax(
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_infs, state_steps])
for ((grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs, grouped_state_steps), _) in grouped_tensors.values():
if maximize:
grouped_grads = torch._foreach_neg(grouped_grads)
if has_complex:
_view_as_real(grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs)
if maximize:
grouped_grads = torch._foreach_neg(grouped_grads)
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just

View File

@ -521,15 +521,15 @@ def _multi_tensor_adamw(
device_max_exp_avg_sqs,
device_state_steps,
), _) in grouped_tensors.values():
if maximize:
device_grads = torch._foreach_neg(device_grads)
if has_complex:
if amsgrad:
_view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs, device_max_exp_avg_sqs)
else:
_view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs)
if maximize:
device_grads = torch._foreach_neg(device_grads)
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just

View File

@ -313,13 +313,12 @@ def _multi_tensor_asgd(
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, axs, mus, etas, state_steps])
for ((device, _), ((grouped_params, grouped_grads, grouped_axs, grouped_mus,
grouped_etas, grouped_state_steps), _)) in grouped_tensors.items():
if maximize:
grouped_grads = torch._foreach_neg(grouped_grads)
grouped_grads = list(grouped_grads)
if has_complex:
_view_as_real(grouped_params, grouped_grads, grouped_axs)
if maximize:
grouped_grads = torch._foreach_neg(grouped_grads)
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just

View File

@ -336,6 +336,14 @@ def _multi_tensor_rmsprop(
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, square_avgs, grad_avgs, momentum_buffer_list])
for (((grouped_params, grouped_grads, grouped_square_avgs, grouped_grad_avgs,
grouped_momentum_buffer_list)), _) in grouped_tensors.values():
if has_complex:
state_and_grads = [grouped_grads, grouped_square_avgs]
if momentum > 0:
state_and_grads.append(grouped_momentum_buffer_list)
if centered:
state_and_grads.append(grouped_grad_avgs)
_view_as_real(grouped_params, *state_and_grads)
if maximize:
grouped_grads = torch._foreach_neg(grouped_grads)
@ -346,16 +354,6 @@ def _multi_tensor_rmsprop(
else:
grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)
grouped_grads = list(grouped_grads)
if has_complex:
state_and_grads = [grouped_grads, grouped_square_avgs]
if momentum > 0:
state_and_grads.append(grouped_momentum_buffer_list)
if centered:
state_and_grads.append(grouped_grad_avgs)
_view_as_real(grouped_params, *state_and_grads)
torch._foreach_mul_(grouped_square_avgs, alpha)
torch._foreach_addcmul_(grouped_square_avgs, grouped_grads, grouped_grads, value=1 - alpha)

View File

@ -263,9 +263,7 @@ def get_error_inputs_for_all_optims(device, dtype):
def optim_inputs_func_adadelta(device=None):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(
params=None, kwargs={"lr": 0.01}, desc="non-default lr"
), # TODO: Move out to testing in param_group?
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
),
@ -275,8 +273,8 @@ def optim_inputs_func_adadelta(device=None):
desc="maximize",
),
OptimizerInput(
params=None, kwargs={"rho": 0.95, "weight_decay": 0.1}, desc="rho"
), # TODO: Move out to testing in param_group?
params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
),
]
@ -494,6 +492,7 @@ def optim_inputs_func_asgd(device=None):
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"),
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
),
@ -545,6 +544,21 @@ def optim_error_inputs_func_lbfgs(device, dtype):
def optim_inputs_func_nadam(device=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.9, "momentum_decay": 6e-3, "capturable": True},
desc="weight_decay, capturable",
),
OptimizerInput(
params=None,
kwargs={
"weight_decay": 0.9,
"momentum_decay": 6e-3,
"decoupled_weight_decay": True,
"capturable": True,
},
desc="decoupled_weight_decay, capturable",
),
]
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
@ -1106,6 +1120,11 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_forloop_goes_right_direction_multigpu",
),
DecorateInfo(
skipIfTorchDynamo("Mismatched _foreach_addcdiv_ types, see #118159"),
"TestOptimRenewed",
"test_complex",
),
DecorateInfo(
skipIfTorchDynamo(
"See https://github.com/pytorch/pytorch/issues/115607"
@ -1319,6 +1338,11 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_forloop_goes_right_direction_multigpu",
),
DecorateInfo(
unittest.skip("Missing complex support, see #118148"),
"TestOptimRenewed",
"test_complex",
),
),
),
OptimizerInfo(
@ -1747,6 +1771,11 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
DecorateInfo(
unittest.skip("Missing complex support, see #118153"),
"TestOptimRenewed",
"test_complex",
),
),
),
]