mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one.
Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model.
(The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.)
For example:
```
class FeedForward(nn.Module):
def __init__(self, config: TransformerArgs) -> None:
super().__init__()
w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w1 = parallelize_module(w1, Colwise)
self.w2 = parallelize_module(w2, Rowwise)
self.w3 = parallelize_module(w3, Colwise)
def forward(self, x: Tensor) -> Tensor:
y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
# y is a DTensor with Partial placement; we can return it as is.
return y
# Or we can convert it to Replicate -- there is modeling flexibility here.
return y.redistribute(Replicate())
with device_mesh:
model = FeedForward(config)
# Now model is a model parallelized onto device_mesh
y = model(x)
```
The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context.
Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134247
Approved by: https://github.com/tianyu-l
129 lines
5.7 KiB
Python
129 lines
5.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
import warnings
|
|
from fnmatch import fnmatch
|
|
from typing import Dict, Optional, Union
|
|
|
|
import torch
|
|
import torch.distributed.tensor._random as random
|
|
import torch.nn as nn
|
|
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
|
from torch.distributed.tensor._random import (
|
|
is_rng_supported_mesh,
|
|
TensorParallelRNGTracker,
|
|
)
|
|
from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
|
|
from torch.distributed.tensor.parallel.style import ParallelStyle
|
|
|
|
|
|
__all__ = [
|
|
"parallelize_module",
|
|
]
|
|
|
|
|
|
def parallelize_module( # type: ignore[return]
|
|
module: nn.Module,
|
|
device_mesh: Optional[DeviceMesh] = None,
|
|
parallelize_plan: Optional[Union[ParallelStyle, Dict[str, ParallelStyle]]] = None,
|
|
) -> nn.Module:
|
|
"""
|
|
Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.
|
|
|
|
We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains
|
|
:class:`ParallelStyle`, which indicates how user wants the module or sub_module
|
|
to be parallelized.
|
|
|
|
User can also specify different parallel style per module fully qualified name (FQN).
|
|
|
|
Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`,
|
|
slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``)
|
|
|
|
Args:
|
|
module (:class:`nn.Module`):
|
|
Module to be parallelized.
|
|
device_mesh (:class:`DeviceMesh`, optional):
|
|
Object which describes the mesh topology of devices for the DTensor.
|
|
If not specified, the call must be under a DeviceMesh context.
|
|
parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]], optional):
|
|
The plan used to parallelize the module. It can be either a
|
|
:class:`ParallelStyle` object which contains how we prepare
|
|
input/output for Tensor Parallelism or it can be a dict of module
|
|
FQN and its corresponding :class:`ParallelStyle` object. If not
|
|
specified, the call will do nothing at the moment.
|
|
Return:
|
|
A :class:`nn.Module` object parallelized.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP("distributed")
|
|
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
|
|
>>> from torch.distributed.device_mesh import init_device_mesh
|
|
>>>
|
|
>>> # Define the module.
|
|
>>> m = Model(...)
|
|
>>> tp_mesh = init_device_mesh("cuda", (8,))
|
|
>>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
|
|
>>>
|
|
|
|
.. note:: For complex module architecture like Attention, MLP layers, we recommend composing
|
|
different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass
|
|
as a parallelize_plan, to achieves the desired sharding computation.
|
|
"""
|
|
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
|
|
|
|
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
|
|
_validate_tp_mesh_dim(device_mesh)
|
|
|
|
if parallelize_plan is None:
|
|
warnings.warn(
|
|
"No parallelize_plan is provided and auto-parallel is not supported "
|
|
"at the moment, so this parallelize_module call will do nothing."
|
|
)
|
|
return module
|
|
|
|
# instantiate a TP RNG state tracker if it's not there
|
|
if is_rng_supported_mesh(device_mesh) and not isinstance(
|
|
random._rng_tracker, TensorParallelRNGTracker
|
|
):
|
|
random._rng_tracker = TensorParallelRNGTracker(device_mesh.device_type)
|
|
# TODO: we should allow user to pass in the default seed from a config
|
|
random._rng_tracker._manual_seed(device_mesh, base_seed=1234)
|
|
# By default we execute random ops in non-tensor-parallel region. If users want
|
|
# to execute in tensor-parallel region, they can manually set this field to True
|
|
# after parallelizing the model.
|
|
random._rng_tracker.distribute_region_enabled = False
|
|
|
|
if isinstance(parallelize_plan, ParallelStyle):
|
|
return parallelize_plan._apply(module, device_mesh)
|
|
elif isinstance(parallelize_plan, dict):
|
|
for module_path, parallelize_style in parallelize_plan.items():
|
|
path_splits = module_path.split(".")
|
|
if len(path_splits) == 0:
|
|
raise ValueError(
|
|
"Expect module path to be non-empty, but got empty string!"
|
|
)
|
|
while path_splits:
|
|
atom = path_splits.pop(0)
|
|
matched_children = filter(
|
|
# `t[0]` is child name
|
|
lambda t: fnmatch(t[0], atom),
|
|
module.named_children(),
|
|
)
|
|
# apply the plan to all matched submodules
|
|
for _, submodule in matched_children:
|
|
if path_splits:
|
|
# we haven't reached the leaf, apply in dict style
|
|
leaf_path = ".".join(
|
|
path_splits
|
|
) # rest of the path after `atom`
|
|
parallelize_module(
|
|
submodule, device_mesh, {leaf_path: parallelize_style}
|
|
)
|
|
else:
|
|
# otherwise, directly apply style to this submodule
|
|
parallelize_module(submodule, device_mesh, parallelize_style)
|
|
return module
|
|
else:
|
|
raise TypeError( # pyre-ignore[7]
|
|
"Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
|
|
f" parallelize_plan, {type(parallelize_plan)} found!"
|
|
)
|