mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71604
Implement 2 helper functions:
- as_functional_optim which takes in a torch.optim class type and arguments and
creates the corresponding functional optimizer.
- create_functional_optim which takes in the functional optimizer class type
and constructs it. Note that as_functional_optim calls into
create_functional_optim.
The first will be used in future PRs as described in
https://github.com/pytorch/pytorch/issues/67570 to create a functional
optimizer from a traditional optimizer. The latter is used in
_OptimizerHookState to create a functional optimizer.
Both new helper functions are covered by unittests.
ghstack-source-id: 147577170
Test Plan: CI
Reviewed By: cbalioglu
Differential Revision: D33688995
fbshipit-source-id: 8b2daafd1b914efa90877cc4313aa9a428546fc1
(cherry picked from commit 42fdae2991)
42 lines
1.5 KiB
Python
42 lines
1.5 KiB
Python
from typing import Type
|
|
from torch import optim
|
|
from .functional_adagrad import _FunctionalAdagrad
|
|
from .functional_adam import _FunctionalAdam
|
|
from .functional_adamw import _FunctionalAdamW
|
|
from .functional_sgd import _FunctionalSGD
|
|
from .functional_adadelta import _FunctionalAdadelta
|
|
from .functional_rmsprop import _FunctionalRMSprop
|
|
from .functional_rprop import _FunctionalRprop
|
|
from .functional_adamax import _FunctionalAdamax
|
|
|
|
# dict to map a user passed in optimizer_class to a functional
|
|
# optimizer class if we have already defined inside the
|
|
# distributed.optim package, this is so that we hide the
|
|
# functional optimizer to user and still provide the same API.
|
|
functional_optim_map = {
|
|
optim.Adagrad: _FunctionalAdagrad,
|
|
optim.Adam: _FunctionalAdam,
|
|
optim.AdamW: _FunctionalAdamW,
|
|
optim.SGD: _FunctionalSGD,
|
|
optim.Adadelta: _FunctionalAdadelta,
|
|
optim.RMSprop: _FunctionalRMSprop,
|
|
optim.Rprop: _FunctionalRprop,
|
|
optim.Adamax: _FunctionalAdamax,
|
|
}
|
|
|
|
def as_functional_optim(optim_cls: Type, *args, **kwargs):
|
|
try:
|
|
functional_cls = functional_optim_map[optim_cls]
|
|
except KeyError:
|
|
raise ValueError(f"Optimizer {optim_cls} does not have a functional counterpart!")
|
|
|
|
return create_functional_optim(functional_cls, *args, **kwargs)
|
|
|
|
def create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
|
|
return functional_optim_cls(
|
|
[],
|
|
*args,
|
|
**kwargs,
|
|
_allow_empty_param_list=True,
|
|
)
|