From 17ecd1e9cde6531bb9da0b6742c0470924684fce Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Wed, 24 Jan 2024 10:14:27 -0800 Subject: [PATCH] Migrate test_complex_optimizer to OptimizerInfo (#118160) This PR does what it says and more. 1. We increase coverage by a LOT! Previously, complex was not tested for many many configs, including foreach + maximize at the same time. Or the fused impls. Or just random configs people forgot about. 2. I rearranged the maximize conditional and the _view_as_real to preserve list-ness. This is needed for _view_as_real to function properly, I did add a comment in the Files Changed. This new order also just...makes more aesthetic sense. 3. Note that LBFGS and SparseAdam are skipped--they don't support complex and now we know. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118160 Approved by: https://github.com/mikaylagawarecki --- test/optim/test_optim.py | 156 ------------------- test/test_optim.py | 24 +++ torch/optim/adadelta.py | 6 +- torch/optim/adagrad.py | 6 +- torch/optim/adam.py | 6 +- torch/optim/adamax.py | 6 +- torch/optim/adamw.py | 6 +- torch/optim/asgd.py | 7 +- torch/optim/rmsprop.py | 18 +-- torch/testing/_internal/common_optimizers.py | 39 ++++- 10 files changed, 84 insertions(+), 190 deletions(-) diff --git a/test/optim/test_optim.py b/test/optim/test_optim.py index 3a1ac3dcd2d..744b3d0736f 100644 --- a/test/optim/test_optim.py +++ b/test/optim/test_optim.py @@ -270,19 +270,6 @@ class TestOptim(TestCase): constructor_accepts_foreach, ) - def _test_complex_optimizer(self, optimizer_constructor): - complex_param = torch.randn(5, 5, dtype=torch.complex64, requires_grad=True) - real_param = torch.view_as_real(complex_param).detach().clone().requires_grad_() - complex_opt = optimizer_constructor(complex_param) - real_opt = optimizer_constructor(real_param) - - for _ in range(3): - complex_param.grad = torch.randn_like(complex_param) - real_param.grad = torch.view_as_real(complex_param.grad) - complex_opt.step() - real_opt.step() - - self.assertEqual(torch.view_as_real(complex_param), real_param) def _test_complex_2d(self, optimizer_constructor): a1 = torch.randn(2, dtype=torch.complex64, requires_grad=True) @@ -398,40 +385,6 @@ class TestOptim(TestCase): multi_tensor=foreach, ) - def test_sgd_complex(self): - for foreach in (False, True): - self._test_complex_optimizer( - lambda param: SGD([param], lr=0.001, foreach=foreach) - ) - self._test_complex_optimizer( - lambda param: SGD([param], lr=0.001, momentum=1, foreach=foreach) - ) - self._test_complex_optimizer( - lambda param: SGD( - [param], lr=0.001, momentum=1, weight_decay=1, foreach=foreach - ) - ) - self._test_complex_optimizer( - lambda param: SGD( - [param], - lr=0.001, - nesterov=True, - momentum=1, - weight_decay=1, - foreach=foreach, - ) - ) - self._test_complex_optimizer( - lambda param: SGD( - [param], - lr=0.001, - momentum=1, - dampening=0.5, - weight_decay=1, - foreach=foreach, - ) - ) - def test_adam(self): self._test_basic_cases( @@ -603,15 +556,6 @@ class TestOptim(TestCase): ) - def test_adadelta_complex(self): - # Handles https://github.com/pytorch/pytorch/issues/110606 - self.rel_tol = 2e-2 - for foreach in (False, True): - self._test_complex_optimizer(lambda weight: Adadelta([weight], foreach=foreach)) - self._test_complex_optimizer(lambda weight: Adadelta([weight], rho=0.95, foreach=foreach)) - self._test_complex_optimizer( - lambda weight: Adadelta([weight], rho=0.95, weight_decay=1, foreach=foreach) - ) def test_nadam(self): self._test_basic_cases( @@ -640,28 +584,6 @@ class TestOptim(TestCase): ) - def test_nadam_complex(self): - for foreach in (False, True): - self._test_complex_optimizer( - lambda param: NAdam([param], lr=1e-1, foreach=foreach) - ) - self._test_complex_optimizer( - lambda param: NAdam( - [param], - lr=1e-1, - weight_decay=0.01, - foreach=foreach, - ) - ) - self._test_complex_optimizer( - lambda param: NAdam( - [param], - lr=1e-1, - momentum_decay=0.01, - foreach=foreach, - ) - ) - def test_adagrad(self): self._test_basic_cases( lambda weight, bias, maximize, foreach: Adagrad( @@ -705,19 +627,6 @@ class TestOptim(TestCase): multi_tensor=foreach, ) - def test_adagrad_complex(self): - for foreach in (False, True): - self._test_complex_optimizer( - lambda param: Adagrad([param], lr=1e-1, foreach=foreach) - ) - self._test_complex_optimizer( - lambda param: Adagrad( - [param], - lr=1e-1, - initial_accumulator_value=0.1, - foreach=foreach, - ) - ) def test_adamax(self): self._test_complex_2d(Adamax) @@ -748,29 +657,6 @@ class TestOptim(TestCase): ) - def test_radam_complex(self): - for foreach in (False, True): - self._test_complex_optimizer( - lambda param: RAdam([param], lr=1e-1, foreach=foreach) - ) - self._test_complex_optimizer( - lambda param: RAdam( - [param], - lr=1e-1, - weight_decay=0.01, - foreach=foreach, - ) - ) - self._test_complex_optimizer( - lambda param: RAdam( - [param], - lr=1e-1, - weight_decay=0.01, - decoupled_weight_decay=True, - foreach=foreach, - ) - ) - def test_rmsprop(self): for foreach in (False, True): self._test_complex_2d(lambda param: RMSprop(param, foreach=foreach)) @@ -783,40 +669,6 @@ class TestOptim(TestCase): self._test_complex_2d( lambda param: RMSprop(param, maximize=True, foreach=foreach) ) - self._test_complex_optimizer( - lambda param: RMSprop([param], foreach=foreach) - ) - self._test_complex_optimizer( - lambda param: RMSprop([param], centered=True, foreach=foreach) - ) - self._test_complex_optimizer( - lambda param: RMSprop([param], momentum=0.1, foreach=foreach) - ) - self._test_complex_optimizer( - lambda param: RMSprop([param], maximize=True, foreach=foreach) - ) - - - def test_asgd(self): - for foreach in (False, True): - # Ref: https://github.com/pytorch/pytorch/issues/84560 - # self._test_complex_2d(optimizer) - self._test_complex_optimizer( - lambda params: ASGD([params], foreach=foreach) - ) - self._test_complex_optimizer( - lambda params: ASGD([params], maximize=True, foreach=foreach) - ) - self._test_complex_optimizer( - lambda params: ASGD( - [params], maximize=True, weight_decay=0.1, foreach=foreach - ) - ) - self._test_complex_optimizer( - lambda params: ASGD( - [params], maximize=False, weight_decay=0.1, foreach=foreach - ) - ) @skipIfRocm @@ -824,14 +676,6 @@ class TestOptim(TestCase): def test_rprop(self): for foreach in (False, True): self._test_complex_2d(lambda param: Rprop(param, foreach=foreach)) - self._test_complex_optimizer( - lambda param: Rprop([param], lr=0.001, foreach=foreach) - ) - self._test_complex_optimizer( - lambda param: Rprop( - [param], lr=0.001, maximize=True, foreach=foreach - ) - ) def test_lbfgs_returns_consistent_type(self): diff --git a/test/test_optim.py b/test/test_optim.py index ea1beb89de0..73c08984d7e 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -128,6 +128,30 @@ class TestOptimRenewed(TestCase): self.assertLess(closure().item(), initial_value) + @skipMPS + @optims(optim_db, dtypes=[torch.complex64]) + def test_complex(self, device, dtype, optim_info): + optim_cls = optim_info.optim_cls + # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 + all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",)) + for optim_input in all_optim_inputs: + complex_params = [torch.randn(2, 3, device=device, dtype=dtype, requires_grad=True) for _ in range(3)] + real_params = [torch.view_as_real(p).detach().clone().requires_grad_(True) for p in complex_params] + + complex_optimizer = optim_cls(complex_params, **optim_input.kwargs) + real_optimizer = optim_cls(real_params, **optim_input.kwargs) + + for _ in range(3): + for (c, r) in zip(complex_params, real_params): + c.grad = torch.randn_like(c) + r.grad = torch.view_as_real(c.grad) + complex_optimizer.step() + real_optimizer.step() + + for (c, r) in zip(complex_params, real_params): + self.assertEqual(torch.view_as_real(c), r) + + def _test_derived_optimizers(self, device, dtype, optim_info, flag, reduced_precision=False, assert_step_dtype=None): """ Given a flag 'fused' or 'foreach', test for parity of optimizer state diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index 181008252d8..ac16c13101b 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -286,12 +286,12 @@ def _multi_tensor_adadelta( grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, square_avgs, acc_deltas]) for ((device_params, device_grads, device_square_avgs, device_acc_deltas), _) in grouped_tensors.values(): - if maximize: - device_grads = torch._foreach_neg(device_grads) - if has_complex: _view_as_real(device_params, device_grads, device_square_avgs, device_acc_deltas) + if maximize: + device_grads = torch._foreach_neg(device_grads) + if weight_decay != 0: # Re-use the intermediate memory (device_grads) already allocated for maximize if maximize: diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index 3d5ef79cc27..ce17e2d7ad3 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -344,13 +344,13 @@ def _multi_tensor_adagrad( ) continue - if maximize: - device_grads = torch._foreach_neg(device_grads) - # Handle complex parameters if has_complex: _view_as_real(device_params, device_grads, device_state_sums) + if maximize: + device_grads = torch._foreach_neg(device_grads) + # Update steps # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just diff --git a/torch/optim/adam.py b/torch/optim/adam.py index ee0223fee9d..15abef4d9af 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -489,9 +489,6 @@ def _multi_tensor_adam(params: List[Tensor], device_state_steps, ), _) in grouped_tensors.values(): - if maximize: - device_grads = torch._foreach_neg(device_grads) - # Handle complex parameters if has_complex: if amsgrad: @@ -499,6 +496,9 @@ def _multi_tensor_adam(params: List[Tensor], else: _view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs) + if maximize: + device_grads = torch._foreach_neg(device_grads) + # Update steps # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index d914110ee8f..8f0ee6c44d0 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -330,12 +330,12 @@ def _multi_tensor_adamax( grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_infs, state_steps]) for ((grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs, grouped_state_steps), _) in grouped_tensors.values(): - if maximize: - grouped_grads = torch._foreach_neg(grouped_grads) - if has_complex: _view_as_real(grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs) + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) + # Update steps # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 071e879d57e..bbf5fe7fd77 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -521,15 +521,15 @@ def _multi_tensor_adamw( device_max_exp_avg_sqs, device_state_steps, ), _) in grouped_tensors.values(): - if maximize: - device_grads = torch._foreach_neg(device_grads) - if has_complex: if amsgrad: _view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs, device_max_exp_avg_sqs) else: _view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs) + if maximize: + device_grads = torch._foreach_neg(device_grads) + # Update steps # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index 5465d84350a..c65411aaa21 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -313,13 +313,12 @@ def _multi_tensor_asgd( grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, axs, mus, etas, state_steps]) for ((device, _), ((grouped_params, grouped_grads, grouped_axs, grouped_mus, grouped_etas, grouped_state_steps), _)) in grouped_tensors.items(): - if maximize: - grouped_grads = torch._foreach_neg(grouped_grads) - - grouped_grads = list(grouped_grads) if has_complex: _view_as_real(grouped_params, grouped_grads, grouped_axs) + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) + # Update steps # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index bf7e0f737b9..62d28ae51a0 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -336,6 +336,14 @@ def _multi_tensor_rmsprop( grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, square_avgs, grad_avgs, momentum_buffer_list]) for (((grouped_params, grouped_grads, grouped_square_avgs, grouped_grad_avgs, grouped_momentum_buffer_list)), _) in grouped_tensors.values(): + if has_complex: + state_and_grads = [grouped_grads, grouped_square_avgs] + if momentum > 0: + state_and_grads.append(grouped_momentum_buffer_list) + if centered: + state_and_grads.append(grouped_grad_avgs) + _view_as_real(grouped_params, *state_and_grads) + if maximize: grouped_grads = torch._foreach_neg(grouped_grads) @@ -346,16 +354,6 @@ def _multi_tensor_rmsprop( else: grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay) - grouped_grads = list(grouped_grads) - - if has_complex: - state_and_grads = [grouped_grads, grouped_square_avgs] - if momentum > 0: - state_and_grads.append(grouped_momentum_buffer_list) - if centered: - state_and_grads.append(grouped_grad_avgs) - _view_as_real(grouped_params, *state_and_grads) - torch._foreach_mul_(grouped_square_avgs, alpha) torch._foreach_addcmul_(grouped_square_avgs, grouped_grads, grouped_grads, value=1 - alpha) diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 3534172150c..7b53b8cf38b 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -263,9 +263,7 @@ def get_error_inputs_for_all_optims(device, dtype): def optim_inputs_func_adadelta(device=None): return [ OptimizerInput(params=None, kwargs={}, desc="default"), - OptimizerInput( - params=None, kwargs={"lr": 0.01}, desc="non-default lr" - ), # TODO: Move out to testing in param_group? + OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"), OptimizerInput( params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" ), @@ -275,8 +273,8 @@ def optim_inputs_func_adadelta(device=None): desc="maximize", ), OptimizerInput( - params=None, kwargs={"rho": 0.95, "weight_decay": 0.1}, desc="rho" - ), # TODO: Move out to testing in param_group? + params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho" + ), ] @@ -494,6 +492,7 @@ def optim_inputs_func_asgd(device=None): OptimizerInput(params=None, kwargs={}, desc="default"), OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"), OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"), + OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"), OptimizerInput( params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" ), @@ -545,6 +544,21 @@ def optim_error_inputs_func_lbfgs(device, dtype): def optim_inputs_func_nadam(device=None): cuda_supported_configs = [ OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.9, "momentum_decay": 6e-3, "capturable": True}, + desc="weight_decay, capturable", + ), + OptimizerInput( + params=None, + kwargs={ + "weight_decay": 0.9, + "momentum_decay": 6e-3, + "decoupled_weight_decay": True, + "capturable": True, + }, + desc="decoupled_weight_decay, capturable", + ), ] return [ OptimizerInput(params=None, kwargs={}, desc="default"), @@ -1106,6 +1120,11 @@ optim_db: List[OptimizerInfo] = [ "TestOptimRenewed", "test_forloop_goes_right_direction_multigpu", ), + DecorateInfo( + skipIfTorchDynamo("Mismatched _foreach_addcdiv_ types, see #118159"), + "TestOptimRenewed", + "test_complex", + ), DecorateInfo( skipIfTorchDynamo( "See https://github.com/pytorch/pytorch/issues/115607" @@ -1319,6 +1338,11 @@ optim_db: List[OptimizerInfo] = [ "TestOptimRenewed", "test_forloop_goes_right_direction_multigpu", ), + DecorateInfo( + unittest.skip("Missing complex support, see #118148"), + "TestOptimRenewed", + "test_complex", + ), ), ), OptimizerInfo( @@ -1747,6 +1771,11 @@ optim_db: List[OptimizerInfo] = [ "TestOptimRenewed", "test_deepcopy_copies_all_public_attrs", ), + DecorateInfo( + unittest.skip("Missing complex support, see #118153"), + "TestOptimRenewed", + "test_complex", + ), ), ), ]