mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This function is only called for nn.MHA or the custom MHA we use, and if it is the former it is converted to the latter. So this check can actually be an assert. Differential Revision: [D45300396](https://our.internmc.facebook.com/intern/diff/D45300396/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/100046 Approved by: https://github.com/wanchaol
429 lines
15 KiB
Python
429 lines
15 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
from typing import Dict, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.distributed._tensor import (
|
|
DeviceMesh,
|
|
DTensor,
|
|
distribute_module,
|
|
distribute_tensor,
|
|
Replicate,
|
|
Shard,
|
|
)
|
|
from torch.distributed._tensor.sharding_prop import _CachingPropagator
|
|
from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh
|
|
from torch.distributed.tensor.parallel.multihead_attention_tp import (
|
|
TensorParallelMultiheadAttention,
|
|
)
|
|
from torch.distributed.tensor.parallel.style import (
|
|
ColwiseParallel,
|
|
PairwiseParallel,
|
|
ParallelStyle,
|
|
RowwiseParallel,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"parallelize_module",
|
|
]
|
|
|
|
# switch the DTensor propagator to use the caching propagator to speed up
|
|
# the TP eager execution time.
|
|
DTensor._propagator = _CachingPropagator(DTensor._propagator.op_to_rules)
|
|
|
|
def parallelize_module( # type: ignore[return]
|
|
module: nn.Module,
|
|
device_mesh: DeviceMesh,
|
|
parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]],
|
|
tp_mesh_dim: int = 0,
|
|
) -> nn.Module:
|
|
"""
|
|
The API to apply Tensor Parallelism (TP) in PyTorch. We parallelize module
|
|
or sub_modules based on a parallelize_plan. The parallelize_plan contains
|
|
:class:`ParallelStyle`, which indicates how user wants the module or sub_module
|
|
to be parallelized.
|
|
|
|
User can also specify different parallel style per module fully qualified name (FQN).
|
|
The API supports 2D parallelism natively by accepting an n-dimension device_mesh
|
|
and users just need to specify the dimension where we perform tensor parallelism on.
|
|
|
|
Args:
|
|
module (:class:`nn.Module`):
|
|
Module to be parallelized.
|
|
device_mesh (:class:`DeviceMesh`):
|
|
Object which describes the mesh topology
|
|
of devices for the DTensor.
|
|
parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]):
|
|
The plan used to parallelize the module. It can be either a
|
|
:class:`ParallelStyle` object which contains how
|
|
we prepare input/output for Tensor Parallelism or it can be a
|
|
dict of module FQN and its corresponding :class:`ParallelStyle` object.
|
|
tp_mesh_dim (int):
|
|
The dimension of ``device_mesh`` where we perform
|
|
Tensor Parallelism on.
|
|
|
|
Return:
|
|
A :class:`nn.Module` object parallelized.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP("distributed")
|
|
>>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel
|
|
>>>
|
|
>>> # Define the module.
|
|
>>> m = Model(...)
|
|
>>> m = parallelize_module(m, PairwiseParallel())
|
|
>>>
|
|
|
|
.. warning::
|
|
``PairwiseParallel`` comes with constraints for now. If you need finer
|
|
granularity, you need to pass in a dict of module FQN and parallel style instead.
|
|
"""
|
|
|
|
if device_mesh.ndim > 1:
|
|
device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
|
|
|
|
if isinstance(parallelize_plan, ParallelStyle):
|
|
# RowwiseParallel or ColwiseParallel
|
|
if isinstance(parallelize_plan, (ColwiseParallel, RowwiseParallel)):
|
|
return _parallelize_linear(module, device_mesh, parallelize_plan)
|
|
# PairwiseParallel
|
|
if _is_mha_for_pairwise_parallel(module):
|
|
return _parallelize_multihead_attn(module, device_mesh)
|
|
elif _is_mlp_for_pairwise_parallel(module):
|
|
return _parallelize_mlp(module, device_mesh, parallelize_plan)
|
|
else:
|
|
for n, m in module.named_children():
|
|
module.register_module(
|
|
n, parallelize_module(m, device_mesh, parallelize_plan)
|
|
)
|
|
return module
|
|
elif isinstance(parallelize_plan, dict):
|
|
for module_path, parallelize_style in parallelize_plan.items():
|
|
sub_module = module.get_submodule(module_path)
|
|
parent_module = module
|
|
if "." in module_path:
|
|
parent_module_path = ".".join(module_path.split(".")[:-1])
|
|
parent_module = module.get_submodule(parent_module_path)
|
|
module_path = module_path.split(".")[-1]
|
|
parent_module.register_module( # type: ignore[call-arg] # pyre-ignore[20]
|
|
module_path,
|
|
parallelize_module( # type: ignore[arg-type]
|
|
sub_module, device_mesh, parallelize_style # type: ignore[arg-type] # pyre-ignore[6]
|
|
),
|
|
)
|
|
return module
|
|
else:
|
|
raise RuntimeError( # pyre-ignore[7]
|
|
"Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
|
|
f" parallelize_plan, {type(parallelize_plan)} found!"
|
|
)
|
|
|
|
|
|
def _is_mha_for_pairwise_parallel(module: nn.Module) -> bool:
|
|
"""
|
|
Check whether the mha module is the one can be handled for Pairwise parallel.
|
|
|
|
Args:
|
|
module (:class:`nn.Module`):
|
|
Module to be checked.
|
|
|
|
Return:
|
|
A boolean object which specifies whether the module is MHA supported by Pairwise parallel or not.
|
|
"""
|
|
return isinstance(module, (TensorParallelMultiheadAttention, nn.MultiheadAttention))
|
|
|
|
|
|
def _is_mlp_for_pairwise_parallel(module: nn.Module) -> bool:
|
|
"""
|
|
Traverse through all the immediate children of the given module and count the
|
|
number of Linear module. If the number is more than one, we return True.
|
|
|
|
Args:
|
|
module (:class:`nn.Module`):
|
|
Module to be traversed and counted.
|
|
|
|
Return:
|
|
A bool which specifies whether the module is MLP supported or not.
|
|
|
|
.. warning::
|
|
The traversal is not recursive for now.
|
|
"""
|
|
linear_submodules = list(
|
|
filter(lambda x: isinstance(x, nn.Linear), module.children())
|
|
)
|
|
return len(linear_submodules) > 1
|
|
|
|
|
|
def _rowwise_parallelize_linear_fn(
|
|
name: str,
|
|
module: nn.Module,
|
|
device_mesh: DeviceMesh,
|
|
) -> None:
|
|
"""
|
|
This function parallelizes the input :class:`nn.Linear` module in
|
|
:class:`RowwiseParallel` style.
|
|
|
|
Args:
|
|
name (str):
|
|
Name of the input module.
|
|
module (:class:`nn.Module`):
|
|
The :class:`nn.Linear` module to be parallelized.
|
|
device_mesh (:class:`DeviceMesh`):
|
|
Object which describes the mesh topology of devices.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
|
|
for name, param in module.named_parameters():
|
|
dist_spec = (
|
|
[Shard(1)] if name == "weight" else [Replicate()] # type: ignore[list-item]
|
|
)
|
|
dist_param = torch.nn.Parameter(
|
|
distribute_tensor(param, device_mesh, dist_spec)
|
|
)
|
|
module.register_parameter(name, dist_param)
|
|
|
|
|
|
def _colwise_parallelize_linear_fn(
|
|
name: str,
|
|
module: nn.Module,
|
|
device_mesh: DeviceMesh,
|
|
) -> None:
|
|
"""
|
|
This function parallelizes the input :class:`nn.Linear` module in
|
|
:class:`ColwiseParallel` style.
|
|
|
|
Args:
|
|
name (str):
|
|
Name of the input module.
|
|
module (:class:`nn.Module`):
|
|
The :class:`nn.Linear` module to be parallelized.
|
|
device_mesh (:class:`DeviceMesh`):
|
|
Object which describes the mesh topology of devices.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
|
|
for name, param in module.named_parameters():
|
|
dist_param = torch.nn.Parameter(
|
|
distribute_tensor(param, device_mesh, [Shard(0)])
|
|
)
|
|
module.register_parameter(name, dist_param)
|
|
|
|
|
|
def _parallelize_linear(
|
|
module: nn.Module,
|
|
device_mesh: DeviceMesh,
|
|
parallel_style: ParallelStyle = ColwiseParallel(),
|
|
tp_mesh_dim: int = 0,
|
|
) -> nn.Module:
|
|
"""
|
|
This function requires that the input module be an object
|
|
of :class:`nn.Linear`.
|
|
The module will be parallelized over a 1-d :class:`DeviceMesh`
|
|
based on the :class:`ParallelStyle`.
|
|
|
|
Args:
|
|
module (:class:`nn.Module`):
|
|
The module to be parallelized.
|
|
device_mesh (:class:`DeviceMesh`):
|
|
Object which describes the mesh topology of devices for the :class:`DTensor`.
|
|
If the mesh is more than 1-dimensional, we will use the mesh dim of
|
|
`device_mesh` specified by `tp_mesh_dim`.
|
|
parallel_style (:class:`ParallelStyle`, optional):
|
|
The object which describes how the :class:`nn.Linear` module
|
|
should be distributed over :class:`DeviceMesh` and how the input
|
|
and output should be prepared for Tensor Parallelism.
|
|
:class:`RowwiseStyle`: weight is sharded on dim 1 and bias is
|
|
replicate.
|
|
:class:`ColwiseStyle`: weight and bias are both sharded on dim 0.
|
|
Default: :class:`ColwiseParallel`
|
|
tp_mesh_dim (int):
|
|
The dimension of :class:`DeviceMesh` on which we
|
|
perform Tensor Parallelism.
|
|
Default: 0
|
|
|
|
Return:
|
|
A :class:`nn.Module` object parallelized.
|
|
"""
|
|
|
|
if not isinstance(module, nn.Linear):
|
|
raise RuntimeError(
|
|
f"Expect a torch.nn.Linear module but received {type(module)}!"
|
|
)
|
|
|
|
if not isinstance(parallel_style, ParallelStyle):
|
|
raise RuntimeError(
|
|
"Expect a ParallelStyle object but received" f" {type(parallel_style)}!"
|
|
)
|
|
|
|
if device_mesh.ndim > 1:
|
|
device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
|
|
|
|
if isinstance(parallel_style, RowwiseParallel):
|
|
distribute_module(
|
|
module,
|
|
device_mesh,
|
|
_rowwise_parallelize_linear_fn,
|
|
input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
|
output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
|
)
|
|
elif isinstance(parallel_style, ColwiseParallel):
|
|
distribute_module(
|
|
module,
|
|
device_mesh,
|
|
_colwise_parallelize_linear_fn,
|
|
input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
|
output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
|
)
|
|
else:
|
|
raise RuntimeError(f"{type(parallel_style)} is not supported!")
|
|
return module
|
|
|
|
|
|
def _parallelize_multihead_attn(
|
|
module: nn.Module,
|
|
device_mesh: DeviceMesh,
|
|
parallel_style: ParallelStyle = PairwiseParallel(),
|
|
tp_mesh_dim: int = 0,
|
|
) -> nn.Module:
|
|
"""
|
|
This function assumes the input module is a sequence of nn.Linear
|
|
and we parallelize the module based on the given parallel style.
|
|
We don't change the FQN of each sub-module and replace each parameter
|
|
in place.
|
|
|
|
Args:
|
|
module (:class:`nn.Module`):
|
|
Module to be parallelized.
|
|
device_mesh (:class:`DeviceMesh`):
|
|
Object which describes the mesh topology of devices.
|
|
parallel_style (:class:`ParallelStyle`):
|
|
Object which contains how we prepare input/output
|
|
for Tensor Parallelism.
|
|
tp_mesh_dim (int):
|
|
The dimension of `device_mesh` where we perform
|
|
Tensor Parallelism on.
|
|
|
|
Return:
|
|
A :class:`nn.Module` object parallelized.
|
|
|
|
.. warning::
|
|
We only support ``PairwiseParallel`` right now.
|
|
"""
|
|
|
|
if not isinstance(parallel_style, PairwiseParallel):
|
|
raise NotImplementedError(
|
|
"Only support PairwiseParallel for Multihead Attention" " parallelization."
|
|
)
|
|
|
|
if device_mesh.ndim > 1:
|
|
device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
|
|
|
|
if isinstance(module, nn.MultiheadAttention):
|
|
tp_multi_head_attention = TensorParallelMultiheadAttention(
|
|
module.embed_dim,
|
|
module.num_heads,
|
|
device=torch.device(device_mesh.device_type),
|
|
tp_size=device_mesh.size(tp_mesh_dim),
|
|
add_bias_kv=module.bias_k is not None,
|
|
)
|
|
tp_multi_head_attention.copy(module)
|
|
module = tp_multi_head_attention
|
|
|
|
assert isinstance(module, TensorParallelMultiheadAttention), (
|
|
f"Expects TensorParallelMultiheadAttention but got {type(module)}"
|
|
)
|
|
# shard TPMA
|
|
for n, m in module.named_children():
|
|
if n == "qkv":
|
|
# Col-wise Parallelize the qkv layer.
|
|
distribute_module(
|
|
m,
|
|
device_mesh,
|
|
_colwise_parallelize_linear_fn,
|
|
input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
|
)
|
|
elif n == "proj":
|
|
# Row-wise Parallelize the proj layer
|
|
distribute_module(
|
|
m,
|
|
device_mesh,
|
|
_rowwise_parallelize_linear_fn,
|
|
output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
|
)
|
|
return module
|
|
|
|
|
|
def _parallelize_mlp(
|
|
module: nn.Module,
|
|
device_mesh: DeviceMesh,
|
|
parallel_style: ParallelStyle = PairwiseParallel(),
|
|
tp_mesh_dim: int = 0,
|
|
) -> nn.Module:
|
|
"""
|
|
This function assumes the input module is a sequence of nn.Linear
|
|
and we parallelize the module based on the given parallel style.
|
|
We don't change the FQN of each sub-module and replace each parameter
|
|
in place.
|
|
|
|
Args:
|
|
module (:class:`nn.Module`):
|
|
Module to be parallelized.
|
|
device_mesh (:class:`DeviceMesh`):
|
|
Object which describes the mesh topology of devices.
|
|
parallel_style (:class:`ParallelStyle`):
|
|
Object which contains how we prepare input/output
|
|
for Tensor Parallelism.
|
|
tp_mesh_dim (int):
|
|
The dimension of `device_mesh` where we perform
|
|
Tensor Parallelism on.
|
|
|
|
Return:
|
|
A :class:`nn.Module` object parallelized.
|
|
|
|
.. warning::
|
|
We only support ``PairwiseParallel`` right now.
|
|
"""
|
|
if not isinstance(parallel_style, PairwiseParallel):
|
|
raise NotImplementedError(
|
|
"Only support PairwiseParallel for MLP parallelization."
|
|
)
|
|
|
|
if not _is_mlp_for_pairwise_parallel(module):
|
|
raise RuntimeError("More than one nn.Linear needed for a MLP.")
|
|
|
|
if device_mesh.ndim > 1:
|
|
device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
|
|
|
|
linear_submodules = list(
|
|
filter(lambda x: isinstance(x, nn.Linear), module.children())
|
|
)
|
|
mlp_last_even_layer = (len(linear_submodules) // 2) * 2
|
|
for i in range(mlp_last_even_layer):
|
|
m = linear_submodules[i]
|
|
if i % 2 == 0:
|
|
# Col-wise Parallelize the linear layer
|
|
distribute_module(
|
|
m,
|
|
device_mesh,
|
|
_colwise_parallelize_linear_fn,
|
|
input_fn=parallel_style._prepare_input # type: ignore[arg-type, misc] # pyre-ignore[6]
|
|
if i == 0
|
|
else None,
|
|
)
|
|
else:
|
|
# Row-wise Parallelize the linear layer
|
|
distribute_module(
|
|
m,
|
|
device_mesh,
|
|
_rowwise_parallelize_linear_fn,
|
|
output_fn=parallel_style._prepare_output # type: ignore[arg-type, misc] # pyre-ignore[6]
|
|
if i == (mlp_last_even_layer - 1)
|
|
else None,
|
|
)
|
|
return module
|