mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Introduce OptimizerInfos + use them to refactor out the error testing. Why OptimizerInfos? - cleaner, easier way to test all configs of optimizers - would plug in well with devicetype to auto-enable tests for devices like MPS, meta - would allow for more granular testing. currently, lots of functionality is tested in `_test_basic_cases` and some of that should be broken down more. What did I do for error testing? - I moved out some error cases from `_test_basic_cases` into a new test_errors parametrized test. - The new test has to live in TestOptimRenewed (bikeshedding welcome) because the parametrized tests need to take in device and dtype and hook correctly, and not all tests in TestOptim do that. - TestOptimRenewed also is migrating to the toplevel test/test_optim.py now because importing TestOptimRenewed does not work (because of test instantiation, TestOptimRenewed gets replaced with TestOptimRenewedDevice for CPU, CUDA, and whatever other device). Is there any change in test coverage? - INCREASE: The error case where a single Parameter (vs a container of them) are passed in has now expanded to all optims instead of only LBFGS - DECREASE: Not much. The only thing is we no longer test two error cases for foreach=True AND foreach=False, which I think is redundant. (Highlighted in comments) Possible but not urgent next step: test ALL possible error cases by going through all the constructors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114178 Approved by: https://github.com/albanD
627 lines
21 KiB
Python
627 lines
21 KiB
Python
import functools
|
|
import itertools
|
|
from enum import Enum
|
|
from typing import Any, Dict, List, Tuple, Union
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.nn import Parameter
|
|
from torch.optim import (
|
|
Adadelta,
|
|
Adagrad,
|
|
Adam,
|
|
Adamax,
|
|
AdamW,
|
|
ASGD,
|
|
LBFGS,
|
|
NAdam,
|
|
Optimizer,
|
|
RAdam,
|
|
RMSprop,
|
|
Rprop,
|
|
SGD,
|
|
SparseAdam,
|
|
)
|
|
from torch.testing._internal.common_methods_invocations import DecorateInfo
|
|
from torch.testing._internal.common_utils import (
|
|
_TestParametrizer,
|
|
set_single_threaded_if_parallel_tbb,
|
|
)
|
|
|
|
|
|
class OptimizerInput:
|
|
"""Contains args / kwargs to be passed to an optimizer constructor."""
|
|
|
|
__slots__ = ["params", "kwargs", "desc"]
|
|
|
|
def __init__(
|
|
self,
|
|
params: Union[List[Parameter], List[Tensor], Dict[Any, Any]],
|
|
kwargs: Dict[str, Any],
|
|
desc: str = "",
|
|
):
|
|
self.params = (
|
|
params # Here, params can be a list of Tensors OR param_groups as well.
|
|
)
|
|
self.kwargs = kwargs
|
|
self.desc = desc
|
|
|
|
|
|
class OptimizerErrorEnum(Enum):
|
|
"""Enumerates when an error is raised when testing optimizers."""
|
|
|
|
CONSTRUCTION_ERROR = 0
|
|
STEP_ERROR = 1
|
|
|
|
|
|
class ErrorOptimizerInput:
|
|
"""
|
|
An OptimizerInput that will cause the optimizer to throw an error when constructed.
|
|
Includes the type and string of the resulting error.
|
|
"""
|
|
|
|
__slots__ = ["optimizer_error_input", "error_on", "error_type", "error_regex"]
|
|
|
|
def __init__(
|
|
self,
|
|
optimizer_error_input,
|
|
*,
|
|
error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR,
|
|
error_type=RuntimeError,
|
|
error_regex="",
|
|
):
|
|
self.optimizer_error_input = optimizer_error_input
|
|
self.error_on = error_on
|
|
self.error_type = error_type
|
|
self.error_regex = error_regex
|
|
|
|
|
|
class OptimizerInfo:
|
|
"""Optimizer information to be used in testing."""
|
|
|
|
def __init__(
|
|
self,
|
|
optim_cls: Optimizer, # Class object for the Optimizer under test
|
|
*,
|
|
# Function to generate optimizer inputs
|
|
optim_inputs_func,
|
|
# Implementation specific kwargs the optimizer supports, e.g., fused, foreach, differentiable
|
|
# We consider capturable to be a base constructor flag since it is implemented across the board.
|
|
supported_impls: Tuple[str] = ("foreach", "differentiable"),
|
|
# the devices on which the optim supports sparse tensors for params and grads, see SGD
|
|
supports_sparse_on: Tuple[str] = (),
|
|
# the optim only supports one config: sparse grads w/ dense params, see SparseAdam
|
|
only_supports_sparse_grads: bool = False,
|
|
# whether the optimizer.step() function requires a closure to be passed
|
|
step_requires_closure: bool = False,
|
|
# whether the optimizer supports per-param options with parameter groups
|
|
supports_param_groups: bool = True,
|
|
# whether the optimizer supports parameters on multiple devices
|
|
supports_multiple_devices: bool = True,
|
|
skips=(), # Indicates which tests to skip
|
|
decorators=None, # Additional decorators to apply to generated tests
|
|
optim_error_inputs_func=None, # Function to generate optim inputs that error
|
|
):
|
|
self.optim_cls = optim_cls
|
|
self.optim_inputs_func = optim_inputs_func
|
|
self.supported_impls = supported_impls
|
|
self.supports_sparse_on = supports_sparse_on
|
|
self.only_supports_sparse_grads = only_supports_sparse_grads
|
|
self.step_requires_closure = step_requires_closure
|
|
self.supports_param_groups = supports_param_groups
|
|
self.supports_multiple_devices = supports_multiple_devices
|
|
self.decorators = (
|
|
*(decorators if decorators else []),
|
|
*(skips if skips else []),
|
|
)
|
|
self.optim_error_inputs_func = optim_error_inputs_func
|
|
|
|
def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
|
|
result = [set_single_threaded_if_parallel_tbb]
|
|
for decorator in self.decorators:
|
|
if isinstance(decorator, DecorateInfo):
|
|
if decorator.is_active(
|
|
test_class, test_name, device, dtype, param_kwargs
|
|
):
|
|
result.extend(decorator.decorators)
|
|
else:
|
|
result.append(decorator)
|
|
return result
|
|
|
|
@property
|
|
def name(self):
|
|
return self.optim_cls.__name__
|
|
|
|
|
|
class optims(_TestParametrizer):
|
|
"""Decorator for specifying a list of optimizers over which to run a test."""
|
|
|
|
def __init__(self, optim_info_iterable, dtypes=None):
|
|
self.optim_info_list = list(optim_info_iterable)
|
|
|
|
# optimizers aren't limited to be one dtype as parameters can have different dtypes
|
|
# We default to torch.float32, but dtypes should be specified through passed in
|
|
# parameters.
|
|
self.dtypes = dtypes if dtypes is not None else [torch.float32]
|
|
|
|
def _parametrize_test(self, test, generic_cls, device_cls):
|
|
if device_cls is None:
|
|
raise RuntimeError(
|
|
"The @optims decorator is only intended to be used in a device-specific "
|
|
"context; use it with instantiate_device_type_tests() instead of "
|
|
"instantiate_parametrized_tests()"
|
|
)
|
|
|
|
for optim_info, dtype in itertools.product(self.optim_info_list, self.dtypes):
|
|
# Construct the test name; device / dtype parts are handled outside.
|
|
# See [Note: device and dtype suffix placement]
|
|
test_name = optim_info.name
|
|
|
|
# Construct parameter kwargs to pass to the test.
|
|
param_kwargs = {"optim_info": optim_info, "dtype": dtype}
|
|
|
|
try:
|
|
|
|
@functools.wraps(test)
|
|
def test_wrapper(*args, **kwargs):
|
|
return test(*args, **kwargs)
|
|
|
|
decorator_fn = functools.partial(
|
|
optim_info.get_decorators,
|
|
generic_cls.__name__,
|
|
test.__name__,
|
|
device_cls.device_type,
|
|
dtype,
|
|
)
|
|
|
|
yield (test_wrapper, test_name, param_kwargs, decorator_fn)
|
|
except Exception as ex:
|
|
# Provides an error message for debugging before rethrowing the exception
|
|
print(
|
|
f"Failed to instantiate {test_name} for module {optim_info.name}!"
|
|
)
|
|
raise ex
|
|
|
|
|
|
def optim_error_inputs_func_adadelta(device, dtype):
|
|
return [
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=None,
|
|
kwargs=dict(lr=1e-2, rho=1.1),
|
|
desc="rho should be between 0 and 1",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="Invalid rho value: 1.1",
|
|
),
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
|
|
kwargs={},
|
|
desc="invalid param type",
|
|
),
|
|
error_type=TypeError,
|
|
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
|
|
),
|
|
]
|
|
|
|
|
|
def optim_error_inputs_func_adagrad(device, dtype):
|
|
return [
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=None,
|
|
kwargs=dict(lr=1e-2, lr_decay=-0.5),
|
|
desc="lr_decay must be bigger than 0",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="Invalid lr_decay value: -0.5",
|
|
),
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
|
|
kwargs={},
|
|
desc="invalid param type",
|
|
),
|
|
error_type=TypeError,
|
|
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
|
|
),
|
|
]
|
|
|
|
|
|
def optim_error_inputs_func_adam(device, dtype):
|
|
return [
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=None,
|
|
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
|
|
desc="beta1 should be between 0 and 1",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="Invalid beta parameter at index 0: 1.0",
|
|
),
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=None,
|
|
kwargs=dict(lr=1e-2, weight_decay=-1),
|
|
desc="weight_decay should > 0",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="Invalid weight_decay value: -1",
|
|
),
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=None,
|
|
kwargs=dict(lr=torch.tensor(0.001), foreach=True),
|
|
desc="lr as Tensor doesn't work with foreach & not capturable",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="lr as a Tensor is not supported for capturable=False and foreach=True",
|
|
),
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
|
|
kwargs={},
|
|
desc="invalid param type",
|
|
),
|
|
error_type=TypeError,
|
|
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
|
|
),
|
|
]
|
|
|
|
|
|
def optim_error_inputs_func_adamax(device, dtype):
|
|
return [
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=None,
|
|
kwargs=dict(lr=1e-2, betas=(0.0, 1.0)),
|
|
desc="beta2 should be between 0 and 1",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="Invalid beta parameter at index 1: 1.0",
|
|
),
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
|
|
kwargs={},
|
|
desc="invalid param type",
|
|
),
|
|
error_type=TypeError,
|
|
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
|
|
),
|
|
]
|
|
|
|
|
|
def optim_error_inputs_func_adamw(device, dtype):
|
|
return [
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=None,
|
|
kwargs=dict(lr=1e-2, weight_decay=-1),
|
|
desc="weight_decay should > 0",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="Invalid weight_decay value: -1",
|
|
),
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=None,
|
|
kwargs=dict(lr=torch.tensor(0.001), foreach=True),
|
|
desc="lr as Tensor doesn't work with foreach & not capturable",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="lr as a Tensor is not supported for capturable=False and foreach=True",
|
|
),
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
|
|
kwargs={},
|
|
desc="invalid param type",
|
|
),
|
|
error_type=TypeError,
|
|
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
|
|
),
|
|
]
|
|
|
|
|
|
def optim_error_inputs_func_asgd(device, dtype):
|
|
return [
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=None,
|
|
kwargs=dict(lr=1e-2, weight_decay=-0.5),
|
|
desc="weight_decay should > 0",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="Invalid weight_decay value: -0.5",
|
|
),
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
|
|
kwargs={},
|
|
desc="invalid param type",
|
|
),
|
|
error_type=TypeError,
|
|
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
|
|
),
|
|
]
|
|
|
|
|
|
def optim_error_inputs_func_lbfgs(device, dtype):
|
|
return [
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
|
|
kwargs={},
|
|
desc="invalid param type",
|
|
),
|
|
error_type=TypeError,
|
|
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
|
|
),
|
|
]
|
|
|
|
|
|
def optim_error_inputs_func_nadam(device, dtype):
|
|
return [
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=None,
|
|
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
|
|
desc="beta1 should be between 0 and 1",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="Invalid beta parameter at index 0: 1.0",
|
|
),
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=None,
|
|
kwargs=dict(lr=1e-2, momentum_decay=-0.2),
|
|
desc="momentum_decay should > 0",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="Invalid momentum_decay value: -0.2",
|
|
),
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
|
|
kwargs={},
|
|
desc="invalid param type",
|
|
),
|
|
error_type=TypeError,
|
|
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
|
|
),
|
|
]
|
|
|
|
|
|
def optim_error_inputs_func_radam(device, dtype):
|
|
return [
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=None,
|
|
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
|
|
desc="beta1 should be between 0 and 1",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="Invalid beta parameter at index 0: 1.0",
|
|
),
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=None,
|
|
kwargs=dict(lr=1e-2, weight_decay=-1),
|
|
desc="weight_decay should > 0",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="Invalid weight_decay value: -1",
|
|
),
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
|
|
kwargs={},
|
|
desc="invalid param type",
|
|
),
|
|
error_type=TypeError,
|
|
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
|
|
),
|
|
]
|
|
|
|
|
|
def optim_error_inputs_func_rmsprop(device, dtype):
|
|
return [
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=None,
|
|
kwargs=dict(lr=1e-2, momentum=-1.0),
|
|
desc="momentum should be between 0 and 1",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="Invalid momentum value: -1.0",
|
|
),
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
|
|
kwargs={},
|
|
desc="invalid param type",
|
|
),
|
|
error_type=TypeError,
|
|
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
|
|
),
|
|
]
|
|
|
|
|
|
def optim_error_inputs_func_rprop(device, dtype):
|
|
return [
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=None,
|
|
kwargs=dict(lr=1e-2, etas=(1.0, 0.5)),
|
|
desc="0 < eta1 < 1 < eta2",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="Invalid eta values: 1.0, 0.5",
|
|
),
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
|
|
kwargs={},
|
|
desc="invalid param type",
|
|
),
|
|
error_type=TypeError,
|
|
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
|
|
),
|
|
]
|
|
|
|
|
|
def optim_error_inputs_func_sgd(device, dtype):
|
|
return [
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=None,
|
|
kwargs=dict(lr=1e-2, momentum=-0.5),
|
|
desc="momentum should be between 0 and 1",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="Invalid momentum value: -0.5",
|
|
),
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
|
|
kwargs={},
|
|
desc="invalid param type",
|
|
),
|
|
error_type=TypeError,
|
|
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
|
|
),
|
|
]
|
|
|
|
|
|
def optim_error_inputs_func_sparseadam(device, dtype):
|
|
return [
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=None,
|
|
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
|
|
desc="beta1 should be between 0 and 1",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="Invalid beta parameter at index 0: 1.0",
|
|
),
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=[
|
|
torch.zeros(3, layout=torch.sparse_coo, device=device, dtype=dtype)
|
|
],
|
|
kwargs={},
|
|
desc="dense params required",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="SparseAdam requires dense parameter tensors",
|
|
),
|
|
ErrorOptimizerInput(
|
|
OptimizerInput(
|
|
params=[
|
|
{
|
|
"params": [
|
|
torch.zeros(
|
|
3, layout=torch.sparse_coo, device=device, dtype=dtype
|
|
)
|
|
]
|
|
}
|
|
],
|
|
kwargs={},
|
|
desc="dense params required in param_groups",
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="SparseAdam requires dense parameter tensors",
|
|
),
|
|
]
|
|
|
|
|
|
# Database of OptimizerInfo entries in alphabetical order.
|
|
optim_db: List[OptimizerInfo] = [
|
|
OptimizerInfo(
|
|
Adadelta,
|
|
optim_inputs_func=None,
|
|
optim_error_inputs_func=optim_error_inputs_func_adadelta,
|
|
supported_impls=("foreach", "differentiable"),
|
|
),
|
|
OptimizerInfo(
|
|
Adagrad,
|
|
optim_inputs_func=None,
|
|
optim_error_inputs_func=optim_error_inputs_func_adagrad,
|
|
supported_impls=("foreach", "differentiable"),
|
|
supports_sparse_on=("cpu"),
|
|
),
|
|
OptimizerInfo(
|
|
Adam,
|
|
optim_inputs_func=None,
|
|
optim_error_inputs_func=optim_error_inputs_func_adam,
|
|
supported_impls=("foreach", "differentiable", "fused"),
|
|
),
|
|
OptimizerInfo(
|
|
Adamax,
|
|
optim_inputs_func=None,
|
|
optim_error_inputs_func=optim_error_inputs_func_adamax,
|
|
supported_impls=("foreach", "differentiable"),
|
|
),
|
|
OptimizerInfo(
|
|
AdamW,
|
|
optim_inputs_func=None,
|
|
optim_error_inputs_func=optim_error_inputs_func_adamw,
|
|
supported_impls=("foreach", "differentiable", "fused"),
|
|
),
|
|
OptimizerInfo(
|
|
ASGD,
|
|
optim_inputs_func=None,
|
|
optim_error_inputs_func=optim_error_inputs_func_asgd,
|
|
supported_impls=("foreach", "differentiable"),
|
|
),
|
|
OptimizerInfo(
|
|
LBFGS,
|
|
optim_inputs_func=None,
|
|
optim_error_inputs_func=optim_error_inputs_func_lbfgs,
|
|
supported_impls=(),
|
|
step_requires_closure=True,
|
|
supports_param_groups=False,
|
|
supports_multiple_devices=False,
|
|
),
|
|
OptimizerInfo(
|
|
NAdam,
|
|
optim_inputs_func=None,
|
|
optim_error_inputs_func=optim_error_inputs_func_nadam,
|
|
supported_impls=("foreach", "differentiable"),
|
|
),
|
|
OptimizerInfo(
|
|
RAdam,
|
|
optim_inputs_func=None,
|
|
optim_error_inputs_func=optim_error_inputs_func_radam,
|
|
supported_impls=("foreach", "differentiable"),
|
|
),
|
|
OptimizerInfo(
|
|
RMSprop,
|
|
optim_inputs_func=None,
|
|
optim_error_inputs_func=optim_error_inputs_func_rmsprop,
|
|
supported_impls=("foreach", "differentiable"),
|
|
),
|
|
OptimizerInfo(
|
|
Rprop,
|
|
optim_inputs_func=None,
|
|
optim_error_inputs_func=optim_error_inputs_func_rprop,
|
|
supported_impls=("foreach", "differentiable"),
|
|
),
|
|
OptimizerInfo(
|
|
SGD,
|
|
optim_inputs_func=None,
|
|
optim_error_inputs_func=optim_error_inputs_func_sgd,
|
|
supported_impls=("foreach", "differentiable"),
|
|
supports_sparse_on=("cpu", "cuda"),
|
|
),
|
|
OptimizerInfo(
|
|
SparseAdam,
|
|
optim_inputs_func=None,
|
|
optim_error_inputs_func=optim_error_inputs_func_sparseadam,
|
|
supported_impls=(),
|
|
only_supports_sparse_grads=True,
|
|
),
|
|
]
|