mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183) Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732 Approved by: https://github.com/wz337, https://github.com/fduwjj
49 lines
1.6 KiB
ReStructuredText
49 lines
1.6 KiB
ReStructuredText
.. role:: hidden
|
|
:class: hidden-section
|
|
|
|
Tensor Parallelism - torch.distributed.tensor.parallel
|
|
======================================================
|
|
|
|
Tensor Parallelism(TP) is built on top of the PyTorch DistributedTensor
|
|
(`DTensor <https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/README.md>`__)
|
|
and provides different parallelism styles: Colwise and Rowwise Parallelism.
|
|
|
|
.. warning ::
|
|
Tensor Parallelism APIs are experimental and subject to change.
|
|
|
|
The entrypoint to parallelize your ``nn.Module`` using Tensor Parallelism is:
|
|
|
|
.. automodule:: torch.distributed.tensor.parallel
|
|
|
|
.. currentmodule:: torch.distributed.tensor.parallel
|
|
|
|
.. autofunction:: parallelize_module
|
|
|
|
Tensor Parallelism supports the following parallel styles:
|
|
|
|
.. autoclass:: torch.distributed.tensor.parallel.ColwiseParallel
|
|
:members:
|
|
:undoc-members:
|
|
|
|
.. autoclass:: torch.distributed.tensor.parallel.RowwiseParallel
|
|
:members:
|
|
:undoc-members:
|
|
|
|
To simply configure the nn.Module's inputs and outputs with DTensor layouts
|
|
and perform necessary layout redistributions, without distribute the module
|
|
parameters to DTensors, the following classes can be used in
|
|
the ``parallelize_plan``of ``parallelize_module``:
|
|
|
|
.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleInput
|
|
:members:
|
|
:undoc-members:
|
|
|
|
.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleOutput
|
|
:members:
|
|
:undoc-members:
|
|
|
|
|
|
For models like Transformer, we recommend users to use ``ColwiseParallel``
|
|
and ``RowwiseParallel`` together in the parallelize_plan for achieve the desired
|
|
sharding for the entire model (i.e. Attention and MLP).
|