mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Migrate the rest of state_dict testing to OptimizerInfo (#117186)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117186 Approved by: https://github.com/albanD ghstack dependencies: #116509
This commit is contained in:
parent
bcf1f312a0
commit
c329eddcb9
|
|
@ -27,7 +27,6 @@ from torch.testing._internal.common_utils import (
|
|||
skipIfTorchDynamo
|
||||
)
|
||||
|
||||
from torch._dynamo import disable as disable_dynamo
|
||||
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from typing import Dict, Any, Tuple
|
||||
|
|
@ -216,41 +215,6 @@ class TestOptim(TestCase):
|
|||
else:
|
||||
self.assertLess(fn().item(), initial_value)
|
||||
|
||||
# Note: disable dynamo on this function
|
||||
# This allows us to continue running actual logic of the optimizer
|
||||
# tests in dynamo without tracing this test code which has a lot of unsupported
|
||||
# behavior
|
||||
@disable_dynamo(recursive=False)
|
||||
def _test_state_dict(self, weight, bias, input, constructor, atol=None, rtol=None):
|
||||
weight = Parameter(weight)
|
||||
bias = Parameter(bias)
|
||||
with torch.no_grad():
|
||||
input = input.clone().detach().requires_grad_()
|
||||
|
||||
# Note: Disable dynamo on this function
|
||||
# This avoids a bug where input_cuda is not detected in the environment
|
||||
# because it currently is not defined in the local environmet. Unable to repro
|
||||
# anywhere else however and this is test code that we don't need to spend
|
||||
# time getting dynamo to trace unless the issue repros in real models.
|
||||
@disable_dynamo(recursive=False)
|
||||
def fn_base(optimizer, weight, bias):
|
||||
optimizer.zero_grad()
|
||||
loss = (weight.mv(input) + bias).pow(2).sum()
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
optimizer = constructor(weight, bias)
|
||||
fn = functools.partial(fn_base, optimizer, weight, bias)
|
||||
|
||||
# Prime the optimizer
|
||||
for _i in range(20):
|
||||
optimizer.step(fn)
|
||||
|
||||
# validate deepcopy() copies all public attributes
|
||||
def getPublicAttr(obj):
|
||||
return {k for k in obj.__dict__ if not k.startswith("_")}
|
||||
|
||||
self.assertEqual(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer)))
|
||||
|
||||
def _test_basic_cases(
|
||||
self,
|
||||
|
|
@ -276,18 +240,6 @@ class TestOptim(TestCase):
|
|||
return lambda weight, bias: constructor(weight, bias, foreach)
|
||||
return constructor
|
||||
|
||||
for maximize, foreach in itertools.product(
|
||||
{False, constructor_accepts_maximize},
|
||||
{False, constructor_accepts_foreach},
|
||||
):
|
||||
self._test_state_dict(
|
||||
torch.randn(10, 5),
|
||||
torch.randn(10),
|
||||
torch.randn(5),
|
||||
make_two_arg_constructor(constructor, maximize, foreach),
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
self._test_basic_cases_template(
|
||||
torch.randn(10, 5),
|
||||
torch.randn(10),
|
||||
|
|
|
|||
|
|
@ -584,6 +584,38 @@ class TestOptimRenewed(TestCase):
|
|||
self.assertEqual(optimizer.state_dict(), optimizer_cuda.state_dict())
|
||||
|
||||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_deepcopy_copies_all_public_attrs(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
|
||||
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
|
||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
|
||||
|
||||
params = [Parameter(torch.randn(2, 3, device=device, dtype=dtype)) for _ in range(2)]
|
||||
for p in params:
|
||||
p.grad = torch.rand_like(p)
|
||||
if optim_cls.__name__ == "SparseAdam":
|
||||
# SparseAdam requires sparse gradients. For this test, we convert the Tensor layout,
|
||||
# which we know does NOT represent the expected use case!
|
||||
p.grad = p.grad.to_sparse()
|
||||
|
||||
# Needed for LBFGS
|
||||
def closure():
|
||||
return 1 if optim_cls.__name__ == "LBFGS" else None
|
||||
|
||||
def getPublicAttrs(obj):
|
||||
return {k for k in obj.__dict__ if not k.startswith("_")}
|
||||
|
||||
for optim_input in all_optim_inputs:
|
||||
optimizer = optim_cls(params, **optim_input.kwargs)
|
||||
|
||||
# Make some state
|
||||
for _ in range(3):
|
||||
optimizer.step(closure)
|
||||
|
||||
self.assertEqual(getPublicAttrs(optimizer), getPublicAttrs(deepcopy(optimizer)))
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -912,6 +912,13 @@ optim_db: List[OptimizerInfo] = [
|
|||
"TestOptimRenewed",
|
||||
"test_state_dict_with_cuda_params",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"fails, https://github.com/pytorch/pytorch/issues/117165"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_deepcopy_copies_all_public_attrs",
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
|
|
@ -949,6 +956,13 @@ optim_db: List[OptimizerInfo] = [
|
|||
"TestOptimRenewed",
|
||||
"test_state_dict_deterministic",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"fails, https://github.com/pytorch/pytorch/issues/117165"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_deepcopy_copies_all_public_attrs",
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
|
|
@ -985,6 +999,13 @@ optim_db: List[OptimizerInfo] = [
|
|||
"TestOptimRenewed",
|
||||
"test_state_dict_deterministic",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"fails, https://github.com/pytorch/pytorch/issues/117165"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_deepcopy_copies_all_public_attrs",
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
|
|
@ -1026,6 +1047,13 @@ optim_db: List[OptimizerInfo] = [
|
|||
"TestOptimRenewed",
|
||||
"test_state_dict_deterministic",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"fails, https://github.com/pytorch/pytorch/issues/117165"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_deepcopy_copies_all_public_attrs",
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
|
|
@ -1062,6 +1090,13 @@ optim_db: List[OptimizerInfo] = [
|
|||
"TestOptimRenewed",
|
||||
"test_state_dict_deterministic",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"fails, https://github.com/pytorch/pytorch/issues/117165"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_deepcopy_copies_all_public_attrs",
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
|
|
@ -1107,6 +1142,13 @@ optim_db: List[OptimizerInfo] = [
|
|||
"TestOptimRenewed",
|
||||
"test_state_dict_deterministic",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"fails, https://github.com/pytorch/pytorch/issues/117165"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_deepcopy_copies_all_public_attrs",
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
|
|
@ -1122,6 +1164,13 @@ optim_db: List[OptimizerInfo] = [
|
|||
DecorateInfo(
|
||||
skipIfMps, "TestOptimRenewed", "test_can_load_older_state_dict"
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"fails, https://github.com/pytorch/pytorch/issues/117165"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_deepcopy_copies_all_public_attrs",
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
|
|
@ -1173,6 +1222,13 @@ optim_db: List[OptimizerInfo] = [
|
|||
"TestOptimRenewed",
|
||||
"test_state_dict_with_cuda_params",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"fails, https://github.com/pytorch/pytorch/issues/117165"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_deepcopy_copies_all_public_attrs",
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
|
|
@ -1195,6 +1251,13 @@ optim_db: List[OptimizerInfo] = [
|
|||
"TestOptimRenewed",
|
||||
"test_set_default_dtype_works_with_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"fails, https://github.com/pytorch/pytorch/issues/117165"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_deepcopy_copies_all_public_attrs",
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
|
|
@ -1248,6 +1311,13 @@ optim_db: List[OptimizerInfo] = [
|
|||
"TestOptimRenewed",
|
||||
"test_state_dict_with_cuda_params",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"fails, https://github.com/pytorch/pytorch/issues/117165"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_deepcopy_copies_all_public_attrs",
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
|
|
@ -1291,6 +1361,13 @@ optim_db: List[OptimizerInfo] = [
|
|||
"TestOptimRenewed",
|
||||
"test_state_dict_with_cuda_params",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"fails, https://github.com/pytorch/pytorch/issues/117165"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_deepcopy_copies_all_public_attrs",
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
|
|
@ -1354,6 +1431,13 @@ optim_db: List[OptimizerInfo] = [
|
|||
"TestOptimRenewed",
|
||||
"test_state_dict_with_cuda_params",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"fails, https://github.com/pytorch/pytorch/issues/117165"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_deepcopy_copies_all_public_attrs",
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
|
|
@ -1391,6 +1475,11 @@ optim_db: List[OptimizerInfo] = [
|
|||
"TestOptimRenewed",
|
||||
"test_state_dict_with_cuda_params",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
|
||||
"TestOptimRenewed",
|
||||
"test_deepcopy_copies_all_public_attrs",
|
||||
),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user