Pass _allow_empty_param_list into func opt ctor (#63163)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63163

Test Plan: Imported from OSS

Reviewed By: mrshenli

Differential Revision: D30284615

Pulled By: andwgu

fbshipit-source-id: 4857f5b618ec5b007648737ab532ce605e5d70dc
This commit is contained in:
Andrew Gu 2021-08-13 08:19:23 -07:00 committed by Facebook GitHub Bot
parent bd81c9178a
commit 28f9e108b1
2 changed files with 11 additions and 2 deletions

View File

@ -967,7 +967,6 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
lr=SGD_LR,
momentum=SGD_MOMENTUM,
weight_decay=SGD_WEIGHT_DECAY,
_allow_empty_param_list=True
)
ddp_model_overlap.register_comm_hook(
None,

View File

@ -6,6 +6,7 @@
import collections
import copy
import enum
import inspect
import io
import logging
from itertools import chain
@ -1375,7 +1376,16 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
assert len(param_groups) == 1, "Initializing the local " \
"functional optimizer with more than one parameter group"
params = param_groups[0]["params"]
self.optim: Any = self._optim_constructor(params, **self._optim_defaults)
# Try to pass `_allow_empty_param_list=True` to avoid erroring
if "_allow_empty_param_list" in inspect.signature(self._optim_constructor).parameters:
self.optim: Any = self._optim_constructor(params, **self._optim_defaults, _allow_empty_param_list=True)
else:
logging.warning(
f"{self._optim_constructor} does not support the argument "
"`_allow_empty_param_list`; ZeroRedundancyOptimizer may "
"error due to an empty parameter list"
)
self.optim: Any = self._optim_constructor(params, **self._optim_defaults)
# Log information about the DDP and ZeRO bucketing
if dist._get_debug_mode() != dist._DistributedDebugLevel.OFF: