# Owner(s): ["module: optimizer"] from copy import deepcopy import torch from optim.test_optim import TestOptim, TestDifferentiableOptimizer # noqa: F401 from optim.test_lrscheduler import TestLRScheduler # noqa: F401 from optim.test_swa_utils import TestSWAUtils # noqa: F401 from torch.testing._internal.common_optimizers import optim_db, optims, OptimizerErrorEnum from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCPU, skipMPS from torch.testing._internal.common_utils import markDynamoStrictTest, run_tests, TestCase @markDynamoStrictTest class TestOptimRenewed(TestCase): @onlyCPU @optims([optim for optim in optim_db if optim.optim_error_inputs_func is not None]) def test_errors(self, device, dtype, optim_info): optim_cls = optim_info.optim_cls error_inputs = optim_info.optim_error_inputs_func(device=device, dtype=dtype) for error_input in error_inputs: optim_input = error_input.optimizer_error_input params, kwargs = optim_input.params, optim_input.kwargs if error_input.error_on == OptimizerErrorEnum.CONSTRUCTION_ERROR: with self.assertRaisesRegex(error_input.error_type, error_input.error_regex): optim_cls(params, **kwargs) elif error_input.error_on == OptimizerErrorEnum.STEP_ERROR: optim = optim_cls(params, **kwargs) with self.assertRaisesRegex(error_input.error_type, error_input.error_regex): optim.step() else: raise NotImplementedError(f"Unknown error type {error_input.error_on}") def _test_derived_optimizers(self, device, dtype, optim_info, flag): assert flag in ("foreach", "fused") # why 7? iteration 7 is where we start to see differences for RAdam # params interacting with the small eps value, because that's right # after rho_t becomes greater than 5 in step 6. kIterations = 7 optim_inputs = optim_info.optim_inputs_func() optim_cls = optim_info.optim_cls for optim_input in optim_inputs: updated_params, state = [], [] kwargs = deepcopy(optim_input.kwargs) if (kwargs.get("capturable", False) and (str(device) == "cpu" or optim_cls.__name__ == "ASGD")): # capturable is not supported on CPU nor in single tensor ASGD continue for flag_value in (False, True): kwargs[flag] = flag_value input = torch.tensor( [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=device ).reshape(3, 2) torch.manual_seed(1) model = torch.nn.Sequential( torch.nn.Linear(2, 3), torch.nn.Sigmoid(), torch.nn.Linear(3, 1), torch.nn.Sigmoid(), ) model.to(dtype=dtype, device=device) # foreach/fused optimizers should be tested with a # zero_size tensor as its last param. # ref: https://github.com/pytorch/pytorch/issues/100701 empty_param = torch.empty((), device=device, dtype=dtype, requires_grad=True) empty_param.grad = torch.rand_like(empty_param) params = list(model.parameters()) + [empty_param] optimizer = optim_cls(params, **kwargs) for i in range(kIterations): optimizer.zero_grad() # Test that step behaves as expected (a no-op) when grads are set to None if i != 3: output = model(input) loss = output.sum() loss.backward() optimizer.step() state.append(optimizer.state) updated_params.append(model.parameters()) og_state, new_state = state for og_p, new_p in zip(updated_params[0], updated_params[1]): self.assertEqual(og_p, new_p) # check that optimizer states are the same og_p_state = og_state[og_p] new_p_state = new_state[new_p] for k in og_p_state: self.assertEqual(og_p_state[k], new_p_state[k]) @skipMPS # MPS doesn't support torch.float64, see https://github.com/pytorch/pytorch/issues/115350 @optims([optim for optim in optim_db if "foreach" in optim.supported_impls], dtypes=[torch.float64]) def test_foreach_matches_forloop(self, device, dtype, optim_info): self._test_derived_optimizers(device, dtype, optim_info, "foreach") @onlyCPU @optims(optim_db) def test_optim_infos_do_not_specify_global_cliquey_kwargs(self, device, dtype, optim_info): global_cliquey_flags = ["foreach", "fused", "differentiable"] for optim_input in optim_info.optim_inputs_func(): self.assertFalse(any(f for f in global_cliquey_flags if f in optim_input.kwargs)) instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True) if __name__ == '__main__': run_tests()