mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Opt Overlap] Clean up code in _OptimizerHookState (#71620)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71620
Remove from_functional_optim and make it the default constructor since
that is the only way _OptimizerHookState is now being built. Also, no longer
need to expose create_functional_optim helper function
ghstack-source-id: 147577174
Test Plan: CI
Reviewed By: cbalioglu
Differential Revision: D33700593
fbshipit-source-id: ba089ce3bf66ccf8f71cffdd0f4d4bddc03e8b14
(cherry picked from commit a50b2caf0e)
This commit is contained in:
parent
1c8fcc44cb
commit
bdcdf94bdd
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch.optim import SGD, Adam, AdamW
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.distributed.optim import functional_optim_map
|
||||
from torch.distributed.optim.utils import functional_optim_map
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class _OverlappedStandardOptimizer(OverlappedOptimizer):
|
|||
def __init__(self, optim_cls: Type, params, *optim_args, **optim_kwargs) -> None:
|
||||
super().__init__(optim_cls)
|
||||
f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs)
|
||||
self._opt_hook_state = _OptimizerHookState.from_functional_optim(f_optim, params)
|
||||
self._opt_hook_state = _OptimizerHookState(f_optim, params)
|
||||
|
||||
def register_ddp(self, ddp_inst: DistributedDataParallel):
|
||||
# NOTE: using a custom communication hook and fused optimizer is not
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ from typing import Any, Callable
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.optim import create_functional_optim
|
||||
|
||||
_FUNCTIONAL_OPTIM_STEP_METHOD_NAME = "step_param"
|
||||
|
||||
|
|
@ -14,32 +13,11 @@ class _OptimizerHookState(object):
|
|||
|
||||
__slots__ = ["functional_optimizer", "params_to_optimize"]
|
||||
|
||||
def __init__(
|
||||
self, functional_optim_cls, *functional_optim_args, params=None, **functional_optim_kwargs
|
||||
):
|
||||
self.functional_optimizer = create_functional_optim(
|
||||
functional_optim_cls,
|
||||
*functional_optim_args,
|
||||
**functional_optim_kwargs,
|
||||
)
|
||||
def __init__(self, functional_optim, params=None):
|
||||
self.functional_optimizer = functional_optim
|
||||
self._check_valid_functional_optim()
|
||||
self._set_params_to_optimize(params)
|
||||
|
||||
@classmethod
|
||||
def from_functional_optim(cls, functional_optim, params=None):
|
||||
r"""
|
||||
Create a `_OptimizerHookState`, which simply
|
||||
holds a functional optimizer, directly from a
|
||||
functional optimizer given by `functional_optim`.
|
||||
Note that the `functional_optim` must implement
|
||||
`step_param` to support per-parameter optimization.
|
||||
"""
|
||||
opt_hook_state_inst = cls.__new__(cls) # Does not call __init__
|
||||
opt_hook_state_inst.functional_optimizer = functional_optim
|
||||
opt_hook_state_inst._check_valid_functional_optim()
|
||||
opt_hook_state_inst._set_params_to_optimize(params)
|
||||
return opt_hook_state_inst
|
||||
|
||||
def _set_params_to_optimize(self, params):
|
||||
if params is not None:
|
||||
self.params_to_optimize = set(params)
|
||||
|
|
|
|||
|
|
@ -16,11 +16,7 @@ from .functional_adadelta import _FunctionalAdadelta
|
|||
from .functional_rmsprop import _FunctionalRMSprop
|
||||
from .functional_rprop import _FunctionalRprop
|
||||
from .functional_adamax import _FunctionalAdamax
|
||||
from .utils import (
|
||||
functional_optim_map,
|
||||
create_functional_optim,
|
||||
as_functional_optim,
|
||||
)
|
||||
from .utils import as_functional_optim
|
||||
|
||||
|
||||
# DistributedOptimizer imports torch.distributed.rpc names, so gate availability
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import torch.jit as jit
|
|||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.distributed.rpc import RRef
|
||||
from torch.distributed.optim import functional_optim_map
|
||||
from .utils import functional_optim_map
|
||||
import torch.distributed.autograd as dist_autograd
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -30,9 +30,9 @@ def as_functional_optim(optim_cls: Type, *args, **kwargs):
|
|||
except KeyError:
|
||||
raise ValueError(f"Optimizer {optim_cls} does not have a functional counterpart!")
|
||||
|
||||
return create_functional_optim(functional_cls, *args, **kwargs)
|
||||
return _create_functional_optim(functional_cls, *args, **kwargs)
|
||||
|
||||
def create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
|
||||
def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
|
||||
return functional_optim_cls(
|
||||
[],
|
||||
*args,
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Type
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
|
||||
from torch.distributed.optim import functional_optim_map
|
||||
from torch.distributed.optim.utils import functional_optim_map
|
||||
from torch.optim import Optimizer
|
||||
|
||||
__all__ = ["ZeroRedundancyOptimizer"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user