Fix pydocstyle errors in fully_sharded_data_parallel.py, api.py, graph_utils.py, distribute.py, iter_graph_module.py, comm_tensor.py, experimental_ops.py, batch_dim_utils.py, data_parallel.py, graph_optimization.py (#113216)

Fixes #113191

```
pydocstyle torch/distributed/fsdp/fully_sharded_data_parallel.py --count
```

On master: 80
After my changes on this PR: 3

```
pydocstyle torch/distributed/_spmd/comm_tensor.py --count
```
On master: 5
After my changes on this PR: 3

```
pydocstyle torch/distributed/_spmd/experimental_ops.py --count
```
On master: 3
After my changes on this PR: 1

```
pydocstyle torch/distributed/_spmd/iter_graph_module.py --count
```
On master: 39
After my changes on this PR: 27

```
pydocstyle torch/distributed/_spmd/graph_utils.py --count
```
On master: 16
After my changes on this PR: 4

```
pydocstyle torch/distributed/_spmd/distribute.py --count
```
On master: 19
After my changes on this PR: 10

```
pydocstyle torch/distributed/_spmd/api.py --count
```
On master: 10
After my changes on this PR: 3

```
pydocstyle torch/distributed/_spmd/batch_dim_utils.py  --count
```
On master: 14
After my changes on this PR: 3

```
pydocstyle torch/distributed/_spmd/data_parallel.py --count
```
On master: 34
After my changes on this PR: 2

```
pydocstyle torch/distributed/_spmd/graph_optimization.py --count
```
On master: 35
After my changes on this PR: 13

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113216
Approved by: https://github.com/ezyang
This commit is contained in:
Adrian Wälchli 2023-11-10 03:08:28 +00:00 committed by PyTorch MergeBot
parent 773b1cbe4f
commit 866457e746
10 changed files with 184 additions and 234 deletions

View File

@ -32,8 +32,8 @@ from torch.nn.utils._named_member_accessor import NamedMemberAccessor
class Override(ABC):
r"""
Override the tracing and transformation behavior of :meth:`~torch.distributed._spmd.compile`.
r"""Override the tracing and transformation behavior of :meth:`~torch.distributed._spmd.compile`.
This is useful when any part of the model is not traceable or if you prefer
to not trace it due to any reason. More specifically, users can implement
:meth:`torch.distributed._spmd.Override.replacement` to replace an original
@ -47,10 +47,10 @@ class Override(ABC):
@abstractmethod
def replacement(self, fqn: str, orig_submodule: torch.nn.Module) -> torch.nn.Module:
r"""
Implement this method to return a new :class:`nn.Module` instance to
replace the ``orig_submodule`` argument in the model. This helps if
``orig_submodule`` is not traceable or should not be traced.
r"""Implement this method to return a new :class:`nn.Module` instance to replace the ``orig_submodule``
argument in the model.
This helps if ``orig_submodule`` is not traceable or should not be traced.
Args:
fqn (str): fully quantified name of the submodule.
@ -58,6 +58,7 @@ class Override(ABC):
Returns:
A new :class:`nn.Module` instance to replace the original one.
"""
pass
@ -83,6 +84,7 @@ class Override(ABC):
Returns:
The :class:`fx.Graph` after transformation.
"""
pass
@ -98,8 +100,7 @@ class _PyTreeCodeGenOutputsOnly(_PyTreeCodeGen):
def _to_caller_flattened_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""Move the responsibility of flattening the input arguments from the
graph module to the caller.
"""Move the responsibility of flattening the input arguments from the graph module to the caller.
Example:
@ -108,6 +109,7 @@ def _to_caller_flattened_graph_module(gm: torch.fx.GraphModule) -> torch.fx.Grap
gm = gm(to_caller_flattened_graph_module)
output = gm(*pytree.flatten(my_struct)[0])
"""
# pyre-ignore[16]
gm._graph._codegen = _PyTreeCodeGenOutputsOnly(
@ -500,9 +502,9 @@ def compile(
gm_transformation: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
parallel_mode: Optional[ParallelMode] = None,
):
r"""
Compile and optimize a callable, which can be a train step within a training
loop. This method will extract :class:`nn.Module` and :class:`torch.optim.Optimizer`
r"""Compile and optimize a callable, which can be a train step within a training loop.
This method will extract :class:`nn.Module` and :class:`torch.optim.Optimizer`
instances from the input arguments and trace operations applied to their
parameters and states.
@ -519,6 +521,7 @@ def compile(
that specifies how to parallelize the callable. Each ParallelMode
would have its own strategy to partition the model and the captured
graph (Default: ``None``)
"""
def inner(func: Callable):

View File

@ -20,9 +20,9 @@ aten = torch.ops.aten
class BatchDimAnalyzer:
"""
This class is used to analyze the batch dimension of each tensor/node in the
graph. We need to know the batch dimension of each tensor/node so that we know
"""This class is used to analyze the batch dimension of each tensor/node in the graph.
We need to know the batch dimension of each tensor/node so that we know
exactly the sharding layout of intermediate tensors.
We possibly should evaluate using symbolic shapes to track the batch dimension.
@ -54,9 +54,7 @@ class BatchDimAnalyzer:
}
def init_batch_dim_size(self, batch_dim_size: int) -> None:
"""
initialize batch dim size base on the first input batch size
"""
"""Initialize batch dim size base on the first input batch size."""
if self.batch_dim_size != -1 and self.batch_dim_size != batch_dim_size:
raise RuntimeError(
f"batch dim size is already initialized! "
@ -74,9 +72,7 @@ class BatchDimAnalyzer:
return self.batch_dim_map[node]
def compute_batch_dim(self, node: fx.Node, full_reduction=False) -> int:
"""
compute the batch dimension for the `node`
"""
"""Compute the batch dimension for the `node`."""
assert self.batch_dim_size != -1, "batch dim size is not initialized!"
if node in self.batch_dim_map:
@ -168,10 +164,7 @@ class BatchDimAnalyzer:
return -2
def compute_act_spec(self, node: fx.Node, mesh: DeviceMesh) -> DTensorSpec:
"""
This function first compute the batch dimension for the current node,
then generate the sharding spec that shards on the batch dimension.
"""
"""Compute the batch dimension for the current node, then generate the sharding spec that shards on the batch dimension."""
node_batch_dim = self.compute_batch_dim(node)
if node_batch_dim == -1:
# indicate this activation is replicated

View File

@ -56,9 +56,9 @@ def _get_tracer() -> Optional[torch.fx.Tracer]:
class CommTensor(torch.Tensor):
r"""
A Tensor subclass to wrap input tensors for collective communications. This
Tensor subclass works for both eager and tracing mode.
A Tensor subclass to wrap input tensors for collective communications.
This Tensor subclass works for both eager and tracing mode.
In eager mode, it will record whether the inplace collective communication
has been launched using this Tensor and remember the corresponding work
handle. If yes, it will explicitly call wait() in the ``__torch_dispatch__``

View File

@ -41,7 +41,8 @@ _spmd_lib_impl.impl("tag_grad", lambda x: x, "CompositeExplicitAutograd")
class DataParallelStyle(Enum):
"""
"""This enum represents the style of the data-parallel operation.
We have three types of Data Parallel style:
1. DEFAULT: the default data parallel style, which is to represent a mixed
replicate and fully shard behavior. For each parameter that is able
@ -64,8 +65,8 @@ class DataParallelStyle(Enum):
class NodeType(Enum):
"""
NodeType is a enum that records the type of the tensors in the graph.
"""NodeType is an enum that records the type of the tensors in the graph.
This is used to determine the data parallel strategy.
"""
@ -77,9 +78,8 @@ class NodeType(Enum):
class DataParallelStrategy(OpStrategy):
"""
DataParallelStrategy is a special case of OpStrategy that only records
the "data parallel style" placement strategy for each fx Node.
"""DataParallelStrategy is a special case of OpStrategy that only records the "data parallel style" placement
strategy for each fx Node.
It takes a list of PlacementStrategy, where each PlacementStrategy describes
one way to distribute the tensor and computation. In the DataParallel case,
@ -113,13 +113,10 @@ class DataParallelStrategy(OpStrategy):
@contextmanager
def gradients_tagging(params: Dict[str, torch.Tensor]):
"""
This is a helper function that tags the gradient of the parameters
with a special tag, so that we can identify them during SPMD expansion.
"""Tag the gradient of the parameters with a special tag, so that we can identify them during SPMD expansion.
It's safe to trace those hooks and we would remove those nodes later.
"""
tagging_hooks = []
try:
for p in params.values():
@ -135,9 +132,7 @@ def gradients_tagging(params: Dict[str, torch.Tensor]):
def _gen_shard_strategy(
mesh: DeviceMesh, shard_dim: int, input_specs: Optional[List[DTensorSpec]] = None
) -> PlacementStrategy:
"""
util function to generate a shard strategy on shard_dim
"""
"""Util function to generate a shard strategy on shard_dim."""
return PlacementStrategy(
output_spec=DTensorSpec(mesh=mesh, placements=(Shard(shard_dim),)),
input_specs=input_specs,
@ -147,9 +142,7 @@ def _gen_shard_strategy(
def _gen_replicate_strategy(
mesh: DeviceMesh, input_specs: Optional[List[DTensorSpec]] = None
) -> PlacementStrategy:
"""
util function to generate a replicate strategy
"""
"""Util function to generate a replicate strategy."""
return PlacementStrategy(
output_spec=DTensorSpec(mesh=mesh, placements=(Replicate(),)),
input_specs=input_specs,
@ -157,9 +150,7 @@ def _gen_replicate_strategy(
def _gen_partial_strategy(mesh: DeviceMesh) -> PlacementStrategy:
"""
util function to generate a partial strategy
"""
"""Util function to generate a partial strategy."""
# NOTE: we use AVG by default, avg reduction is needed depending on
# the loss function, for most loss function it should do
# gradient averaging. There might be certain cases it should
@ -180,10 +171,7 @@ def build_data_parallel_strategies(
mesh: DeviceMesh,
batch_dim: int = 0,
) -> Dict[fx.Node, StrategyType]:
"""
This function loop through the train step graph and build the
data parallel strategy for each fx Node
"""
"""Loop through the train step graph and build the data parallel strategy for each fx Node."""
activation_idx = num_params + num_states
non_compute_ops = [
aten.clone.default,
@ -518,9 +506,7 @@ def mark_data_parallel_shardings(
dp_strategy_map: Dict[fx.Node, StrategyType],
parallel_mode: DataParallelStyle = DataParallelStyle.FULLY_SHARD,
) -> None:
"""
This function marks the sharding for the nodes in the train_step_graph
"""
"""Mark the sharding for the nodes in the train_step_graph."""
activation_idx = num_parameters + num_states
placeholder_idx = 0
for node in train_step_graph.graph.nodes:
@ -601,9 +587,7 @@ def mark_data_parallel_shardings(
def _partition_val(val: Any, spec: DTensorSpec) -> Any:
"""
util function to convert a full tensor val to its local component
"""
"""Util function to convert a full tensor val to its local component."""
if isinstance(val, torch.Tensor):
local_shard = val
if val.ndim == 0:
@ -629,10 +613,7 @@ def _partition_val(val: Any, spec: DTensorSpec) -> Any:
def partitioner(graph: GraphModule) -> GraphModule:
"""
Graph partitioner that partitions the single device graph
to distributed graph
"""
"""Graph partitioner that partitions the single device graph to distributed graph."""
shape_adjustment_ops = {
aten._unsafe_view.default: 1,
aten.expand.default: 1,
@ -761,10 +742,9 @@ def partition_data_parallel(
parallel_style: DataParallelStyle,
input_batch_dim: int,
) -> GraphModule:
"""
The entry point function to partition the graph to data parallel
graph, it also shard/replicate the model parameters and optimizer
states to DTensors.
"""Partition the graph to into a data parallel graph.
This function also shards/replicates the model parameters and optimizer states to DTensors.
"""
num_params_buffers = len(params_buffers)
flattened_states = pytree.tree_leaves(named_states)

View File

@ -48,9 +48,9 @@ class Schema:
@dataclass
class DSymInt:
"""
DSymInt represents a value retrieved by a SymInt op from a DTensor. DSymInt
helps View and Factory ops to determine the placement and shape of the
"""DSymInt represents a value retrieved by a SymInt op from a DTensor.
DSymInt helps View and Factory ops to determine the placement and shape of the
output tensor, as those operators either do not have an input DTensor or
the input DTensor is insufficient to determine the output tensor's placement.
"""
@ -90,7 +90,7 @@ class DSymInt:
def _is_partial_dtensor(obj: Any) -> bool:
"""check if object is 1) DTensor and 2) with any placement of _Partial"""
"""Check if object is 1) DTensor and 2) with any placement of _Partial."""
if not isinstance(obj, DTensor):
return False
@ -475,8 +475,8 @@ def _get_dtensor_dispatch_graph(
def _build_dummy_add_graph(
dt: DTensor, node_to_obj: Dict[fx.Node, Any]
) -> Tuple[fx.GraphModule, Any]:
"""
Creates a graph for a dummy add function from a partial DTensor.
"""Create a graph for a dummy add function from a partial DTensor.
This dummy add is used for triggering all_reduce on a Partial DTensor
during the DTensor expansion of the traced graph.
Also returns the actual DTensor after resharding.
@ -703,10 +703,12 @@ def _convert_to_distributed(
default_mesh: Optional[DeviceMesh] = None,
_allow_partial: bool = False,
) -> Tuple[fx.GraphModule, Dict[str, Schema]]:
"""
"""Transform a graph module to a distributed graph module.
Returns:
- transformed graph module
- map from output name to DTensorSpec
"""
global logger
logger = get_logger("spmd_exp")

View File

@ -342,10 +342,7 @@ def _prop_native_layer_norm_backward(op_schema: OpSchema) -> OutputSharding:
def _refine_sharding(
op_schema: OpSchema, active_dim: Optional[int]
) -> Sequence[Placement]:
"""
Considers 2 first inputs of op_schema as having same shape,
and returns suggested placement for a pointwise operation.
"""
"""Considers 2 first inputs of op_schema as having same shape, and returns suggested placement for a pointwise operation."""
# consider the operating dimension as a singleton to prevent sharding on it
# however, if active_dim is None, this means the input and output shapes are equal and
# we'll apply exactly the pointwise rule.

View File

@ -64,10 +64,9 @@ def graph_optimization_pass(
prerequisites: Iterable[Callable],
apply_after: Iterable[Callable],
) -> Callable:
"""
The contract of graph optimization pass. All the passes should be wrapped
with this decorator.
"""Define the contract of a graph optimization pass.
All the passes should be wrapped with this decorator.
`prerequisites` is used to annotate the prerequisite passes of the this pass.
`apply_after` means that this wrapped pass must be applied after the passes
in `apply_after`. The difference between `prerequisites` and `apply_after`
@ -158,12 +157,11 @@ class CommBlock:
def get_comm_block(comm_node: fx.Node) -> CommBlock:
"""
Given a collective node (e.g., allreduce), find out all the nodes belong to
this communcation.
"""Find out all the nodes belong to this communcation given a collective node (e.g., allreduce).
Args:
comm_node(fx.Node): The target communication/collective node.
Returns:
The CommBlock that encapsulates the related nodes (e.g., wait_node) of
the given comm_node.
@ -306,10 +304,7 @@ def _scatter_wait_result(
comm_blocks: List[CommBlock],
node_indices: Dict[fx.Node, int],
) -> None:
"""
Scatters the result of the fused communication node to the original users --
splitting the output and reshape each subitem.
"""
"""Scatter the result of the fused communication node to the original users -- splitting the output and reshape each subitem."""
last_wait_node_idx = 0
for node in gm.graph.nodes:
if node == fused_comm_block.comm_node:
@ -371,9 +366,7 @@ def _fuse_with_cat(
comm_blocks: List[CommBlock],
node_indices: Dict[fx.Node, int],
) -> CommBlock:
"""
Given a list of CommBlock (only allreduce), fuse the CommBlocks using concat.
"""
"""Fuse the CommBlocks using concat given a list of CommBlock (only allreduce)."""
# Find the last input node.
last_input_node = comm_blocks[0].inputs[0]
last_input_index = -1
@ -474,8 +467,8 @@ def comm_fusion_with_concat(
gm: IterGraphModule,
bucket_size_mb: int,
) -> None:
"""
Run fuse communication with concat.
"""Run fuse communication with concat.
This implementation uses concat to concat the bucketed gradients.
"""
comm_blocks = get_all_comm_blocks(gm, (CommType.ALLREDUCE, "all_reduce"))
@ -508,9 +501,7 @@ def comm_fusion_with_concat(
apply_after=[],
)
def schedule_comm_wait(gm: IterGraphModule) -> None:
"""
Delay the execution of wait tensors of allreduce until its first user.
"""
"""Delay the execution of wait tensors of allreduce until its first user."""
comm_blocks = get_all_comm_blocks(gm, (CommType.ALLREDUCE, "all_reduce"))
# Find all the end users.
@ -549,8 +540,8 @@ def schedule_comm_wait(gm: IterGraphModule) -> None:
apply_after=[],
)
def remove_copy_from_optimizer(gm: IterGraphModule) -> None:
"""
Erase the orphant copy_ that generated when tracing optimizer.
"""Erase the orphant copy_ that generated when tracing optimizer.
Two reasons why we could not simply use the DCE of fx.Graph.
1. fx.Graph treats copy_ as a side-effect node and does not erase it.
2. Users may want to preserve some orphan `copy_` that is not from the
@ -742,9 +733,7 @@ class FusedOptimizerBlock:
def get_fused_optimizer_block(optim_node: fx.Node) -> FusedOptimizerBlock:
"""
Given a fused optimizer node and return the FusedOptimizerBlock.
"""
"""Given a fused optimizer node and return the FusedOptimizerBlock."""
MAX_STEP_DISTANCE = 5
# Find the step (foreach_add)
nodes = collections.deque([optim_node, None])
@ -783,10 +772,7 @@ def get_fused_optimizer_block(optim_node: fx.Node) -> FusedOptimizerBlock:
def get_all_fused_optimizer_blocks(
gm: IterGraphModule, optim_ops: Union[Tuple[str, ...], str]
) -> List[FusedOptimizerBlock]:
"""
Find all the FusedOptimizerBlock that the optimizer operators are in
`optim_ops`.
"""
"""Find all the FusedOptimizerBlock that the optimizer operators are in `optim_ops`."""
return [
get_fused_optimizer_block(node)
for node in gm.graph.nodes
@ -799,9 +785,9 @@ def _split_fused_adam(
orig_optim_block: FusedOptimizerBlock,
split_gradients: Set[fx.Node],
) -> Tuple[FusedOptimizerBlock, FusedOptimizerBlock]:
"""
Split the `orig_optim_block` into two FusedOptimizerBlock. The first one
will be the optimizer that optimize `split_gradients`. The second one is
"""Split the `orig_optim_block` into two FusedOptimizerBlock.
The first one will be the optimizer that optimize `split_gradients`. The second one is
used to optimize the remaining gradients.
An assert will be raised if one of the optimizer optimize zero gradients.
"""
@ -949,7 +935,8 @@ def iter_move_grads_and_optimizers(
target_comm_node: str,
target_dest_node: str,
) -> None:
"""Function to extract a comm block and split out a new optimizer and step for it.
"""Extract a comm block and split out a new optimizer and step for it.
This subgraph is then moved to the forward graph.
"""
for comm_block in get_all_comm_blocks(gm, "all_reduce"):
@ -982,8 +969,7 @@ def find_all_descendants(
gm: IterGraphModule,
parent_nodes: List[fx.Node],
) -> List[fx.Node]:
"""identifying list of nodes to move during FX graph transformation"""
"""Identify the list of nodes to move during FX graph transformation."""
assert len(parent_nodes) > 0, "No parent nodes are given."
output = get_output(gm.graph)

View File

@ -41,9 +41,9 @@ def get_node_tensor_metadata(node: fx.Node, is_required: bool = True) -> TensorM
def get_output(graph: fx.Graph) -> fx.Node:
"""
Take a graphmodule and returns the graph output node. We traverse in reverse
to expedite it, with the idea that last node should be output
"""Take a graphmodule and return the graph output node.
We traverse in reverse to expedite it, with the idea that last node should be output
"""
for node in reversed(graph.nodes):
if node.op == OP.OUTPUT:
@ -54,10 +54,7 @@ def get_output(graph: fx.Graph) -> fx.Node:
def find_node(
graph: fx.Graph, predicate: Callable, reverse_order: bool = False
) -> List[fx.Node]:
"""
Take a predicate and return all the nodes in the `graph` where the predicate
holds.
"""
"""Take a predicate and return all the nodes in the `graph` where the predicate holds."""
nodes = cast(Iterable[fx.Node], graph.nodes)
if reverse_order:
nodes = cast(Iterable[fx.Node], iter(reversed(nodes))) # type: ignore[call-overload]
@ -65,8 +62,8 @@ def find_node(
def is_leaf_subgraph(graph: fx.Graph, subgraph: List[fx.Node]) -> bool:
"""
This function ensures nodes in ``subgraph`` satisfy one of the rules:
"""Ensure nodes in ``subgraph`` satisfy one of the following rules.
1. The user of the node is in ``subgraph``.
2. The user of the node is output.
3. There are no users -- the node is a side-effect node.
@ -85,11 +82,10 @@ def is_leaf_subgraph(graph: fx.Graph, subgraph: List[fx.Node]) -> bool:
def clone_subgraph(
graph: fx.Graph, subgraph: List[fx.Node], target: fx.Node
) -> List[fx.Node]:
"""
Clone the given subgraph and insert it before ``target``.
"""Clone the given subgraph and insert it before ``target``.
This API currently does not support inserting after ``target``.
"""
all_nodes = set(subgraph)
mapping: Dict[fx.Node, fx.Node] = dict()
cloned_subgraph = []
@ -125,12 +121,11 @@ def clone_subgraph(
def rebuild_graph(gm: fx.GraphModule, remove_dead_code: bool = True) -> None:
"""
Runs the required steps to ensure production-ready graph.
note - per the fx docs, eliminate dead code is not very precise.
"""Run the required steps to ensure production-ready graph.
Note - per the fx docs, elimination of dead code is not very precise.
Hence, the flag to make this step optional.
"""
gm.graph.lint()
if remove_dead_code:
gm.graph.eliminate_dead_code()

View File

@ -22,9 +22,9 @@ logger: logging.Logger = logging.getLogger("IterGraphModule")
class IterGraph(fx.Graph):
"""
``IterGraph`` is used to perform cross-iteration optimization. ``IterGraph``
keeps track of the 3 graphs, self (the original graph), setup graph, and
"""``IterGraph`` is used to perform cross-iteration optimization.
``IterGraph`` keeps track of the 3 graphs, self (the original graph), setup graph, and
cleanup graph. The 3 graphs should be identical copies of a ``fx.Graph``.
IterGraph subclass fx.Graph to override the necessary APIs that will be used
@ -127,11 +127,10 @@ class IterGraph(fx.Graph):
def _forward_subgraph_inputs(
self, subgraph: List[fx.Node], graph: fx.Graph, erase_node: bool
) -> int:
"""
This function turns the inputs of a subgraph into the extra output
of the entire graph. If ``erase_node`` is True, the subgraph will be
erased from the graph -- essentially forward the inputs of the subgraph
to the output of the graph.
"""Turn the inputs of a subgraph into the extra output of the entire graph.
If ``erase_node`` is True, the subgraph will be erased from the graph -- essentially forward the inputs
of the subgraph to the output of the graph.
"""
output = get_output(graph)
inputs = []
@ -219,10 +218,10 @@ class IterGraph(fx.Graph):
def _forward_inputs_to_subgraph(
self, subgraph: List[fx.Node], graph: fx.Graph, extra_input: int
) -> None:
"""
This function creates extra input nodes and forward the input nodes to
the ``subgraph``. The external input nodes of ``subgraph`` (nodes that
are not in ``subgraph``) will replaced by the newly created input nodes.
"""Create extra input nodes and forward the input nodes to the ``subgraph``.
The external input nodes of ``subgraph`` (nodes that are not in ``subgraph``) will replaced by the newly
created input nodes.
"""
placeholders = [node for node in graph.nodes if str(node.op) == "placeholder"]
assert placeholders, "No placeholders are found"
@ -275,8 +274,8 @@ class IterGraph(fx.Graph):
def move_to_next_iter_before(
self, subgraph: List[fx.Node], target_node: fx.Node
) -> None:
"""
Move the ``subgraph`` to the next iteration before ``target_node``.
"""Move the ``subgraph`` to the next iteration before ``target_node``.
The ``subgraph`` is a list of fx.Node and must satisfy the following
restrictions:
1. The order of the nodes in ``subgraph`` must obey the topological
@ -633,8 +632,8 @@ class IterGraph(fx.Graph):
class IterGraphModule(nn.Module):
"""
``IterGraphModule`` provides the ability to do cross-iteration optimization.
"""``IterGraphModule`` provides the ability to do cross-iteration optimization.
Given a ``fx.GraphModule``, main_gm, ``IterGraphModule`` internally
duplicate it to 3 copies and redirect the ``forward`` request to a different
``fx.GraphModule`` based on the iteration count. This allows users to do
@ -674,10 +673,9 @@ class IterGraphModule(nn.Module):
self._enable_inductor = enable_inductor
def finalize_setup(self) -> None:
"""
Must be called before the forward() is called. This method setups
the internal states and also get the signal from users that what
is the maximum iteration count.
"""Set up the internal states and also get the signal from users that what is the maximum iteration count.
This method must be called before the forward() is called.
"""
if not self._is_frozen:
self.graph.freeze_cross_iter_movement()

View File

@ -112,14 +112,16 @@ FLAT_PARAM = "_flat_param"
class OptimStateKeyType(Enum):
"""Represents the type of key in an optimizer state-dict."""
PARAM_NAME = auto()
PARAM_ID = auto()
class FullyShardedDataParallel(nn.Module, _FSDPState):
"""
A wrapper for sharding module parameters across data parallel workers. This
is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
"""A wrapper for sharding module parameters across data parallel workers.
This is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
FullyShardedDataParallel is commonly shortened to FSDP.
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336
@ -515,9 +517,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
@property
def module(self) -> nn.Module:
"""
Returns the wrapped module (like :class:`DistributedDataParallel`).
"""
"""Return the wrapped module."""
# FSDP's `.module` must refer to the innermost wrapped module when
# composing with other module wrappers in order for state dict to work
if isinstance(self._fsdp_wrapped_module, ActivationWrapper):
@ -547,6 +547,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
return super().__getitem__(key)
def check_is_root(self) -> bool:
"""Check if this instance is a root FSDP module."""
return _is_fsdp_root(self, self)
@staticmethod
@ -554,9 +555,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
module: nn.Module,
root_only: bool = False,
) -> List["FullyShardedDataParallel"]:
"""
Returns all nested FSDP instances, possibly including ``module`` itself
and only including FSDP root modules if ``root_only=True``.
"""Return all nested FSDP instances.
This possibly includes ``module`` itself and only includes FSDP root modules if ``root_only=True``.
Args:
module (torch.nn.Module): Root module, which may or may not be an
@ -573,9 +574,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
return traversal_utils._get_fsdp_states(module)
def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``)
as well as self. Typical use includes initializing the parameters of a model
(see also :ref:`nn-init-doc`).
r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`).
Compared to ``torch.nn.Module.apply``, this version additionally gathers
the full parameters before applying ``fn``. It should not be called from
@ -614,8 +615,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
return ret
def _mixed_precision_enabled_for_buffers(self) -> bool:
"""
Returns if the user explicitly enabled buffer mixed precision.
"""Return whether the user explicitly enabled buffer mixed precision.
NOTE: Unlike parameters and gradient reduction, buffer mixed precision
is applied at the FSDP instance level, not the ``FlatParameter`` level,
@ -624,15 +624,11 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
return self.mixed_precision.buffer_dtype is not None
def _low_precision_hook_enabled(self) -> bool:
"""
Whether a low precision hook is registered or not.
"""
"""Whether a low precision hook is registered or not."""
return self._comm_hook is not None and self._comm_hook in LOW_PRECISION_HOOKS
def _reset_lazy_init(self) -> None:
"""
Reset instance so :func:`_lazy_init` will run on the next forward.
"""
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
self._is_root: Optional[bool] = None
@staticmethod
@ -642,9 +638,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
state_dict_config: Optional[StateDictConfig] = None,
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
) -> StateDictSettings:
"""
Set the ``state_dict_type`` and the corresponding (optional)
configurations of all the descendant FSDP modules of the target module.
"""Set the ``state_dict_type`` of all the descendant FSDP modules of the target module.
Also takes (optional) configuration for the model's and optimizer's state dict.
The target module does not have to be a FSDP module. If the target
module is a FSDP module, its ``state_dict_type`` will also be changed.
@ -747,10 +743,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
@staticmethod
def get_state_dict_type(module: nn.Module) -> StateDictSettings:
"""
Get the state_dict_type and the corresponding configurations
for the FSDP modules rooted at ``module``. The target module
does not have to be an FSDP module.
"""Get the state_dict_type and the corresponding configurations for the FSDP modules rooted at ``module``.
The target module does not have to be an FSDP module.
Returns:
A ``StateDictSettings`` containing the state_dict_type and
@ -797,10 +792,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
state_dict_config: Optional[StateDictConfig] = None,
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
) -> Generator:
"""
A context manager to set the ``state_dict_type`` of all the descendant
FSDP modules of the target module. This context manager has the same
functions as :meth:`set_state_dict_type`. Read the document of
"""Set the ``state_dict_type`` of all the descendant FSDP modules of the target module.
This context manager has the same functions as :meth:`set_state_dict_type`. Read the document of
:meth:`set_state_dict_type` for the detail.
Example::
@ -836,10 +830,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
)
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""
Runs the forward pass for the wrapped module, inserting FSDP-specific
pre- and post-forward sharding logic.
"""
"""Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic."""
handle = self._handle
with torch.autograd.profiler.record_function(
"FullyShardedDataParallel.forward"
@ -875,7 +866,8 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
offload_to_cpu: bool = False,
with_grads: bool = False,
) -> Generator:
r"""A context manager to expose full params for FSDP instances.
r"""Expose full params for FSDP instances with this context manager.
Can be useful *after* forward/backward for a model to get
the params for additional processing or checking. It can take a non-FSDP
module and will summon full params for all contained FSDP modules as
@ -941,9 +933,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
@contextlib.contextmanager
def _deregister_orig_params_ctx(self):
"""
This deregisters the original parameters and exposes the
:class:`FlatParameter` s. If a :class:`FlatParameter` is sharded, then
"""Deregister the original parameters and expose the :class:`FlatParameter`.
If a :class:`FlatParameter` is sharded, then
this refreshes the sharded views before exiting. This method should
only be called when using the original parameters.
"""
@ -961,11 +953,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
_register_orig_params(fsdp_module, fsdp_module)
def _apply(self, *args, **kwargs):
"""
When using the original parameters, this deregisters the original
parameters and exposes the :class:`FlatParameter` s before calling
``_apply()``.
"""
"""Deregister the original parameters and expose the :class:`FlatParameter` s before calling ``_apply()``."""
# When using the original parameters: Since (1) the `FlatParameter`s
# own the storage and (2) `_apply()` is the subroutine underlying the
# most common storage-changing ops like `to()` and `cuda()`, we
@ -985,9 +973,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
*args,
**kwargs,
) -> Iterator[Tuple[str, torch.Tensor]]:
"""
Overrides :meth:`named_buffers()` to intercept buffer names and
remove all occurrences of the FSDP-specific flattened buffer prefix
"""Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself.
Intercepts buffer names and removes all occurrences of the FSDP-specific flattened buffer prefix
when inside the :meth:`summon_full_params` context manager.
"""
should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
@ -1003,9 +991,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
*args,
**kwargs,
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
"""
Overrides :meth:`named_parameters()` to intercept parameter names and
remove all occurrences of the FSDP-specific flattened parameter prefix
"""Return an iterator over module parameters, yielding both the name of the parameter and the parameter itself.
Intercepts parameter names and removes all occurrences of the FSDP-specific flattened parameter prefix
when inside the :meth:`summon_full_params` context manager.
"""
should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
@ -1038,9 +1026,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
@contextmanager
def no_sync(self) -> Generator:
"""
A context manager to disable gradient synchronizations across FSDP
instances. Within this context, gradients will be accumulated in module
"""Disable gradient synchronizations across FSDP instances.
Within this context, gradients will be accumulated in module
variables, which will later be synchronized in the first
forward-backward pass after exiting the context. This should only be
used on the root FSDP instance and will recursively apply to all
@ -1079,9 +1067,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
def clip_grad_norm_(
self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0
) -> torch.Tensor:
"""
Clips the gradient norm of all parameters. The norm is computed over
all parameters' gradients as viewed as a single vector, and the
"""Clip the gradient norm of all parameters.
The norm is computed over all parameters' gradients as viewed as a single vector, and the
gradients are modified in-place.
Args:
@ -1245,8 +1233,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
group: Optional[dist.ProcessGroup] = None,
cpu_offload: bool = True,
) -> Dict[str, Any]:
"""
The internal API that is used by all the optim_state_dict implementations.
"""Transform the state-dict of an optimizer corresponding to a sharded model.
This is the internal API that is used by all the optim_state_dict implementations.
Given model, optim, the original optim_state_dict, this API removes the
FSDP internal information and internal sharding from the optim_state_dict.
"""
@ -1298,7 +1287,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
"""
The internal API that is used by all the load optim_state_dict implementations.
Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
This is the internal API that is used by all the load optim_state_dict implementations.
Given model, optim, and the saved optim_state_dict, this API adds the FSDP
internal information and internal sharding to the optim_state_dict.
"""
@ -1352,7 +1343,8 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
rank0_only: bool = True,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
"""
"""Return the full optimizer state-dict.
Consolidates the full optimizer state on rank 0 and returns it
as a :class:`dict` following the convention of
:meth:`torch.optim.Optimizer.state_dict`, i.e. with keys ``"state"``
@ -1417,7 +1409,8 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
optim: torch.optim.Optimizer,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
"""
"""Return the optimizer state-dict in its sharded form.
The API is similar to :meth:`full_optim_state_dict` but this API chunks
all non-zero-dimension states to :class:`ShardedTensor` to save memory.
This API should only be used when the model ``state_dict`` is derived
@ -1453,12 +1446,11 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
] = None,
optim: Optional[torch.optim.Optimizer] = None,
) -> Dict[str, Any]:
"""
Shards the full optimizer state dict ``full_optim_state_dict`` by
remapping the state to flattened parameters instead of unflattened
parameters and restricting to only this rank's part of the optimizer
state. The first argument should be the return value of
:meth:`full_optim_state_dict`.
"""Shard a full optimizer state-dict.
Remaps the state in ``full_optim_state_dict`` to flattened parameters instead of unflattened
parameters and restricts to only this rank's part of the optimizer state.
The first argument should be the return value of :meth:`full_optim_state_dict`.
Example::
@ -1525,7 +1517,8 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
model: torch.nn.Module,
optim: torch.optim.Optimizer,
) -> Dict[str, Any]:
"""
"""Flatten a sharded optimizer state-dict.
The API is similar to :meth:`shard_full_optim_state_dict`. The only
difference is that the input ``sharded_optim_state_dict`` should be
returned from :meth:`sharded_optim_state_dict`. Therefore, there will
@ -1568,10 +1561,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
optim: Optional[torch.optim.Optimizer] = None,
group: Optional[Any] = None,
) -> Dict[str, Any]:
"""
Scatters the full optimizer state dict from rank 0 to all other ranks,
returning the sharded optimizer state dict on each rank. The return
value is the same as :meth:`shard_full_optim_state_dict`, and on rank
"""Scatter the full optimizer state dict from rank 0 to all other ranks.
Returns the sharded optimizer state dict on each rank.
The return value is the same as :meth:`shard_full_optim_state_dict`, and on rank
0, the first argument should be the return value of
:meth:`full_optim_state_dict`.
@ -1650,10 +1643,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
] = None,
optim: Optional[torch.optim.Optimizer] = None,
) -> Dict[str, Any]:
"""
Re-keys the optimizer state dict ``optim_state_dict`` to use the key
type ``optim_state_key_type``. This can be used to achieve
compatibility between optimizer state dicts from models with FSDP
"""Re-keys the optimizer state dict ``optim_state_dict`` to use the key type ``optim_state_key_type``.
This can be used to achieve compatibility between optimizer state dicts from models with FSDP
instances and ones without.
To re-key an FSDP full optimizer state dict (i.e. from
@ -1771,9 +1763,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
"""
Transforms the state_dict of ``optim`` for the ``model`` that is sharded
by FSDP to one of the three types: 1) full optimizer state_dict, 2)
sharded optimizer state_dict, 3) local optimizer state_dict.
Transform the state-dict of an optimizer corresponding to a sharded model.
The given state-dict can be transformed to one of three types:
1) full optimizer state_dict, 2) sharded optimizer state_dict, 3) local optimizer state_dict.
For full optimizer state_dict, all states are unflattened and not sharded.
Rank0 only and CPU only can be specified via :meth:`state_dict_type` to
@ -1867,8 +1860,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
"""
Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
Given a ``optim_state_dict`` that is transformed through
:meth:`optim_state_dict`, converts it to the flattened optimizer
:meth:`optim_state_dict`, it gets converted to the flattened optimizer
state_dict that can be loaded to ``optim`` which is the optimizer for
``model``. ``model`` must be sharded by FullyShardedDataParallel.
@ -1946,10 +1941,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
return result
def register_comm_hook(self, state: object, hook: callable):
"""
Registers a communication hook which is an enhancement that provides a
flexible hook to users where they can specify how FSDP aggregates gradients
across multiple workers.
"""Register a communication hook.
This is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates
gradients across multiple workers.
This hook can be used to implement several algorithms like
`GossipGrad <https://arxiv.org/abs/1803.05880>`_ and gradient compression
which involve different communication strategies for
@ -2008,9 +2003,9 @@ def _get_grad_norm(
norm_type: float,
) -> torch.Tensor:
"""
Returns the gradient norm of parameters ``param`` s, where the gradients
are viewed as a single vector. The returned norm is in FP32 even if
parameters/gradients are in a low precision. This is because the downstream
Return the gradient norm of parameters ``param`` s, where the gradients are viewed as a single vector.
The returned norm is in FP32 even if parameters/gradients are in a low precision. This is because the downstream
use of this return value is a reduction across ranks.
"""
params_with_grad = [param for param in params if param.grad is not None]
@ -2041,8 +2036,9 @@ def _get_param_to_fqn(
model: torch.nn.Module,
) -> Dict[torch.nn.Parameter, str]:
"""
Constructs a mapping from parameters to their parameter names. ``model``
should not contain any :class:`FullyShardedDataParallel` instances, which
Construct a mapping from parameters to their parameter names.
The ``model`` should not contain any :class:`FullyShardedDataParallel` instances, which
means that none of the parameters should be ``FlatParameter`` s. As a
result, compared to :meth:`_get_param_to_fqns`, the mapped
values may be flattened from singleton :class:`list` s to the contained
@ -2071,6 +2067,6 @@ def _get_param_to_fqn(
def _get_fqn_to_param(
model: torch.nn.Module,
) -> Dict[str, torch.nn.Parameter]:
"""Constructs the inverse mapping of :meth:`_get_param_to_fqn`."""
"""Construct the inverse mapping of :meth:`_get_param_to_fqn`."""
param_to_param_name = _get_param_to_fqn(model)
return dict(zip(param_to_param_name.values(), param_to_param_name.keys()))