mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Fixes #102428 Also improves hook registration type hints: ```python from typing import Any, Dict, Tuple from torch import nn from torch.optim import Adam, Adagrad, Optimizer linear = nn.Linear(2,2) optimizer = Adam(linear.parameters(), lr=0.001) def pre_hook_fn_return_none(optimizer: Adam, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None: return None def pre_hook_fn_return_modified( optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: return inputs, kwargs def hook_fn(optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None: return None def hook_fn_other_optimizer(optimizer: Adagrad, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None: return None optimizer.register_step_post_hook(hook_fn) # OK optimizer.register_step_pre_hook(pre_hook_fn_return_none) # OK optimizer.register_step_pre_hook(pre_hook_fn_return_modified) # OK optimizer.register_step_post_hook(hook_fn_other_optimizer) # Parameter 1: type "Adam" cannot be assigned to type "Adagrad" ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/102593 Approved by: https://github.com/janeyx99, https://github.com/malfet
84 lines
2.8 KiB
Python
84 lines
2.8 KiB
Python
import enum
|
|
from typing import Any, Callable, Dict, List, Optional, overload, Set, Type
|
|
|
|
import torch
|
|
from torch.distributed.algorithms.join import Joinable, JoinHook
|
|
from torch.optim import Optimizer
|
|
|
|
class _ZeROJoinHook(JoinHook):
|
|
zero: Any = ...
|
|
def __init__(self, zero: Any) -> None: ...
|
|
def main_hook(self) -> None: ...
|
|
|
|
class _DDPBucketAssignment:
|
|
bucket_index: int
|
|
parameters: List[torch.Tensor]
|
|
offset: int
|
|
device: torch.device
|
|
tensor: Optional[torch.Tensor]
|
|
|
|
class _OverlapStatus(enum.IntEnum):
|
|
UNINITIALIZED: int = ...
|
|
DDP_HAS_REBUILT_BUCKETS: int = ...
|
|
INITIALIZED: int = ...
|
|
|
|
class _OverlapInfo:
|
|
status: Any = ...
|
|
params_per_bucket: Any = ...
|
|
params_per_rank: Any = ...
|
|
offsets: Any = ...
|
|
broadcast_handles: Any = ...
|
|
bucket_index_to_future: Any = ...
|
|
bucket_index_to_bucket: Any = ...
|
|
bucket_indices_seen: Any = ...
|
|
assigned_ranks_per_bucket: List[Set[int]] = ...
|
|
total_size: int = ...
|
|
shard_buckets: bool = ...
|
|
def __init__(self) -> None: ...
|
|
def wait_for_broadcasts(self) -> None: ...
|
|
def clear_per_iter_info(self) -> None: ...
|
|
|
|
class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|
functional_optim_map: Any = ...
|
|
initialized: bool = ...
|
|
process_group: Any = ...
|
|
world_size: int = ...
|
|
rank: int = ...
|
|
global_rank: int = ...
|
|
parameters_as_bucket_view: bool = ...
|
|
optim: Any = ...
|
|
_device_to_device_index: Dict[torch.device, int] = ...
|
|
_overlap_with_ddp: bool = ...
|
|
_overlap_info: _OverlapInfo = ...
|
|
_buckets: List[List[torch.Tensor]] = ...
|
|
_bucket_assignments_per_rank: List[Dict[int, _DDPBucketAssignment]] = ...
|
|
def __init__(
|
|
self,
|
|
params: Any,
|
|
optimizer_class: Type[Optimizer],
|
|
process_group: Optional[Any] = ...,
|
|
parameters_as_bucket_view: bool = ...,
|
|
overlap_with_ddp: bool = ...,
|
|
**defaults: Any,
|
|
) -> None: ...
|
|
def add_param_group(self, param_group: dict) -> None: ...
|
|
def consolidate_state_dict(self, to: int = ...) -> None: ...
|
|
@overload
|
|
def step(self, closure: None = ..., **kwargs: Any) -> None: ...
|
|
@overload
|
|
def step(self, closure: Callable[[], float], **kwargs: Any) -> float: ...
|
|
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ...
|
|
def state_dict(self) -> Dict[str, Any]: ...
|
|
def _local_step(
|
|
self,
|
|
gradients: Optional[List[Optional[torch.Tensor]]] = None,
|
|
closure: Optional[Callable[[], float]] = None,
|
|
**kwargs: Any,
|
|
) -> Optional[float]: ...
|
|
def _get_assigned_rank(self, bucket_index: int) -> int: ...
|
|
def _init_zero_for_overlap(self) -> None: ...
|
|
def join_hook(self, **kwargs): ...
|
|
@property
|
|
def join_device(self) -> torch.device: ...
|
|
def join_process_group(self) -> Any: ...
|