mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61294 Pull Request resolved: https://github.com/pytorch/pytorch/pull/60925 * Make `torch.distributed.launch` restarts to 0 * Remove unnecessary `-use_env` warning, move `-use_env` warnings * Move `-use_env` warnings to `torch.distributed.launch` * Make default log level WARNING * Add new doc section around transitioning to `torch.distributed.run` * Make `torch.distributed.launch` not use error-propagation * Set default events handler to `null` that does not print events to console * Add reference from `torch.distributed.launch` to `torch.distributed.run` * Set correct preexec function that sends SIGTERM to child processes when parent dies Issues resolved: https://github.com/pytorch/pytorch/issues/60716 https://github.com/pytorch/pytorch/issues/60754 Test Plan: sandcastle python -m torch.distributed.launch --nproc_per_node 2 main.py -> uses 0 restarts python -m torch.distributed.run --nproc_per_node 2 main.py -> uses default for torchelastic, 0 restarts python -m torch.distributed.launch --nproc_per_node=4 --use_env --no_python main.py -> produces error python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py -> no warning python -m torch.distributed.launch --nproc_per_node=4 --no_python main.py ->warning Output of running torch.distributed.launch without --use_env: $path/torch/distributed/launch.py:173: FutureWarning: The module torch.distributed.launch is deprecated and will be removed in future. Use torch.distributed.run. Note that --use_env is set by default in torch.distributed.run. If your script expects `--local_rank` argument to be set, please change it to read from `os.environ('LOCAL_RANK')` instead. New section: {F628923078} {F628974089} Reviewed By: cbalioglu Differential Revision: D29559553 fbshipit-source-id: 03ed9ba638bf154354e1530ffc964688431edf6b
51 lines
1.9 KiB
ReStructuredText
51 lines
1.9 KiB
ReStructuredText
.. _elastic_train_script:
|
|
|
|
Train script
|
|
-------------
|
|
|
|
If your train script works with ``torch.distributed.launch`` it will continue
|
|
working with ``torch.distributed.run`` with these differences:
|
|
|
|
1. No need to manually pass ``RANK``, ``WORLD_SIZE``,
|
|
``MASTER_ADDR``, and ``MASTER_PORT``.
|
|
|
|
2. ``rdzv_backend`` and ``rdzv_endpoint`` can be provided. For most users
|
|
this will be set to ``c10d`` (see `rendezvous <rendezvous.html>`_). The default
|
|
``rdzv_backend`` creates a non-elastic rendezvous where ``rdzv_endpoint`` holds
|
|
the master address.
|
|
|
|
3. Make sure you have a ``load_checkpoint(path)`` and
|
|
``save_checkpoint(path)`` logic in your script. When any number of
|
|
workers fail we restart all the workers with the same program
|
|
arguments so you will lose progress up to the most recent checkpoint
|
|
(see `elastic launch <distributed.html>`_).
|
|
|
|
4. ``use_env`` flag has been removed. If you were parsing local rank by parsing
|
|
the ``--local_rank`` option, you need to get the local rank from the
|
|
environment variable ``LOCAL_RANK`` (e.g. ``int(os.environ["LOCAL_RANK"])``).
|
|
|
|
Below is an expository example of a training script that checkpoints on each
|
|
epoch, hence the worst-case progress lost on failure is one full epoch worth
|
|
of training.
|
|
|
|
.. code-block:: python
|
|
|
|
def main():
|
|
args = parse_args(sys.argv[1:])
|
|
state = load_checkpoint(args.checkpoint_path)
|
|
initialize(state)
|
|
|
|
# torch.distributed.run ensures that this will work
|
|
# by exporting all the env vars needed to initialize the process group
|
|
torch.distributed.init_process_group(backend=args.backend)
|
|
|
|
for i in range(state.epoch, state.total_num_epochs)
|
|
for batch in iter(state.dataset)
|
|
train(batch, state.model)
|
|
|
|
state.epoch += 1
|
|
save_checkpoint(state)
|
|
|
|
For concrete examples of torchelastic-compliant train scripts, visit
|
|
our `examples <examples.html>`_ page.
|