Introduce the torchrun entrypoint (#64049)

Summary:
This PR introduces a new `torchrun` entrypoint that simply "points" to `python -m torch.distributed.run`. It is shorter and less error-prone to type and gives a nicer syntax than a rather cryptic `python -m ...` command line. Along with the new entrypoint the documentation is also updated and places where `torch.distributed.run` are mentioned are replaced with `torchrun`.

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse agolynski SciPioneer H-Huang mrzzd cbalioglu gcramer23

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64049

Reviewed By: cbalioglu

Differential Revision: D30584041

Pulled By: kiukchung

fbshipit-source-id: d99db3b5d12e7bf9676bab70e680d4b88031ae2d
This commit is contained in:
Can Balioglu 2021-08-26 20:16:10 -07:00 committed by Facebook GitHub Bot
parent 510d2ece81
commit 65e6194aeb
6 changed files with 45 additions and 44 deletions

View File

@ -5,7 +5,7 @@ To launch a **fault-tolerant** job, run the following on all nodes.
.. code-block:: bash
python -m torch.distributed.run
torchrun
--nnodes=NUM_NODES
--nproc_per_node=TRAINERS_PER_NODE
--rdzv_id=JOB_ID
@ -19,7 +19,7 @@ and at most ``MAX_SIZE`` nodes.
.. code-block:: bash
python -m torch.distributed.run
torchrun
--nnodes=MIN_SIZE:MAX_SIZE
--nproc_per_node=TRAINERS_PER_NODE
--rdzv_id=JOB_ID
@ -46,6 +46,6 @@ ideally you should pick a node that has a high bandwidth.
Learn more about writing your distributed training script
`here <train_script.html>`_.
If ``torch.distributed.run`` does not meet your requirements you may use our
APIs directly for more powerful customization. Start by taking a look at the
`elastic agent <agent.html>`_ API).
If ``torchrun`` does not meet your requirements you may use our APIs directly
for more powerful customization. Start by taking a look at the
`elastic agent <agent.html>`_ API.

View File

@ -1,6 +1,6 @@
.. _launcher-api:
torch.distributed.run (Elastic Launch)
torchrun (Elastic Launch)
======================================
.. automodule:: torch.distributed.run

View File

@ -4,7 +4,7 @@ Train script
-------------
If your train script works with ``torch.distributed.launch`` it will continue
working with ``torch.distributed.run`` with these differences:
working with ``torchrun`` with these differences:
1. No need to manually pass ``RANK``, ``WORLD_SIZE``,
``MASTER_ADDR``, and ``MASTER_PORT``.

View File

@ -854,6 +854,7 @@ def configure_extension_build():
'console_scripts': [
'convert-caffe2-to-onnx = caffe2.python.onnx.bin.conversion:caffe2_to_onnx',
'convert-onnx-to-caffe2 = caffe2.python.onnx.bin.conversion:onnx_to_caffe2',
'torchrun = torch.distributed.run:main',
]
}

View File

@ -4,7 +4,7 @@ training processes on each of the training nodes.
.. warning::
This module is going to be deprecated in favor of :ref:`torch.distributed.run <launcher-api>`.
This module is going to be deprecated in favor of :ref:`torchrun <launcher-api>`.
The utility can be used for single-node distributed training, in which one or
more processes per node will be spawned. The utility can be used for either
@ -177,8 +177,8 @@ def launch(args):
def main(args=None):
warnings.warn(
"The module torch.distributed.launch is deprecated\n"
"and will be removed in future. Use torch.distributed.run.\n"
"Note that --use_env is set by default in torch.distributed.run.\n"
"and will be removed in future. Use torchrun.\n"
"Note that --use_env is set by default in torchrun.\n"
"If your script expects `--local_rank` argument to be set, please\n"
"change it to read from `os.environ['LOCAL_RANK']` instead. See \n"
"https://pytorch.org/docs/stable/distributed.html#launch-utility for \n"

View File

@ -7,7 +7,7 @@
# LICENSE file in the root directory of this source tree.
"""
``torch.distributed.run`` provides a superset of the functionality as ``torch.distributed.launch``
``torchrun`` provides a superset of the functionality as ``torch.distributed.launch``
with the following additional functionalities:
1. Worker failures are handled gracefully by restarting all workers.
@ -18,33 +18,33 @@ with the following additional functionalities:
Transitioning from torch.distributed.launch to torch.distributed.run
Transitioning from torch.distributed.launch to torchrun
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``torch.distributed.run`` supports the same arguments as ``torch.distributed.launch`` **except**
``torchrun`` supports the same arguments as ``torch.distributed.launch`` **except**
for ``--use_env`` which is now deprecated. To migrate from ``torch.distributed.launch``
to ``torch.distributed.run`` follow these steps:
to ``torchrun`` follow these steps:
1. If your training script is already reading ``local_rank`` from the ``LOCAL_RANK`` environment variable.
Then you need simply omit the ``--use_env`` flag, e.g.:
+--------------------------------------------------------------------+------------------------------------------------------+
| ``torch.distributed.launch`` | ``torch.distributed.run`` |
+====================================================================+======================================================+
+--------------------------------------------------------------------+--------------------------------------------+
| ``torch.distributed.launch`` | ``torchrun`` |
+====================================================================+============================================+
| | |
| .. code-block:: shell-session | .. code-block:: shell-session |
| | |
| $ python -m torch.distributed.launch --use_env train_script.py | $ python -m torch.distributed.run train_script.py |
| $ python -m torch.distributed.launch --use_env train_script.py | $ torchrun train_script.py |
| | |
+--------------------------------------------------------------------+------------------------------------------------------+
+--------------------------------------------------------------------+--------------------------------------------+
2. If your training script reads local rank from a ``--local_rank`` cmd argument.
Change your training script to read from the ``LOCAL_RANK`` environment variable as
demonstrated by the following code snippet:
+-------------------------------------------------------+----------------------------------------------------+
| ``torch.distributed.launch`` | ``torch.distributed.run`` |
| ``torch.distributed.launch`` | ``torchrun`` |
+=======================================================+====================================================+
| | |
| .. code-block:: python | .. code-block:: python |
@ -59,12 +59,12 @@ to ``torch.distributed.run`` follow these steps:
| | |
+-------------------------------------------------------+----------------------------------------------------+
The aformentioned changes suffice to migrate from ``torch.distributed.launch`` to ``torch.distributed.run``.
To take advantage of new features such as elasticity, fault-tolerance, and error reporting of ``torch.distributed.run``
The aformentioned changes suffice to migrate from ``torch.distributed.launch`` to ``torchrun``.
To take advantage of new features such as elasticity, fault-tolerance, and error reporting of ``torchrun``
please refer to:
* :ref:`elastic_train_script` for more information on authoring training scripts that are ``torch.distributed.run`` compliant.
* the rest of this page for more information on the features of ``torch.distributed.run``.
* :ref:`elastic_train_script` for more information on authoring training scripts that are ``torchrun`` compliant.
* the rest of this page for more information on the features of ``torchrun``.
@ -75,7 +75,7 @@ Usage
::
>>> python -m torch.distributed.run
>>> torchrun
--standalone
--nnodes=1
--nproc_per_node=$NUM_TRAINERS
@ -85,7 +85,7 @@ Usage
::
>>> python -m torch.distributed.run
>>> torchrun
--nnodes=$NUM_NODES
--nproc_per_node=$NUM_TRAINERS
--rdzv_id=$JOB_ID
@ -104,7 +104,7 @@ node in your training cluster, but ideally you should pick a node that has a hig
::
>>> python -m torch.distributed.run
>>> torchrun
--nnodes=1:4
--nproc_per_node=$NUM_TRAINERS
--rdzv_id=$JOB_ID
@ -186,7 +186,7 @@ The following environment variables are made available to you in your script:
of the worker is specified in the ``WorkerSpec``.
5. ``LOCAL_WORLD_SIZE`` - The local world size (e.g. number of workers running locally); equals to
``--nproc_per_node`` specified on ``torch.distributed.run``.
``--nproc_per_node`` specified on ``torchrun``.
6. ``WORLD_SIZE`` - The world size (total number of workers in the job).