mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Attempts to fix #92656 BC-breaking! This changes the default of zero_grad in optim and in nn to default set grads to None instead of zero tensors. We are changing the default because there are proven perf wins and existing code has typically not regressed due to this change. (will probably have to flesh out this note more). Pull Request resolved: https://github.com/pytorch/pytorch/pull/92731 Approved by: https://github.com/ngimel
110 lines
4.3 KiB
Python
110 lines
4.3 KiB
Python
import warnings
|
|
|
|
import torch
|
|
import torch.distributed.algorithms.model_averaging.averagers as averagers
|
|
|
|
|
|
class PostLocalSGDOptimizer(torch.optim.Optimizer):
|
|
r"""
|
|
Wraps an arbitrary :class:`torch.optim.Optimizer` and runs `post-local SGD <https://arxiv.org/abs/1808.07217>`_,
|
|
This optimizer runs local optimizer at every step.
|
|
After the warm-up stage, it averages parameters periodically afer the local optimizer is applied.
|
|
|
|
Args:
|
|
optim: The local optimizer.
|
|
averager: A model averager instance to run post-localSGD algorithm.
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +SKIP("undefined variables")
|
|
>>> import torch
|
|
>>> import torch.distributed as dist
|
|
>>> import torch.distributed.algorithms.model_averaging.averagers as averagers
|
|
>>> import torch.nn as nn
|
|
>>> from torch.distributed.optim import PostLocalSGDOptimizer
|
|
>>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
|
|
>>> PostLocalSGDState,
|
|
>>> post_localSGD_hook,
|
|
>>> )
|
|
>>>
|
|
>>> model = nn.parallel.DistributedDataParallel(
|
|
>>> module, device_ids=[rank], output_device=rank
|
|
>>> )
|
|
>>>
|
|
>>> # Register a post-localSGD communication hook.
|
|
>>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
|
|
>>> model.register_comm_hook(state, post_localSGD_hook)
|
|
>>>
|
|
>>> # Create a post-localSGD optimizer that wraps a local optimizer.
|
|
>>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as
|
|
>>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``.
|
|
>>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
|
|
>>> opt = PostLocalSGDOptimizer(
|
|
>>> optim=local_optim,
|
|
>>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
|
|
>>> )
|
|
>>>
|
|
>>> # In the first 100 steps, DDP runs global gradient averaging at every step.
|
|
>>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default),
|
|
>>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer.
|
|
>>> for step in range(0, 200):
|
|
>>> opt.zero_grad()
|
|
>>> loss = loss_fn(output, labels)
|
|
>>> loss.backward()
|
|
>>> opt.step()
|
|
"""
|
|
|
|
def __init__(self, optim: torch.optim.Optimizer, averager: averagers.ModelAverager):
|
|
self.optim = optim
|
|
self.param_groups = self.optim.param_groups
|
|
self.averager = averager
|
|
|
|
@property
|
|
def state(self):
|
|
return self.optim.state
|
|
|
|
def __repr__(self):
|
|
return self.optim.__repr__()
|
|
|
|
def state_dict(self):
|
|
r"""
|
|
This is the same as :class:`torch.optim.Optimizer` :meth:`state_dict`,
|
|
but adds an extra entry to record model averager's step to the checkpoint
|
|
to ensure reload does not cause unnecessary warm up again.
|
|
"""
|
|
optim_state_dict = self.optim.state_dict()
|
|
optim_state_dict["step"] = self.averager.step
|
|
return optim_state_dict
|
|
|
|
def load_state_dict(self, state_dict):
|
|
r"""
|
|
This is the same as :class:`torch.optim.Optimizer` :meth:`load_state_dict`,
|
|
but also restores model averager's step value to the one
|
|
saved in the provided ``state_dict``.
|
|
|
|
If there is no ``"step"`` entry in ``state_dict``,
|
|
it will raise a warning and initialize the model averager's step to 0.
|
|
"""
|
|
self.optim.load_state_dict(state_dict)
|
|
if "step" in state_dict:
|
|
self.averager.step = state_dict["step"]
|
|
else:
|
|
warnings.warn(
|
|
"Loaded state dict does not contain a step counter for an averager. "
|
|
"Setting step counter to 0."
|
|
)
|
|
self.averager.step = 0
|
|
|
|
def step(self):
|
|
r"""
|
|
Performs a single optimization step (parameter update).
|
|
"""
|
|
self.optim.step()
|
|
self.averager.average_parameters(params=self.param_groups)
|
|
|
|
def zero_grad(self, set_to_none: bool = True): # type: ignore[override]
|
|
self.optim.zero_grad(set_to_none=set_to_none)
|
|
|
|
def add_param_group(self, param_group):
|
|
self.optim.add_param_group(param_group)
|