mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
API to retrieve in backward optimizer for checkpointing purposes Differential Revision: [D47782225](https://our.internmc.facebook.com/intern/diff/D47782225/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/105991 Approved by: https://github.com/awgu
35 lines
1.3 KiB
Python
35 lines
1.3 KiB
Python
"""
|
|
:mod:`torch.distributed.optim` exposes DistributedOptimizer, which takes a list
|
|
of remote parameters (:class:`~torch.distributed.rpc.RRef`) and runs the
|
|
optimizer locally on the workers where the parameters live. The distributed
|
|
optimizer can use any of the local optimizer :ref:`optimizer-algorithms` to
|
|
apply the gradients on each worker.
|
|
"""
|
|
import torch
|
|
from torch import optim
|
|
|
|
from .apply_optimizer_in_backward import (
|
|
_apply_optimizer_in_backward,
|
|
_get_in_backward_optimizers,
|
|
)
|
|
from .functional_adadelta import _FunctionalAdadelta
|
|
|
|
from .functional_adagrad import _FunctionalAdagrad
|
|
from .functional_adam import _FunctionalAdam
|
|
from .functional_adamax import _FunctionalAdamax
|
|
from .functional_adamw import _FunctionalAdamW
|
|
from .functional_rmsprop import _FunctionalRMSprop
|
|
from .functional_rprop import _FunctionalRprop
|
|
from .functional_sgd import _FunctionalSGD
|
|
from .named_optimizer import _NamedOptimizer
|
|
from .utils import as_functional_optim
|
|
|
|
|
|
# DistributedOptimizer imports torch.distributed.rpc names, so gate availability
|
|
# based on RPC being available.
|
|
if hasattr(torch._C, "_rpc_init"):
|
|
from .optimizer import DistributedOptimizer
|
|
|
|
from .post_localSGD_optimizer import PostLocalSGDOptimizer
|
|
from .zero_redundancy_optimizer import ZeroRedundancyOptimizer
|