diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py index de8ea511b63..67c274575d4 100644 --- a/test/distributed/optim/test_zero_redundancy_optimizer.py +++ b/test/distributed/optim/test_zero_redundancy_optimizer.py @@ -33,7 +33,7 @@ from torch.distributed.algorithms.join import Join, Joinable, JoinHook from torch.distributed.optim import ZeroRedundancyOptimizer from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import SGD +from torch.optim import SGD, AdamW from torch.testing._internal import common_distributed, common_utils from torch.testing._internal.common_utils import ( TEST_WITH_ASAN, @@ -249,27 +249,54 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer): def test_constructor(self): """Check the robustness of the ZeroRedundancyOptimizer constructor by - passing different values for `params`""" + passing different values for the ``params`` argument.""" self.dist_init(self.rank) - m = torch.nn.Linear(1, 1) - # (input, expected error) - inputs = [ + m = torch.nn.Sequential( + torch.nn.Linear(5, 10), + torch.nn.Linear(10, 10), + torch.nn.Linear(10, 10), + ) + + # Test various constructor inputs in the form: (input, expected error) + ctor_inputs = [ ([], ValueError), # empty parameter list (torch.randn(1), TypeError), # non-iterable: `torch.Tensor` (1.2, TypeError), # non-iterable: `float` - ([{"params": m.parameters()}], TypeError), # iterable of dict - (list(m.parameters()) + [42], TypeError), # iterable containing non-`torch.Tensor` + ([ + {"params": [l.weight for l in m]}, + {"params": [l.bias for l in m]}, + ], None), # iterable of dict + (list(m.parameters()) + [42], TypeError), # iterable containing invalid type (m.parameters(), None), # `params` as a generator (list(m.parameters()), None) # `params` as a list ] - for input, error in inputs: - if (error): + for ctor_input, error in ctor_inputs: + if error: with self.assertRaises(error): - ZeroRedundancyOptimizer(input, optimizer_class=SGD, lr=0.1) + ZeroRedundancyOptimizer(ctor_input, optimizer_class=SGD, lr=0.01) else: - ZeroRedundancyOptimizer(input, optimizer_class=SGD, lr=0.1) + ZeroRedundancyOptimizer(ctor_input, optimizer_class=SGD, lr=0.01) + + # Test constructing with multiple parameter groups more thoroughly + weight_decay = 0.01 + lr = 0.01 + betas = (0.9, 0.999) + eps = 1e-8 + params = [ + {"params": [l.weight for l in m], "weight_decay": 0.}, + {"params": [l.bias for l in m], "weight_decay": weight_decay}, + ] + o = ZeroRedundancyOptimizer( + params, optimizer_class=AdamW, + lr=lr, betas=betas, eps=eps, + ) + assert len(o.param_groups) == 2, \ + f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}" + assert len(o.optim.param_groups) == 2, \ + "Expected 2 local optimizer param groups, but got " \ + f"{len(o.optim.param_groups)}" def test_same_dense_param_type(self): """Check that ZeroRedundancyOptimizer raises an exception if the input @@ -459,7 +486,76 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer): all_trainable() some_trainable() + @common_distributed.skip_if_no_gpu + def test_multiple_param_groups(self): + """ + Tests parity between constructing ZeRO with multiple parameter groups + upfront versus adding parameter groups to ZeRO after construction + versus a non-sharded optimizer. + """ + self.dist_init(self.rank) + + model1 = torch.nn.Sequential( + torch.nn.Linear(5, 10), + torch.nn.Linear(10, 10), + torch.nn.Linear(10, 5), + ) + model2 = copy.deepcopy(model1) + model3 = copy.deepcopy(model1) + model1 = model1.to(self.device) + model2 = model2.to(self.device) + model3 = model3.to(self.device) + + batch_size = 8 + num_iters = 3 + inputs = [ + torch.randn(batch_size, 5).to(self.device) for _ in range(num_iters) + ] + wd = 0.01 + lr = 0.01 + # Construct `optim1` with both parameter groups upfront + optim1 = ZeroRedundancyOptimizer( + [ + {"params": [l.weight for l in model1], "weight_decay": 0.}, + {"params": [l.bias for l in model1], "weight_decay": wd}, + ], + optimizer_class=AdamW, lr=lr, + ) + # Construct `optim2` by adding the second parameter after + optim2 = ZeroRedundancyOptimizer( + [l.weight for l in model2], + optimizer_class=AdamW, lr=lr, weight_decay=0., + ) + optim2.add_param_group( + {"params": [l.bias for l in model2], "weight_decay": wd} + ) + # Construct `optim3` as a non-sharded optimizer + optim3 = AdamW( + [ + {"params": [l.weight for l in model3], "weight_decay": 0.}, + {"params": [l.bias for l in model3], "weight_decay": wd}, + ], lr=lr, + ) + + # Check parity over a few iterations + for iter in range(num_iters): + for model, optim in ( + (model1, optim1), (model2, optim2), (model3, optim3), + ): + optim.zero_grad() + out = model(inputs[iter]) + loss = out.sum() + loss.backward() + optim.step() + + for layer1, layer2, layer3 in zip(model1, model2, model3): + assert torch.allclose(layer1.weight, layer2.weight) + assert torch.allclose(layer1.weight, layer3.weight) + assert torch.allclose(layer1.bias, layer2.bias) + assert torch.allclose(layer1.bias, layer3.bias) + @common_distributed.skip_if_lt_x_gpu(2) + @common_distributed.skip_if_rocm def test_collect_shards(self): """ Check the state consolidation mechanism, and the state dict exposed by ZeroRedundancyOptimizer""" self.dist_init(self.rank) diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py index 70779eac3f1..a87bfdaf5fd 100644 --- a/torch/distributed/optim/zero_redundancy_optimizer.py +++ b/torch/distributed/optim/zero_redundancy_optimizer.py @@ -10,7 +10,16 @@ import inspect import io import logging from itertools import chain -from typing import Any, Callable, Dict, List, Optional, Set, Type +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Set, + Type, + Union, +) import torch import torch.distributed as dist @@ -287,7 +296,8 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): Arguments: params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s - giving all parameters, which will be sharded across ranks. + or :class:`dict` s giving all parameters, which will be sharded + across ranks. Keyword Args: optimizer_class (:class:`torch.nn.Optimizer`): the class of the local @@ -364,7 +374,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): **defaults: Any, ): # Perform type and assumption checks on the input parameters - self._verify_and_init_params(params) + params = self._verify_and_init_params(params) self._verify_same_dense_param_type() # NOTE: The parent constructor uses `add_param_group()` which is @@ -373,7 +383,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): # between the parent and child. self.initialized = False - Optimizer.__init__(self, self._all_params, defaults) + Optimizer.__init__(self, params, defaults) Joinable.__init__(self) # Now, all parameters are held in both `self._all_params` and # `self.param_groups` @@ -1289,36 +1299,60 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): offset = offset_next bucket_assignment.tensor = tensor - def _verify_and_init_params(self, params: Any) -> None: + def _verify_and_init_params( + self, params: Any, + ) -> Union[List[torch.Tensor], List[dict]]: r""" Verifies the type of ``params`` and initializes ``self._all_params`` - if ``params`` is valid. + as a :class:`list` of all parameters if ``params`` is valid. - While :class:`optim.Optimizer ` allows - ``params`` to be an iterable of :class:`dict` s, currently - ``ZeroRedundancyOptimizer`` strictly requires ``params`` to be an - iterable of :class:`torch.Tensor` s. + Arguments: + params (Any): Candidate parameter list or parameter groups to + verify. Raises: TypeError: ``params`` has an invalid type. ValueError: ``params`` is empty. + + Returns: + The persistent form of ``params`` to be passed into the parent + :class:`Optimizer` constructor -- i.e. returns ``params`` as a + :class:`list` to ensure that it can be iterated over again. """ if isinstance(params, torch.Tensor): - raise TypeError("params argument should be an iterable of " + raise TypeError("`params` argument should be an iterable of " f"Tensors, but got {torch.typename(params)}") try: - self._all_params = list(params) + all_params = list(params) except TypeError: - raise TypeError("params argument should be an iterable of " + raise TypeError("`params` argument should be an iterable of " f"Tensors, but got {torch.typename(params)}") - if len(self._all_params) == 0: + if len(all_params) == 0: raise ValueError("ZeroRedundancyOptimizer got an empty parameter " "list") - for param in self._all_params: - if not isinstance(param, torch.Tensor): - raise TypeError("params argument should be an iterable of " - "Tensors, but got an iterable containing " - f"{torch.typename(param)}") + all_tensors = True + all_dicts = True + for param in all_params: + all_tensors &= isinstance(param, torch.Tensor) + all_dicts &= isinstance(param, dict) + if not all_tensors and not all_dicts: + raise TypeError("`params` argument should be an iterable of " + "Tensors or dicts") + # Ensure that `self._all_params` contains a list of all parameters + if all_tensors: + self._all_params = all_params + elif all_dicts: + self._all_params = [] + # `all_params` contains parameter groups (not parameters) + for param_group in all_params: + if "params" not in param_group: + raise ValueError( + "Each parameter group passed-in via `params` must " + "have a 'params' key mapping to the parameters in " + "the group" + ) + self._all_params.extend(param_group["params"]) + return all_params def _verify_same_dense_param_type(self) -> None: r"""