mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This is a new version of #15648 based on the latest master branch. Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR. In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.) Fixes https://github.com/pytorch/pytorch/issues/71105 @ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797 Approved by: https://github.com/ezyang
59 lines
2.1 KiB
Python
59 lines
2.1 KiB
Python
from typing import Type
|
|
from torch import optim
|
|
from .functional_adagrad import _FunctionalAdagrad
|
|
from .functional_adam import _FunctionalAdam
|
|
from .functional_adamw import _FunctionalAdamW
|
|
from .functional_sgd import _FunctionalSGD
|
|
from .functional_adadelta import _FunctionalAdadelta
|
|
from .functional_rmsprop import _FunctionalRMSprop
|
|
from .functional_rprop import _FunctionalRprop
|
|
from .functional_adamax import _FunctionalAdamax
|
|
|
|
# dict to map a user passed in optimizer_class to a functional
|
|
# optimizer class if we have already defined inside the
|
|
# distributed.optim package, this is so that we hide the
|
|
# functional optimizer to user and still provide the same API.
|
|
functional_optim_map = {
|
|
optim.Adagrad: _FunctionalAdagrad,
|
|
optim.Adam: _FunctionalAdam,
|
|
optim.AdamW: _FunctionalAdamW,
|
|
optim.SGD: _FunctionalSGD,
|
|
optim.Adadelta: _FunctionalAdadelta,
|
|
optim.RMSprop: _FunctionalRMSprop,
|
|
optim.Rprop: _FunctionalRprop,
|
|
optim.Adamax: _FunctionalAdamax,
|
|
}
|
|
|
|
|
|
def register_functional_optim(key, optim):
|
|
"""
|
|
Interface to insert a new functional optimizer to functional_optim_map
|
|
``fn_optim_key`` and ``fn_optimizer`` are user defined. The optimizer and key
|
|
need not be of :class:`torch.optim.Optimizer` (e.g. for custom optimizers)
|
|
Example::
|
|
>>> # import the new functional optimizer
|
|
>>> # xdoctest: +SKIP
|
|
>>> from xyz import fn_optimizer
|
|
>>> from torch.distributed.optim.utils import register_functional_optim
|
|
>>> fn_optim_key = "XYZ_optim"
|
|
>>> register_functional_optim(fn_optim_key, fn_optimizer)
|
|
"""
|
|
if key not in functional_optim_map:
|
|
functional_optim_map[key] = optim
|
|
|
|
def as_functional_optim(optim_cls: Type, *args, **kwargs):
|
|
try:
|
|
functional_cls = functional_optim_map[optim_cls]
|
|
except KeyError:
|
|
raise ValueError(f"Optimizer {optim_cls} does not have a functional counterpart!")
|
|
|
|
return _create_functional_optim(functional_cls, *args, **kwargs)
|
|
|
|
def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
|
|
return functional_optim_cls(
|
|
[],
|
|
*args,
|
|
**kwargs,
|
|
_allow_empty_param_list=True,
|
|
)
|