mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183) Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732 Approved by: https://github.com/wz337, https://github.com/fduwjj
120 lines
4.9 KiB
Python
120 lines
4.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
from typing import Dict, Union
|
|
|
|
import torch
|
|
import torch.distributed._tensor.random as random
|
|
import torch.nn as nn
|
|
from torch.distributed._tensor import (
|
|
DeviceMesh,
|
|
)
|
|
from torch.distributed._tensor.random import (
|
|
is_rng_supported_mesh,
|
|
TensorParallelRNGTracker,
|
|
)
|
|
from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh, _validate_tp_mesh_dim, _deprecate_warnings
|
|
from torch.distributed.tensor.parallel.style import (
|
|
ParallelStyle,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"parallelize_module",
|
|
]
|
|
|
|
|
|
def parallelize_module( # type: ignore[return]
|
|
module: nn.Module,
|
|
device_mesh: DeviceMesh,
|
|
parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]],
|
|
tp_mesh_dim: int = 0,
|
|
) -> nn.Module:
|
|
"""
|
|
Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.
|
|
|
|
We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains
|
|
:class:`ParallelStyle`, which indicates how user wants the module or sub_module
|
|
to be parallelized.
|
|
|
|
User can also specify different parallel style per module fully qualified name (FQN).
|
|
The API supports 2D parallelism natively by accepting an n-dimension device_mesh
|
|
and users just need to specify the dimension where we perform tensor parallelism on.
|
|
|
|
Args:
|
|
module (:class:`nn.Module`):
|
|
Module to be parallelized.
|
|
device_mesh (:class:`DeviceMesh`):
|
|
Object which describes the mesh topology
|
|
of devices for the DTensor.
|
|
parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]):
|
|
The plan used to parallelize the module. It can be either a
|
|
:class:`ParallelStyle` object which contains how
|
|
we prepare input/output for Tensor Parallelism or it can be a
|
|
dict of module FQN and its corresponding :class:`ParallelStyle` object.
|
|
tp_mesh_dim (int):
|
|
The dimension of ``device_mesh`` where we perform
|
|
Tensor Parallelism on.
|
|
|
|
Return:
|
|
A :class:`nn.Module` object parallelized.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP("distributed")
|
|
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
|
|
>>>
|
|
>>> # Define the module.
|
|
>>> m = Model(...)
|
|
>>> m = parallelize_module(m, ColwiseParallel())
|
|
>>>
|
|
|
|
.. warning::
|
|
Currently, there are some constraints which makes it hard for complicated modules
|
|
like ``MultiheadAttention`` to work out of box for Tensor or Sequence Parallelism.
|
|
We recommend users to try ``ColwiseParallel`` and ``RowwiseParallel`` for each parameter
|
|
or submodule and there might be some code changes needed now.
|
|
"""
|
|
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
|
|
|
|
# instantiate a TP RNG state tracker if it's not there
|
|
if is_rng_supported_mesh(device_mesh) and not isinstance(
|
|
random._rng_tracker, TensorParallelRNGTracker
|
|
):
|
|
random._rng_tracker = TensorParallelRNGTracker(device_mesh.device_type)
|
|
# TODO: we should allow user to pass in the default seed from a config
|
|
random._rng_tracker._manual_seed(
|
|
device_mesh, base_seed=1234, tp_dim=tp_mesh_dim
|
|
)
|
|
# By default we execute random ops in non-tensor-parallel region. If users want
|
|
# to execute in tensor-parallel region, they can manually set this field to True
|
|
# after parallelizing the model.
|
|
random._rng_tracker.distribute_region_enabled = False
|
|
|
|
if device_mesh.ndim > 1:
|
|
_deprecate_warnings("tp_mesh_dim", "If you have a 2-D or N-D device_mesh, consider passing in device_mesh[\"tp\"]")
|
|
device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
|
|
else:
|
|
_validate_tp_mesh_dim(device_mesh)
|
|
|
|
|
|
if isinstance(parallelize_plan, ParallelStyle):
|
|
return parallelize_plan._apply(module, device_mesh)
|
|
elif isinstance(parallelize_plan, dict):
|
|
for module_path, parallelize_style in parallelize_plan.items():
|
|
sub_module = module.get_submodule(module_path)
|
|
parent_module = module
|
|
if "." in module_path:
|
|
parent_module_path = ".".join(module_path.split(".")[:-1])
|
|
parent_module = module.get_submodule(parent_module_path)
|
|
module_path = module_path.split(".")[-1]
|
|
parent_module.register_module( # type: ignore[call-arg] # pyre-ignore[20]
|
|
module_path,
|
|
parallelize_module( # type: ignore[arg-type]
|
|
sub_module, device_mesh, parallelize_style # type: ignore[arg-type] # pyre-ignore[6]
|
|
),
|
|
)
|
|
return module
|
|
else:
|
|
raise RuntimeError( # pyre-ignore[7]
|
|
"Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
|
|
f" parallelize_plan, {type(parallelize_plan)} found!"
|
|
)
|