[ZeRO] (Reland) Add ctor support for multiple param groups (#72932)

Summary:
Reland of https://github.com/pytorch/pytorch/pull/72578.

**Overview**
Windows CI was failing due to the multi-rank single-GPU case (see [here](https://github.com/pytorch/pytorch/runs/5204906995?check_suite_focus=true)).

To address this, I
- added `common_distributed.skip_if_no_gpu` for `test_multiple_param_groups()` to ensure that each rank can safely call `to(self.device)` -- this targets the expected SPSD use case where each rank has its own GPU;
- moved `test_constructor()` back to `TestZeroRedundancyOptimizerSingleRank` to check that the multiple parameter group method for construction works even on a single rank.

**Test Plan**
- I checked both tests for CPU, 1 GPU, 2 GPUs, 4 GPUs, and 8 GPUs.
- I added the `ciflow/win` label to run the failing Windows CI test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72932

Reviewed By: rohan-varma

Differential Revision: D34281482

Pulled By: awgu

fbshipit-source-id: c4fe604ddd9d2c123c3071249741e6b8a6454b6e
(cherry picked from commit 6bea9bcc63)
This commit is contained in:
Andrew Gu 2022-02-22 08:07:32 -08:00 committed by PyTorch MergeBot
parent 1d404727c5
commit c30659ffcc
2 changed files with 160 additions and 30 deletions

View File

@ -33,7 +33,7 @@ from torch.distributed.algorithms.join import Join, Joinable, JoinHook
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD
from torch.optim import SGD, AdamW
from torch.testing._internal import common_distributed, common_utils
from torch.testing._internal.common_utils import (
TEST_WITH_ASAN,
@ -249,27 +249,54 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
def test_constructor(self):
"""Check the robustness of the ZeroRedundancyOptimizer constructor by
passing different values for `params`"""
passing different values for the ``params`` argument."""
self.dist_init(self.rank)
m = torch.nn.Linear(1, 1)
# (input, expected error)
inputs = [
m = torch.nn.Sequential(
torch.nn.Linear(5, 10),
torch.nn.Linear(10, 10),
torch.nn.Linear(10, 10),
)
# Test various constructor inputs in the form: (input, expected error)
ctor_inputs = [
([], ValueError), # empty parameter list
(torch.randn(1), TypeError), # non-iterable: `torch.Tensor`
(1.2, TypeError), # non-iterable: `float`
([{"params": m.parameters()}], TypeError), # iterable of dict
(list(m.parameters()) + [42], TypeError), # iterable containing non-`torch.Tensor`
([
{"params": [l.weight for l in m]},
{"params": [l.bias for l in m]},
], None), # iterable of dict
(list(m.parameters()) + [42], TypeError), # iterable containing invalid type
(m.parameters(), None), # `params` as a generator
(list(m.parameters()), None) # `params` as a list
]
for input, error in inputs:
if (error):
for ctor_input, error in ctor_inputs:
if error:
with self.assertRaises(error):
ZeroRedundancyOptimizer(input, optimizer_class=SGD, lr=0.1)
ZeroRedundancyOptimizer(ctor_input, optimizer_class=SGD, lr=0.01)
else:
ZeroRedundancyOptimizer(input, optimizer_class=SGD, lr=0.1)
ZeroRedundancyOptimizer(ctor_input, optimizer_class=SGD, lr=0.01)
# Test constructing with multiple parameter groups more thoroughly
weight_decay = 0.01
lr = 0.01
betas = (0.9, 0.999)
eps = 1e-8
params = [
{"params": [l.weight for l in m], "weight_decay": 0.},
{"params": [l.bias for l in m], "weight_decay": weight_decay},
]
o = ZeroRedundancyOptimizer(
params, optimizer_class=AdamW,
lr=lr, betas=betas, eps=eps,
)
assert len(o.param_groups) == 2, \
f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}"
assert len(o.optim.param_groups) == 2, \
"Expected 2 local optimizer param groups, but got " \
f"{len(o.optim.param_groups)}"
def test_same_dense_param_type(self):
"""Check that ZeroRedundancyOptimizer raises an exception if the input
@ -459,7 +486,76 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
all_trainable()
some_trainable()
@common_distributed.skip_if_no_gpu
def test_multiple_param_groups(self):
"""
Tests parity between constructing ZeRO with multiple parameter groups
upfront versus adding parameter groups to ZeRO after construction
versus a non-sharded optimizer.
"""
self.dist_init(self.rank)
model1 = torch.nn.Sequential(
torch.nn.Linear(5, 10),
torch.nn.Linear(10, 10),
torch.nn.Linear(10, 5),
)
model2 = copy.deepcopy(model1)
model3 = copy.deepcopy(model1)
model1 = model1.to(self.device)
model2 = model2.to(self.device)
model3 = model3.to(self.device)
batch_size = 8
num_iters = 3
inputs = [
torch.randn(batch_size, 5).to(self.device) for _ in range(num_iters)
]
wd = 0.01
lr = 0.01
# Construct `optim1` with both parameter groups upfront
optim1 = ZeroRedundancyOptimizer(
[
{"params": [l.weight for l in model1], "weight_decay": 0.},
{"params": [l.bias for l in model1], "weight_decay": wd},
],
optimizer_class=AdamW, lr=lr,
)
# Construct `optim2` by adding the second parameter after
optim2 = ZeroRedundancyOptimizer(
[l.weight for l in model2],
optimizer_class=AdamW, lr=lr, weight_decay=0.,
)
optim2.add_param_group(
{"params": [l.bias for l in model2], "weight_decay": wd}
)
# Construct `optim3` as a non-sharded optimizer
optim3 = AdamW(
[
{"params": [l.weight for l in model3], "weight_decay": 0.},
{"params": [l.bias for l in model3], "weight_decay": wd},
], lr=lr,
)
# Check parity over a few iterations
for iter in range(num_iters):
for model, optim in (
(model1, optim1), (model2, optim2), (model3, optim3),
):
optim.zero_grad()
out = model(inputs[iter])
loss = out.sum()
loss.backward()
optim.step()
for layer1, layer2, layer3 in zip(model1, model2, model3):
assert torch.allclose(layer1.weight, layer2.weight)
assert torch.allclose(layer1.weight, layer3.weight)
assert torch.allclose(layer1.bias, layer2.bias)
assert torch.allclose(layer1.bias, layer3.bias)
@common_distributed.skip_if_lt_x_gpu(2)
@common_distributed.skip_if_rocm
def test_collect_shards(self):
""" Check the state consolidation mechanism, and the state dict exposed by ZeroRedundancyOptimizer"""
self.dist_init(self.rank)

View File

@ -10,7 +10,16 @@ import inspect
import io
import logging
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Set, Type
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Set,
Type,
Union,
)
import torch
import torch.distributed as dist
@ -287,7 +296,8 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
Arguments:
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
giving all parameters, which will be sharded across ranks.
or :class:`dict` s giving all parameters, which will be sharded
across ranks.
Keyword Args:
optimizer_class (:class:`torch.nn.Optimizer`): the class of the local
@ -364,7 +374,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
**defaults: Any,
):
# Perform type and assumption checks on the input parameters
self._verify_and_init_params(params)
params = self._verify_and_init_params(params)
self._verify_same_dense_param_type()
# NOTE: The parent constructor uses `add_param_group()` which is
@ -373,7 +383,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
# between the parent and child.
self.initialized = False
Optimizer.__init__(self, self._all_params, defaults)
Optimizer.__init__(self, params, defaults)
Joinable.__init__(self)
# Now, all parameters are held in both `self._all_params` and
# `self.param_groups`
@ -1289,36 +1299,60 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
offset = offset_next
bucket_assignment.tensor = tensor
def _verify_and_init_params(self, params: Any) -> None:
def _verify_and_init_params(
self, params: Any,
) -> Union[List[torch.Tensor], List[dict]]:
r"""
Verifies the type of ``params`` and initializes ``self._all_params``
if ``params`` is valid.
as a :class:`list` of all parameters if ``params`` is valid.
While :class:`optim.Optimizer <torch.optim.Optimizer>` allows
``params`` to be an iterable of :class:`dict` s, currently
``ZeroRedundancyOptimizer`` strictly requires ``params`` to be an
iterable of :class:`torch.Tensor` s.
Arguments:
params (Any): Candidate parameter list or parameter groups to
verify.
Raises:
TypeError: ``params`` has an invalid type.
ValueError: ``params`` is empty.
Returns:
The persistent form of ``params`` to be passed into the parent
:class:`Optimizer` constructor -- i.e. returns ``params`` as a
:class:`list` to ensure that it can be iterated over again.
"""
if isinstance(params, torch.Tensor):
raise TypeError("params argument should be an iterable of "
raise TypeError("`params` argument should be an iterable of "
f"Tensors, but got {torch.typename(params)}")
try:
self._all_params = list(params)
all_params = list(params)
except TypeError:
raise TypeError("params argument should be an iterable of "
raise TypeError("`params` argument should be an iterable of "
f"Tensors, but got {torch.typename(params)}")
if len(self._all_params) == 0:
if len(all_params) == 0:
raise ValueError("ZeroRedundancyOptimizer got an empty parameter "
"list")
for param in self._all_params:
if not isinstance(param, torch.Tensor):
raise TypeError("params argument should be an iterable of "
"Tensors, but got an iterable containing "
f"{torch.typename(param)}")
all_tensors = True
all_dicts = True
for param in all_params:
all_tensors &= isinstance(param, torch.Tensor)
all_dicts &= isinstance(param, dict)
if not all_tensors and not all_dicts:
raise TypeError("`params` argument should be an iterable of "
"Tensors or dicts")
# Ensure that `self._all_params` contains a list of all parameters
if all_tensors:
self._all_params = all_params
elif all_dicts:
self._all_params = []
# `all_params` contains parameter groups (not parameters)
for param_group in all_params:
if "params" not in param_group:
raise ValueError(
"Each parameter group passed-in via `params` must "
"have a 'params' key mapping to the parameters in "
"the group"
)
self._all_params.extend(param_group["params"])
return all_params
def _verify_same_dense_param_type(self) -> None:
r"""