mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Cleaning up ignores that are no longer needed in the repo and adding select suppressions so the main branch is clean. test plan: `lintrunner -a` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166178 Approved by: https://github.com/oulgen
819 lines
36 KiB
Python
819 lines
36 KiB
Python
# mypy: allow-untyped-defs
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
from abc import ABC, abstractmethod
|
|
from functools import partial
|
|
from typing import Any, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.distributed.tensor import (
|
|
DeviceMesh,
|
|
distribute_module,
|
|
distribute_tensor,
|
|
DTensor,
|
|
Replicate,
|
|
Shard,
|
|
)
|
|
from torch.distributed.tensor.placement_types import Placement
|
|
|
|
|
|
__all__ = [
|
|
"ParallelStyle",
|
|
"RowwiseParallel",
|
|
"SequenceParallel",
|
|
"ColwiseParallel",
|
|
"PrepareModuleInput",
|
|
"PrepareModuleInputOutput",
|
|
"PrepareModuleOutput",
|
|
]
|
|
|
|
|
|
class ParallelStyle(ABC):
|
|
"""
|
|
The parallel style contract defines how the module or submodule should be parallelized.
|
|
|
|
It only defines the ``apply`` method for ``parallelize_module`` to use, this allows maximum
|
|
flexibility for different kind of style implementations.
|
|
"""
|
|
|
|
src_data_rank: Optional[int] = 0
|
|
|
|
@abstractmethod
|
|
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ...
|
|
|
|
|
|
class ColwiseParallel(ParallelStyle):
|
|
"""
|
|
Partition a compatible nn.Module in a column-wise fashion. Currently supports nn.Linear and nn.Embedding.
|
|
Users can compose it together with RowwiseParallel to achieve the sharding of more complicated modules.
|
|
(i.e. MLP, Attention)
|
|
|
|
Keyword Args:
|
|
input_layouts (Placement, optional):
|
|
The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
|
|
become a DTensor. If not specified, we assume the input tensor to be replicated.
|
|
output_layouts (Placement, optional):
|
|
The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
|
|
with the user desired layout. If not specified, the output tensor is sharded on the last dimension.
|
|
use_local_output (bool, optional):
|
|
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
|
|
Returns:
|
|
A :class:`ParallelStyle` object that represents Colwise sharding of the nn.Module.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP(failing)
|
|
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
|
|
>>> from torch.distributed.device_mesh import init_device_mesh
|
|
>>> ...
|
|
>>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule
|
|
>>> tp_mesh = init_device_mesh("cuda", (8,))
|
|
>>>
|
|
>>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor
|
|
>>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim.
|
|
>>>
|
|
>>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()})
|
|
>>> ...
|
|
|
|
.. note:: By default ``ColwiseParallel`` output is sharded on the last dimension if the ``output_layouts`` not
|
|
specified, if there're operators that require specific tensor shape (i.e. before the paired ``RowwiseParallel``),
|
|
keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
input_layouts: Optional[Placement] = None,
|
|
output_layouts: Optional[Placement] = None,
|
|
use_local_output: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.input_layouts = (input_layouts or Replicate(),)
|
|
self.output_layouts = (output_layouts or Shard(-1),)
|
|
# colwise linear runtime sharding (desired sharding):
|
|
# 1. requires replicate input
|
|
# 2. shard output on last dim
|
|
self.desired_input_layouts = (Replicate(),)
|
|
self.use_local_output = use_local_output
|
|
|
|
@staticmethod
|
|
def _prepare_input_fn(
|
|
input_layouts, desired_input_layouts, mod, inputs, device_mesh
|
|
):
|
|
# TODO: figure out dynamo support for instance method and switch this to instance method
|
|
|
|
# annotate module input placements/sharding with input_layouts
|
|
input_tensor = inputs[0]
|
|
if not isinstance(input_tensor, DTensor):
|
|
input_tensor = DTensor.from_local(
|
|
input_tensor, device_mesh, input_layouts, run_check=False
|
|
)
|
|
|
|
# transform the input layouts to the desired layouts of ColwiseParallel
|
|
if input_layouts != desired_input_layouts:
|
|
input_tensor = input_tensor.redistribute(
|
|
placements=desired_input_layouts, async_op=True
|
|
)
|
|
return input_tensor
|
|
|
|
def _partition_linear_fn(self, name, module, device_mesh):
|
|
# colwise shard weight/bias to Shard(0), weight be Shard(0)
|
|
# means Colwise as Linear is input * weight^T + bias, where
|
|
# weight would become Shard(1)
|
|
for name, param in module.named_parameters():
|
|
dist_param = nn.Parameter(
|
|
distribute_tensor(
|
|
param, device_mesh, [Shard(0)], src_data_rank=self.src_data_rank
|
|
)
|
|
)
|
|
module.register_parameter(name, dist_param)
|
|
|
|
def _partition_embedding_fn(self, name, module, device_mesh):
|
|
# colwise shard embedding.weight is straight forward as Shard(1)
|
|
for name, param in module.named_parameters():
|
|
dist_param = nn.Parameter(
|
|
distribute_tensor(
|
|
param, device_mesh, [Shard(1)], src_data_rank=self.src_data_rank
|
|
)
|
|
)
|
|
module.register_parameter(name, dist_param)
|
|
|
|
@staticmethod
|
|
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
|
# outputs is a shard on last dimension DTensor, i.e. Shard(-1)
|
|
if outputs.placements != output_layouts:
|
|
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
|
|
# back to local tensor
|
|
return outputs.to_local() if use_local_output else outputs
|
|
|
|
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
|
if isinstance(module, nn.Linear):
|
|
partition_fn = self._partition_linear_fn
|
|
elif isinstance(module, nn.Embedding):
|
|
partition_fn = self._partition_embedding_fn
|
|
else:
|
|
raise NotImplementedError(
|
|
"ColwiseParallel currently only support nn.Linear and nn.Embedding!"
|
|
)
|
|
|
|
return distribute_module(
|
|
module,
|
|
device_mesh,
|
|
partition_fn,
|
|
partial(
|
|
self._prepare_input_fn, self.input_layouts, self.desired_input_layouts
|
|
),
|
|
partial(
|
|
self._prepare_output_fn, self.output_layouts, self.use_local_output
|
|
),
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
tmpstr = self.__class__.__name__ + "("
|
|
tmpstr += f"input_layouts={self.input_layouts}, "
|
|
tmpstr += f"output_layouts={self.output_layouts}, "
|
|
tmpstr += f"use_local_output={self.use_local_output}"
|
|
tmpstr += ")"
|
|
return tmpstr
|
|
|
|
|
|
class RowwiseParallel(ParallelStyle):
|
|
"""
|
|
Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding.
|
|
Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules.
|
|
(i.e. MLP, Attention)
|
|
|
|
Keyword Args:
|
|
input_layouts (Placement, optional):
|
|
The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
|
|
become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension.
|
|
output_layouts (Placement, optional):
|
|
The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
|
|
with the user desired layout. If not specified, the output tensor is replicated.
|
|
use_local_output (bool, optional):
|
|
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
|
|
Returns:
|
|
A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP(failing)
|
|
>>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel
|
|
>>> from torch.distributed.device_mesh import init_device_mesh
|
|
>>> ...
|
|
>>> m = Model(...) # m is a nn.Module that contains a "w2" nn.Linear submodule
|
|
>>> tp_mesh = init_device_mesh("cuda", (8,))
|
|
>>>
|
|
>>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim
|
|
>>> # and the output of "w2" will return a replicated :class:`torch.Tensor`.
|
|
>>>
|
|
>>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}),
|
|
>>> ...
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
input_layouts: Optional[Placement] = None,
|
|
output_layouts: Optional[Placement] = None,
|
|
use_local_output: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.input_layouts = (input_layouts or Shard(-1),)
|
|
self.output_layouts = (output_layouts or Replicate(),)
|
|
self.use_local_output = use_local_output
|
|
|
|
@staticmethod
|
|
def _prepare_input_fn(
|
|
input_layouts, desired_input_layouts, mod, inputs, device_mesh
|
|
):
|
|
input_tensor = inputs[0]
|
|
if not isinstance(input_tensor, DTensor):
|
|
input_tensor = DTensor.from_local(
|
|
input_tensor, device_mesh, input_layouts, run_check=False
|
|
)
|
|
|
|
if input_layouts != desired_input_layouts:
|
|
input_tensor = input_tensor.redistribute(
|
|
placements=desired_input_layouts, async_op=True
|
|
)
|
|
return input_tensor
|
|
|
|
def _partition_linear_fn(self, name, module, device_mesh):
|
|
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
|
|
# means Rowwise as nn.Linear is input * weight^T + bias, where
|
|
# weight would become Shard(0)
|
|
module.register_parameter(
|
|
"weight",
|
|
nn.Parameter(
|
|
distribute_tensor(
|
|
module.weight,
|
|
device_mesh,
|
|
[Shard(1)],
|
|
src_data_rank=self.src_data_rank,
|
|
)
|
|
),
|
|
)
|
|
if getattr(module, "bias", None) is not None:
|
|
# The Linear module has bias
|
|
module.register_parameter(
|
|
"bias",
|
|
nn.Parameter(
|
|
distribute_tensor(
|
|
module.bias,
|
|
device_mesh,
|
|
[Replicate()],
|
|
src_data_rank=self.src_data_rank,
|
|
)
|
|
),
|
|
)
|
|
|
|
def _partition_embedding_fn(self, name, module, device_mesh):
|
|
# rowwise shard embedding.weight is Shard(0)
|
|
for name, param in module.named_parameters():
|
|
dist_param = nn.Parameter(
|
|
distribute_tensor(
|
|
param, device_mesh, [Shard(0)], src_data_rank=self.src_data_rank
|
|
)
|
|
)
|
|
module.register_parameter(name, dist_param)
|
|
|
|
@staticmethod
|
|
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
|
# Rowwise sharding produces partial output, depending on output layouts:
|
|
# 1. to replicate -> allreduce
|
|
# 2. to shard -> reduce_scatter
|
|
if outputs.placements != output_layouts:
|
|
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
|
|
# back to local tensor if use_local_output is True
|
|
return outputs.to_local() if use_local_output else outputs
|
|
|
|
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
|
if isinstance(module, nn.Linear):
|
|
partition_fn = self._partition_linear_fn
|
|
# rowwise linear runtime sharding requires input tensor shard on last dim
|
|
self.desired_input_layouts: tuple[Placement, ...] = (Shard(-1),)
|
|
elif isinstance(module, nn.Embedding):
|
|
partition_fn = self._partition_embedding_fn
|
|
# rowwise embedding runtime sharding requires input tensor replicated
|
|
self.desired_input_layouts = (Replicate(),)
|
|
else:
|
|
raise NotImplementedError(
|
|
"RowwiseParallel currently only support nn.Linear and nn.Embedding!"
|
|
)
|
|
|
|
return distribute_module(
|
|
module,
|
|
device_mesh,
|
|
partition_fn,
|
|
partial(
|
|
self._prepare_input_fn, self.input_layouts, self.desired_input_layouts
|
|
),
|
|
partial(
|
|
self._prepare_output_fn, self.output_layouts, self.use_local_output
|
|
),
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
tmpstr = self.__class__.__name__ + "("
|
|
tmpstr += f"input_layouts={self.input_layouts}, "
|
|
tmpstr += f"output_layouts={self.output_layouts}, "
|
|
tmpstr += f"use_local_output={self.use_local_output}"
|
|
tmpstr += ")"
|
|
return tmpstr
|
|
|
|
|
|
class SequenceParallel(ParallelStyle):
|
|
"""
|
|
SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with
|
|
input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the
|
|
`RMSNorm python implementation <https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34>`__
|
|
|
|
This style implements the operation that is described in the paper
|
|
`Reducing Activation Recomputation in Large Transformer Models <https://arxiv.org/abs/2205.05198>`__
|
|
|
|
If the input passed in to this ``nn.Module`` is a :class:`torch.Tensor`, it assumes that the input is already sharded
|
|
on the sequence dimension and converts the input to a :class:`DTensor` sharded on the sequence dimension. If the input
|
|
passed in to this ``nn.Module`` is already a :class:`DTensor` but is not sharded on the sequence dimension, it would
|
|
redistribute the input to be sharded on the sequence dimension.
|
|
|
|
The output of the ``nn.Module`` will be sharded on the sequence dimension.
|
|
|
|
Keyword Args:
|
|
sequence_dim (int, optional):
|
|
The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to
|
|
become a DTensor that is sharded on the sequence dimension, default: 1.
|
|
use_local_output (bool, optional):
|
|
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False.
|
|
Returns:
|
|
A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP(failing)
|
|
>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
|
|
>>> from torch.distributed.device_mesh import init_device_mesh
|
|
>>> ...
|
|
>>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
|
|
>>> tp_mesh = init_device_mesh("cuda", (8,))
|
|
>>>
|
|
>>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
|
|
>>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
|
|
>>>
|
|
>>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),
|
|
>>> ...
|
|
|
|
.. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e.
|
|
``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom
|
|
inits for the weights on those modules, you need to broadcast the weights before/after parallelizing
|
|
to ensure that they are replicated.
|
|
"""
|
|
|
|
def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False):
|
|
super().__init__()
|
|
self.sequence_sharding = (Shard(sequence_dim),)
|
|
self.use_local_output = use_local_output
|
|
|
|
def _replicate_module_fn(
|
|
self, name: str, module: nn.Module, device_mesh: DeviceMesh
|
|
):
|
|
for p_name, param in module.named_parameters():
|
|
# simple replication with fixed ones_ init from LayerNorm/RMSNorm, which allow
|
|
# us to simply just use from_local
|
|
replicated_param = torch.nn.Parameter(
|
|
DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
|
|
)
|
|
module.register_parameter(p_name, replicated_param)
|
|
|
|
@staticmethod
|
|
def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh):
|
|
input_tensor = inputs[0]
|
|
if isinstance(input_tensor, DTensor):
|
|
# if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it
|
|
if input_tensor.placements != sequence_sharding:
|
|
input_tensor = input_tensor.redistribute(
|
|
placements=sequence_sharding, async_op=True
|
|
)
|
|
return input_tensor
|
|
elif isinstance(input_tensor, torch.Tensor):
|
|
# assume the input passed in already sharded on the sequence dim and create the DTensor
|
|
return DTensor.from_local(
|
|
input_tensor, device_mesh, sequence_sharding, run_check=False
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}"
|
|
)
|
|
|
|
@staticmethod
|
|
def _prepare_output_fn(use_local_output, mod, outputs, device_mesh):
|
|
return outputs.to_local() if use_local_output else outputs
|
|
|
|
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
|
return distribute_module(
|
|
module,
|
|
device_mesh,
|
|
self._replicate_module_fn,
|
|
partial(self._prepare_input_fn, self.sequence_sharding),
|
|
partial(self._prepare_output_fn, self.use_local_output),
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
tmpstr = self.__class__.__name__ + "("
|
|
if len(self.sequence_sharding) == 1:
|
|
tmpstr += f"sequence_dim={self.sequence_sharding[0].dim}, "
|
|
tmpstr += f"use_local_output={self.use_local_output}"
|
|
tmpstr += ")"
|
|
return tmpstr
|
|
|
|
|
|
class PrepareModuleInput(ParallelStyle):
|
|
"""
|
|
Configure the nn.Module's inputs to convert the input tensors of the nn.Module to DTensors at runtime according to
|
|
``input_layouts``, and perform layout redistribution according to the ``desired_input_layouts``.
|
|
|
|
Keyword Args:
|
|
input_layouts (Union[Placement, Tuple[Optional[Placement]]]):
|
|
The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to
|
|
DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified
|
|
as a placeholder. default: None.
|
|
desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]):
|
|
The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module
|
|
have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``. default: None.
|
|
input_kwarg_layouts (Dict[str, Placement]):
|
|
The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors.
|
|
default: None
|
|
desired_input_kwarg_layouts: (Dict[str, Placement]):
|
|
The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module
|
|
have the desired DTensor layouts. default: None.
|
|
use_local_output (bool, optional):
|
|
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False.
|
|
Returns:
|
|
A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP(failing)
|
|
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
|
|
>>> from torch.distributed.device_mesh import init_device_mesh
|
|
>>> ...
|
|
>>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule
|
|
>>> tp_mesh = init_device_mesh("cuda", (8,))
|
|
>>>
|
|
>>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor
|
|
>>> # and then redistributed to Replicated DTensor.
|
|
>>> parallelize_module(
|
|
>>> block, # this can be a submodule or module
|
|
>>> tp_mesh,
|
|
>>> parallelize_plan={
|
|
>>> "attn": PrepareModuleInput(
|
|
>>> input_layouts=(Shard(0), None, None, ...),
|
|
>>> desired_input_layouts=(Replicate(), None, None, ...)
|
|
>>> ),
|
|
>>> }
|
|
>>> )
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
input_layouts: Optional[
|
|
Union[Placement, tuple[Optional[Placement], ...]]
|
|
] = None,
|
|
desired_input_layouts: Optional[
|
|
Union[Placement, tuple[Optional[Placement], ...]]
|
|
] = None,
|
|
input_kwarg_layouts: Optional[dict[str, Placement]] = None,
|
|
desired_input_kwarg_layouts: Optional[dict[str, Placement]] = None,
|
|
use_local_output: bool = False,
|
|
):
|
|
self.input_layouts = (
|
|
(input_layouts,) if isinstance(input_layouts, Placement) else input_layouts
|
|
)
|
|
self.desired_input_layouts = (
|
|
(desired_input_layouts,)
|
|
if isinstance(desired_input_layouts, Placement)
|
|
else desired_input_layouts
|
|
)
|
|
self.use_local_output = use_local_output
|
|
if self.input_layouts is not None:
|
|
assert self.desired_input_layouts is not None, (
|
|
"desired module inputs should not be None!"
|
|
)
|
|
assert len(self.input_layouts) == len(self.desired_input_layouts), (
|
|
"input_layouts and desired_input_layouts should have same length!"
|
|
)
|
|
self.with_kwargs = input_kwarg_layouts is not None
|
|
self.input_kwarg_layouts = input_kwarg_layouts or {}
|
|
self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {}
|
|
if self.with_kwargs:
|
|
assert len(self.input_kwarg_layouts) == len(
|
|
self.desired_input_kwarg_layouts
|
|
), (
|
|
"input_kwarg_layouts and desired_input_kwarg_layouts should have same length!"
|
|
)
|
|
|
|
def _prepare_input_arg(
|
|
self,
|
|
input: Any,
|
|
mesh: DeviceMesh,
|
|
input_layout: Optional[Placement],
|
|
desired_layout: Optional[Placement],
|
|
):
|
|
if input_layout is not None:
|
|
if isinstance(input, DTensor):
|
|
# TODO: re-enable the check once we fix the compile path
|
|
# assert inp.placements[0] == input_layout
|
|
dt_inp = input
|
|
else:
|
|
assert isinstance(input, torch.Tensor), (
|
|
"expecting input to be a torch.Tensor!"
|
|
)
|
|
dt_inp = DTensor.from_local(
|
|
input, mesh, (input_layout,), run_check=False
|
|
)
|
|
|
|
if desired_layout is not None and input_layout != desired_layout:
|
|
dt_inp = dt_inp.redistribute(placements=(desired_layout,))
|
|
|
|
return dt_inp.to_local() if self.use_local_output else dt_inp
|
|
else:
|
|
return input
|
|
|
|
def _prepare_input_fn(self, inputs, device_mesh):
|
|
if self.input_layouts is None:
|
|
return inputs
|
|
prepared_inputs = []
|
|
if not isinstance(inputs, tuple):
|
|
inputs = (inputs,)
|
|
if len(inputs) != len(self.input_layouts):
|
|
raise ValueError("module inputs and input_layouts should have same length!")
|
|
|
|
assert self.desired_input_layouts is not None, (
|
|
"desired module inputs should not be None!"
|
|
)
|
|
|
|
for inp, input_layout, desired_layout in zip(
|
|
inputs, self.input_layouts, self.desired_input_layouts
|
|
):
|
|
prepared_inputs.append(
|
|
self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout)
|
|
)
|
|
return tuple(prepared_inputs)
|
|
|
|
def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh):
|
|
prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh)
|
|
prepared_kwarg_inputs = {}
|
|
for kwarg_key in kwarg_inputs.keys():
|
|
kwarg_val = kwarg_inputs[kwarg_key]
|
|
input_layout = self.input_kwarg_layouts.get(kwarg_key)
|
|
desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key)
|
|
|
|
prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg(
|
|
kwarg_val, device_mesh, input_layout, desired_input_layout
|
|
)
|
|
|
|
return (prepared_arg_inputs, prepared_kwarg_inputs)
|
|
|
|
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
|
if self.with_kwargs:
|
|
module.register_forward_pre_hook(
|
|
lambda _, inputs, kwargs: self._prepare_input_kwarg_fn(
|
|
inputs, kwargs, device_mesh
|
|
),
|
|
with_kwargs=True,
|
|
) # type: ignore[misc]
|
|
else:
|
|
module.register_forward_pre_hook(
|
|
lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)
|
|
) # type: ignore[misc, call-arg]
|
|
return module
|
|
|
|
def __repr__(self) -> str:
|
|
tmpstr = self.__class__.__name__ + "("
|
|
tmpstr += f"input_layouts={self.input_layouts}, "
|
|
tmpstr += f"desired_input_layouts={self.desired_input_layouts}, "
|
|
tmpstr += f"input_kwarg_layouts={self.input_kwarg_layouts}, "
|
|
tmpstr += f"desired_input_kwarg_layouts={self.desired_input_kwarg_layouts}, "
|
|
tmpstr += f"use_local_output={self.use_local_output}"
|
|
tmpstr += ")"
|
|
return tmpstr
|
|
|
|
|
|
class PrepareModuleOutput(ParallelStyle):
|
|
"""
|
|
Configure the nn.Module's outputs to convert the output tensors of the nn.Module to DTensors at runtime according to
|
|
``output_layouts``, and perform layout redistribution according to the ``desired_output_layouts``.
|
|
|
|
Keyword Args:
|
|
output_layouts (Union[Placement, Tuple[Placement]]):
|
|
The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to
|
|
DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors,
|
|
``None`` need to be specified as a placeholder.
|
|
desired_output_layouts (Union[Placement, Tuple[Placement]]):
|
|
The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module
|
|
have the desired DTensor layouts.
|
|
use_local_output (bool, optional):
|
|
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: True.
|
|
Returns:
|
|
A ParallelStyle object that prepares the sharding layouts of the nn.Module's outputs.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP(failing)
|
|
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput
|
|
>>> from torch.distributed.device_mesh import init_device_mesh
|
|
>>> ...
|
|
>>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule
|
|
>>> tp_mesh = init_device_mesh("cuda", (8,))
|
|
>>>
|
|
>>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor
|
|
>>> # and then redistributed to Sharded DTensor.
|
|
>>> parallelize_module(
|
|
>>> block, # this can be a submodule or module
|
|
>>> tp_mesh,
|
|
>>> parallelize_plan = PrepareModuleOutput(
|
|
>>> output_layouts=Replicate(),
|
|
>>> desired_output_layouts=Shard(0)
|
|
>>> )
|
|
>>> )
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
output_layouts: Union[Placement, tuple[Optional[Placement], ...]],
|
|
desired_output_layouts: Union[Placement, tuple[Placement, ...]],
|
|
use_local_output: bool = True,
|
|
):
|
|
self.output_layouts = (
|
|
(output_layouts,)
|
|
if isinstance(output_layouts, Placement)
|
|
else output_layouts
|
|
)
|
|
self.desired_output_layouts = (
|
|
(desired_output_layouts,)
|
|
if isinstance(desired_output_layouts, Placement)
|
|
else desired_output_layouts
|
|
)
|
|
self.use_local_output = use_local_output
|
|
assert len(self.output_layouts) == len(self.desired_output_layouts), (
|
|
"output_layouts and desired_output_layouts should have same length!"
|
|
)
|
|
|
|
def _prepare_out_fn(self, outputs, device_mesh):
|
|
prepared_outputs = []
|
|
if not isinstance(outputs, tuple):
|
|
outputs = (outputs,)
|
|
if len(outputs) != len(self.output_layouts):
|
|
raise ValueError(
|
|
"module outputs and output_layouts should have same length!"
|
|
)
|
|
|
|
for out, out_layout, desired_out_layout in zip(
|
|
outputs, self.output_layouts, self.desired_output_layouts
|
|
):
|
|
if out_layout is not None:
|
|
if isinstance(out, DTensor):
|
|
# TODO: re-enable the check once we fix the compile path
|
|
# assert out.placements[0] == out_layout
|
|
dt_out = out
|
|
else:
|
|
dt_out = DTensor.from_local(
|
|
out, device_mesh, (out_layout,), run_check=False
|
|
)
|
|
|
|
if out_layout != desired_out_layout:
|
|
dt_out = dt_out.redistribute(placements=(desired_out_layout,))
|
|
prepared_outputs.append(
|
|
dt_out.to_local() if self.use_local_output else dt_out
|
|
)
|
|
else:
|
|
prepared_outputs.append(out)
|
|
if len(prepared_outputs) == 1:
|
|
return prepared_outputs[0]
|
|
else:
|
|
return tuple(prepared_outputs)
|
|
|
|
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
|
module.register_forward_hook(
|
|
lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)
|
|
) # type: ignore[misc, call-arg]
|
|
return module
|
|
|
|
def __repr__(self) -> str:
|
|
tmpstr = self.__class__.__name__ + "("
|
|
tmpstr += f"output_layouts={self.output_layouts}, "
|
|
tmpstr += f"desired_output_layouts={self.desired_output_layouts}, "
|
|
tmpstr += f"use_local_output={self.use_local_output}"
|
|
tmpstr += ")"
|
|
return tmpstr
|
|
|
|
|
|
class PrepareModuleInputOutput(ParallelStyle):
|
|
"""
|
|
Configure the nn.Module's inputs (and outputs) to convert the input tensors (and output tensors, respectively) of the nn.Module
|
|
to DTensors at runtime according to ``input_layouts`` (and output_layouts, respectively), and perform layout redistribution
|
|
according to the ``desired_input_layouts`` (and ``desired_output_layouts``, respectively). This is a combination of
|
|
:class:`PrepareModuleInput` and :class:`PrepareModuleOutput`.
|
|
|
|
Keyword Args:
|
|
input_layouts (Union[Placement, Tuple[Optional[Placement]]]):
|
|
The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to
|
|
DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified
|
|
as a placeholder. default: None.
|
|
desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]):
|
|
The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module
|
|
have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``. default: None.
|
|
input_kwarg_layouts (Dict[str, Placement]):
|
|
The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors.
|
|
default: None
|
|
desired_input_kwarg_layouts: (Dict[str, Placement]):
|
|
The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module
|
|
have the desired DTensor layouts. default: None.
|
|
use_local_input (bool, optional):
|
|
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False.
|
|
output_layouts (Union[Placement, Tuple[Placement]]):
|
|
The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to
|
|
DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors,
|
|
``None`` need to be specified as a placeholder.
|
|
desired_output_layouts (Union[Placement, Tuple[Placement]]):
|
|
The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module
|
|
have the desired DTensor layouts.
|
|
use_local_output (bool, optional):
|
|
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: True.
|
|
Returns:
|
|
A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs and outputs.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP(failing)
|
|
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInputOutput
|
|
>>> from torch.distributed.device_mesh import init_device_mesh
|
|
>>> ...
|
|
>>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule
|
|
>>> tp_mesh = init_device_mesh("cuda", (8,))
|
|
>>>
|
|
>>> # According to the style specified below, the first input of attn will be annotated as Sharded DTensor
|
|
>>> # and then redistributed to Replicated DTensor, and the output of the TransformerBlock will be annotated
|
|
>>> # as Replicated DTensor and then redistributed to Sharded DTensor.
|
|
>>> parallelize_module(
|
|
>>> block, # this can be a submodule or module
|
|
>>> tp_mesh,
|
|
>>> parallelize_plan={
|
|
>>> "attn": PrepareModuleInputOutput(
|
|
>>> input_layouts=(Shard(0), None, None, ...),
|
|
>>> desired_input_layouts=(Replicate(), None, None, ...),
|
|
>>> output_layouts=Replicate(),
|
|
>>> desired_output_layouts=Shard(0),
|
|
>>> ),
|
|
>>> }
|
|
>>> )
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
input_layouts: Optional[
|
|
Union[Placement, tuple[Optional[Placement], ...]]
|
|
] = None,
|
|
desired_input_layouts: Optional[
|
|
Union[Placement, tuple[Optional[Placement], ...]]
|
|
] = None,
|
|
input_kwarg_layouts: Optional[dict[str, Placement]] = None,
|
|
desired_input_kwarg_layouts: Optional[dict[str, Placement]] = None,
|
|
use_local_input: bool = False,
|
|
output_layouts: Union[Placement, tuple[Optional[Placement], ...]],
|
|
desired_output_layouts: Union[Placement, tuple[Placement, ...]],
|
|
use_local_output: bool = True,
|
|
):
|
|
self.prepare_module_input = PrepareModuleInput(
|
|
input_layouts=input_layouts,
|
|
desired_input_layouts=desired_input_layouts,
|
|
input_kwarg_layouts=input_kwarg_layouts,
|
|
desired_input_kwarg_layouts=desired_input_kwarg_layouts,
|
|
use_local_output=use_local_input,
|
|
)
|
|
self.prepare_module_output = PrepareModuleOutput(
|
|
output_layouts=output_layouts,
|
|
desired_output_layouts=desired_output_layouts,
|
|
use_local_output=use_local_output,
|
|
)
|
|
|
|
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
|
self.prepare_module_input._apply(module, device_mesh)
|
|
self.prepare_module_output._apply(module, device_mesh)
|
|
|
|
return module
|
|
|
|
def __repr__(self) -> str:
|
|
tmpstr = self.__class__.__name__ + "("
|
|
tmpstr += f"input_layouts={self.prepare_module_input.input_layouts}, "
|
|
tmpstr += (
|
|
f"desired_input_layouts={self.prepare_module_input.desired_input_layouts}, "
|
|
)
|
|
tmpstr += (
|
|
f"input_kwarg_layouts={self.prepare_module_input.input_kwarg_layouts}, "
|
|
)
|
|
tmpstr += f"desired_input_kwarg_layouts={self.prepare_module_input.desired_input_kwarg_layouts}, "
|
|
tmpstr += f"use_local_input={self.prepare_module_input.use_local_output}, "
|
|
tmpstr += f"output_layouts={self.prepare_module_output.output_layouts}, "
|
|
tmpstr += f"desired_output_layouts={self.prepare_module_output.desired_output_layouts}, "
|
|
tmpstr += f"use_local_output={self.prepare_module_output.use_local_output}"
|
|
tmpstr += ")"
|
|
return tmpstr
|