pytorch/torch/distributed/tensor/parallel/api.py
Wanchao Liang 28925902fa [TP] fully rewrite Tensor Parallel APIs (#114732)
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs
supposed to be a very thin-wrapper to DTensor APIs, but the current
implementation got too messy and buggy. It's really hard to debug what
went wrong when using it. It's crucially important for advanced users or
developers to understand the API and its implementation easily without
going through all different types of functions and utils, so that
they could trust what happen under the hood.

In particular this PR:

* Make ParallelStyle to be a real contract API for parallelize_module to
  take, each concrete ParallelStyle only needs to implement `apply` to
apply the sharding to nn.Module, remove all non-necessary fields. This
also enable easier ParallelStyle authoring going forward.
* Keep the ColwiseParallel and RowwiseParallel public interface, but
  refactor them in a way that makes the parameter sharding, inputs and
outputs handling lives within the style itself, so that it's easy to
understand how Linear/Embedding layers are sharded and how the inputs/outputs
transformations are performed.
* remove all those private _prepare_input/_prepare_output_fn fields for
  both ColwiseParallel/RowwiseParallel. Since we throw deprecation
messages in nightly for a while and TP is on prototype release, the
fields are also private, it should be safe to remove them
* Refactor the recently landed PrepareModuleInput/Output style, change
  output_layouts to desired_input/output_layouts, group
  the function inside the style itself, no default arguments for these
two styles and user need to specify them to think about the sharding
layouts. Fixed bugs about not handling
`use_local_output` flag.
* Make default arguments be None instead of Placement object, this is
  standard python practice to not have custom object instance as default
argument
* Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel
  style, all prepare input/output functions) as we throw deprecation
 msgs for a while, and in the progress of removing all of them from the tests.
* throw deprecation warning for `tp_mesh_dim` as we recomemnd use device
  mesh slice/indexing instead of manually specify mesh dim
* Rewrite all documentations for every ParallelStyle and make the
  documentation more clear about what each style is doing

TODOs:
* Rewrite TP tests to adjust for the changes we have in this PR
* add more tests to guard the bug fixes

Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732
Approved by: https://github.com/wz337, https://github.com/fduwjj
2023-12-02 08:18:12 +00:00

120 lines
4.9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Dict, Union
import torch
import torch.distributed._tensor.random as random
import torch.nn as nn
from torch.distributed._tensor import (
DeviceMesh,
)
from torch.distributed._tensor.random import (
is_rng_supported_mesh,
TensorParallelRNGTracker,
)
from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh, _validate_tp_mesh_dim, _deprecate_warnings
from torch.distributed.tensor.parallel.style import (
ParallelStyle,
)
__all__ = [
"parallelize_module",
]
def parallelize_module( # type: ignore[return]
module: nn.Module,
device_mesh: DeviceMesh,
parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]],
tp_mesh_dim: int = 0,
) -> nn.Module:
"""
Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.
We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains
:class:`ParallelStyle`, which indicates how user wants the module or sub_module
to be parallelized.
User can also specify different parallel style per module fully qualified name (FQN).
The API supports 2D parallelism natively by accepting an n-dimension device_mesh
and users just need to specify the dimension where we perform tensor parallelism on.
Args:
module (:class:`nn.Module`):
Module to be parallelized.
device_mesh (:class:`DeviceMesh`):
Object which describes the mesh topology
of devices for the DTensor.
parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]):
The plan used to parallelize the module. It can be either a
:class:`ParallelStyle` object which contains how
we prepare input/output for Tensor Parallelism or it can be a
dict of module FQN and its corresponding :class:`ParallelStyle` object.
tp_mesh_dim (int):
The dimension of ``device_mesh`` where we perform
Tensor Parallelism on.
Return:
A :class:`nn.Module` object parallelized.
Example::
>>> # xdoctest: +SKIP("distributed")
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>>
>>> # Define the module.
>>> m = Model(...)
>>> m = parallelize_module(m, ColwiseParallel())
>>>
.. warning::
Currently, there are some constraints which makes it hard for complicated modules
like ``MultiheadAttention`` to work out of box for Tensor or Sequence Parallelism.
We recommend users to try ``ColwiseParallel`` and ``RowwiseParallel`` for each parameter
or submodule and there might be some code changes needed now.
"""
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
# instantiate a TP RNG state tracker if it's not there
if is_rng_supported_mesh(device_mesh) and not isinstance(
random._rng_tracker, TensorParallelRNGTracker
):
random._rng_tracker = TensorParallelRNGTracker(device_mesh.device_type)
# TODO: we should allow user to pass in the default seed from a config
random._rng_tracker._manual_seed(
device_mesh, base_seed=1234, tp_dim=tp_mesh_dim
)
# By default we execute random ops in non-tensor-parallel region. If users want
# to execute in tensor-parallel region, they can manually set this field to True
# after parallelizing the model.
random._rng_tracker.distribute_region_enabled = False
if device_mesh.ndim > 1:
_deprecate_warnings("tp_mesh_dim", "If you have a 2-D or N-D device_mesh, consider passing in device_mesh[\"tp\"]")
device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
else:
_validate_tp_mesh_dim(device_mesh)
if isinstance(parallelize_plan, ParallelStyle):
return parallelize_plan._apply(module, device_mesh)
elif isinstance(parallelize_plan, dict):
for module_path, parallelize_style in parallelize_plan.items():
sub_module = module.get_submodule(module_path)
parent_module = module
if "." in module_path:
parent_module_path = ".".join(module_path.split(".")[:-1])
parent_module = module.get_submodule(parent_module_path)
module_path = module_path.split(".")[-1]
parent_module.register_module( # type: ignore[call-arg] # pyre-ignore[20]
module_path,
parallelize_module( # type: ignore[arg-type]
sub_module, device_mesh, parallelize_style # type: ignore[arg-type] # pyre-ignore[6]
),
)
return module
else:
raise RuntimeError( # pyre-ignore[7]
"Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
f" parallelize_plan, {type(parallelize_plan)} found!"
)