mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[ZeRO] (Reland) Add ctor support for multiple param groups (#72932)
Summary:
Reland of https://github.com/pytorch/pytorch/pull/72578.
**Overview**
Windows CI was failing due to the multi-rank single-GPU case (see [here](https://github.com/pytorch/pytorch/runs/5204906995?check_suite_focus=true)).
To address this, I
- added `common_distributed.skip_if_no_gpu` for `test_multiple_param_groups()` to ensure that each rank can safely call `to(self.device)` -- this targets the expected SPSD use case where each rank has its own GPU;
- moved `test_constructor()` back to `TestZeroRedundancyOptimizerSingleRank` to check that the multiple parameter group method for construction works even on a single rank.
**Test Plan**
- I checked both tests for CPU, 1 GPU, 2 GPUs, 4 GPUs, and 8 GPUs.
- I added the `ciflow/win` label to run the failing Windows CI test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72932
Reviewed By: rohan-varma
Differential Revision: D34281482
Pulled By: awgu
fbshipit-source-id: c4fe604ddd9d2c123c3071249741e6b8a6454b6e
(cherry picked from commit 6bea9bcc63)
This commit is contained in:
parent
1d404727c5
commit
c30659ffcc
|
|
@ -33,7 +33,7 @@ from torch.distributed.algorithms.join import Join, Joinable, JoinHook
|
|||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import SGD
|
||||
from torch.optim import SGD, AdamW
|
||||
from torch.testing._internal import common_distributed, common_utils
|
||||
from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_ASAN,
|
||||
|
|
@ -249,27 +249,54 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
|||
|
||||
def test_constructor(self):
|
||||
"""Check the robustness of the ZeroRedundancyOptimizer constructor by
|
||||
passing different values for `params`"""
|
||||
passing different values for the ``params`` argument."""
|
||||
self.dist_init(self.rank)
|
||||
|
||||
m = torch.nn.Linear(1, 1)
|
||||
# (input, expected error)
|
||||
inputs = [
|
||||
m = torch.nn.Sequential(
|
||||
torch.nn.Linear(5, 10),
|
||||
torch.nn.Linear(10, 10),
|
||||
torch.nn.Linear(10, 10),
|
||||
)
|
||||
|
||||
# Test various constructor inputs in the form: (input, expected error)
|
||||
ctor_inputs = [
|
||||
([], ValueError), # empty parameter list
|
||||
(torch.randn(1), TypeError), # non-iterable: `torch.Tensor`
|
||||
(1.2, TypeError), # non-iterable: `float`
|
||||
([{"params": m.parameters()}], TypeError), # iterable of dict
|
||||
(list(m.parameters()) + [42], TypeError), # iterable containing non-`torch.Tensor`
|
||||
([
|
||||
{"params": [l.weight for l in m]},
|
||||
{"params": [l.bias for l in m]},
|
||||
], None), # iterable of dict
|
||||
(list(m.parameters()) + [42], TypeError), # iterable containing invalid type
|
||||
(m.parameters(), None), # `params` as a generator
|
||||
(list(m.parameters()), None) # `params` as a list
|
||||
]
|
||||
|
||||
for input, error in inputs:
|
||||
if (error):
|
||||
for ctor_input, error in ctor_inputs:
|
||||
if error:
|
||||
with self.assertRaises(error):
|
||||
ZeroRedundancyOptimizer(input, optimizer_class=SGD, lr=0.1)
|
||||
ZeroRedundancyOptimizer(ctor_input, optimizer_class=SGD, lr=0.01)
|
||||
else:
|
||||
ZeroRedundancyOptimizer(input, optimizer_class=SGD, lr=0.1)
|
||||
ZeroRedundancyOptimizer(ctor_input, optimizer_class=SGD, lr=0.01)
|
||||
|
||||
# Test constructing with multiple parameter groups more thoroughly
|
||||
weight_decay = 0.01
|
||||
lr = 0.01
|
||||
betas = (0.9, 0.999)
|
||||
eps = 1e-8
|
||||
params = [
|
||||
{"params": [l.weight for l in m], "weight_decay": 0.},
|
||||
{"params": [l.bias for l in m], "weight_decay": weight_decay},
|
||||
]
|
||||
o = ZeroRedundancyOptimizer(
|
||||
params, optimizer_class=AdamW,
|
||||
lr=lr, betas=betas, eps=eps,
|
||||
)
|
||||
assert len(o.param_groups) == 2, \
|
||||
f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}"
|
||||
assert len(o.optim.param_groups) == 2, \
|
||||
"Expected 2 local optimizer param groups, but got " \
|
||||
f"{len(o.optim.param_groups)}"
|
||||
|
||||
def test_same_dense_param_type(self):
|
||||
"""Check that ZeroRedundancyOptimizer raises an exception if the input
|
||||
|
|
@ -459,7 +486,76 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||
all_trainable()
|
||||
some_trainable()
|
||||
|
||||
@common_distributed.skip_if_no_gpu
|
||||
def test_multiple_param_groups(self):
|
||||
"""
|
||||
Tests parity between constructing ZeRO with multiple parameter groups
|
||||
upfront versus adding parameter groups to ZeRO after construction
|
||||
versus a non-sharded optimizer.
|
||||
"""
|
||||
self.dist_init(self.rank)
|
||||
|
||||
model1 = torch.nn.Sequential(
|
||||
torch.nn.Linear(5, 10),
|
||||
torch.nn.Linear(10, 10),
|
||||
torch.nn.Linear(10, 5),
|
||||
)
|
||||
model2 = copy.deepcopy(model1)
|
||||
model3 = copy.deepcopy(model1)
|
||||
model1 = model1.to(self.device)
|
||||
model2 = model2.to(self.device)
|
||||
model3 = model3.to(self.device)
|
||||
|
||||
batch_size = 8
|
||||
num_iters = 3
|
||||
inputs = [
|
||||
torch.randn(batch_size, 5).to(self.device) for _ in range(num_iters)
|
||||
]
|
||||
wd = 0.01
|
||||
lr = 0.01
|
||||
# Construct `optim1` with both parameter groups upfront
|
||||
optim1 = ZeroRedundancyOptimizer(
|
||||
[
|
||||
{"params": [l.weight for l in model1], "weight_decay": 0.},
|
||||
{"params": [l.bias for l in model1], "weight_decay": wd},
|
||||
],
|
||||
optimizer_class=AdamW, lr=lr,
|
||||
)
|
||||
# Construct `optim2` by adding the second parameter after
|
||||
optim2 = ZeroRedundancyOptimizer(
|
||||
[l.weight for l in model2],
|
||||
optimizer_class=AdamW, lr=lr, weight_decay=0.,
|
||||
)
|
||||
optim2.add_param_group(
|
||||
{"params": [l.bias for l in model2], "weight_decay": wd}
|
||||
)
|
||||
# Construct `optim3` as a non-sharded optimizer
|
||||
optim3 = AdamW(
|
||||
[
|
||||
{"params": [l.weight for l in model3], "weight_decay": 0.},
|
||||
{"params": [l.bias for l in model3], "weight_decay": wd},
|
||||
], lr=lr,
|
||||
)
|
||||
|
||||
# Check parity over a few iterations
|
||||
for iter in range(num_iters):
|
||||
for model, optim in (
|
||||
(model1, optim1), (model2, optim2), (model3, optim3),
|
||||
):
|
||||
optim.zero_grad()
|
||||
out = model(inputs[iter])
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
for layer1, layer2, layer3 in zip(model1, model2, model3):
|
||||
assert torch.allclose(layer1.weight, layer2.weight)
|
||||
assert torch.allclose(layer1.weight, layer3.weight)
|
||||
assert torch.allclose(layer1.bias, layer2.bias)
|
||||
assert torch.allclose(layer1.bias, layer3.bias)
|
||||
|
||||
@common_distributed.skip_if_lt_x_gpu(2)
|
||||
@common_distributed.skip_if_rocm
|
||||
def test_collect_shards(self):
|
||||
""" Check the state consolidation mechanism, and the state dict exposed by ZeroRedundancyOptimizer"""
|
||||
self.dist_init(self.rank)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,16 @@ import inspect
|
|||
import io
|
||||
import logging
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Type
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -287,7 +296,8 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
|
||||
Arguments:
|
||||
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
|
||||
giving all parameters, which will be sharded across ranks.
|
||||
or :class:`dict` s giving all parameters, which will be sharded
|
||||
across ranks.
|
||||
|
||||
Keyword Args:
|
||||
optimizer_class (:class:`torch.nn.Optimizer`): the class of the local
|
||||
|
|
@ -364,7 +374,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
**defaults: Any,
|
||||
):
|
||||
# Perform type and assumption checks on the input parameters
|
||||
self._verify_and_init_params(params)
|
||||
params = self._verify_and_init_params(params)
|
||||
self._verify_same_dense_param_type()
|
||||
|
||||
# NOTE: The parent constructor uses `add_param_group()` which is
|
||||
|
|
@ -373,7 +383,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
# between the parent and child.
|
||||
self.initialized = False
|
||||
|
||||
Optimizer.__init__(self, self._all_params, defaults)
|
||||
Optimizer.__init__(self, params, defaults)
|
||||
Joinable.__init__(self)
|
||||
# Now, all parameters are held in both `self._all_params` and
|
||||
# `self.param_groups`
|
||||
|
|
@ -1289,36 +1299,60 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
offset = offset_next
|
||||
bucket_assignment.tensor = tensor
|
||||
|
||||
def _verify_and_init_params(self, params: Any) -> None:
|
||||
def _verify_and_init_params(
|
||||
self, params: Any,
|
||||
) -> Union[List[torch.Tensor], List[dict]]:
|
||||
r"""
|
||||
Verifies the type of ``params`` and initializes ``self._all_params``
|
||||
if ``params`` is valid.
|
||||
as a :class:`list` of all parameters if ``params`` is valid.
|
||||
|
||||
While :class:`optim.Optimizer <torch.optim.Optimizer>` allows
|
||||
``params`` to be an iterable of :class:`dict` s, currently
|
||||
``ZeroRedundancyOptimizer`` strictly requires ``params`` to be an
|
||||
iterable of :class:`torch.Tensor` s.
|
||||
Arguments:
|
||||
params (Any): Candidate parameter list or parameter groups to
|
||||
verify.
|
||||
|
||||
Raises:
|
||||
TypeError: ``params`` has an invalid type.
|
||||
ValueError: ``params`` is empty.
|
||||
|
||||
Returns:
|
||||
The persistent form of ``params`` to be passed into the parent
|
||||
:class:`Optimizer` constructor -- i.e. returns ``params`` as a
|
||||
:class:`list` to ensure that it can be iterated over again.
|
||||
"""
|
||||
if isinstance(params, torch.Tensor):
|
||||
raise TypeError("params argument should be an iterable of "
|
||||
raise TypeError("`params` argument should be an iterable of "
|
||||
f"Tensors, but got {torch.typename(params)}")
|
||||
try:
|
||||
self._all_params = list(params)
|
||||
all_params = list(params)
|
||||
except TypeError:
|
||||
raise TypeError("params argument should be an iterable of "
|
||||
raise TypeError("`params` argument should be an iterable of "
|
||||
f"Tensors, but got {torch.typename(params)}")
|
||||
if len(self._all_params) == 0:
|
||||
if len(all_params) == 0:
|
||||
raise ValueError("ZeroRedundancyOptimizer got an empty parameter "
|
||||
"list")
|
||||
for param in self._all_params:
|
||||
if not isinstance(param, torch.Tensor):
|
||||
raise TypeError("params argument should be an iterable of "
|
||||
"Tensors, but got an iterable containing "
|
||||
f"{torch.typename(param)}")
|
||||
all_tensors = True
|
||||
all_dicts = True
|
||||
for param in all_params:
|
||||
all_tensors &= isinstance(param, torch.Tensor)
|
||||
all_dicts &= isinstance(param, dict)
|
||||
if not all_tensors and not all_dicts:
|
||||
raise TypeError("`params` argument should be an iterable of "
|
||||
"Tensors or dicts")
|
||||
# Ensure that `self._all_params` contains a list of all parameters
|
||||
if all_tensors:
|
||||
self._all_params = all_params
|
||||
elif all_dicts:
|
||||
self._all_params = []
|
||||
# `all_params` contains parameter groups (not parameters)
|
||||
for param_group in all_params:
|
||||
if "params" not in param_group:
|
||||
raise ValueError(
|
||||
"Each parameter group passed-in via `params` must "
|
||||
"have a 'params' key mapping to the parameters in "
|
||||
"the group"
|
||||
)
|
||||
self._all_params.extend(param_group["params"])
|
||||
return all_params
|
||||
|
||||
def _verify_same_dense_param_type(self) -> None:
|
||||
r"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user