[Reland] [Model Averaging] Simplify PostLocalSGD Optimizer API (#65197)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65197

1. The constructor accepts a local optimizer instance instead of the inputs of local optimizer constructor and the class type.
2. The parameters are read from local optimizer's param_groups instead of a separate input.

Proposal: https://github.com/pytorch/pytorch/issues/59699
ghstack-source-id: 138307226

Test Plan: buck test mode/dev-nosan //caffe2/test/distributed:distributed_nccl_spawn -- test_post_localSGD_optimizer_parity

Reviewed By: rohan-varma

Differential Revision: D31007439

fbshipit-source-id: bbb0526e6763ef76775b85088571506b3942c722
This commit is contained in:
Yi Wang 2021-09-17 10:00:13 -07:00 committed by Facebook GitHub Bot
parent 752a820230
commit c1415a0a72
3 changed files with 15 additions and 22 deletions

View File

@ -32,5 +32,5 @@ def average_parameters(
offset = 0
for p in params_it2:
p.data = flat_params[offset : offset + p.numel()].view_as(p)
p.data = flat_params[offset : offset + p.numel()].view_as(p).type_as(p)
offset += p.numel()

View File

@ -1,5 +1,3 @@
from typing import Any, Iterator, Type
import torch
import torch.distributed.algorithms.model_averaging.averagers as averagers
@ -11,11 +9,8 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer):
After the warm-up stage, it averages parameters periodically afer the local optimizer is applied.
Args:
params: All the parameters.
optimizer_class: The class of the local optimizer.
optim: The local optimizer.
averager: A model averager instance to run post-localSGD algorithm.
**defaults: A dict containing default values of optimization options,
which are forwarded to the local optimizer.
Example::
@ -37,11 +32,10 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer):
>>> # Create a post-localSGD optimizer that wraps a local optimizer.
>>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as
>>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``.
>>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
>>> opt = PostLocalSGDOptimizer(
>>> model.parameters(),
>>> optimizer_class=torch.optim.SGD,
>>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100),
>>> lr=0.01
>>> optim=local_optim,
>>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
>>> )
>>>
>>> # In the first 100 steps, DDP runs global gradient averaging at every step.
@ -59,13 +53,10 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer):
def __init__(
self,
params: Iterator[torch.nn.Parameter],
optimizer_class: Type[torch.optim.Optimizer],
averager: averagers.ModelAverager,
**defaults: Any,
optim: torch.optim.Optimizer,
averager: averagers.ModelAverager
):
self.params = list(params)
self.optim = optimizer_class(iter(self.params), **defaults)
self.optim = optim
self.param_groups = self.optim.param_groups
self.averager = averager
@ -87,7 +78,11 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer):
Performs a single optimization step (parameter update).
"""
self.optim.step()
self.averager.average_parameters(iter(self.params))
for param_group in self.param_groups:
for params in param_group["params"]:
if params.grad is None:
continue
self.averager.average_parameters(iter(params))
def zero_grad(self):
self.optim.zero_grad()

View File

@ -4625,12 +4625,10 @@ class DistributedTest:
gradient_as_bucket_view=grad_is_view,
)
post_localSGD_opt = post_localSGD_optimizer.PostLocalSGDOptimizer(
params=post_localSGD_net.parameters(),
optimizer_class=torch.optim.SGD,
optim=torch.optim.SGD(post_localSGD_net.parameters(), lr=learning_rate),
averager=averagers.PeriodicModelAverager(
period=period, warmup_steps=warmup_steps
),
lr=learning_rate,
)
)
input = torch.randn(dist.get_world_size() * 2, 2).cuda()