From c329eddcb9e9ddb59726097eda0d1f2bbbd29114 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Fri, 12 Jan 2024 11:17:30 -0800 Subject: [PATCH] Migrate the rest of state_dict testing to OptimizerInfo (#117186) Pull Request resolved: https://github.com/pytorch/pytorch/pull/117186 Approved by: https://github.com/albanD ghstack dependencies: #116509 --- test/optim/test_optim.py | 48 ----------- test/test_optim.py | 32 +++++++ torch/testing/_internal/common_optimizers.py | 89 ++++++++++++++++++++ 3 files changed, 121 insertions(+), 48 deletions(-) diff --git a/test/optim/test_optim.py b/test/optim/test_optim.py index bd1e75a0f06..d0bcad889b2 100644 --- a/test/optim/test_optim.py +++ b/test/optim/test_optim.py @@ -27,7 +27,6 @@ from torch.testing._internal.common_utils import ( skipIfTorchDynamo ) -from torch._dynamo import disable as disable_dynamo from torch.testing._internal.common_cuda import TEST_CUDA from typing import Dict, Any, Tuple @@ -216,41 +215,6 @@ class TestOptim(TestCase): else: self.assertLess(fn().item(), initial_value) - # Note: disable dynamo on this function - # This allows us to continue running actual logic of the optimizer - # tests in dynamo without tracing this test code which has a lot of unsupported - # behavior - @disable_dynamo(recursive=False) - def _test_state_dict(self, weight, bias, input, constructor, atol=None, rtol=None): - weight = Parameter(weight) - bias = Parameter(bias) - with torch.no_grad(): - input = input.clone().detach().requires_grad_() - - # Note: Disable dynamo on this function - # This avoids a bug where input_cuda is not detected in the environment - # because it currently is not defined in the local environmet. Unable to repro - # anywhere else however and this is test code that we don't need to spend - # time getting dynamo to trace unless the issue repros in real models. - @disable_dynamo(recursive=False) - def fn_base(optimizer, weight, bias): - optimizer.zero_grad() - loss = (weight.mv(input) + bias).pow(2).sum() - loss.backward() - return loss - - optimizer = constructor(weight, bias) - fn = functools.partial(fn_base, optimizer, weight, bias) - - # Prime the optimizer - for _i in range(20): - optimizer.step(fn) - - # validate deepcopy() copies all public attributes - def getPublicAttr(obj): - return {k for k in obj.__dict__ if not k.startswith("_")} - - self.assertEqual(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer))) def _test_basic_cases( self, @@ -276,18 +240,6 @@ class TestOptim(TestCase): return lambda weight, bias: constructor(weight, bias, foreach) return constructor - for maximize, foreach in itertools.product( - {False, constructor_accepts_maximize}, - {False, constructor_accepts_foreach}, - ): - self._test_state_dict( - torch.randn(10, 5), - torch.randn(10), - torch.randn(5), - make_two_arg_constructor(constructor, maximize, foreach), - atol=atol, - rtol=rtol, - ) self._test_basic_cases_template( torch.randn(10, 5), torch.randn(10), diff --git a/test/test_optim.py b/test/test_optim.py index 0438b0db318..78538735492 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -584,6 +584,38 @@ class TestOptimRenewed(TestCase): self.assertEqual(optimizer.state_dict(), optimizer_cuda.state_dict()) + @optims(optim_db, dtypes=[torch.float32]) + def test_deepcopy_copies_all_public_attrs(self, device, dtype, optim_info): + optim_cls = optim_info.optim_cls + + # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 + all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",)) + + params = [Parameter(torch.randn(2, 3, device=device, dtype=dtype)) for _ in range(2)] + for p in params: + p.grad = torch.rand_like(p) + if optim_cls.__name__ == "SparseAdam": + # SparseAdam requires sparse gradients. For this test, we convert the Tensor layout, + # which we know does NOT represent the expected use case! + p.grad = p.grad.to_sparse() + + # Needed for LBFGS + def closure(): + return 1 if optim_cls.__name__ == "LBFGS" else None + + def getPublicAttrs(obj): + return {k for k in obj.__dict__ if not k.startswith("_")} + + for optim_input in all_optim_inputs: + optimizer = optim_cls(params, **optim_input.kwargs) + + # Make some state + for _ in range(3): + optimizer.step(closure) + + self.assertEqual(getPublicAttrs(optimizer), getPublicAttrs(deepcopy(optimizer))) + + instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True) diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index e90950ad99c..f687ec610ea 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -912,6 +912,13 @@ optim_db: List[OptimizerInfo] = [ "TestOptimRenewed", "test_state_dict_with_cuda_params", ), + DecorateInfo( + skipIfTorchDynamo( + "fails, https://github.com/pytorch/pytorch/issues/117165" + ), + "TestOptimRenewed", + "test_deepcopy_copies_all_public_attrs", + ), ), ), OptimizerInfo( @@ -949,6 +956,13 @@ optim_db: List[OptimizerInfo] = [ "TestOptimRenewed", "test_state_dict_deterministic", ), + DecorateInfo( + skipIfTorchDynamo( + "fails, https://github.com/pytorch/pytorch/issues/117165" + ), + "TestOptimRenewed", + "test_deepcopy_copies_all_public_attrs", + ), ), ), OptimizerInfo( @@ -985,6 +999,13 @@ optim_db: List[OptimizerInfo] = [ "TestOptimRenewed", "test_state_dict_deterministic", ), + DecorateInfo( + skipIfTorchDynamo( + "fails, https://github.com/pytorch/pytorch/issues/117165" + ), + "TestOptimRenewed", + "test_deepcopy_copies_all_public_attrs", + ), ), ), OptimizerInfo( @@ -1026,6 +1047,13 @@ optim_db: List[OptimizerInfo] = [ "TestOptimRenewed", "test_state_dict_deterministic", ), + DecorateInfo( + skipIfTorchDynamo( + "fails, https://github.com/pytorch/pytorch/issues/117165" + ), + "TestOptimRenewed", + "test_deepcopy_copies_all_public_attrs", + ), ), ), OptimizerInfo( @@ -1062,6 +1090,13 @@ optim_db: List[OptimizerInfo] = [ "TestOptimRenewed", "test_state_dict_deterministic", ), + DecorateInfo( + skipIfTorchDynamo( + "fails, https://github.com/pytorch/pytorch/issues/117165" + ), + "TestOptimRenewed", + "test_deepcopy_copies_all_public_attrs", + ), ), ), OptimizerInfo( @@ -1107,6 +1142,13 @@ optim_db: List[OptimizerInfo] = [ "TestOptimRenewed", "test_state_dict_deterministic", ), + DecorateInfo( + skipIfTorchDynamo( + "fails, https://github.com/pytorch/pytorch/issues/117165" + ), + "TestOptimRenewed", + "test_deepcopy_copies_all_public_attrs", + ), ), ), OptimizerInfo( @@ -1122,6 +1164,13 @@ optim_db: List[OptimizerInfo] = [ DecorateInfo( skipIfMps, "TestOptimRenewed", "test_can_load_older_state_dict" ), + DecorateInfo( + skipIfTorchDynamo( + "fails, https://github.com/pytorch/pytorch/issues/117165" + ), + "TestOptimRenewed", + "test_deepcopy_copies_all_public_attrs", + ), ), ), OptimizerInfo( @@ -1173,6 +1222,13 @@ optim_db: List[OptimizerInfo] = [ "TestOptimRenewed", "test_state_dict_with_cuda_params", ), + DecorateInfo( + skipIfTorchDynamo( + "fails, https://github.com/pytorch/pytorch/issues/117165" + ), + "TestOptimRenewed", + "test_deepcopy_copies_all_public_attrs", + ), ), ), OptimizerInfo( @@ -1195,6 +1251,13 @@ optim_db: List[OptimizerInfo] = [ "TestOptimRenewed", "test_set_default_dtype_works_with_foreach", ), + DecorateInfo( + skipIfTorchDynamo( + "fails, https://github.com/pytorch/pytorch/issues/117165" + ), + "TestOptimRenewed", + "test_deepcopy_copies_all_public_attrs", + ), ), ), OptimizerInfo( @@ -1248,6 +1311,13 @@ optim_db: List[OptimizerInfo] = [ "TestOptimRenewed", "test_state_dict_with_cuda_params", ), + DecorateInfo( + skipIfTorchDynamo( + "fails, https://github.com/pytorch/pytorch/issues/117165" + ), + "TestOptimRenewed", + "test_deepcopy_copies_all_public_attrs", + ), ), ), OptimizerInfo( @@ -1291,6 +1361,13 @@ optim_db: List[OptimizerInfo] = [ "TestOptimRenewed", "test_state_dict_with_cuda_params", ), + DecorateInfo( + skipIfTorchDynamo( + "fails, https://github.com/pytorch/pytorch/issues/117165" + ), + "TestOptimRenewed", + "test_deepcopy_copies_all_public_attrs", + ), ), ), OptimizerInfo( @@ -1354,6 +1431,13 @@ optim_db: List[OptimizerInfo] = [ "TestOptimRenewed", "test_state_dict_with_cuda_params", ), + DecorateInfo( + skipIfTorchDynamo( + "fails, https://github.com/pytorch/pytorch/issues/117165" + ), + "TestOptimRenewed", + "test_deepcopy_copies_all_public_attrs", + ), ), ), OptimizerInfo( @@ -1391,6 +1475,11 @@ optim_db: List[OptimizerInfo] = [ "TestOptimRenewed", "test_state_dict_with_cuda_params", ), + DecorateInfo( + skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"), + "TestOptimRenewed", + "test_deepcopy_copies_all_public_attrs", + ), ), ), ]