Migrate the rest of state_dict testing to OptimizerInfo (#117186)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117186
Approved by: https://github.com/albanD
ghstack dependencies: #116509
This commit is contained in:
Jane Xu 2024-01-12 11:17:30 -08:00 committed by PyTorch MergeBot
parent bcf1f312a0
commit c329eddcb9
3 changed files with 121 additions and 48 deletions

View File

@ -27,7 +27,6 @@ from torch.testing._internal.common_utils import (
skipIfTorchDynamo
)
from torch._dynamo import disable as disable_dynamo
from torch.testing._internal.common_cuda import TEST_CUDA
from typing import Dict, Any, Tuple
@ -216,41 +215,6 @@ class TestOptim(TestCase):
else:
self.assertLess(fn().item(), initial_value)
# Note: disable dynamo on this function
# This allows us to continue running actual logic of the optimizer
# tests in dynamo without tracing this test code which has a lot of unsupported
# behavior
@disable_dynamo(recursive=False)
def _test_state_dict(self, weight, bias, input, constructor, atol=None, rtol=None):
weight = Parameter(weight)
bias = Parameter(bias)
with torch.no_grad():
input = input.clone().detach().requires_grad_()
# Note: Disable dynamo on this function
# This avoids a bug where input_cuda is not detected in the environment
# because it currently is not defined in the local environmet. Unable to repro
# anywhere else however and this is test code that we don't need to spend
# time getting dynamo to trace unless the issue repros in real models.
@disable_dynamo(recursive=False)
def fn_base(optimizer, weight, bias):
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
loss.backward()
return loss
optimizer = constructor(weight, bias)
fn = functools.partial(fn_base, optimizer, weight, bias)
# Prime the optimizer
for _i in range(20):
optimizer.step(fn)
# validate deepcopy() copies all public attributes
def getPublicAttr(obj):
return {k for k in obj.__dict__ if not k.startswith("_")}
self.assertEqual(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer)))
def _test_basic_cases(
self,
@ -276,18 +240,6 @@ class TestOptim(TestCase):
return lambda weight, bias: constructor(weight, bias, foreach)
return constructor
for maximize, foreach in itertools.product(
{False, constructor_accepts_maximize},
{False, constructor_accepts_foreach},
):
self._test_state_dict(
torch.randn(10, 5),
torch.randn(10),
torch.randn(5),
make_two_arg_constructor(constructor, maximize, foreach),
atol=atol,
rtol=rtol,
)
self._test_basic_cases_template(
torch.randn(10, 5),
torch.randn(10),

View File

@ -584,6 +584,38 @@ class TestOptimRenewed(TestCase):
self.assertEqual(optimizer.state_dict(), optimizer_cuda.state_dict())
@optims(optim_db, dtypes=[torch.float32])
def test_deepcopy_copies_all_public_attrs(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
params = [Parameter(torch.randn(2, 3, device=device, dtype=dtype)) for _ in range(2)]
for p in params:
p.grad = torch.rand_like(p)
if optim_cls.__name__ == "SparseAdam":
# SparseAdam requires sparse gradients. For this test, we convert the Tensor layout,
# which we know does NOT represent the expected use case!
p.grad = p.grad.to_sparse()
# Needed for LBFGS
def closure():
return 1 if optim_cls.__name__ == "LBFGS" else None
def getPublicAttrs(obj):
return {k for k in obj.__dict__ if not k.startswith("_")}
for optim_input in all_optim_inputs:
optimizer = optim_cls(params, **optim_input.kwargs)
# Make some state
for _ in range(3):
optimizer.step(closure)
self.assertEqual(getPublicAttrs(optimizer), getPublicAttrs(deepcopy(optimizer)))
instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True)

View File

@ -912,6 +912,13 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_state_dict_with_cuda_params",
),
DecorateInfo(
skipIfTorchDynamo(
"fails, https://github.com/pytorch/pytorch/issues/117165"
),
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
),
),
OptimizerInfo(
@ -949,6 +956,13 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_state_dict_deterministic",
),
DecorateInfo(
skipIfTorchDynamo(
"fails, https://github.com/pytorch/pytorch/issues/117165"
),
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
),
),
OptimizerInfo(
@ -985,6 +999,13 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_state_dict_deterministic",
),
DecorateInfo(
skipIfTorchDynamo(
"fails, https://github.com/pytorch/pytorch/issues/117165"
),
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
),
),
OptimizerInfo(
@ -1026,6 +1047,13 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_state_dict_deterministic",
),
DecorateInfo(
skipIfTorchDynamo(
"fails, https://github.com/pytorch/pytorch/issues/117165"
),
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
),
),
OptimizerInfo(
@ -1062,6 +1090,13 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_state_dict_deterministic",
),
DecorateInfo(
skipIfTorchDynamo(
"fails, https://github.com/pytorch/pytorch/issues/117165"
),
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
),
),
OptimizerInfo(
@ -1107,6 +1142,13 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_state_dict_deterministic",
),
DecorateInfo(
skipIfTorchDynamo(
"fails, https://github.com/pytorch/pytorch/issues/117165"
),
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
),
),
OptimizerInfo(
@ -1122,6 +1164,13 @@ optim_db: List[OptimizerInfo] = [
DecorateInfo(
skipIfMps, "TestOptimRenewed", "test_can_load_older_state_dict"
),
DecorateInfo(
skipIfTorchDynamo(
"fails, https://github.com/pytorch/pytorch/issues/117165"
),
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
),
),
OptimizerInfo(
@ -1173,6 +1222,13 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_state_dict_with_cuda_params",
),
DecorateInfo(
skipIfTorchDynamo(
"fails, https://github.com/pytorch/pytorch/issues/117165"
),
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
),
),
OptimizerInfo(
@ -1195,6 +1251,13 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_set_default_dtype_works_with_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"fails, https://github.com/pytorch/pytorch/issues/117165"
),
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
),
),
OptimizerInfo(
@ -1248,6 +1311,13 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_state_dict_with_cuda_params",
),
DecorateInfo(
skipIfTorchDynamo(
"fails, https://github.com/pytorch/pytorch/issues/117165"
),
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
),
),
OptimizerInfo(
@ -1291,6 +1361,13 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_state_dict_with_cuda_params",
),
DecorateInfo(
skipIfTorchDynamo(
"fails, https://github.com/pytorch/pytorch/issues/117165"
),
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
),
),
OptimizerInfo(
@ -1354,6 +1431,13 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_state_dict_with_cuda_params",
),
DecorateInfo(
skipIfTorchDynamo(
"fails, https://github.com/pytorch/pytorch/issues/117165"
),
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
),
),
OptimizerInfo(
@ -1391,6 +1475,11 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_state_dict_with_cuda_params",
),
DecorateInfo(
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
),
),
]