mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes: * many path renames and path import fixes * a dedicated doc page without too much content yet (adding in the next PRs) * To preserve the BC for users still using the `torch.distributed._tensor`, I added a shim script to redirect old path calls to the new module The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113 Approved by: https://github.com/XilunWu ghstack dependencies: #133305, #133306
119 lines
5.1 KiB
Python
119 lines
5.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
from fnmatch import fnmatch
|
|
from typing import Dict, Union
|
|
|
|
import torch
|
|
import torch.distributed.tensor.random as random
|
|
import torch.nn as nn
|
|
from torch.distributed.tensor import DeviceMesh
|
|
from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
|
|
from torch.distributed.tensor.parallel.style import ParallelStyle
|
|
from torch.distributed.tensor.random import (
|
|
is_rng_supported_mesh,
|
|
TensorParallelRNGTracker,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"parallelize_module",
|
|
]
|
|
|
|
|
|
def parallelize_module( # type: ignore[return]
|
|
module: nn.Module,
|
|
device_mesh: DeviceMesh,
|
|
parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]],
|
|
) -> nn.Module:
|
|
"""
|
|
Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.
|
|
|
|
We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains
|
|
:class:`ParallelStyle`, which indicates how user wants the module or sub_module
|
|
to be parallelized.
|
|
|
|
User can also specify different parallel style per module fully qualified name (FQN).
|
|
|
|
Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`,
|
|
slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``)
|
|
|
|
Args:
|
|
module (:class:`nn.Module`):
|
|
Module to be parallelized.
|
|
device_mesh (:class:`DeviceMesh`):
|
|
Object which describes the mesh topology
|
|
of devices for the DTensor.
|
|
parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]):
|
|
The plan used to parallelize the module. It can be either a
|
|
:class:`ParallelStyle` object which contains how
|
|
we prepare input/output for Tensor Parallelism or it can be a
|
|
dict of module FQN and its corresponding :class:`ParallelStyle` object.
|
|
Return:
|
|
A :class:`nn.Module` object parallelized.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP("distributed")
|
|
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
|
|
>>> from torch.distributed.device_mesh import init_device_mesh
|
|
>>>
|
|
>>> # Define the module.
|
|
>>> m = Model(...)
|
|
>>> tp_mesh = init_device_mesh("cuda", (8,))
|
|
>>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
|
|
>>>
|
|
|
|
.. note:: For complex module architecture like Attention, MLP layers, we recommend composing
|
|
different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass
|
|
as a parallelize_plan, to achieves the desired sharding computation.
|
|
"""
|
|
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
|
|
|
|
_validate_tp_mesh_dim(device_mesh)
|
|
|
|
# instantiate a TP RNG state tracker if it's not there
|
|
if is_rng_supported_mesh(device_mesh) and not isinstance(
|
|
random._rng_tracker, TensorParallelRNGTracker
|
|
):
|
|
random._rng_tracker = TensorParallelRNGTracker(device_mesh.device_type)
|
|
# TODO: we should allow user to pass in the default seed from a config
|
|
random._rng_tracker._manual_seed(device_mesh, base_seed=1234)
|
|
# By default we execute random ops in non-tensor-parallel region. If users want
|
|
# to execute in tensor-parallel region, they can manually set this field to True
|
|
# after parallelizing the model.
|
|
random._rng_tracker.distribute_region_enabled = False
|
|
|
|
if isinstance(parallelize_plan, ParallelStyle):
|
|
return parallelize_plan._apply(module, device_mesh)
|
|
elif isinstance(parallelize_plan, dict):
|
|
for module_path, parallelize_style in parallelize_plan.items():
|
|
path_splits = module_path.split(".")
|
|
if len(path_splits) == 0:
|
|
raise ValueError(
|
|
"Expect module path to be non-empty, but got empty string!"
|
|
)
|
|
while path_splits:
|
|
atom = path_splits.pop(0)
|
|
matched_children = filter(
|
|
# `t[0]` is child name
|
|
lambda t: fnmatch(t[0], atom),
|
|
module.named_children(),
|
|
)
|
|
# apply the plan to all matched submodules
|
|
for _, submodule in matched_children:
|
|
if path_splits:
|
|
# we haven't reached the leaf, apply in dict style
|
|
leaf_path = ".".join(
|
|
path_splits
|
|
) # rest of the path after `atom`
|
|
parallelize_module(
|
|
submodule, device_mesh, {leaf_path: parallelize_style}
|
|
)
|
|
else:
|
|
# otherwise, directly apply style to this submodule
|
|
parallelize_module(submodule, device_mesh, parallelize_style)
|
|
return module
|
|
else:
|
|
raise TypeError( # pyre-ignore[7]
|
|
"Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
|
|
f" parallelize_plan, {type(parallelize_plan)} found!"
|
|
)
|