mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert D30117838: [WIP] Gate DistributedOptimizers on RPC availability
Test Plan: revert-hammer
Differential Revision:
D30117838 (3f09485d7e)
Original commit changeset: e6365a910a3d
fbshipit-source-id: f276b2b2bdf5f7bd27df473fca0eebaee9f7aef2
This commit is contained in:
parent
e6a3154519
commit
b45cf9b81b
|
|
@ -233,6 +233,7 @@ WINDOWS_BLOCKLIST = [
|
|||
'distributed/pipeline/sync/test_stream',
|
||||
'distributed/pipeline/sync/test_transparency',
|
||||
'distributed/pipeline/sync/test_worker',
|
||||
'distributed/optim/test_zero_redundancy_optimizer',
|
||||
"distributed/elastic/agent/server/test/api_test",
|
||||
'distributed/elastic/multiprocessing/api_test',
|
||||
]
|
||||
|
|
|
|||
|
|
@ -5,34 +5,6 @@ optimizer locally on the workers where the parameters live. The distributed
|
|||
optimizer can use any of the local optimizer :ref:`optimizer-algorithms` to
|
||||
apply the gradients on each worker.
|
||||
"""
|
||||
import torch
|
||||
from torch import optim
|
||||
from .functional_adagrad import _FunctionalAdagrad
|
||||
from .functional_adam import _FunctionalAdam
|
||||
from .functional_adamw import _FunctionalAdamW
|
||||
from .functional_sgd import _FunctionalSGD
|
||||
from .functional_adadelta import _FunctionalAdadelta
|
||||
from .functional_rmsprop import _FunctionalRMSprop
|
||||
from .functional_rprop import _FunctionalRprop
|
||||
from .functional_adamax import _FunctionalAdamax
|
||||
|
||||
# dict to map a user passed in optimizer_class to a functional
|
||||
# optimizer class if we have already defined inside the
|
||||
# distributed.optim package, this is so that we hide the
|
||||
# functional optimizer to user and still provide the same API.
|
||||
functional_optim_map = {
|
||||
optim.Adagrad: _FunctionalAdagrad,
|
||||
optim.Adam: _FunctionalAdam,
|
||||
optim.AdamW: _FunctionalAdamW,
|
||||
optim.SGD: _FunctionalSGD,
|
||||
optim.Adadelta: _FunctionalAdadelta,
|
||||
optim.RMSprop: _FunctionalRMSprop,
|
||||
optim.Rprop: _FunctionalRprop,
|
||||
optim.Adamax: _FunctionalAdamax,
|
||||
}
|
||||
|
||||
if hasattr(torch._C, '_rpc_init'):
|
||||
from .optimizer import DistributedOptimizer
|
||||
|
||||
from .post_localSGD_optimizer import PostLocalSGDOptimizer
|
||||
from .zero_redundancy_optimizer import ZeroRedundancyOptimizer
|
||||
|
|
|
|||
|
|
@ -2,11 +2,19 @@ from typing import List, Optional
|
|||
import logging
|
||||
|
||||
import torch.distributed.rpc as rpc
|
||||
import torch.optim as optim
|
||||
import torch.jit as jit
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.distributed.rpc import RRef
|
||||
from torch.distributed.optim import functional_optim_map
|
||||
from .functional_adagrad import _FunctionalAdagrad
|
||||
from .functional_adam import _FunctionalAdam
|
||||
from .functional_adamw import _FunctionalAdamW
|
||||
from .functional_sgd import _FunctionalSGD
|
||||
from .functional_adadelta import _FunctionalAdadelta
|
||||
from .functional_rmsprop import _FunctionalRMSprop
|
||||
from .functional_rprop import _FunctionalRprop
|
||||
from .functional_adamax import _FunctionalAdamax
|
||||
import torch.distributed.autograd as dist_autograd
|
||||
|
||||
|
||||
|
|
@ -185,13 +193,28 @@ class DistributedOptimizer:
|
|||
__ https://github.com/pytorch/tutorials/pull/1465
|
||||
"""
|
||||
|
||||
# dict to map a user passed in optimizer_class to a functional
|
||||
# optimizer class if we have already defined inside the
|
||||
# distributed.optim package, this is so that we hide the
|
||||
# functional optimizer to user and still provide the same API.
|
||||
functional_optim_map = {
|
||||
optim.Adagrad: _FunctionalAdagrad,
|
||||
optim.Adam: _FunctionalAdam,
|
||||
optim.AdamW: _FunctionalAdamW,
|
||||
optim.SGD: _FunctionalSGD,
|
||||
optim.Adadelta: _FunctionalAdadelta,
|
||||
optim.RMSprop: _FunctionalRMSprop,
|
||||
optim.Rprop: _FunctionalRprop,
|
||||
optim.Adamax: _FunctionalAdamax,
|
||||
}
|
||||
|
||||
def __init__(self, optimizer_class, params_rref, *args, **kwargs):
|
||||
per_worker_params_rref = defaultdict(list)
|
||||
for param in params_rref:
|
||||
per_worker_params_rref[param.owner()].append(param)
|
||||
|
||||
if optimizer_class in functional_optim_map and jit._state._enabled:
|
||||
optim_ctor = functional_optim_map.get(optimizer_class)
|
||||
if optimizer_class in DistributedOptimizer.functional_optim_map and jit._state._enabled:
|
||||
optim_ctor = DistributedOptimizer.functional_optim_map.get(optimizer_class)
|
||||
else:
|
||||
optim_ctor = optimizer_class
|
||||
self.is_functional_optim = (optim_ctor != optimizer_class)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, Union
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
|
||||
from torch.distributed.optim import functional_optim_map
|
||||
from torch.distributed.optim import DistributedOptimizer
|
||||
from torch.optim import Optimizer
|
||||
|
||||
__all__ = ["ZeroRedundancyOptimizer"]
|
||||
|
|
@ -309,6 +309,8 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
|
||||
"""
|
||||
|
||||
functional_optim_map = DistributedOptimizer.functional_optim_map
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
|
|
@ -1335,6 +1337,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
- if ``overlap_with_ddp=False`` and ``optimizer_class`` is a
|
||||
functional optimizer.
|
||||
"""
|
||||
functional_optim_map = ZeroRedundancyOptimizer.functional_optim_map
|
||||
functional_optims = functional_optim_map.values()
|
||||
if not self._overlap_with_ddp:
|
||||
if optimizer_class in functional_optims:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user