mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR addresses issue address #75666. Stateful communication hook now can be saved and reloaded to resume training. Current PR adds the functionality for PowerSGD communication hook and tests that communication hook can be properly saved and restored. PowerSGD implementation uses ``__slots__``, as a result introduced __getstate__ and __setstate__ methods are implemented to work with `__slots__` and not` __dict__`. `__getstate__ ` Returns: A dictionary that represents a ``PowerSGDState`` which will be pickled and saved. ``process_group`` is non-serializable and excluded from a returned state. `__setstate__` Takes a provided ``state`` and retrieves ``PowerSGDState``. ``process_group`` is set to default with a proper warning issued to a user. Unit test A hook-independent `_test_hook_pickling` is added with this PR, as well as `test_ddp_hook_pickling_powerSGD`, which tests `powerSGD`’s ability to be saved and reloaded. Currently, the test creates a ddp model with a provided hook, trains it for 10 epochs and saves model’s state and hook’s state. During reloading, unit test makes sure that a warning was logged (only one warning and the proper one). It then proceeds to check that reloaded hook and original hook are the same. Finally, it checks that a hook’s state was properly initialized: - it compares slot values (all, but 2: `process_group` and `rng`) for original and reloaded state - it checks that process group was set to a default group - it checks that a random state was restored properly with np.testing.assert_array_equal, because `rng` is an instance of `np.random.RandomState`, represented by a tuple. One of entries is of `ndarray dtype[uint32]` type and `np.testing.assert_array_equal` is used for assertion. Future To-Do: - Implement similar __getstate__ and __setstate__ for other stateful communication hooks - Add appropriate tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/79334 Approved by: https://github.com/rohan-varma, https://github.com/awgu
213 lines
7.5 KiB
ReStructuredText
213 lines
7.5 KiB
ReStructuredText
DDP Communication Hooks
|
|
=======================
|
|
|
|
DDP communication hook is a generic interface to control how to communicate
|
|
gradients across workers by overriding the vanilla allreduce in
|
|
`DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.>`_.
|
|
A few built-in communication hooks are provided,
|
|
and users can easily apply any of these hooks to optimize communication.
|
|
Besides, the hook interface can also support user-defined communication
|
|
strategies for more advanced use cases.
|
|
|
|
How to Use a Communication Hook?
|
|
--------------------------------
|
|
|
|
To use a communication hook, the user just needs to let the DDP model register
|
|
the hook before the training loop as below.
|
|
|
|
:func:`torch.nn.parallel.DistributedDataParallel.register_comm_hook`
|
|
|
|
What Does a Communication Hook Operate On?
|
|
------------------------------------------
|
|
|
|
Communication hook provides a flexible way to allreduce gradients.
|
|
Therefore, it mainly operates on the gradients on each replica before allreduce,
|
|
which are bucketized to increase the overlap between communication and computation.
|
|
Particularly, :class:`torch.distributed.GradBucket` represents a bucket of gradient tensors to be allreduced.
|
|
|
|
.. autoclass:: torch.distributed.GradBucket
|
|
|
|
.. autofunction:: torch.distributed.GradBucket.index
|
|
.. autofunction:: torch.distributed.GradBucket.buffer
|
|
.. autofunction:: torch.distributed.GradBucket.gradients
|
|
.. autofunction:: torch.distributed.GradBucket.is_last
|
|
.. autofunction:: torch.distributed.GradBucket.set_buffer
|
|
.. autofunction:: torch.distributed.GradBucket.parameters
|
|
|
|
Default Communication Hooks
|
|
---------------------------
|
|
|
|
Default communication hooks are simple **stateless** hooks, so the input state
|
|
in ``register_comm_hook`` is either a process group or ``None``.
|
|
The input ``bucket`` is a :class:`torch.distributed.GradBucket` object.
|
|
|
|
.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks
|
|
.. autofunction:: allreduce_hook
|
|
.. autofunction:: fp16_compress_hook
|
|
.. autofunction:: bf16_compress_hook
|
|
|
|
Additionally, a communication hook wraper is provided to support :meth:`~fp16_compress_hook` or :meth:`~bf16_compress_hook` as a wrapper,
|
|
which can be combined with other communication hooks.
|
|
|
|
.. autofunction:: fp16_compress_wrapper
|
|
.. autofunction:: bf16_compress_wrapper
|
|
|
|
PowerSGD Communication Hook
|
|
---------------------------
|
|
|
|
PowerSGD (`Vogels et al., NeurIPS 2019 <https://arxiv.org/abs/1905.13727>`_)
|
|
is a gradient compression algorithm, which can provide very high compression
|
|
rates and accelerate bandwidth-bound distributed training.
|
|
This algorithm needs to maintain both some hyperparameters and the internal
|
|
state. Therefore, PowerSGD communication hook is a **stateful** hook,
|
|
and the user needs to provide a state object defined as below.
|
|
|
|
PowerSGD State
|
|
^^^^^^^^^^^^^^^^
|
|
|
|
.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook
|
|
.. autoclass:: PowerSGDState
|
|
|
|
PowerSGD Hooks
|
|
^^^^^^^^^^^^^^^^
|
|
|
|
.. warning ::
|
|
PowerSGD typically requires extra memory of the same size as the model's
|
|
gradients to enable error feedback, which can compensate for biased
|
|
compressed communication and improve accuracy.
|
|
|
|
.. warning ::
|
|
PowerSGD hooks may conflict with `Apex automatic mixed precision package <https://github.com/NVIDIA/apex>`_.
|
|
Please use PyTorch `native automatic mixed precision package <https://pytorch.org/docs/stable/amp.html>`_
|
|
instead.
|
|
|
|
.. autofunction:: powerSGD_hook
|
|
.. autofunction:: batched_powerSGD_hook
|
|
|
|
Debugging Communication Hooks
|
|
-----------------------------
|
|
|
|
As the name implies, debugging communication hooks are **only** used for debugging and performance optimization purpose.
|
|
|
|
.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks
|
|
|
|
.. warning ::
|
|
Debugging communication hooks do not necessarily output the correct results.
|
|
|
|
.. autofunction:: noop_hook
|
|
|
|
Checkpointing of Communication Hooks
|
|
------------------------------------
|
|
|
|
.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook
|
|
|
|
A stateful communication hook can be saved as a part of model checkpointing to enable trainer restarts.
|
|
To make a hook serializable, ``__setstate__`` and ``__getstate__`` should be defined.
|
|
|
|
.. warning ::
|
|
``__getstate__`` should exclude non-serializable attributes from a returned dictionary.
|
|
|
|
.. warning ::
|
|
``__setstate__`` should properly initialize non-serializable attributes, excluded from a provided ``state``.
|
|
|
|
:class:`PowerSGDState` has ``__setstate__`` and ``__getstate__`` implemented and can be used as a reference.
|
|
|
|
.. class:: PowerSGDState
|
|
:noindex:
|
|
|
|
.. automethod:: PowerSGDState.__getstate__
|
|
.. automethod:: PowerSGDState.__setstate__
|
|
|
|
Here is a simple, end-to-end example of saving and reloading PowerSGD state and hook.
|
|
|
|
::
|
|
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
|
|
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD
|
|
|
|
class SimpleModel(nn.Module):
|
|
def __init__(self):
|
|
super(SimpleModel, self).__init__()
|
|
self.fc1 = nn.Linear(24,24)
|
|
self.relu = nn.ReLU()
|
|
self.fc2 = nn.Linear(24,12)
|
|
|
|
def forward(self, x):
|
|
return self.fc2(self.relu(self.fc1(x)))
|
|
|
|
def setup(rank, world_size):
|
|
os.environ['MASTER_ADDR'] = 'localhost'
|
|
os.environ['MASTER_PORT'] = '12355'
|
|
|
|
# initialize the process group
|
|
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
|
|
|
def cleanup():
|
|
dist.destroy_process_group()
|
|
|
|
def run_demo(demo_fn, world_size):
|
|
mp.spawn(
|
|
demo_fn,
|
|
args=(world_size,),
|
|
nprocs=world_size,
|
|
join=True)
|
|
|
|
def demo_serialization(rank, world_size):
|
|
setup(rank, world_size)
|
|
|
|
CHECKPOINT = tempfile.gettempdir() + "/checkpoint.pt"
|
|
|
|
model = SimpleModel().to(rank)
|
|
ddp_model = DistributedDataParallel(model, device_ids=[rank])
|
|
|
|
powersgd_hook = powerSGD.powerSGD_hook
|
|
powersgd_state = powerSGD.PowerSGDState(process_group=None)
|
|
|
|
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
|
|
ddp_model.register_comm_hook(powersgd_state, powersgd_hook)
|
|
|
|
state = {
|
|
'state_dict': ddp_model.state_dict(),
|
|
'comm_hook': hook,
|
|
'comm_hook_state': hook_state}
|
|
|
|
if rank == 0:
|
|
torch.save(state, CHECKPOINT)
|
|
|
|
dist.barrier()
|
|
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
|
|
checkpoint = torch.load(CHECKPOINT, map_location=map_location)
|
|
|
|
ddp_model.load_state_dict(checkpoint['state_dict'])
|
|
powersgd_hook = checkpoint['comm_hook']
|
|
powersgd_state = checkpoint['comm_hook_state']
|
|
|
|
ddp_model.register_comm_hook(powersgd_state, powersgd_hook)
|
|
|
|
if rank == 0:
|
|
os.remove(CHECKPOINT)
|
|
|
|
cleanup()
|
|
|
|
if __name__ == "__main__":
|
|
n_gpus = torch.cuda.device_count()
|
|
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
|
|
world_size = n_gpus
|
|
run_demo(demo_serialization, world_size)
|
|
|
|
Acknowledgements
|
|
----------------
|
|
|
|
Many thanks to PowerSGD paper author **Thijs Vogels** for the code review on
|
|
PowerSGD communication hook, as well as the
|
|
`comparison experiments <https://observablehq.com/@tvogels/powersgd-benchmark>`_,
|
|
which show that the performance of PowerSGD communication hook is on par with
|
|
the implementation in the original `paper <https://arxiv.org/abs/1905.13727>`_.
|