mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Reland] [Model Averaging] Simplify PostLocalSGD Optimizer API (#65197)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65197 1. The constructor accepts a local optimizer instance instead of the inputs of local optimizer constructor and the class type. 2. The parameters are read from local optimizer's param_groups instead of a separate input. Proposal: https://github.com/pytorch/pytorch/issues/59699 ghstack-source-id: 138307226 Test Plan: buck test mode/dev-nosan //caffe2/test/distributed:distributed_nccl_spawn -- test_post_localSGD_optimizer_parity Reviewed By: rohan-varma Differential Revision: D31007439 fbshipit-source-id: bbb0526e6763ef76775b85088571506b3942c722
This commit is contained in:
parent
752a820230
commit
c1415a0a72
|
|
@ -32,5 +32,5 @@ def average_parameters(
|
|||
|
||||
offset = 0
|
||||
for p in params_it2:
|
||||
p.data = flat_params[offset : offset + p.numel()].view_as(p)
|
||||
p.data = flat_params[offset : offset + p.numel()].view_as(p).type_as(p)
|
||||
offset += p.numel()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from typing import Any, Iterator, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed.algorithms.model_averaging.averagers as averagers
|
||||
|
||||
|
|
@ -11,11 +9,8 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer):
|
|||
After the warm-up stage, it averages parameters periodically afer the local optimizer is applied.
|
||||
|
||||
Args:
|
||||
params: All the parameters.
|
||||
optimizer_class: The class of the local optimizer.
|
||||
optim: The local optimizer.
|
||||
averager: A model averager instance to run post-localSGD algorithm.
|
||||
**defaults: A dict containing default values of optimization options,
|
||||
which are forwarded to the local optimizer.
|
||||
|
||||
Example::
|
||||
|
||||
|
|
@ -37,11 +32,10 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer):
|
|||
>>> # Create a post-localSGD optimizer that wraps a local optimizer.
|
||||
>>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as
|
||||
>>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``.
|
||||
>>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
|
||||
>>> opt = PostLocalSGDOptimizer(
|
||||
>>> model.parameters(),
|
||||
>>> optimizer_class=torch.optim.SGD,
|
||||
>>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100),
|
||||
>>> lr=0.01
|
||||
>>> optim=local_optim,
|
||||
>>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
|
||||
>>> )
|
||||
>>>
|
||||
>>> # In the first 100 steps, DDP runs global gradient averaging at every step.
|
||||
|
|
@ -59,13 +53,10 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
params: Iterator[torch.nn.Parameter],
|
||||
optimizer_class: Type[torch.optim.Optimizer],
|
||||
averager: averagers.ModelAverager,
|
||||
**defaults: Any,
|
||||
optim: torch.optim.Optimizer,
|
||||
averager: averagers.ModelAverager
|
||||
):
|
||||
self.params = list(params)
|
||||
self.optim = optimizer_class(iter(self.params), **defaults)
|
||||
self.optim = optim
|
||||
self.param_groups = self.optim.param_groups
|
||||
self.averager = averager
|
||||
|
||||
|
|
@ -87,7 +78,11 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer):
|
|||
Performs a single optimization step (parameter update).
|
||||
"""
|
||||
self.optim.step()
|
||||
self.averager.average_parameters(iter(self.params))
|
||||
for param_group in self.param_groups:
|
||||
for params in param_group["params"]:
|
||||
if params.grad is None:
|
||||
continue
|
||||
self.averager.average_parameters(iter(params))
|
||||
|
||||
def zero_grad(self):
|
||||
self.optim.zero_grad()
|
||||
|
|
|
|||
|
|
@ -4625,12 +4625,10 @@ class DistributedTest:
|
|||
gradient_as_bucket_view=grad_is_view,
|
||||
)
|
||||
post_localSGD_opt = post_localSGD_optimizer.PostLocalSGDOptimizer(
|
||||
params=post_localSGD_net.parameters(),
|
||||
optimizer_class=torch.optim.SGD,
|
||||
optim=torch.optim.SGD(post_localSGD_net.parameters(), lr=learning_rate),
|
||||
averager=averagers.PeriodicModelAverager(
|
||||
period=period, warmup_steps=warmup_steps
|
||||
),
|
||||
lr=learning_rate,
|
||||
)
|
||||
)
|
||||
|
||||
input = torch.randn(dist.get_world_size() * 2, 2).cuda()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user