mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR adds a cached propagator for TP use, it caches the sharding prop decision for the same input sharding on an operator. This could improve eager mode performance. Differential Revision: [D42876249](https://our.internmc.facebook.com/intern/diff/D42876249) Pull Request resolved: https://github.com/pytorch/pytorch/pull/90734 Approved by: https://github.com/XilunWu, https://github.com/fduwjj
423 lines
15 KiB
Python
423 lines
15 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
from typing import Dict, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.distributed._tensor import (
|
|
DeviceMesh,
|
|
DTensor,
|
|
distribute_module,
|
|
distribute_tensor,
|
|
Replicate,
|
|
Shard,
|
|
)
|
|
from torch.distributed._tensor.sharding_prop import _CachingPropagator
|
|
from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh
|
|
from torch.distributed.tensor.parallel.multihead_attention_tp import (
|
|
TensorParallelMultiheadAttention,
|
|
)
|
|
from torch.distributed.tensor.parallel.style import (
|
|
ColwiseParallel,
|
|
PairwiseParallel,
|
|
ParallelStyle,
|
|
RowwiseParallel,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"parallelize_module",
|
|
]
|
|
|
|
# switch the DTensor propagator to use the caching propagator to speed up
|
|
# the TP eager execution time.
|
|
DTensor._propagator = _CachingPropagator(DTensor._propagator.op_to_rules)
|
|
|
|
def parallelize_module( # type: ignore[return]
|
|
module: nn.Module,
|
|
device_mesh: DeviceMesh,
|
|
parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]],
|
|
tp_mesh_dim: int = 0,
|
|
) -> nn.Module:
|
|
"""
|
|
The API to apply Tensor Parallelism (TP) in PyTorch. We parallelize module
|
|
or sub_modules based on a parallelize_plan which contains the parallel_style
|
|
which indicates how user want the module or sub_module to be parallelized.
|
|
User can also specify different parallel_style per module fully qualifed name (FQN).
|
|
The API supports 2D parallelism natively by accepting an n-dimension device_mesh
|
|
and users just need to specify the dimension where we perform tensor parallelism on.
|
|
|
|
Args:
|
|
module (:class:`nn.Module`):
|
|
Module to be parallelized.
|
|
device_mesh (:class:`DeviceMesh`):
|
|
Object which describes the mesh topology
|
|
of devices for the DTensor.
|
|
parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]):
|
|
The plan used to parallelize the module. It can be either a
|
|
:class:`ParallelStyle` object which contains how
|
|
we prepare input/output for Tensor Parallelism or it can be a
|
|
dict of module FQN and its corresponding :class:`ParallelStyle` object.
|
|
tp_mesh_dim (int):
|
|
The dimension of ``device_mesh`` where we perform
|
|
Tensor Parallelism on.
|
|
|
|
Return:
|
|
A :class:`nn.Module` object parallelized.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP("distributed")
|
|
>>> from torch.distributed._tensor.parallel import parallelize_module, PairwiseParallel
|
|
>>>
|
|
>>> # Define the module.
|
|
>>> m = Model(...)
|
|
>>> m = parallelize_module(m, PairwiseParallel())
|
|
>>>
|
|
|
|
.. warning::
|
|
``PairwiseParallel`` comes with constraints for now. If you need finer
|
|
granularity, you need to pass in a dict of module FQN and parallel style instead.
|
|
"""
|
|
|
|
if device_mesh.ndim > 1:
|
|
device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
|
|
|
|
if isinstance(parallelize_plan, ParallelStyle):
|
|
# RowwiseParallel or ColwiseParallel
|
|
if isinstance(parallelize_plan, ColwiseParallel) or isinstance(
|
|
parallelize_plan, RowwiseParallel
|
|
):
|
|
return _parallelize_linear(module, device_mesh, parallelize_plan)
|
|
# PairwiseParallel
|
|
if _is_mha_for_pairwise_parallel(module):
|
|
return _parallelize_multihead_attn(module, device_mesh)
|
|
elif _is_mlp_for_pairwise_parallel(module):
|
|
return _parallelize_mlp(module, device_mesh)
|
|
else:
|
|
for n, m in module.named_children():
|
|
module.register_module(
|
|
n, parallelize_module(m, device_mesh, parallelize_plan)
|
|
)
|
|
return module
|
|
elif isinstance(parallelize_plan, dict):
|
|
for module_path, parallelize_style in parallelize_plan.items():
|
|
sub_module = module.get_submodule(module_path)
|
|
module.register_module( # type: ignore[call-arg] # pyre-ignore[20]
|
|
module_path,
|
|
parallelize_module( # type: ignore[arg-type]
|
|
sub_module, device_mesh, parallelize_style # type: ignore[arg-type] # pyre-ignore[6]
|
|
),
|
|
)
|
|
return module
|
|
else:
|
|
raise RuntimeError( # pyre-ignore[7]
|
|
"Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
|
|
f" parallelize_plan, {type(parallelize_plan)} found!"
|
|
)
|
|
|
|
|
|
def _is_mha_for_pairwise_parallel(module: nn.Module) -> bool:
|
|
"""
|
|
Check whether the mha module is the one can be handled for Pairwise parallel.
|
|
|
|
Args:
|
|
module (:class:`nn.Module`):
|
|
Module to be checked.
|
|
|
|
Return:
|
|
A boolean object which specifies whether the module is MHA supported by Pairwise parallel or not.
|
|
"""
|
|
return isinstance(module, TensorParallelMultiheadAttention) or isinstance(
|
|
module, nn.MultiheadAttention
|
|
)
|
|
|
|
|
|
def _is_mlp_for_pairwise_parallel(module: nn.Module) -> bool:
|
|
"""
|
|
Traverse through all the immediate children of the given module and count the
|
|
number of Linear module. If the number is more than one, we return True.
|
|
|
|
Args:
|
|
module (:class:`nn.Module`):
|
|
Module to be traversed and counted.
|
|
|
|
Return:
|
|
A bool which specifies whether the module is MLP supported or not.
|
|
|
|
.. warning::
|
|
The traversal is not recursive for now.
|
|
"""
|
|
linear_submodules = list(
|
|
filter(lambda x: isinstance(x, nn.Linear), module.children())
|
|
)
|
|
return len(linear_submodules) > 1
|
|
|
|
|
|
def _rowwise_parallelize_linear_fn(
|
|
name: str,
|
|
module: nn.Module,
|
|
device_mesh: DeviceMesh,
|
|
) -> None:
|
|
"""
|
|
This function parallelizes the input :class:`nn.Linear` module in
|
|
:class:`RowwiseParallel` style.
|
|
|
|
Args:
|
|
name (str):
|
|
Name of the input module.
|
|
module (:class:`nn.Module`):
|
|
The :class:`nn.Linear` module to be parallelized.
|
|
device_mesh (:class:`DeviceMesh`):
|
|
Object which describes the mesh topology of devices.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
|
|
for name, param in module.named_parameters():
|
|
dist_spec = (
|
|
[Shard(1)] if name == "weight" else [Replicate()] # type: ignore[list-item]
|
|
)
|
|
dist_param = torch.nn.Parameter(
|
|
distribute_tensor(param, device_mesh, dist_spec)
|
|
)
|
|
module.register_parameter(name, dist_param)
|
|
|
|
|
|
def _colwise_parallelize_linear_fn(
|
|
name: str,
|
|
module: nn.Module,
|
|
device_mesh: DeviceMesh,
|
|
) -> None:
|
|
"""
|
|
This function parallelizes the input :class:`nn.Linear` module in
|
|
:class:`ColwiseParallel` style.
|
|
|
|
Args:
|
|
name (str):
|
|
Name of the input module.
|
|
module (:class:`nn.Module`):
|
|
The :class:`nn.Linear` module to be parallelized.
|
|
device_mesh (:class:`DeviceMesh`):
|
|
Object which describes the mesh topology of devices.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
|
|
for name, param in module.named_parameters():
|
|
dist_param = torch.nn.Parameter(
|
|
distribute_tensor(param, device_mesh, [Shard(0)])
|
|
)
|
|
module.register_parameter(name, dist_param)
|
|
|
|
|
|
def _parallelize_linear(
|
|
module: nn.Module,
|
|
device_mesh: DeviceMesh,
|
|
parallel_style: ParallelStyle = ColwiseParallel(),
|
|
tp_mesh_dim: int = 0,
|
|
) -> nn.Module:
|
|
"""
|
|
This function requires that the input module be an object
|
|
of :class:`nn.Linear`.
|
|
The module will be parallelized over a 1-d :class:`DeviceMesh`
|
|
based on the :class:`ParallelStyle`.
|
|
|
|
Args:
|
|
module (:class:`nn.Module`):
|
|
The module to be parallelized.
|
|
device_mesh (:class:`DeviceMesh`):
|
|
Object which describes the mesh topology of devices for the :class:`DTensor`.
|
|
If the mesh is more than 1-dimensional, we will use the mesh dim of
|
|
`device_mesh` specified by `tp_mesh_dim`.
|
|
parallel_style (:class:`ParallelStyle`, optional):
|
|
The object which describes how the :class:`nn.Linear` module
|
|
should be distributed over :class:`DeviceMesh` and how the input
|
|
and output should be prepared for Tensor Parallelism.
|
|
:class:`RowwiseStyle`: weight is sharded on dim 1 and bias is
|
|
replicate.
|
|
:class:`ColwiseStyle`: weight and bias are both sharded on dim 0.
|
|
Default: :class:`ColwiseParallel`
|
|
tp_mesh_dim (int):
|
|
The dimension of :class:`DeviceMesh` on which we
|
|
perform Tensor Parallelism.
|
|
Default: 0
|
|
|
|
Return:
|
|
A :class:`nn.Module` object parallelized.
|
|
"""
|
|
|
|
if not isinstance(module, nn.Linear):
|
|
raise RuntimeError(
|
|
f"Expect a torch.nn.Linear module but received {type(module)}!"
|
|
)
|
|
|
|
if not isinstance(parallel_style, ParallelStyle):
|
|
raise RuntimeError(
|
|
"Expect a ParallelStyle object but received" f" {type(parallel_style)}!"
|
|
)
|
|
|
|
if device_mesh.ndim > 1:
|
|
device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
|
|
|
|
if isinstance(parallel_style, RowwiseParallel):
|
|
distribute_module(
|
|
module,
|
|
device_mesh,
|
|
_rowwise_parallelize_linear_fn,
|
|
input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
|
output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
|
)
|
|
elif isinstance(parallel_style, ColwiseParallel):
|
|
distribute_module(
|
|
module,
|
|
device_mesh,
|
|
_colwise_parallelize_linear_fn,
|
|
input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
|
output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
|
)
|
|
else:
|
|
raise RuntimeError(f"{type(parallel_style)} is not supported!")
|
|
return module
|
|
|
|
|
|
def _parallelize_multihead_attn(
|
|
module: nn.Module,
|
|
device_mesh: DeviceMesh,
|
|
parallel_style: ParallelStyle = PairwiseParallel(),
|
|
tp_mesh_dim: int = 0,
|
|
) -> nn.Module:
|
|
"""
|
|
This function assumes the input module is a sequence of nn.Linear
|
|
and we parallelize the module based on the given parallel style.
|
|
We don't change the FQN of each sub-module and replace each parameter
|
|
in place.
|
|
|
|
Args:
|
|
module (:class:`nn.Module`):
|
|
Module to be parallelized.
|
|
device_mesh (:class:`DeviceMesh`):
|
|
Object which describes the mesh topology of devices.
|
|
parallel_style (:class:`ParallelStyle`):
|
|
Object which contains how we prepare input/output
|
|
for Tensor Parallelism.
|
|
tp_mesh_dim (int):
|
|
The dimension of `device_mesh` where we perform
|
|
Tensor Parallelism on.
|
|
|
|
Return:
|
|
A :class:`nn.Module` object parallelized.
|
|
|
|
.. warning::
|
|
We only support ``PairwiseParallel`` right now.
|
|
"""
|
|
|
|
if not isinstance(parallel_style, PairwiseParallel):
|
|
raise NotImplementedError(
|
|
"Only support PairwiseParallel for Multihead Attention" " parallelization."
|
|
)
|
|
|
|
if device_mesh.ndim > 1:
|
|
device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
|
|
|
|
if isinstance(module, nn.MultiheadAttention):
|
|
tp_multi_head_attention = TensorParallelMultiheadAttention(
|
|
module.embed_dim,
|
|
module.num_heads,
|
|
device=torch.device(device_mesh.device_type),
|
|
tp_size=device_mesh.size(tp_mesh_dim),
|
|
add_bias_kv=module.bias_k is not None,
|
|
)
|
|
tp_multi_head_attention.copy(module)
|
|
module = tp_multi_head_attention
|
|
|
|
if isinstance(module, TensorParallelMultiheadAttention): # shard TPMA
|
|
for n, m in module.named_children():
|
|
if n == "qkv":
|
|
# Col-wise Parallelize the qkv layer.
|
|
distribute_module(
|
|
m,
|
|
device_mesh,
|
|
_colwise_parallelize_linear_fn,
|
|
input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
|
)
|
|
elif n == "proj":
|
|
# Row-wise Parallelize the proj layer
|
|
distribute_module(
|
|
m,
|
|
device_mesh,
|
|
_rowwise_parallelize_linear_fn,
|
|
output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
|
)
|
|
return module
|
|
|
|
|
|
def _parallelize_mlp(
|
|
module: nn.Module,
|
|
device_mesh: DeviceMesh,
|
|
parallel_style: ParallelStyle = PairwiseParallel(),
|
|
tp_mesh_dim: int = 0,
|
|
) -> nn.Module:
|
|
"""
|
|
This function assumes the input module is a sequence of nn.Linear
|
|
and we parallelize the module based on the given parallel style.
|
|
We don't change the FQN of each sub-module and replace each parameter
|
|
in place.
|
|
|
|
Args:
|
|
module (:class:`nn.Module`):
|
|
Module to be parallelized.
|
|
device_mesh (:class:`DeviceMesh`):
|
|
Object which describes the mesh topology of devices.
|
|
parallel_style (:class:`ParallelStyle`):
|
|
Object which contains how we prepare input/output
|
|
for Tensor Parallelism.
|
|
tp_mesh_dim (int):
|
|
The dimension of `device_mesh` where we perform
|
|
Tensor Parallelism on.
|
|
|
|
Return:
|
|
A :class:`nn.Module` object parallelized.
|
|
|
|
.. warning::
|
|
We only support ``PairwiseParallel`` right now.
|
|
"""
|
|
if not isinstance(parallel_style, PairwiseParallel):
|
|
raise NotImplementedError(
|
|
"Only support PairwiseParallel for MLP parallelization."
|
|
)
|
|
|
|
if not _is_mlp_for_pairwise_parallel(module):
|
|
raise RuntimeError("More than one nn.Linear needed for a MLP.")
|
|
|
|
if device_mesh.ndim > 1:
|
|
device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
|
|
|
|
linear_submodules = list(
|
|
filter(lambda x: isinstance(x, nn.Linear), module.children())
|
|
)
|
|
mlp_last_even_layer = (len(linear_submodules) // 2) * 2
|
|
for i in range(mlp_last_even_layer):
|
|
m = linear_submodules[i]
|
|
if i % 2 == 0:
|
|
# Col-wise Parallelize the linear layer
|
|
distribute_module(
|
|
m,
|
|
device_mesh,
|
|
_colwise_parallelize_linear_fn,
|
|
input_fn=parallel_style._prepare_input # type: ignore[arg-type, misc] # pyre-ignore[6]
|
|
if i == 0
|
|
else None,
|
|
)
|
|
else:
|
|
# Row-wise Parallelize the linear layer
|
|
distribute_module(
|
|
m,
|
|
device_mesh,
|
|
_rowwise_parallelize_linear_fn,
|
|
output_fn=parallel_style._prepare_output # type: ignore[arg-type, misc] # pyre-ignore[6]
|
|
if i == (mlp_last_even_layer - 1)
|
|
else None,
|
|
)
|
|
return module
|