[Opt Overlap] Clean up code in _OptimizerHookState (#71620)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71620

Remove from_functional_optim and make it the default constructor since
that is the only way _OptimizerHookState is now being built. Also, no longer
need to expose create_functional_optim helper function
ghstack-source-id: 147577174

Test Plan: CI

Reviewed By: cbalioglu

Differential Revision: D33700593

fbshipit-source-id: ba089ce3bf66ccf8f71cffdd0f4d4bddc03e8b14
(cherry picked from commit a50b2caf0e)
This commit is contained in:
Rohan Varma 2022-01-26 10:35:33 -08:00 committed by PyTorch MergeBot
parent 1c8fcc44cb
commit bdcdf94bdd
7 changed files with 9 additions and 35 deletions

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam, AdamW
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.distributed.optim import functional_optim_map
from torch.distributed.optim.utils import functional_optim_map
class MyModule(torch.nn.Module):
def __init__(self):

View File

@ -59,7 +59,7 @@ class _OverlappedStandardOptimizer(OverlappedOptimizer):
def __init__(self, optim_cls: Type, params, *optim_args, **optim_kwargs) -> None:
super().__init__(optim_cls)
f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs)
self._opt_hook_state = _OptimizerHookState.from_functional_optim(f_optim, params)
self._opt_hook_state = _OptimizerHookState(f_optim, params)
def register_ddp(self, ddp_inst: DistributedDataParallel):
# NOTE: using a custom communication hook and fused optimizer is not

View File

@ -2,7 +2,6 @@ from typing import Any, Callable
import torch
import torch.distributed as dist
from torch.distributed.optim import create_functional_optim
_FUNCTIONAL_OPTIM_STEP_METHOD_NAME = "step_param"
@ -14,32 +13,11 @@ class _OptimizerHookState(object):
__slots__ = ["functional_optimizer", "params_to_optimize"]
def __init__(
self, functional_optim_cls, *functional_optim_args, params=None, **functional_optim_kwargs
):
self.functional_optimizer = create_functional_optim(
functional_optim_cls,
*functional_optim_args,
**functional_optim_kwargs,
)
def __init__(self, functional_optim, params=None):
self.functional_optimizer = functional_optim
self._check_valid_functional_optim()
self._set_params_to_optimize(params)
@classmethod
def from_functional_optim(cls, functional_optim, params=None):
r"""
Create a `_OptimizerHookState`, which simply
holds a functional optimizer, directly from a
functional optimizer given by `functional_optim`.
Note that the `functional_optim` must implement
`step_param` to support per-parameter optimization.
"""
opt_hook_state_inst = cls.__new__(cls) # Does not call __init__
opt_hook_state_inst.functional_optimizer = functional_optim
opt_hook_state_inst._check_valid_functional_optim()
opt_hook_state_inst._set_params_to_optimize(params)
return opt_hook_state_inst
def _set_params_to_optimize(self, params):
if params is not None:
self.params_to_optimize = set(params)

View File

@ -16,11 +16,7 @@ from .functional_adadelta import _FunctionalAdadelta
from .functional_rmsprop import _FunctionalRMSprop
from .functional_rprop import _FunctionalRprop
from .functional_adamax import _FunctionalAdamax
from .utils import (
functional_optim_map,
create_functional_optim,
as_functional_optim,
)
from .utils import as_functional_optim
# DistributedOptimizer imports torch.distributed.rpc names, so gate availability

View File

@ -7,7 +7,7 @@ import torch.jit as jit
import torch.nn as nn
from torch import Tensor
from torch.distributed.rpc import RRef
from torch.distributed.optim import functional_optim_map
from .utils import functional_optim_map
import torch.distributed.autograd as dist_autograd

View File

@ -30,9 +30,9 @@ def as_functional_optim(optim_cls: Type, *args, **kwargs):
except KeyError:
raise ValueError(f"Optimizer {optim_cls} does not have a functional counterpart!")
return create_functional_optim(functional_cls, *args, **kwargs)
return _create_functional_optim(functional_cls, *args, **kwargs)
def create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
return functional_optim_cls(
[],
*args,

View File

@ -15,7 +15,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Type
import torch
import torch.distributed as dist
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
from torch.distributed.optim import functional_optim_map
from torch.distributed.optim.utils import functional_optim_map
from torch.optim import Optimizer
__all__ = ["ZeroRedundancyOptimizer"]