[Model Averaging] Fix docstring of PeriodicModelAverager (#62392)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62392

The constructor of `PeriodicModelAverager` does not need to accept parameters.
ghstack-source-id: 134626245

Test Plan: buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork --  test_periodic_model_averager

Reviewed By: rohan-varma

Differential Revision: D29986446

fbshipit-source-id: 6a8b709e4383a3c44b9e60955fbb067cd2868e76
This commit is contained in:
Yi Wang 2021-07-29 16:30:30 -07:00 committed by Facebook GitHub Bot
parent 8f519c5e07
commit 9fee176be3
2 changed files with 2 additions and 3 deletions

View File

@ -35,7 +35,6 @@ class PeriodicModelAverager(ModelAverager):
using the subgroups created by :meth:`~torch.distributed.new_subgroups`.
Args:
params (Iterator[torch.nn.Parameter]): The model parameters to be averaged.
period (int): The number of steps per model averaging.
Usually the period should be greater than ``1`` to reduce the communication cost.
Otherwise, only DDP needs to be used.
@ -68,7 +67,7 @@ class PeriodicModelAverager(ModelAverager):
>>> # In the first 100 steps, run global gradient averaging like normal DDP at every step.
>>> # After 100 steps, run model averaging every 4 steps.
>>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``.
>>> averager = averagers.PeriodicModelAverager(warmup_steps=100, period=4)
>>> averager = averagers.PeriodicModelAverager(period=4, warmup_steps=100)
>>> for step in range(0, 20):
>>> optimizer.zero_grad()
>>> loss = loss_fn(output, labels)

View File

@ -976,7 +976,7 @@ class DistributedTest:
expected_avg_tensor = torch.ones_like(param.data) * sum(range(world_size)) / world_size
period = 4
for warmup_steps in [12, 13, 14, 15]:
averager = averagers.PeriodicModelAverager(warmup_steps=warmup_steps, period=period)
averager = averagers.PeriodicModelAverager(period=period, warmup_steps=warmup_steps)
for step in range(0, 20):
# Reset the parameters at every step.
param.data = copy.deepcopy(tensor)