mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Pass _allow_empty_param_list into func opt ctor (#63163)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63163 Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D30284615 Pulled By: andwgu fbshipit-source-id: 4857f5b618ec5b007648737ab532ce605e5d70dc
This commit is contained in:
parent
bd81c9178a
commit
28f9e108b1
|
|
@ -967,7 +967,6 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||
lr=SGD_LR,
|
||||
momentum=SGD_MOMENTUM,
|
||||
weight_decay=SGD_WEIGHT_DECAY,
|
||||
_allow_empty_param_list=True
|
||||
)
|
||||
ddp_model_overlap.register_comm_hook(
|
||||
None,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
import collections
|
||||
import copy
|
||||
import enum
|
||||
import inspect
|
||||
import io
|
||||
import logging
|
||||
from itertools import chain
|
||||
|
|
@ -1375,7 +1376,16 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
assert len(param_groups) == 1, "Initializing the local " \
|
||||
"functional optimizer with more than one parameter group"
|
||||
params = param_groups[0]["params"]
|
||||
self.optim: Any = self._optim_constructor(params, **self._optim_defaults)
|
||||
# Try to pass `_allow_empty_param_list=True` to avoid erroring
|
||||
if "_allow_empty_param_list" in inspect.signature(self._optim_constructor).parameters:
|
||||
self.optim: Any = self._optim_constructor(params, **self._optim_defaults, _allow_empty_param_list=True)
|
||||
else:
|
||||
logging.warning(
|
||||
f"{self._optim_constructor} does not support the argument "
|
||||
"`_allow_empty_param_list`; ZeroRedundancyOptimizer may "
|
||||
"error due to an empty parameter list"
|
||||
)
|
||||
self.optim: Any = self._optim_constructor(params, **self._optim_defaults)
|
||||
|
||||
# Log information about the DDP and ZeRO bucketing
|
||||
if dist._get_debug_mode() != dist._DistributedDebugLevel.OFF:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user