mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116151 Approved by: https://github.com/wanchaol
380 lines
18 KiB
Python
380 lines
18 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
from abc import ABC, abstractmethod
|
|
from typing import Optional, Union, Tuple
|
|
from functools import partial
|
|
|
|
import torch.nn as nn
|
|
from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Replicate, Shard, distribute_tensor, distribute_module
|
|
|
|
|
|
__all__ = [
|
|
"ParallelStyle",
|
|
"RowwiseParallel",
|
|
"ColwiseParallel",
|
|
"PrepareModuleInput",
|
|
"PrepareModuleOutput",
|
|
]
|
|
|
|
|
|
class ParallelStyle(ABC):
|
|
"""
|
|
The parallel style contract defines how the module or submodule should be parallelized.
|
|
|
|
It only defines the ``apply`` method for ``parallelize_module`` to use, this allows maximum
|
|
flexibility for different kind of style implementations.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
|
...
|
|
|
|
|
|
class ColwiseParallel(ParallelStyle):
|
|
"""
|
|
Partition a compatible nn.Module in a column-wise fashion. Currently supports nn.Linear and nn.Embedding.
|
|
Users can compose it together with RowwiseParallel to achieve the sharding of more complicated modules.
|
|
(i.e. MLP, Attention)
|
|
|
|
Keyword Args:
|
|
input_layouts (Placement, optional):
|
|
The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
|
|
become a DTensor. If not specified, we assume the input tensor to be replicated.
|
|
output_layouts (Placement, optional):
|
|
The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
|
|
with the user desired layout. If not specified, the output tensor is sharded on the last dimension.
|
|
use_local_output (bool, optional):
|
|
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
|
|
Returns:
|
|
A :class:`ParallelStyle` object that represents Colwise sharding of the nn.Module.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP(failing)
|
|
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
|
|
>>> ...
|
|
>>> # By default, the input of the "w1" Linear will be annotated to Replicated DTensor
|
|
>>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim.
|
|
>>>>
|
|
>>> parallelize_module(
|
|
>>> module=block, # this can be a submodule or module
|
|
>>> ...,
|
|
>>> parallelize_plan={"w1": ColwiseParallel()},
|
|
>>> )
|
|
>>> ...
|
|
|
|
.. note:: By default ``ColwiseParallel`` output is sharded on the last dimension if the ``output_layouts`` not
|
|
specified, if there're operators that require specific tensor shape (i.e. before the paired ``RowwiseParallel``),
|
|
keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
input_layouts: Optional[Placement] = None,
|
|
output_layouts: Optional[Placement] = None,
|
|
use_local_output: bool = True
|
|
):
|
|
super().__init__()
|
|
self.input_layouts = (input_layouts or Replicate(), )
|
|
self.output_layouts = (output_layouts or Shard(-1), )
|
|
# colwise linear runtime sharding (desired sharding):
|
|
# 1. requires replicate input
|
|
# 2. shard output on last dim
|
|
self.desired_input_layouts = (Replicate(), )
|
|
self.use_local_output = use_local_output
|
|
|
|
@staticmethod
|
|
def _prepare_input_fn(input_layouts, desired_input_layouts, inputs, device_mesh):
|
|
# TODO: figure out dynamo support for instance method and switch this to instance method
|
|
|
|
# annotate module input placements/sharding with input_layouts
|
|
input_tensor = inputs[0]
|
|
if not isinstance(input_tensor, DTensor):
|
|
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
|
|
|
|
# transform the input layouts to the desired layouts of ColwiseParallel
|
|
if input_layouts != desired_input_layouts:
|
|
input_tensor = input_tensor.redistribute(placements=desired_input_layouts)
|
|
return input_tensor
|
|
|
|
def _partition_fn(self, name, module, device_mesh):
|
|
if isinstance(module, nn.Linear):
|
|
# colwise shard weight/bias to Shard(0), weight be Shard(0)
|
|
# means Colwise as Linear is input * weight^T + bias, where
|
|
# weight would become Shard(1)
|
|
for name, param in module.named_parameters():
|
|
dist_param = nn.Parameter(
|
|
distribute_tensor(param, device_mesh, [Shard(0)])
|
|
)
|
|
module.register_parameter(name, dist_param)
|
|
elif isinstance(module, nn.Embedding):
|
|
# colwise shard embedding.weight is straight forward as Shard(1)
|
|
for name, param in module.named_parameters():
|
|
dist_param = nn.Parameter(
|
|
distribute_tensor(param, device_mesh, [Shard(1)])
|
|
)
|
|
module.register_parameter(name, dist_param)
|
|
else:
|
|
raise NotImplementedError(
|
|
"ColwiseParallel only supports nn.Linear"
|
|
f"and nn.Embedding for now, but found {type(module)}!"
|
|
)
|
|
|
|
@staticmethod
|
|
def _prepare_output_fn(output_layouts, use_local_output, outputs, device_mesh):
|
|
# outputs is a shard on last dimension DTensor, i.e. Shard(-1)
|
|
outputs = outputs.redistribute(placements=output_layouts)
|
|
# back to local tensor
|
|
return outputs.to_local() if use_local_output else outputs
|
|
|
|
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
|
return distribute_module(
|
|
module,
|
|
device_mesh,
|
|
self._partition_fn,
|
|
partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
|
|
partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
|
|
)
|
|
|
|
|
|
class RowwiseParallel(ParallelStyle):
|
|
"""
|
|
Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear only.
|
|
Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules.
|
|
(i.e. MLP, Attention)
|
|
|
|
Keyword Args:
|
|
input_layouts (Placement, optional):
|
|
The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
|
|
become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension.
|
|
output_layouts (Placement, optional):
|
|
The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
|
|
with the user desired layout. If not specified, the output tensor is replicated.
|
|
use_local_output (bool, optional):
|
|
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
|
|
Returns:
|
|
A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP(failing)
|
|
>>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel
|
|
>>> ...
|
|
>>> # By default, the input of the "w2" Linear will be annotated to DTensor that shards on the last dim
|
|
>>> # and the output of "w2" will return a replicated :class:`torch.Tensor`.
|
|
>>>
|
|
>>> parallelize_module(
|
|
>>> module=block, # this can be a submodule or module
|
|
>>> ...,
|
|
>>> parallelize_plan={"w2": RowwiseParallel()},
|
|
>>> )
|
|
>>> ...
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
input_layouts: Optional[Placement] = None,
|
|
output_layouts: Optional[Placement] = None,
|
|
use_local_output: bool = True
|
|
):
|
|
super().__init__()
|
|
self.input_layouts = (input_layouts or Shard(-1), )
|
|
self.output_layouts = (output_layouts or Replicate(), )
|
|
# rowwise linear runtime sharding:
|
|
# 1. shard input on last dim
|
|
# 2. partial output, to replicate -> allreduce, to shard -> reduce_scatter
|
|
self.desired_input_layouts = (Shard(-1), )
|
|
self.use_local_output = use_local_output
|
|
|
|
@staticmethod
|
|
def _prepare_input_fn(input_layouts, desired_input_layouts, inputs, device_mesh):
|
|
input_tensor = inputs[0]
|
|
if not isinstance(input_tensor, DTensor):
|
|
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
|
|
|
|
if input_layouts != desired_input_layouts:
|
|
input_tensor = input_tensor.redistribute(placements=desired_input_layouts)
|
|
return input_tensor
|
|
|
|
def _partition_fn(self, name, module, device_mesh):
|
|
if isinstance(module, nn.Linear):
|
|
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
|
|
# means Rowwise as Linear is input * weight^T + bias, where
|
|
# weight would become Shard(0)
|
|
module.register_parameter("weight", nn.Parameter(
|
|
distribute_tensor(module.weight, device_mesh, [Shard(1)])
|
|
))
|
|
if module.bias is not None:
|
|
module.register_parameter("bias", nn.Parameter(
|
|
distribute_tensor(module.bias, device_mesh, [Replicate()])
|
|
))
|
|
else:
|
|
raise NotImplementedError("RowwiseParallel currently only support nn.Linear!")
|
|
|
|
@staticmethod
|
|
def _prepare_output_fn(output_layouts, use_local_output, outputs, device_mesh):
|
|
outputs = outputs.redistribute(placements=output_layouts)
|
|
# back to local tensor if use_local_output is True
|
|
return outputs.to_local() if use_local_output else outputs
|
|
|
|
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
|
return distribute_module(
|
|
module,
|
|
device_mesh,
|
|
self._partition_fn,
|
|
partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
|
|
partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
|
|
)
|
|
|
|
|
|
class PrepareModuleInput(ParallelStyle):
|
|
"""
|
|
Configure the nn.Module's inputs to convert the input tensors of the nn.Module to DTensors at runtime according to
|
|
``input_layouts``, and perform layout redistribution according to the ``desired_input_layouts``.
|
|
|
|
Keyword Args:
|
|
input_layouts (Union[Placement, Tuple[Placement]]):
|
|
The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to
|
|
DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified
|
|
as a placeholder.
|
|
desired_input_layouts (Union[Placement, Tuple[Placement]]):
|
|
The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module
|
|
have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``.
|
|
use_local_output (bool, optional):
|
|
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False.
|
|
Returns:
|
|
A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP(failing)
|
|
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
|
|
>>> ...
|
|
>>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor
|
|
>>> # and then redistributed to Replicated DTensor.
|
|
>>> parallelize_module(
|
|
>>> module=block, # this can be a submodule or module
|
|
>>> ...,
|
|
>>> parallelize_plan={
|
|
>>> "attn": PrepareModuleInput(
|
|
>>> input_layouts=(Shard(0), None, None, ...),
|
|
>>> desired_input_layouts=(Replicate(), None, None, ...)
|
|
>>> ),
|
|
>>> }
|
|
>>> )
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
input_layouts: Union[Placement, Tuple[Placement]],
|
|
desired_input_layouts: Union[Placement, Tuple[Placement]],
|
|
use_local_output: bool = False
|
|
):
|
|
self.input_layouts = (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts
|
|
self.desired_input_layouts = \
|
|
(desired_input_layouts,) if isinstance(desired_input_layouts, Placement) else desired_input_layouts
|
|
self.use_local_output = use_local_output
|
|
assert len(self.input_layouts) == len(self.desired_input_layouts), \
|
|
"input_layouts and desired_input_layouts should have same length!"
|
|
|
|
def _prepare_input_fn(self, inputs, device_mesh):
|
|
prepared_inputs = []
|
|
if not isinstance(inputs, tuple):
|
|
inputs = (inputs,)
|
|
assert len(inputs) == len(self.input_layouts), \
|
|
"module inputs and input_layouts should have same length!"
|
|
for inp, input_layout, desired_layout in zip(inputs, self.input_layouts, self.desired_input_layouts):
|
|
if input_layout is not None:
|
|
if isinstance(inp, DTensor):
|
|
assert inp.placements[0] == input_layout
|
|
dt_inp = inp
|
|
else:
|
|
dt_inp = DTensor.from_local(inp, device_mesh, (input_layout,), run_check=False)
|
|
if input_layout != desired_layout:
|
|
dt_inp = dt_inp.redistribute(placements=(desired_layout,))
|
|
prepared_inputs.append(dt_inp.to_local() if self.use_local_output else dt_inp)
|
|
else:
|
|
prepared_inputs.append(inp)
|
|
return tuple(prepared_inputs)
|
|
|
|
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
|
module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)) # type: ignore[misc, call-arg]
|
|
return module
|
|
|
|
|
|
class PrepareModuleOutput(ParallelStyle):
|
|
"""
|
|
Configure the nn.Module's outputs to convert the output tensors of the nn.Module to DTensors at runtime according to
|
|
``output_layouts``, and perform layout redistribution according to the ``desired_output_layouts``.
|
|
|
|
Keyword Args:
|
|
output_layouts (Union[Placement, Tuple[Placement]]):
|
|
The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to
|
|
DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors,
|
|
``None`` need to be specified as a placeholder.
|
|
desired_output_layouts (Union[Placement, Tuple[Placement]]):
|
|
The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module
|
|
have the desired DTensor layouts.
|
|
use_local_output (bool, optional):
|
|
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: False.
|
|
Returns:
|
|
A ParallelStyle object that prepares the sharding layouts of the nn.Module's outputs.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP(failing)
|
|
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput
|
|
>>> ...
|
|
>>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor
|
|
>>> # and then redistributed to Replicated DTensor.
|
|
>>> parallelize_module(
|
|
>>> module=block, # this can be a submodule or module
|
|
>>> ...,
|
|
>>> parallelize_plan={
|
|
>>> "submodule": PrepareModuleOutput(
|
|
>>> output_layouts=Replicate(),
|
|
>>> desired_output_layouts=Shard(0)
|
|
>>> ),
|
|
>>> }
|
|
>>> )
|
|
"""
|
|
def __init__(
|
|
self,
|
|
*,
|
|
output_layouts: Union[Placement, Tuple[Placement]],
|
|
desired_output_layouts: Union[Placement, Tuple[Placement]],
|
|
use_local_output: bool = True
|
|
):
|
|
self.output_layouts = (output_layouts,) if isinstance(output_layouts, Placement) else output_layouts
|
|
self.desired_output_layouts = \
|
|
(desired_output_layouts,) if isinstance(desired_output_layouts, Placement) else desired_output_layouts
|
|
self.use_local_output = use_local_output
|
|
assert len(self.output_layouts) == len(self.desired_output_layouts), \
|
|
"output_layouts and desired_output_layouts should have same length!"
|
|
|
|
def _prepare_out_fn(self, outputs, device_mesh):
|
|
prepared_outputs = []
|
|
if not isinstance(outputs, tuple):
|
|
outputs = (outputs,)
|
|
assert len(outputs) == len(self.output_layouts), \
|
|
"module outputs and output_layouts should have same length!"
|
|
for out, out_layout, desired_out_layout in zip(outputs, self.output_layouts, self.desired_output_layouts):
|
|
if out_layout is not None:
|
|
if isinstance(out, DTensor):
|
|
assert out.placements[0] == out_layout
|
|
dt_out = out
|
|
else:
|
|
dt_out = DTensor.from_local(out, device_mesh, (out_layout,), run_check=False)
|
|
|
|
if out_layout != desired_out_layout:
|
|
dt_out = dt_out.redistribute(placements=(desired_out_layout,))
|
|
prepared_outputs.append(dt_out.to_local() if self.use_local_output else dt_out)
|
|
else:
|
|
prepared_outputs.append(out)
|
|
if len(prepared_outputs) == 1:
|
|
return prepared_outputs[0]
|
|
else:
|
|
return tuple(prepared_outputs)
|
|
|
|
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
|
module.register_forward_hook(lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)) # type: ignore[misc, call-arg]
|
|
return module
|