pytorch/docs/source/distributed.tensor.parallel.rst
fduwjj 41e3189222 [PT-D][Tensor parallelism] Add documentations for TP (#94421)
This is far from completed and we will definitely polish it down the road.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94421
Approved by: https://github.com/wz337
2023-02-09 02:31:06 +00:00

61 lines
2.1 KiB
ReStructuredText

.. role:: hidden
:class: hidden-section
Tensor Parallelism - torch.distributed.tensor.parallel
======================================================
We built Tensor Parallelism(TP) on top of DistributedTensor(DTensor) and
provide several Parallelism styles: Rowwise, Colwise and Pairwise Parallelism.
.. warning ::
Tensor Parallelism is experimental and subject to change.
The entrypoint to parallelize your module and using tensor parallelism is:
.. automodule:: torch.distributed.tensor.parallel
.. currentmodule:: torch.distributed.tensor.parallel
.. autofunction:: parallelize_module
Tensor Parallelism supports the following parallel styles:
.. autoclass:: torch.distributed.tensor.parallel.style.RowwiseParallel
:members:
.. autoclass:: torch.distributed.tensor.parallel.style.ColwiseParallel
:members:
.. autoclass:: torch.distributed.tensor.parallel.style.PairwiseParallel
:members:
Because we use DTensor within Tensor Parallelism, we need to specify the
input and output placement of the module with DTensors so it can expectedly
interacts with the module before and after. The followings are functions
used for input/output preparation:
.. currentmodule:: torch.distributed.tensor.parallel.style
.. autofunction:: make_input_replicate_1d
.. autofunction:: make_input_shard_1d
.. autofunction:: make_input_shard_1d_last_dim
.. autofunction:: make_output_replicate_1d
.. autofunction:: make_output_tensor
.. autofunction:: make_output_shard_1d
Currently, there are some constraints which makes it hard for the `nn.MultiheadAttention`
module to work out of box for Tensor Parallelism, so we built this multihead_attention
module for Tensor Parallelism users. Also, in ``parallelize_module``, we automatically
swap ``nn.MultiheadAttention`` to this custom module when specifying ``PairwiseParallel``.
.. autoclass:: torch.distributed.tensor.parallel.multihead_attention_tp.TensorParallelMultiheadAttention
:members:
We also enabled 2D parallelism to integrate with ``FullyShardedDataParallel``.
Users just need to call the following API explicitly:
.. currentmodule:: torch.distributed.tensor.parallel.fsdp
.. autofunction:: is_available