mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Migrate test_complex_optimizer to OptimizerInfo (#118160)
This PR does what it says and more. 1. We increase coverage by a LOT! Previously, complex was not tested for many many configs, including foreach + maximize at the same time. Or the fused impls. Or just random configs people forgot about. 2. I rearranged the maximize conditional and the _view_as_real to preserve list-ness. This is needed for _view_as_real to function properly, I did add a comment in the Files Changed. This new order also just...makes more aesthetic sense. 3. Note that LBFGS and SparseAdam are skipped--they don't support complex and now we know. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118160 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
parent
6978c3ddf3
commit
17ecd1e9cd
|
|
@ -270,19 +270,6 @@ class TestOptim(TestCase):
|
|||
constructor_accepts_foreach,
|
||||
)
|
||||
|
||||
def _test_complex_optimizer(self, optimizer_constructor):
|
||||
complex_param = torch.randn(5, 5, dtype=torch.complex64, requires_grad=True)
|
||||
real_param = torch.view_as_real(complex_param).detach().clone().requires_grad_()
|
||||
complex_opt = optimizer_constructor(complex_param)
|
||||
real_opt = optimizer_constructor(real_param)
|
||||
|
||||
for _ in range(3):
|
||||
complex_param.grad = torch.randn_like(complex_param)
|
||||
real_param.grad = torch.view_as_real(complex_param.grad)
|
||||
complex_opt.step()
|
||||
real_opt.step()
|
||||
|
||||
self.assertEqual(torch.view_as_real(complex_param), real_param)
|
||||
|
||||
def _test_complex_2d(self, optimizer_constructor):
|
||||
a1 = torch.randn(2, dtype=torch.complex64, requires_grad=True)
|
||||
|
|
@ -398,40 +385,6 @@ class TestOptim(TestCase):
|
|||
multi_tensor=foreach,
|
||||
)
|
||||
|
||||
def test_sgd_complex(self):
|
||||
for foreach in (False, True):
|
||||
self._test_complex_optimizer(
|
||||
lambda param: SGD([param], lr=0.001, foreach=foreach)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda param: SGD([param], lr=0.001, momentum=1, foreach=foreach)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda param: SGD(
|
||||
[param], lr=0.001, momentum=1, weight_decay=1, foreach=foreach
|
||||
)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda param: SGD(
|
||||
[param],
|
||||
lr=0.001,
|
||||
nesterov=True,
|
||||
momentum=1,
|
||||
weight_decay=1,
|
||||
foreach=foreach,
|
||||
)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda param: SGD(
|
||||
[param],
|
||||
lr=0.001,
|
||||
momentum=1,
|
||||
dampening=0.5,
|
||||
weight_decay=1,
|
||||
foreach=foreach,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_adam(self):
|
||||
self._test_basic_cases(
|
||||
|
|
@ -603,15 +556,6 @@ class TestOptim(TestCase):
|
|||
)
|
||||
|
||||
|
||||
def test_adadelta_complex(self):
|
||||
# Handles https://github.com/pytorch/pytorch/issues/110606
|
||||
self.rel_tol = 2e-2
|
||||
for foreach in (False, True):
|
||||
self._test_complex_optimizer(lambda weight: Adadelta([weight], foreach=foreach))
|
||||
self._test_complex_optimizer(lambda weight: Adadelta([weight], rho=0.95, foreach=foreach))
|
||||
self._test_complex_optimizer(
|
||||
lambda weight: Adadelta([weight], rho=0.95, weight_decay=1, foreach=foreach)
|
||||
)
|
||||
|
||||
def test_nadam(self):
|
||||
self._test_basic_cases(
|
||||
|
|
@ -640,28 +584,6 @@ class TestOptim(TestCase):
|
|||
)
|
||||
|
||||
|
||||
def test_nadam_complex(self):
|
||||
for foreach in (False, True):
|
||||
self._test_complex_optimizer(
|
||||
lambda param: NAdam([param], lr=1e-1, foreach=foreach)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda param: NAdam(
|
||||
[param],
|
||||
lr=1e-1,
|
||||
weight_decay=0.01,
|
||||
foreach=foreach,
|
||||
)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda param: NAdam(
|
||||
[param],
|
||||
lr=1e-1,
|
||||
momentum_decay=0.01,
|
||||
foreach=foreach,
|
||||
)
|
||||
)
|
||||
|
||||
def test_adagrad(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adagrad(
|
||||
|
|
@ -705,19 +627,6 @@ class TestOptim(TestCase):
|
|||
multi_tensor=foreach,
|
||||
)
|
||||
|
||||
def test_adagrad_complex(self):
|
||||
for foreach in (False, True):
|
||||
self._test_complex_optimizer(
|
||||
lambda param: Adagrad([param], lr=1e-1, foreach=foreach)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda param: Adagrad(
|
||||
[param],
|
||||
lr=1e-1,
|
||||
initial_accumulator_value=0.1,
|
||||
foreach=foreach,
|
||||
)
|
||||
)
|
||||
|
||||
def test_adamax(self):
|
||||
self._test_complex_2d(Adamax)
|
||||
|
|
@ -748,29 +657,6 @@ class TestOptim(TestCase):
|
|||
)
|
||||
|
||||
|
||||
def test_radam_complex(self):
|
||||
for foreach in (False, True):
|
||||
self._test_complex_optimizer(
|
||||
lambda param: RAdam([param], lr=1e-1, foreach=foreach)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda param: RAdam(
|
||||
[param],
|
||||
lr=1e-1,
|
||||
weight_decay=0.01,
|
||||
foreach=foreach,
|
||||
)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda param: RAdam(
|
||||
[param],
|
||||
lr=1e-1,
|
||||
weight_decay=0.01,
|
||||
decoupled_weight_decay=True,
|
||||
foreach=foreach,
|
||||
)
|
||||
)
|
||||
|
||||
def test_rmsprop(self):
|
||||
for foreach in (False, True):
|
||||
self._test_complex_2d(lambda param: RMSprop(param, foreach=foreach))
|
||||
|
|
@ -783,40 +669,6 @@ class TestOptim(TestCase):
|
|||
self._test_complex_2d(
|
||||
lambda param: RMSprop(param, maximize=True, foreach=foreach)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda param: RMSprop([param], foreach=foreach)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda param: RMSprop([param], centered=True, foreach=foreach)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda param: RMSprop([param], momentum=0.1, foreach=foreach)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda param: RMSprop([param], maximize=True, foreach=foreach)
|
||||
)
|
||||
|
||||
|
||||
def test_asgd(self):
|
||||
for foreach in (False, True):
|
||||
# Ref: https://github.com/pytorch/pytorch/issues/84560
|
||||
# self._test_complex_2d(optimizer)
|
||||
self._test_complex_optimizer(
|
||||
lambda params: ASGD([params], foreach=foreach)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda params: ASGD([params], maximize=True, foreach=foreach)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda params: ASGD(
|
||||
[params], maximize=True, weight_decay=0.1, foreach=foreach
|
||||
)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda params: ASGD(
|
||||
[params], maximize=False, weight_decay=0.1, foreach=foreach
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@skipIfRocm
|
||||
|
|
@ -824,14 +676,6 @@ class TestOptim(TestCase):
|
|||
def test_rprop(self):
|
||||
for foreach in (False, True):
|
||||
self._test_complex_2d(lambda param: Rprop(param, foreach=foreach))
|
||||
self._test_complex_optimizer(
|
||||
lambda param: Rprop([param], lr=0.001, foreach=foreach)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda param: Rprop(
|
||||
[param], lr=0.001, maximize=True, foreach=foreach
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_lbfgs_returns_consistent_type(self):
|
||||
|
|
|
|||
|
|
@ -128,6 +128,30 @@ class TestOptimRenewed(TestCase):
|
|||
self.assertLess(closure().item(), initial_value)
|
||||
|
||||
|
||||
@skipMPS
|
||||
@optims(optim_db, dtypes=[torch.complex64])
|
||||
def test_complex(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
|
||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
|
||||
for optim_input in all_optim_inputs:
|
||||
complex_params = [torch.randn(2, 3, device=device, dtype=dtype, requires_grad=True) for _ in range(3)]
|
||||
real_params = [torch.view_as_real(p).detach().clone().requires_grad_(True) for p in complex_params]
|
||||
|
||||
complex_optimizer = optim_cls(complex_params, **optim_input.kwargs)
|
||||
real_optimizer = optim_cls(real_params, **optim_input.kwargs)
|
||||
|
||||
for _ in range(3):
|
||||
for (c, r) in zip(complex_params, real_params):
|
||||
c.grad = torch.randn_like(c)
|
||||
r.grad = torch.view_as_real(c.grad)
|
||||
complex_optimizer.step()
|
||||
real_optimizer.step()
|
||||
|
||||
for (c, r) in zip(complex_params, real_params):
|
||||
self.assertEqual(torch.view_as_real(c), r)
|
||||
|
||||
|
||||
def _test_derived_optimizers(self, device, dtype, optim_info, flag, reduced_precision=False, assert_step_dtype=None):
|
||||
"""
|
||||
Given a flag 'fused' or 'foreach', test for parity of optimizer state
|
||||
|
|
|
|||
|
|
@ -286,12 +286,12 @@ def _multi_tensor_adadelta(
|
|||
|
||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, square_avgs, acc_deltas])
|
||||
for ((device_params, device_grads, device_square_avgs, device_acc_deltas), _) in grouped_tensors.values():
|
||||
if maximize:
|
||||
device_grads = torch._foreach_neg(device_grads)
|
||||
|
||||
if has_complex:
|
||||
_view_as_real(device_params, device_grads, device_square_avgs, device_acc_deltas)
|
||||
|
||||
if maximize:
|
||||
device_grads = torch._foreach_neg(device_grads)
|
||||
|
||||
if weight_decay != 0:
|
||||
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||
if maximize:
|
||||
|
|
|
|||
|
|
@ -344,13 +344,13 @@ def _multi_tensor_adagrad(
|
|||
)
|
||||
continue
|
||||
|
||||
if maximize:
|
||||
device_grads = torch._foreach_neg(device_grads)
|
||||
|
||||
# Handle complex parameters
|
||||
if has_complex:
|
||||
_view_as_real(device_params, device_grads, device_state_sums)
|
||||
|
||||
if maximize:
|
||||
device_grads = torch._foreach_neg(device_grads)
|
||||
|
||||
# Update steps
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
|
|
|
|||
|
|
@ -489,9 +489,6 @@ def _multi_tensor_adam(params: List[Tensor],
|
|||
device_state_steps,
|
||||
), _) in grouped_tensors.values():
|
||||
|
||||
if maximize:
|
||||
device_grads = torch._foreach_neg(device_grads)
|
||||
|
||||
# Handle complex parameters
|
||||
if has_complex:
|
||||
if amsgrad:
|
||||
|
|
@ -499,6 +496,9 @@ def _multi_tensor_adam(params: List[Tensor],
|
|||
else:
|
||||
_view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs)
|
||||
|
||||
if maximize:
|
||||
device_grads = torch._foreach_neg(device_grads)
|
||||
|
||||
# Update steps
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
|
|
|
|||
|
|
@ -330,12 +330,12 @@ def _multi_tensor_adamax(
|
|||
|
||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_infs, state_steps])
|
||||
for ((grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs, grouped_state_steps), _) in grouped_tensors.values():
|
||||
if maximize:
|
||||
grouped_grads = torch._foreach_neg(grouped_grads)
|
||||
|
||||
if has_complex:
|
||||
_view_as_real(grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs)
|
||||
|
||||
if maximize:
|
||||
grouped_grads = torch._foreach_neg(grouped_grads)
|
||||
|
||||
# Update steps
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
|
|
|
|||
|
|
@ -521,15 +521,15 @@ def _multi_tensor_adamw(
|
|||
device_max_exp_avg_sqs,
|
||||
device_state_steps,
|
||||
), _) in grouped_tensors.values():
|
||||
if maximize:
|
||||
device_grads = torch._foreach_neg(device_grads)
|
||||
|
||||
if has_complex:
|
||||
if amsgrad:
|
||||
_view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs, device_max_exp_avg_sqs)
|
||||
else:
|
||||
_view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs)
|
||||
|
||||
if maximize:
|
||||
device_grads = torch._foreach_neg(device_grads)
|
||||
|
||||
# Update steps
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
|
|
|
|||
|
|
@ -313,13 +313,12 @@ def _multi_tensor_asgd(
|
|||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, axs, mus, etas, state_steps])
|
||||
for ((device, _), ((grouped_params, grouped_grads, grouped_axs, grouped_mus,
|
||||
grouped_etas, grouped_state_steps), _)) in grouped_tensors.items():
|
||||
if maximize:
|
||||
grouped_grads = torch._foreach_neg(grouped_grads)
|
||||
|
||||
grouped_grads = list(grouped_grads)
|
||||
if has_complex:
|
||||
_view_as_real(grouped_params, grouped_grads, grouped_axs)
|
||||
|
||||
if maximize:
|
||||
grouped_grads = torch._foreach_neg(grouped_grads)
|
||||
|
||||
# Update steps
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
|
|
|
|||
|
|
@ -336,6 +336,14 @@ def _multi_tensor_rmsprop(
|
|||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, square_avgs, grad_avgs, momentum_buffer_list])
|
||||
for (((grouped_params, grouped_grads, grouped_square_avgs, grouped_grad_avgs,
|
||||
grouped_momentum_buffer_list)), _) in grouped_tensors.values():
|
||||
if has_complex:
|
||||
state_and_grads = [grouped_grads, grouped_square_avgs]
|
||||
if momentum > 0:
|
||||
state_and_grads.append(grouped_momentum_buffer_list)
|
||||
if centered:
|
||||
state_and_grads.append(grouped_grad_avgs)
|
||||
_view_as_real(grouped_params, *state_and_grads)
|
||||
|
||||
if maximize:
|
||||
grouped_grads = torch._foreach_neg(grouped_grads)
|
||||
|
||||
|
|
@ -346,16 +354,6 @@ def _multi_tensor_rmsprop(
|
|||
else:
|
||||
grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)
|
||||
|
||||
grouped_grads = list(grouped_grads)
|
||||
|
||||
if has_complex:
|
||||
state_and_grads = [grouped_grads, grouped_square_avgs]
|
||||
if momentum > 0:
|
||||
state_and_grads.append(grouped_momentum_buffer_list)
|
||||
if centered:
|
||||
state_and_grads.append(grouped_grad_avgs)
|
||||
_view_as_real(grouped_params, *state_and_grads)
|
||||
|
||||
torch._foreach_mul_(grouped_square_avgs, alpha)
|
||||
torch._foreach_addcmul_(grouped_square_avgs, grouped_grads, grouped_grads, value=1 - alpha)
|
||||
|
||||
|
|
|
|||
|
|
@ -263,9 +263,7 @@ def get_error_inputs_for_all_optims(device, dtype):
|
|||
def optim_inputs_func_adadelta(device=None):
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"lr": 0.01}, desc="non-default lr"
|
||||
), # TODO: Move out to testing in param_group?
|
||||
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||||
),
|
||||
|
|
@ -275,8 +273,8 @@ def optim_inputs_func_adadelta(device=None):
|
|||
desc="maximize",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"rho": 0.95, "weight_decay": 0.1}, desc="rho"
|
||||
), # TODO: Move out to testing in param_group?
|
||||
params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -494,6 +492,7 @@ def optim_inputs_func_asgd(device=None):
|
|||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
|
||||
OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"),
|
||||
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||||
),
|
||||
|
|
@ -545,6 +544,21 @@ def optim_error_inputs_func_lbfgs(device, dtype):
|
|||
def optim_inputs_func_nadam(device=None):
|
||||
cuda_supported_configs = [
|
||||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.9, "momentum_decay": 6e-3, "capturable": True},
|
||||
desc="weight_decay, capturable",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={
|
||||
"weight_decay": 0.9,
|
||||
"momentum_decay": 6e-3,
|
||||
"decoupled_weight_decay": True,
|
||||
"capturable": True,
|
||||
},
|
||||
desc="decoupled_weight_decay, capturable",
|
||||
),
|
||||
]
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
|
|
@ -1106,6 +1120,11 @@ optim_db: List[OptimizerInfo] = [
|
|||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Mismatched _foreach_addcdiv_ types, see #118159"),
|
||||
"TestOptimRenewed",
|
||||
"test_complex",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/115607"
|
||||
|
|
@ -1319,6 +1338,11 @@ optim_db: List[OptimizerInfo] = [
|
|||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.skip("Missing complex support, see #118148"),
|
||||
"TestOptimRenewed",
|
||||
"test_complex",
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
|
|
@ -1747,6 +1771,11 @@ optim_db: List[OptimizerInfo] = [
|
|||
"TestOptimRenewed",
|
||||
"test_deepcopy_copies_all_public_attrs",
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.skip("Missing complex support, see #118153"),
|
||||
"TestOptimRenewed",
|
||||
"test_complex",
|
||||
),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user