pytorch/torch/distributed/optim/utils.py
Rohan Varma bdcdf94bdd [Opt Overlap] Clean up code in _OptimizerHookState (#71620)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71620

Remove from_functional_optim and make it the default constructor since
that is the only way _OptimizerHookState is now being built. Also, no longer
need to expose create_functional_optim helper function
ghstack-source-id: 147577174

Test Plan: CI

Reviewed By: cbalioglu

Differential Revision: D33700593

fbshipit-source-id: ba089ce3bf66ccf8f71cffdd0f4d4bddc03e8b14
(cherry picked from commit a50b2caf0e)
2022-01-26 19:33:49 +00:00

42 lines
1.5 KiB
Python

from typing import Type
from torch import optim
from .functional_adagrad import _FunctionalAdagrad
from .functional_adam import _FunctionalAdam
from .functional_adamw import _FunctionalAdamW
from .functional_sgd import _FunctionalSGD
from .functional_adadelta import _FunctionalAdadelta
from .functional_rmsprop import _FunctionalRMSprop
from .functional_rprop import _FunctionalRprop
from .functional_adamax import _FunctionalAdamax
# dict to map a user passed in optimizer_class to a functional
# optimizer class if we have already defined inside the
# distributed.optim package, this is so that we hide the
# functional optimizer to user and still provide the same API.
functional_optim_map = {
optim.Adagrad: _FunctionalAdagrad,
optim.Adam: _FunctionalAdam,
optim.AdamW: _FunctionalAdamW,
optim.SGD: _FunctionalSGD,
optim.Adadelta: _FunctionalAdadelta,
optim.RMSprop: _FunctionalRMSprop,
optim.Rprop: _FunctionalRprop,
optim.Adamax: _FunctionalAdamax,
}
def as_functional_optim(optim_cls: Type, *args, **kwargs):
try:
functional_cls = functional_optim_map[optim_cls]
except KeyError:
raise ValueError(f"Optimizer {optim_cls} does not have a functional counterpart!")
return _create_functional_optim(functional_cls, *args, **kwargs)
def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
return functional_optim_cls(
[],
*args,
**kwargs,
_allow_empty_param_list=True,
)