mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71620
Remove from_functional_optim and make it the default constructor since
that is the only way _OptimizerHookState is now being built. Also, no longer
need to expose create_functional_optim helper function
ghstack-source-id: 147577174
Test Plan: CI
Reviewed By: cbalioglu
Differential Revision: D33700593
fbshipit-source-id: ba089ce3bf66ccf8f71cffdd0f4d4bddc03e8b14
(cherry picked from commit a50b2caf0e)
42 lines
1.5 KiB
Python
42 lines
1.5 KiB
Python
from typing import Type
|
|
from torch import optim
|
|
from .functional_adagrad import _FunctionalAdagrad
|
|
from .functional_adam import _FunctionalAdam
|
|
from .functional_adamw import _FunctionalAdamW
|
|
from .functional_sgd import _FunctionalSGD
|
|
from .functional_adadelta import _FunctionalAdadelta
|
|
from .functional_rmsprop import _FunctionalRMSprop
|
|
from .functional_rprop import _FunctionalRprop
|
|
from .functional_adamax import _FunctionalAdamax
|
|
|
|
# dict to map a user passed in optimizer_class to a functional
|
|
# optimizer class if we have already defined inside the
|
|
# distributed.optim package, this is so that we hide the
|
|
# functional optimizer to user and still provide the same API.
|
|
functional_optim_map = {
|
|
optim.Adagrad: _FunctionalAdagrad,
|
|
optim.Adam: _FunctionalAdam,
|
|
optim.AdamW: _FunctionalAdamW,
|
|
optim.SGD: _FunctionalSGD,
|
|
optim.Adadelta: _FunctionalAdadelta,
|
|
optim.RMSprop: _FunctionalRMSprop,
|
|
optim.Rprop: _FunctionalRprop,
|
|
optim.Adamax: _FunctionalAdamax,
|
|
}
|
|
|
|
def as_functional_optim(optim_cls: Type, *args, **kwargs):
|
|
try:
|
|
functional_cls = functional_optim_map[optim_cls]
|
|
except KeyError:
|
|
raise ValueError(f"Optimizer {optim_cls} does not have a functional counterpart!")
|
|
|
|
return _create_functional_optim(functional_cls, *args, **kwargs)
|
|
|
|
def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
|
|
return functional_optim_cls(
|
|
[],
|
|
*args,
|
|
**kwargs,
|
|
_allow_empty_param_list=True,
|
|
)
|