pytorch/torch/distributed/optim/utils.py
joncrall 4618371da5 Integrate xdoctest - Rebased (#82797)
This is a new version of #15648 based on the latest master branch.

Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.

In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)

Fixes https://github.com/pytorch/pytorch/issues/71105

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
2022-08-12 02:08:01 +00:00

59 lines
2.1 KiB
Python

from typing import Type
from torch import optim
from .functional_adagrad import _FunctionalAdagrad
from .functional_adam import _FunctionalAdam
from .functional_adamw import _FunctionalAdamW
from .functional_sgd import _FunctionalSGD
from .functional_adadelta import _FunctionalAdadelta
from .functional_rmsprop import _FunctionalRMSprop
from .functional_rprop import _FunctionalRprop
from .functional_adamax import _FunctionalAdamax
# dict to map a user passed in optimizer_class to a functional
# optimizer class if we have already defined inside the
# distributed.optim package, this is so that we hide the
# functional optimizer to user and still provide the same API.
functional_optim_map = {
optim.Adagrad: _FunctionalAdagrad,
optim.Adam: _FunctionalAdam,
optim.AdamW: _FunctionalAdamW,
optim.SGD: _FunctionalSGD,
optim.Adadelta: _FunctionalAdadelta,
optim.RMSprop: _FunctionalRMSprop,
optim.Rprop: _FunctionalRprop,
optim.Adamax: _FunctionalAdamax,
}
def register_functional_optim(key, optim):
"""
Interface to insert a new functional optimizer to functional_optim_map
``fn_optim_key`` and ``fn_optimizer`` are user defined. The optimizer and key
need not be of :class:`torch.optim.Optimizer` (e.g. for custom optimizers)
Example::
>>> # import the new functional optimizer
>>> # xdoctest: +SKIP
>>> from xyz import fn_optimizer
>>> from torch.distributed.optim.utils import register_functional_optim
>>> fn_optim_key = "XYZ_optim"
>>> register_functional_optim(fn_optim_key, fn_optimizer)
"""
if key not in functional_optim_map:
functional_optim_map[key] = optim
def as_functional_optim(optim_cls: Type, *args, **kwargs):
try:
functional_cls = functional_optim_map[optim_cls]
except KeyError:
raise ValueError(f"Optimizer {optim_cls} does not have a functional counterpart!")
return _create_functional_optim(functional_cls, *args, **kwargs)
def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
return functional_optim_cls(
[],
*args,
**kwargs,
_allow_empty_param_list=True,
)