pytorch/torch/distributed/tensor/parallel/api.py
Ke Wen 23c531b3e9 Allow parallelize_module to get device_mesh from ambient context (#134247)
This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one.

Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model.

(The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.)

For example:
```
class FeedForward(nn.Module):
    def __init__(self, config: TransformerArgs) -> None:
        super().__init__()
        w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
        w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
        w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
        self.w1 = parallelize_module(w1, Colwise)
        self.w2 = parallelize_module(w2, Rowwise)
        self.w3 = parallelize_module(w3, Colwise)

    def forward(self, x: Tensor) -> Tensor:
        y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
        # y is a DTensor with Partial placement; we can return it as is.
        return y
        # Or we can convert it to Replicate -- there is modeling flexibility here.
        return y.redistribute(Replicate())

with device_mesh:
    model = FeedForward(config)
    # Now model is a model parallelized onto device_mesh

y = model(x)

```

The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context.

Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134247
Approved by: https://github.com/tianyu-l
2024-10-09 00:19:03 +00:00

129 lines
5.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
import warnings
from fnmatch import fnmatch
from typing import Dict, Optional, Union
import torch
import torch.distributed.tensor._random as random
import torch.nn as nn
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.tensor._random import (
is_rng_supported_mesh,
TensorParallelRNGTracker,
)
from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
from torch.distributed.tensor.parallel.style import ParallelStyle
__all__ = [
"parallelize_module",
]
def parallelize_module( # type: ignore[return]
module: nn.Module,
device_mesh: Optional[DeviceMesh] = None,
parallelize_plan: Optional[Union[ParallelStyle, Dict[str, ParallelStyle]]] = None,
) -> nn.Module:
"""
Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.
We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains
:class:`ParallelStyle`, which indicates how user wants the module or sub_module
to be parallelized.
User can also specify different parallel style per module fully qualified name (FQN).
Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`,
slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``)
Args:
module (:class:`nn.Module`):
Module to be parallelized.
device_mesh (:class:`DeviceMesh`, optional):
Object which describes the mesh topology of devices for the DTensor.
If not specified, the call must be under a DeviceMesh context.
parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]], optional):
The plan used to parallelize the module. It can be either a
:class:`ParallelStyle` object which contains how we prepare
input/output for Tensor Parallelism or it can be a dict of module
FQN and its corresponding :class:`ParallelStyle` object. If not
specified, the call will do nothing at the moment.
Return:
A :class:`nn.Module` object parallelized.
Example::
>>> # xdoctest: +SKIP("distributed")
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>>
>>> # Define the module.
>>> m = Model(...)
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
>>>
.. note:: For complex module architecture like Attention, MLP layers, we recommend composing
different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass
as a parallelize_plan, to achieves the desired sharding computation.
"""
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
_validate_tp_mesh_dim(device_mesh)
if parallelize_plan is None:
warnings.warn(
"No parallelize_plan is provided and auto-parallel is not supported "
"at the moment, so this parallelize_module call will do nothing."
)
return module
# instantiate a TP RNG state tracker if it's not there
if is_rng_supported_mesh(device_mesh) and not isinstance(
random._rng_tracker, TensorParallelRNGTracker
):
random._rng_tracker = TensorParallelRNGTracker(device_mesh.device_type)
# TODO: we should allow user to pass in the default seed from a config
random._rng_tracker._manual_seed(device_mesh, base_seed=1234)
# By default we execute random ops in non-tensor-parallel region. If users want
# to execute in tensor-parallel region, they can manually set this field to True
# after parallelizing the model.
random._rng_tracker.distribute_region_enabled = False
if isinstance(parallelize_plan, ParallelStyle):
return parallelize_plan._apply(module, device_mesh)
elif isinstance(parallelize_plan, dict):
for module_path, parallelize_style in parallelize_plan.items():
path_splits = module_path.split(".")
if len(path_splits) == 0:
raise ValueError(
"Expect module path to be non-empty, but got empty string!"
)
while path_splits:
atom = path_splits.pop(0)
matched_children = filter(
# `t[0]` is child name
lambda t: fnmatch(t[0], atom),
module.named_children(),
)
# apply the plan to all matched submodules
for _, submodule in matched_children:
if path_splits:
# we haven't reached the leaf, apply in dict style
leaf_path = ".".join(
path_splits
) # rest of the path after `atom`
parallelize_module(
submodule, device_mesh, {leaf_path: parallelize_style}
)
else:
# otherwise, directly apply style to this submodule
parallelize_module(submodule, device_mesh, parallelize_style)
return module
else:
raise TypeError( # pyre-ignore[7]
"Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
f" parallelize_plan, {type(parallelize_plan)} found!"
)