pytorch/torch/distributed/optim/utils.py
Aaron Gokaslan 292af3cc89 [BE][Ez]: ISC001 Auto concatenate implicit one line strings (#146408)
Apply ruff rule about implicit string concatenation, this autofixes strings that are all the same type and on the same line. These lines are broken up likely as the result of autoformatters in the past. All fixes are automated using the autofixes in ISC001.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146408
Approved by: https://github.com/justinchuby, https://github.com/janeyx99
2025-02-04 19:07:04 +00:00

66 lines
2.2 KiB
Python

# mypy: allow-untyped-defs
from torch import optim
from .functional_adadelta import _FunctionalAdadelta
from .functional_adagrad import _FunctionalAdagrad
from .functional_adam import _FunctionalAdam
from .functional_adamax import _FunctionalAdamax
from .functional_adamw import _FunctionalAdamW
from .functional_rmsprop import _FunctionalRMSprop
from .functional_rprop import _FunctionalRprop
from .functional_sgd import _FunctionalSGD
# dict to map a user passed in optimizer_class to a functional
# optimizer class if we have already defined inside the
# distributed.optim package, this is so that we hide the
# functional optimizer to user and still provide the same API.
functional_optim_map = {
optim.Adagrad: _FunctionalAdagrad,
optim.Adam: _FunctionalAdam,
optim.AdamW: _FunctionalAdamW,
optim.SGD: _FunctionalSGD,
optim.Adadelta: _FunctionalAdadelta,
optim.RMSprop: _FunctionalRMSprop,
optim.Rprop: _FunctionalRprop,
optim.Adamax: _FunctionalAdamax,
}
def register_functional_optim(key, optim):
"""
Interface to insert a new functional optimizer to functional_optim_map
``fn_optim_key`` and ``fn_optimizer`` are user defined. The optimizer and key
need not be of :class:`torch.optim.Optimizer` (e.g. for custom optimizers)
Example::
>>> # import the new functional optimizer
>>> # xdoctest: +SKIP
>>> from xyz import fn_optimizer
>>> from torch.distributed.optim.utils import register_functional_optim
>>> fn_optim_key = "XYZ_optim"
>>> register_functional_optim(fn_optim_key, fn_optimizer)
"""
if key not in functional_optim_map:
functional_optim_map[key] = optim
def as_functional_optim(optim_cls: type, *args, **kwargs):
try:
functional_cls = functional_optim_map[optim_cls]
except KeyError as e:
raise ValueError(
f"Optimizer {optim_cls} does not have a functional counterpart!"
) from e
return _create_functional_optim(functional_cls, *args, **kwargs)
def _create_functional_optim(functional_optim_cls: type, *args, **kwargs):
return functional_optim_cls(
[],
*args,
**kwargs,
_allow_empty_param_list=True,
)