pytorch/torch/distributed/optim/zero_redundancy_optimizer.py
Benjamin Lefaudeux bbe18e3527 [ZeroRedundancyOptimizer] Elastic and pytorch compatible checkpoints (#50956)
Summary:
- Makes it possible to use non-sharded optimizer checkpoints (as long as the model/param groups are the same, of course)
- Makes it possible to save with a given world size, and load with another world size
- Use Torch Distributed built-in broadcast object list instead of a ad-hoc version

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50956

Reviewed By: malfet

Differential Revision: D26113953

Pulled By: blefaudeux

fbshipit-source-id: 030bfeee2c34c2d987590d45dc8efe05515f2e5c
2021-02-02 14:32:13 -08:00

518 lines
23 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict, deque
import copy
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Type, Deque
import torch
import torch.distributed as dist
from torch.nn import Parameter
from torch._six import container_abcs
from torch.optim import Optimizer
__all__ = ["ZeroRedundancyOptimizer"]
# Credits: classy_vision/generic/distributed_util.py
def _recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any:
"""
Recursively searches lists, tuples, dicts and copies tensors to device if
possible. Non-tensor values are passed as-is in the result.
.. note: These are all copies, so if there are two objects that reference
the same object, then after this call, there will be two different objects
referenced on the device.
"""
if isinstance(value, torch.Tensor):
return value.to(device, non_blocking=non_blocking)
if isinstance(value, (list, tuple)):
values = [_recursive_copy_to_device(val, non_blocking=non_blocking, device=device) for val in value]
return values if isinstance(value, list) else tuple(values)
if isinstance(value, container_abcs.Mapping):
return {
key: _recursive_copy_to_device(val, non_blocking=non_blocking, device=device) for key, val in value.items()
}
return value
def _get_global_rank(group: Any, rank: int) -> int:
return rank if group is dist.group.WORLD else dist.distributed_c10d._get_global_rank(group, rank) # type: ignore
class ZeroRedundancyOptimizer(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
optimizer and shards its state as described by ZeRO_.
::
opt = ZeroRedundancyOptimizer(params, optim=torch.optim.Adam, lr=0.01)
We use a greedy algorithm to pack a number of parameters at each rank.
Each parameter belongs to a single rank and is not divided among ranks.
The partition is arbitrary and does not correspond to the information flow for instance.
After each rank completed their parameter update, they broadcast
the new version of the parameters to all other ranks to synchronize
the parameters for next round forward/backward computation.
Arguments:
params (list of tensors):
parameters to be optimized
Keyword Args:
optim (torch.nn.Optimizer): optimizer to shard
group (group): torch.distributed group (default: group.WORLD)
bucket_cap (int): the size of the buffer used to batch the small parameter tensors,
in number of elements (default 16M)
**default: all trailing arguments will be forwarded to the requested optimizer
.. warning: ZeroRedundancyOptimizer is experimental and subject to change.
.. _ZeRO: https://arxiv.org/abs/1910.02054
Example::
>>> from torch.distributed.optim import ZeroRedundancyOptimizer
>>> from torch import optim
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>>
>>> # Problem statement
>>> model = myAwesomeModel().to(rank)
>>> model = DDP(model, device_ids=[rank])
>>> dataloader = mySuperFastDataloader()
>>> loss_ln = myVeryRelevantLoss()
>>>
>>> # optimizer specific arguments e.g. LR, momentum, etc...
>>> base_optimizer_arguments = { "lr": 1e-4, **smart_options}
>>> optimizer = ZeroRedundancyOptimizer(
>>> params=model.parameters(),
>>> optim=optim.AdamW
>>> **base_optimizer_arguments)
>>>
>>> # Any relevant training loop, almost nothing specific to ZeroRedundancyOptimizer
>>> reference_rank = 0 # This rank will be able to checkpoint
>>> model.train()
>>> for e in range(epochs):
>>> for (data, target) in dataloader:
>>> data, target = data.to(rank), target.to(rank)
>>>
>>> # Train
>>> model.zero_grad()
>>> outputs = model(data)
>>> loss = loss_fn(outputs, target)
>>> loss.backward()
>>> optimizer.step()
>>>
>>> ...
>>> # WARNING - Checkpointing requires has some specificities:
>>> # - all ranks: consolidate needed before one rank can save
>>> optimizer.consolidate_state_dict(reference_rank)
>>>
>>> # - reference rank: the state can be saved
>>> if rank == reference_rank:
>>> checkpoint = optimizer.state_dict()
>>> ...
>>>
>>> # - all ranks: load a checkpoint, can be a checkpoint from normal PyTorch
>>> optimizer.load_state_dict(a_normal_checkpointed_state)
"""
def __init__(
self,
params,
optim: Type[Optimizer],
group: Optional[Any] = None,
bucket_cap_kb: int = 2 ** 24,
**default: Any,
):
# Hold all the model params in the root .param_groups
# NOTE: the default constructor uses `add_param_group` which is partially overloaded here
# we introduce the `initialized` flag for be able to dissociate the behaviour of
# `add_param_group` in between super() and ZeroRedundancyOptimizer
self.initialized = False
super().__init__(params, default)
# Partition information. lazy evaluation, computed if requested
self._per_device_params: OrderedDict[
torch.device, List[List[Parameter]]
] = OrderedDict() # device, rank, params
self._param_rank: Dict[torch.Tensor, int] = {}
self._partition_parameters: List[List[Dict]] = []
# Build the wrapped optimizer, responsible for a shard of the params
self.group = group if group is not None else dist.group.WORLD
self.world_size = dist.get_world_size(self.group)
self.rank = dist.get_rank(self.group)
self.global_rank = _get_global_rank(self.group, self.rank)
self.optim = optim(self.partition_parameters()[self.rank], **default)
# - Sync local and global param_groups keys
for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for k, v in local_group.items():
if k != "params":
global_group[k] = v
# Optional consolidated optimizer state
self._all_states: List[Dict[str, Any]] = []
# Current default device is set by the parameters allocated to this rank
self._device = list(self.per_device_params.keys())[0]
self.buckets: Dict[torch.device, List[torch.Tensor]] = {}
self.bucket_max_size = bucket_cap_kb
self.should_bucket_param: List[bool] = []
self.work_handles: Deque[Any] = deque()
self._setup_bucket_strategy()
self.initialized = True
def add_param_group(self, param_group: dict) -> None:
"""Add a param group to the :class:`Optimizer` s `param_groups`.
This can be useful when fine tuning a pre-trained network as frozen layers can be made
trainable and added to the :class:`Optimizer` as training progresses.
Arguments:
param_group (dict): Specifies what Tensors should be optimized along with group
specific optimization options
.. warning: This handles updating the shards on all partitions, but needs to be called on all ranks.
Calling this on a subset of the ranks will cause the training to hang, because communication primitives
are called depending on the managed parameters, and expect all the ranks to participate.
"""
super().add_param_group(param_group)
if self.initialized:
# Force a re-partitioning
self._partition_parameters.clear()
self._per_device_params.clear()
self._param_rank.clear()
param_groups = self.partition_parameters()[self.rank]
if len(param_groups) == len(self.optim.param_groups) + 1:
self.optim.add_param_group(param_groups[-1])
# Update the bucketing strategy accordingly
self._setup_bucket_strategy()
def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
"""Update the consolidated state_dict list, one per rank.
.. warning: This needs to be called on all replicas"""
# Sync lr and other attributes in case its been updated
self._update_param_groups()
empty_messenger = torch.tensor([0], dtype=torch.uint8, device=self._device)
# Pull the sharded state from all the other replicas
# Store all the states in order, rank by rank
# NOTE: In practice, `broadcast` is used, which is wasteful (gather would have been appropriate)
# compatibility issues with some backends make the use of broadcast mandatory for now.
# a possible follow up would be to move all sharded state management to RPC RRef
self._all_states = []
for rank in range(self.world_size):
global_rank = _get_global_rank(self.group, rank)
# This rank collects the whole state
if self.rank == recipient_rank:
if rank == self.rank:
self._all_states.append(
_recursive_copy_to_device(
self.optim.state_dict(), non_blocking=True, device=torch.device("cpu")
)
)
else:
# Fetch the optim state from the other replicas
replica_state = [empty_messenger]
dist.broadcast_object_list(object_list=replica_state, src=global_rank, group=self.group)
self._all_states.append(
_recursive_copy_to_device(replica_state[0], non_blocking=True, device=torch.device("cpu"))
)
else:
# Acknowledge broadcasts, and send this rank's shard when needed
# Default to CPU space to gain some memory headroom
if rank == self.rank:
# Send the state to the reference replica
dist.broadcast_object_list(
object_list=[self.optim.state_dict()], src=self.global_rank, group=self.group
)
elif rank != recipient_rank:
# Discard this tensor/rank, broadcast was being use for compatibility reasons
dist.broadcast_object_list(object_list=[empty_messenger], src=global_rank, group=self.group)
def partition_parameters(self) -> List[List[Dict]]:
"""Partitions parameters across distributed data parallel ranks.
Returns: a list of ``param_groups`` (which is a list of dict) where each
element of the list contains the param_groups for a rank. Element 0
corresponds to rank 0, etc. We need all the ranks for the broadcast
inside ``step()``.
"""
if len(self._partition_parameters) == 0:
self._partition_parameters = [list() for _ in range(self.world_size)]
sizes = [0] * self.world_size
for param_group in self.param_groups:
param_lists: List[List] = [list() for _ in range(self.world_size)]
for param in param_group["params"]:
# Add this param to rank with smallest size.
rank = sizes.index(min(sizes))
param_lists[rank].append(param)
sizes[rank] += param.numel()
for rank, params in enumerate(param_lists):
param_group_rank = copy.copy(param_group)
param_group_rank["params"] = params
self._partition_parameters[rank].append(param_group_rank)
return self._partition_parameters
@property
def per_device_params(self) -> Dict[torch.device, List[List[Parameter]]]:
"""Sorted list of all the params, first per device then per rank.
Within a list params are sorted per number of elements to allow for an easy bucketing.
"""
if len(self._per_device_params) == 0:
# Go through all params, log them per device
# The ordering is important here, needs to be the same on all ranks
# So that ulterior broadcast calls are matching
for param_group in self.param_groups:
for param in param_group["params"]:
device = param.device
if self._per_device_params.get(device) is None:
self._per_device_params[device] = [[] for _ in range(self.world_size)]
self._per_device_params[device][self.param_to_rank[param]] += [param]
# Sort param_lists by size
for k in self._per_device_params.keys():
for r in self._per_device_params[k]:
r.sort(key=lambda x: x.numel())
return self._per_device_params
@property
def param_to_rank(self) -> Dict[torch.Tensor, int]:
"""Look up table to match a given param with a data parallel rank"""
if len(self._param_rank) == 0:
for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups:
for param in param_group["params"]:
self._param_rank[param] = rank
return self._param_rank
def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
"""Performs a single optimization step (parameter update).
Arguments:
closure (callable): A closure that reevaluates the model and
returns the loss. Optional for most optimizers.
Returns:
optional loss, depends on the underlying optimizer
.. note: Any extra parameter is passed to the base optimizer as-is"""
# Sync oss param_groups attributes in case they've been updated by a scheduler.
self._update_param_groups()
# Run the optimizer step on this shard only:
if closure is not None:
loss = self.optim.step(closure=closure, **kwargs) # type: ignore
else:
loss = self.optim.step(**kwargs)
# Sync all the updated shards in between the ranks
self._broadcast_params()
# Sync hypothethical new results from the wrapped optimizer to the exposed param_groups
self._update_param_groups(local_to_global=True)
return loss
def state_dict(self) -> Dict[str, Any]:
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
sharded properties are not exposed. It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
.. warning:
If the state has not been consolidated, this returns a shard's worth, not the global state.
.. warning:
Returning the global state is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
"""
if len(self._all_states) == 0:
raise RuntimeError(
"Optimizer state has not been consolidated on this rank. \
Please call `consolidate_state_dict()` on all ranks beforehand if you meant to save the global state"
)
# Unify the shard states and the state that pytorch would expect, given the model.
# Indexation needs several redirections, since each shard only knows a limited scope of the model
# - get the pytorch compliant parameter indexing
state_dict = super().state_dict()
# - get an id map which links the parameter id to the index in the reference state
global_id_map = {id(p): i for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))}
# - go through the per-shard states, which are all indexed locally
for rank, s in enumerate(self._all_states):
# -- match the local indexing and the global partition, update the corresponding saved state globally
for local_pg, global_pg in zip(s["param_groups"], self.partition_parameters()[rank]):
# Go through the parameters indexed locally, pick up the global corresponding param
local_index_to_param_id = {
i_param: id(global_pg["params"][i]) for i, i_param in enumerate(local_pg["params"])
}
for local_param_index in local_pg["params"]:
# Update the state, if any
if local_param_index in s["state"].keys():
global_id = global_id_map[local_index_to_param_id[local_param_index]]
state_dict["state"][global_id] = s["state"][local_param_index]
return state_dict
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Restore the global parameter groups as well as the shard.
Arguments:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`
.. note: all the parameters present in the loaded state dict will be exposed by this optimizer
through the `param_groups` attribute, but the actual parameter update computations will be spread
in between the ranks. The actual per-rank optimizer will thus only effectively work on a subset
of the parameters.
"""
# Param index to param map
index_to_param = {i: p for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))}
# Prune the state_dict from the states which this rank does not own, then normal base load
other_state = []
for i_param, key in enumerate(state_dict["state"].keys()):
# Check that this rank owns this param, if not remove from the state
if self.param_to_rank[index_to_param[i_param]] != self.rank:
state_dict["state"][key] = None
super().load_state_dict(state_dict)
# Set the sharded optimizer state.
# Keep the original type (not respected by PyTorch which casts to the model type)
for k, (_, v) in enumerate(state_dict["state"].items()):
if k in index_to_param:
param = index_to_param[k]
# Only add this state to the sharded optimizer if it owns this param
for pg in self.optim.param_groups:
if id(param) in map(lambda x: id(x), pg["params"]):
self.optim.state[param] = _recursive_copy_to_device(v, non_blocking=True, device=param.device)
# Update the param_group keys
for new_pg, pg in zip(state_dict["param_groups"], self.param_groups):
for key in new_pg.keys():
if key != "params":
pg[key] = new_pg[key]
# Sync with the optimizer param groups
self._update_param_groups(local_to_global=False)
def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device"""
i_param = 0
for (
device,
device_params,
) in self.per_device_params.items(): # all the params on this device (inc all ranks)
buckets = self.buckets[device]
# Bucket and issue all the async calls
for (src_rank, params), bucket in zip(enumerate(device_params), buckets):
global_src_rank = _get_global_rank(self.group, src_rank)
# Direct broadcasts only
for param in params:
if not self.should_bucket_param[i_param]:
self.work_handles.append(
dist.broadcast(tensor=param.data, src=global_src_rank, group=self.group, async_op=True)
)
i_param += 1
# Bucket broadcasts
self.work_handles.append(
dist.broadcast(tensor=bucket, src=global_src_rank, group=self.group, async_op=True)
)
# Consume all async calls
while len(self.work_handles) > 0:
work_handle = self.work_handles.popleft()
work_handle.wait()
def _update_param_groups(self, local_to_global: bool = False) -> None:
"""Sync learning rate and other optimizer attributes (needed to support schedulers).
If the global param groups have been altered, and we want to make sure
that the wrapped optimizer uses the up to date version. Conversely if the wrapped
optimizer has new keys, we expose them through the global param groups
"""
for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
# Sync everything but the parameters
for k in filter(lambda x: x != "params", local_group.keys()):
if local_to_global:
global_group[k] = local_group[k]
elif k in global_group.keys():
local_group[k] = global_group[k]
def _setup_bucket_strategy(self) -> None:
"""Tag parameters to either bucket them or broadcast/reduce them directly.
The parameters are ordered (smallest first), the bucket will hold the smallest elements,
the remaining ones will be directly sent.
Generating the partition once and for all allows us to save some time at runtime, and to know when all the
network requests have been issued. The parameters which are part of a bucket become tensor views.
"""
# Allocate one buffer per rank and per device to group the small parameters
for device, per_device in self.per_device_params.items():
self.buckets[device] = [
torch.zeros(self.bucket_max_size, dtype=per_device[0][0].dtype, device=device)
for _ in range(len(per_device))
]
# Pack the smallest elements in a bucket, depending on their owner shard-wise
for device, per_rank_params in self.per_device_params.items():
for dst_rank, params in enumerate(per_rank_params):
offset = 0
for param in params:
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
if param.requires_grad and (offset + param.numel()) < self.bucket_max_size:
self.should_bucket_param.append(True)
# This parameter becomes a view of the bucket
offset_next = offset + param.numel()
self.buckets[device][dst_rank][offset:offset_next] = param.data.flatten()
param.data = self.buckets[device][dst_rank][offset:offset_next].view_as(param.data)
offset = offset_next
else:
self.should_bucket_param.append(False)
# Resize the bucket to remove lost space in the end
self.buckets[device][dst_rank].resize_(offset)