Fix pydocstyle errors in fully_sharded_data_parallel.py, api.py, graph_utils.py, distribute.py, iter_graph_module.py, comm_tensor.py, experimental_ops.py, batch_dim_utils.py, data_parallel.py, graph_optimization.py (#113216)

Fixes #113191

```
pydocstyle torch/distributed/fsdp/fully_sharded_data_parallel.py --count
```

On master: 80
After my changes on this PR: 3

```
pydocstyle torch/distributed/_spmd/comm_tensor.py --count
```
On master: 5
After my changes on this PR: 3

```
pydocstyle torch/distributed/_spmd/experimental_ops.py --count
```
On master: 3
After my changes on this PR: 1

```
pydocstyle torch/distributed/_spmd/iter_graph_module.py --count
```
On master: 39
After my changes on this PR: 27

```
pydocstyle torch/distributed/_spmd/graph_utils.py --count
```
On master: 16
After my changes on this PR: 4

```
pydocstyle torch/distributed/_spmd/distribute.py --count
```
On master: 19
After my changes on this PR: 10

```
pydocstyle torch/distributed/_spmd/api.py --count
```
On master: 10
After my changes on this PR: 3

```
pydocstyle torch/distributed/_spmd/batch_dim_utils.py  --count
```
On master: 14
After my changes on this PR: 3

```
pydocstyle torch/distributed/_spmd/data_parallel.py --count
```
On master: 34
After my changes on this PR: 2

```
pydocstyle torch/distributed/_spmd/graph_optimization.py --count
```
On master: 35
After my changes on this PR: 13

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113216
Approved by: https://github.com/ezyang
This commit is contained in:
Adrian Wälchli 2023-11-10 03:08:28 +00:00 committed by PyTorch MergeBot
parent 773b1cbe4f
commit 866457e746
10 changed files with 184 additions and 234 deletions

View File

@ -32,8 +32,8 @@ from torch.nn.utils._named_member_accessor import NamedMemberAccessor
class Override(ABC): class Override(ABC):
r""" r"""Override the tracing and transformation behavior of :meth:`~torch.distributed._spmd.compile`.
Override the tracing and transformation behavior of :meth:`~torch.distributed._spmd.compile`.
This is useful when any part of the model is not traceable or if you prefer This is useful when any part of the model is not traceable or if you prefer
to not trace it due to any reason. More specifically, users can implement to not trace it due to any reason. More specifically, users can implement
:meth:`torch.distributed._spmd.Override.replacement` to replace an original :meth:`torch.distributed._spmd.Override.replacement` to replace an original
@ -47,10 +47,10 @@ class Override(ABC):
@abstractmethod @abstractmethod
def replacement(self, fqn: str, orig_submodule: torch.nn.Module) -> torch.nn.Module: def replacement(self, fqn: str, orig_submodule: torch.nn.Module) -> torch.nn.Module:
r""" r"""Implement this method to return a new :class:`nn.Module` instance to replace the ``orig_submodule``
Implement this method to return a new :class:`nn.Module` instance to argument in the model.
replace the ``orig_submodule`` argument in the model. This helps if
``orig_submodule`` is not traceable or should not be traced. This helps if ``orig_submodule`` is not traceable or should not be traced.
Args: Args:
fqn (str): fully quantified name of the submodule. fqn (str): fully quantified name of the submodule.
@ -58,6 +58,7 @@ class Override(ABC):
Returns: Returns:
A new :class:`nn.Module` instance to replace the original one. A new :class:`nn.Module` instance to replace the original one.
""" """
pass pass
@ -83,6 +84,7 @@ class Override(ABC):
Returns: Returns:
The :class:`fx.Graph` after transformation. The :class:`fx.Graph` after transformation.
""" """
pass pass
@ -98,8 +100,7 @@ class _PyTreeCodeGenOutputsOnly(_PyTreeCodeGen):
def _to_caller_flattened_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: def _to_caller_flattened_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""Move the responsibility of flattening the input arguments from the """Move the responsibility of flattening the input arguments from the graph module to the caller.
graph module to the caller.
Example: Example:
@ -108,6 +109,7 @@ def _to_caller_flattened_graph_module(gm: torch.fx.GraphModule) -> torch.fx.Grap
gm = gm(to_caller_flattened_graph_module) gm = gm(to_caller_flattened_graph_module)
output = gm(*pytree.flatten(my_struct)[0]) output = gm(*pytree.flatten(my_struct)[0])
""" """
# pyre-ignore[16] # pyre-ignore[16]
gm._graph._codegen = _PyTreeCodeGenOutputsOnly( gm._graph._codegen = _PyTreeCodeGenOutputsOnly(
@ -500,9 +502,9 @@ def compile(
gm_transformation: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, gm_transformation: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
parallel_mode: Optional[ParallelMode] = None, parallel_mode: Optional[ParallelMode] = None,
): ):
r""" r"""Compile and optimize a callable, which can be a train step within a training loop.
Compile and optimize a callable, which can be a train step within a training
loop. This method will extract :class:`nn.Module` and :class:`torch.optim.Optimizer` This method will extract :class:`nn.Module` and :class:`torch.optim.Optimizer`
instances from the input arguments and trace operations applied to their instances from the input arguments and trace operations applied to their
parameters and states. parameters and states.
@ -519,6 +521,7 @@ def compile(
that specifies how to parallelize the callable. Each ParallelMode that specifies how to parallelize the callable. Each ParallelMode
would have its own strategy to partition the model and the captured would have its own strategy to partition the model and the captured
graph (Default: ``None``) graph (Default: ``None``)
""" """
def inner(func: Callable): def inner(func: Callable):

View File

@ -20,9 +20,9 @@ aten = torch.ops.aten
class BatchDimAnalyzer: class BatchDimAnalyzer:
""" """This class is used to analyze the batch dimension of each tensor/node in the graph.
This class is used to analyze the batch dimension of each tensor/node in the
graph. We need to know the batch dimension of each tensor/node so that we know We need to know the batch dimension of each tensor/node so that we know
exactly the sharding layout of intermediate tensors. exactly the sharding layout of intermediate tensors.
We possibly should evaluate using symbolic shapes to track the batch dimension. We possibly should evaluate using symbolic shapes to track the batch dimension.
@ -54,9 +54,7 @@ class BatchDimAnalyzer:
} }
def init_batch_dim_size(self, batch_dim_size: int) -> None: def init_batch_dim_size(self, batch_dim_size: int) -> None:
""" """Initialize batch dim size base on the first input batch size."""
initialize batch dim size base on the first input batch size
"""
if self.batch_dim_size != -1 and self.batch_dim_size != batch_dim_size: if self.batch_dim_size != -1 and self.batch_dim_size != batch_dim_size:
raise RuntimeError( raise RuntimeError(
f"batch dim size is already initialized! " f"batch dim size is already initialized! "
@ -74,9 +72,7 @@ class BatchDimAnalyzer:
return self.batch_dim_map[node] return self.batch_dim_map[node]
def compute_batch_dim(self, node: fx.Node, full_reduction=False) -> int: def compute_batch_dim(self, node: fx.Node, full_reduction=False) -> int:
""" """Compute the batch dimension for the `node`."""
compute the batch dimension for the `node`
"""
assert self.batch_dim_size != -1, "batch dim size is not initialized!" assert self.batch_dim_size != -1, "batch dim size is not initialized!"
if node in self.batch_dim_map: if node in self.batch_dim_map:
@ -168,10 +164,7 @@ class BatchDimAnalyzer:
return -2 return -2
def compute_act_spec(self, node: fx.Node, mesh: DeviceMesh) -> DTensorSpec: def compute_act_spec(self, node: fx.Node, mesh: DeviceMesh) -> DTensorSpec:
""" """Compute the batch dimension for the current node, then generate the sharding spec that shards on the batch dimension."""
This function first compute the batch dimension for the current node,
then generate the sharding spec that shards on the batch dimension.
"""
node_batch_dim = self.compute_batch_dim(node) node_batch_dim = self.compute_batch_dim(node)
if node_batch_dim == -1: if node_batch_dim == -1:
# indicate this activation is replicated # indicate this activation is replicated

View File

@ -56,9 +56,9 @@ def _get_tracer() -> Optional[torch.fx.Tracer]:
class CommTensor(torch.Tensor): class CommTensor(torch.Tensor):
r""" r"""
A Tensor subclass to wrap input tensors for collective communications. This A Tensor subclass to wrap input tensors for collective communications.
Tensor subclass works for both eager and tracing mode.
This Tensor subclass works for both eager and tracing mode.
In eager mode, it will record whether the inplace collective communication In eager mode, it will record whether the inplace collective communication
has been launched using this Tensor and remember the corresponding work has been launched using this Tensor and remember the corresponding work
handle. If yes, it will explicitly call wait() in the ``__torch_dispatch__`` handle. If yes, it will explicitly call wait() in the ``__torch_dispatch__``

View File

@ -41,7 +41,8 @@ _spmd_lib_impl.impl("tag_grad", lambda x: x, "CompositeExplicitAutograd")
class DataParallelStyle(Enum): class DataParallelStyle(Enum):
""" """This enum represents the style of the data-parallel operation.
We have three types of Data Parallel style: We have three types of Data Parallel style:
1. DEFAULT: the default data parallel style, which is to represent a mixed 1. DEFAULT: the default data parallel style, which is to represent a mixed
replicate and fully shard behavior. For each parameter that is able replicate and fully shard behavior. For each parameter that is able
@ -64,8 +65,8 @@ class DataParallelStyle(Enum):
class NodeType(Enum): class NodeType(Enum):
""" """NodeType is an enum that records the type of the tensors in the graph.
NodeType is a enum that records the type of the tensors in the graph.
This is used to determine the data parallel strategy. This is used to determine the data parallel strategy.
""" """
@ -77,9 +78,8 @@ class NodeType(Enum):
class DataParallelStrategy(OpStrategy): class DataParallelStrategy(OpStrategy):
""" """DataParallelStrategy is a special case of OpStrategy that only records the "data parallel style" placement
DataParallelStrategy is a special case of OpStrategy that only records strategy for each fx Node.
the "data parallel style" placement strategy for each fx Node.
It takes a list of PlacementStrategy, where each PlacementStrategy describes It takes a list of PlacementStrategy, where each PlacementStrategy describes
one way to distribute the tensor and computation. In the DataParallel case, one way to distribute the tensor and computation. In the DataParallel case,
@ -113,13 +113,10 @@ class DataParallelStrategy(OpStrategy):
@contextmanager @contextmanager
def gradients_tagging(params: Dict[str, torch.Tensor]): def gradients_tagging(params: Dict[str, torch.Tensor]):
""" """Tag the gradient of the parameters with a special tag, so that we can identify them during SPMD expansion.
This is a helper function that tags the gradient of the parameters
with a special tag, so that we can identify them during SPMD expansion.
It's safe to trace those hooks and we would remove those nodes later. It's safe to trace those hooks and we would remove those nodes later.
""" """
tagging_hooks = [] tagging_hooks = []
try: try:
for p in params.values(): for p in params.values():
@ -135,9 +132,7 @@ def gradients_tagging(params: Dict[str, torch.Tensor]):
def _gen_shard_strategy( def _gen_shard_strategy(
mesh: DeviceMesh, shard_dim: int, input_specs: Optional[List[DTensorSpec]] = None mesh: DeviceMesh, shard_dim: int, input_specs: Optional[List[DTensorSpec]] = None
) -> PlacementStrategy: ) -> PlacementStrategy:
""" """Util function to generate a shard strategy on shard_dim."""
util function to generate a shard strategy on shard_dim
"""
return PlacementStrategy( return PlacementStrategy(
output_spec=DTensorSpec(mesh=mesh, placements=(Shard(shard_dim),)), output_spec=DTensorSpec(mesh=mesh, placements=(Shard(shard_dim),)),
input_specs=input_specs, input_specs=input_specs,
@ -147,9 +142,7 @@ def _gen_shard_strategy(
def _gen_replicate_strategy( def _gen_replicate_strategy(
mesh: DeviceMesh, input_specs: Optional[List[DTensorSpec]] = None mesh: DeviceMesh, input_specs: Optional[List[DTensorSpec]] = None
) -> PlacementStrategy: ) -> PlacementStrategy:
""" """Util function to generate a replicate strategy."""
util function to generate a replicate strategy
"""
return PlacementStrategy( return PlacementStrategy(
output_spec=DTensorSpec(mesh=mesh, placements=(Replicate(),)), output_spec=DTensorSpec(mesh=mesh, placements=(Replicate(),)),
input_specs=input_specs, input_specs=input_specs,
@ -157,9 +150,7 @@ def _gen_replicate_strategy(
def _gen_partial_strategy(mesh: DeviceMesh) -> PlacementStrategy: def _gen_partial_strategy(mesh: DeviceMesh) -> PlacementStrategy:
""" """Util function to generate a partial strategy."""
util function to generate a partial strategy
"""
# NOTE: we use AVG by default, avg reduction is needed depending on # NOTE: we use AVG by default, avg reduction is needed depending on
# the loss function, for most loss function it should do # the loss function, for most loss function it should do
# gradient averaging. There might be certain cases it should # gradient averaging. There might be certain cases it should
@ -180,10 +171,7 @@ def build_data_parallel_strategies(
mesh: DeviceMesh, mesh: DeviceMesh,
batch_dim: int = 0, batch_dim: int = 0,
) -> Dict[fx.Node, StrategyType]: ) -> Dict[fx.Node, StrategyType]:
""" """Loop through the train step graph and build the data parallel strategy for each fx Node."""
This function loop through the train step graph and build the
data parallel strategy for each fx Node
"""
activation_idx = num_params + num_states activation_idx = num_params + num_states
non_compute_ops = [ non_compute_ops = [
aten.clone.default, aten.clone.default,
@ -518,9 +506,7 @@ def mark_data_parallel_shardings(
dp_strategy_map: Dict[fx.Node, StrategyType], dp_strategy_map: Dict[fx.Node, StrategyType],
parallel_mode: DataParallelStyle = DataParallelStyle.FULLY_SHARD, parallel_mode: DataParallelStyle = DataParallelStyle.FULLY_SHARD,
) -> None: ) -> None:
""" """Mark the sharding for the nodes in the train_step_graph."""
This function marks the sharding for the nodes in the train_step_graph
"""
activation_idx = num_parameters + num_states activation_idx = num_parameters + num_states
placeholder_idx = 0 placeholder_idx = 0
for node in train_step_graph.graph.nodes: for node in train_step_graph.graph.nodes:
@ -601,9 +587,7 @@ def mark_data_parallel_shardings(
def _partition_val(val: Any, spec: DTensorSpec) -> Any: def _partition_val(val: Any, spec: DTensorSpec) -> Any:
""" """Util function to convert a full tensor val to its local component."""
util function to convert a full tensor val to its local component
"""
if isinstance(val, torch.Tensor): if isinstance(val, torch.Tensor):
local_shard = val local_shard = val
if val.ndim == 0: if val.ndim == 0:
@ -629,10 +613,7 @@ def _partition_val(val: Any, spec: DTensorSpec) -> Any:
def partitioner(graph: GraphModule) -> GraphModule: def partitioner(graph: GraphModule) -> GraphModule:
""" """Graph partitioner that partitions the single device graph to distributed graph."""
Graph partitioner that partitions the single device graph
to distributed graph
"""
shape_adjustment_ops = { shape_adjustment_ops = {
aten._unsafe_view.default: 1, aten._unsafe_view.default: 1,
aten.expand.default: 1, aten.expand.default: 1,
@ -761,10 +742,9 @@ def partition_data_parallel(
parallel_style: DataParallelStyle, parallel_style: DataParallelStyle,
input_batch_dim: int, input_batch_dim: int,
) -> GraphModule: ) -> GraphModule:
""" """Partition the graph to into a data parallel graph.
The entry point function to partition the graph to data parallel
graph, it also shard/replicate the model parameters and optimizer This function also shards/replicates the model parameters and optimizer states to DTensors.
states to DTensors.
""" """
num_params_buffers = len(params_buffers) num_params_buffers = len(params_buffers)
flattened_states = pytree.tree_leaves(named_states) flattened_states = pytree.tree_leaves(named_states)

View File

@ -48,9 +48,9 @@ class Schema:
@dataclass @dataclass
class DSymInt: class DSymInt:
""" """DSymInt represents a value retrieved by a SymInt op from a DTensor.
DSymInt represents a value retrieved by a SymInt op from a DTensor. DSymInt
helps View and Factory ops to determine the placement and shape of the DSymInt helps View and Factory ops to determine the placement and shape of the
output tensor, as those operators either do not have an input DTensor or output tensor, as those operators either do not have an input DTensor or
the input DTensor is insufficient to determine the output tensor's placement. the input DTensor is insufficient to determine the output tensor's placement.
""" """
@ -90,7 +90,7 @@ class DSymInt:
def _is_partial_dtensor(obj: Any) -> bool: def _is_partial_dtensor(obj: Any) -> bool:
"""check if object is 1) DTensor and 2) with any placement of _Partial""" """Check if object is 1) DTensor and 2) with any placement of _Partial."""
if not isinstance(obj, DTensor): if not isinstance(obj, DTensor):
return False return False
@ -475,8 +475,8 @@ def _get_dtensor_dispatch_graph(
def _build_dummy_add_graph( def _build_dummy_add_graph(
dt: DTensor, node_to_obj: Dict[fx.Node, Any] dt: DTensor, node_to_obj: Dict[fx.Node, Any]
) -> Tuple[fx.GraphModule, Any]: ) -> Tuple[fx.GraphModule, Any]:
""" """Create a graph for a dummy add function from a partial DTensor.
Creates a graph for a dummy add function from a partial DTensor.
This dummy add is used for triggering all_reduce on a Partial DTensor This dummy add is used for triggering all_reduce on a Partial DTensor
during the DTensor expansion of the traced graph. during the DTensor expansion of the traced graph.
Also returns the actual DTensor after resharding. Also returns the actual DTensor after resharding.
@ -703,10 +703,12 @@ def _convert_to_distributed(
default_mesh: Optional[DeviceMesh] = None, default_mesh: Optional[DeviceMesh] = None,
_allow_partial: bool = False, _allow_partial: bool = False,
) -> Tuple[fx.GraphModule, Dict[str, Schema]]: ) -> Tuple[fx.GraphModule, Dict[str, Schema]]:
""" """Transform a graph module to a distributed graph module.
Returns: Returns:
- transformed graph module - transformed graph module
- map from output name to DTensorSpec - map from output name to DTensorSpec
""" """
global logger global logger
logger = get_logger("spmd_exp") logger = get_logger("spmd_exp")

View File

@ -342,10 +342,7 @@ def _prop_native_layer_norm_backward(op_schema: OpSchema) -> OutputSharding:
def _refine_sharding( def _refine_sharding(
op_schema: OpSchema, active_dim: Optional[int] op_schema: OpSchema, active_dim: Optional[int]
) -> Sequence[Placement]: ) -> Sequence[Placement]:
""" """Considers 2 first inputs of op_schema as having same shape, and returns suggested placement for a pointwise operation."""
Considers 2 first inputs of op_schema as having same shape,
and returns suggested placement for a pointwise operation.
"""
# consider the operating dimension as a singleton to prevent sharding on it # consider the operating dimension as a singleton to prevent sharding on it
# however, if active_dim is None, this means the input and output shapes are equal and # however, if active_dim is None, this means the input and output shapes are equal and
# we'll apply exactly the pointwise rule. # we'll apply exactly the pointwise rule.

View File

@ -64,10 +64,9 @@ def graph_optimization_pass(
prerequisites: Iterable[Callable], prerequisites: Iterable[Callable],
apply_after: Iterable[Callable], apply_after: Iterable[Callable],
) -> Callable: ) -> Callable:
""" """Define the contract of a graph optimization pass.
The contract of graph optimization pass. All the passes should be wrapped
with this decorator.
All the passes should be wrapped with this decorator.
`prerequisites` is used to annotate the prerequisite passes of the this pass. `prerequisites` is used to annotate the prerequisite passes of the this pass.
`apply_after` means that this wrapped pass must be applied after the passes `apply_after` means that this wrapped pass must be applied after the passes
in `apply_after`. The difference between `prerequisites` and `apply_after` in `apply_after`. The difference between `prerequisites` and `apply_after`
@ -158,12 +157,11 @@ class CommBlock:
def get_comm_block(comm_node: fx.Node) -> CommBlock: def get_comm_block(comm_node: fx.Node) -> CommBlock:
""" """Find out all the nodes belong to this communcation given a collective node (e.g., allreduce).
Given a collective node (e.g., allreduce), find out all the nodes belong to
this communcation.
Args: Args:
comm_node(fx.Node): The target communication/collective node. comm_node(fx.Node): The target communication/collective node.
Returns: Returns:
The CommBlock that encapsulates the related nodes (e.g., wait_node) of The CommBlock that encapsulates the related nodes (e.g., wait_node) of
the given comm_node. the given comm_node.
@ -306,10 +304,7 @@ def _scatter_wait_result(
comm_blocks: List[CommBlock], comm_blocks: List[CommBlock],
node_indices: Dict[fx.Node, int], node_indices: Dict[fx.Node, int],
) -> None: ) -> None:
""" """Scatter the result of the fused communication node to the original users -- splitting the output and reshape each subitem."""
Scatters the result of the fused communication node to the original users --
splitting the output and reshape each subitem.
"""
last_wait_node_idx = 0 last_wait_node_idx = 0
for node in gm.graph.nodes: for node in gm.graph.nodes:
if node == fused_comm_block.comm_node: if node == fused_comm_block.comm_node:
@ -371,9 +366,7 @@ def _fuse_with_cat(
comm_blocks: List[CommBlock], comm_blocks: List[CommBlock],
node_indices: Dict[fx.Node, int], node_indices: Dict[fx.Node, int],
) -> CommBlock: ) -> CommBlock:
""" """Fuse the CommBlocks using concat given a list of CommBlock (only allreduce)."""
Given a list of CommBlock (only allreduce), fuse the CommBlocks using concat.
"""
# Find the last input node. # Find the last input node.
last_input_node = comm_blocks[0].inputs[0] last_input_node = comm_blocks[0].inputs[0]
last_input_index = -1 last_input_index = -1
@ -474,8 +467,8 @@ def comm_fusion_with_concat(
gm: IterGraphModule, gm: IterGraphModule,
bucket_size_mb: int, bucket_size_mb: int,
) -> None: ) -> None:
""" """Run fuse communication with concat.
Run fuse communication with concat.
This implementation uses concat to concat the bucketed gradients. This implementation uses concat to concat the bucketed gradients.
""" """
comm_blocks = get_all_comm_blocks(gm, (CommType.ALLREDUCE, "all_reduce")) comm_blocks = get_all_comm_blocks(gm, (CommType.ALLREDUCE, "all_reduce"))
@ -508,9 +501,7 @@ def comm_fusion_with_concat(
apply_after=[], apply_after=[],
) )
def schedule_comm_wait(gm: IterGraphModule) -> None: def schedule_comm_wait(gm: IterGraphModule) -> None:
""" """Delay the execution of wait tensors of allreduce until its first user."""
Delay the execution of wait tensors of allreduce until its first user.
"""
comm_blocks = get_all_comm_blocks(gm, (CommType.ALLREDUCE, "all_reduce")) comm_blocks = get_all_comm_blocks(gm, (CommType.ALLREDUCE, "all_reduce"))
# Find all the end users. # Find all the end users.
@ -549,8 +540,8 @@ def schedule_comm_wait(gm: IterGraphModule) -> None:
apply_after=[], apply_after=[],
) )
def remove_copy_from_optimizer(gm: IterGraphModule) -> None: def remove_copy_from_optimizer(gm: IterGraphModule) -> None:
""" """Erase the orphant copy_ that generated when tracing optimizer.
Erase the orphant copy_ that generated when tracing optimizer.
Two reasons why we could not simply use the DCE of fx.Graph. Two reasons why we could not simply use the DCE of fx.Graph.
1. fx.Graph treats copy_ as a side-effect node and does not erase it. 1. fx.Graph treats copy_ as a side-effect node and does not erase it.
2. Users may want to preserve some orphan `copy_` that is not from the 2. Users may want to preserve some orphan `copy_` that is not from the
@ -742,9 +733,7 @@ class FusedOptimizerBlock:
def get_fused_optimizer_block(optim_node: fx.Node) -> FusedOptimizerBlock: def get_fused_optimizer_block(optim_node: fx.Node) -> FusedOptimizerBlock:
""" """Given a fused optimizer node and return the FusedOptimizerBlock."""
Given a fused optimizer node and return the FusedOptimizerBlock.
"""
MAX_STEP_DISTANCE = 5 MAX_STEP_DISTANCE = 5
# Find the step (foreach_add) # Find the step (foreach_add)
nodes = collections.deque([optim_node, None]) nodes = collections.deque([optim_node, None])
@ -783,10 +772,7 @@ def get_fused_optimizer_block(optim_node: fx.Node) -> FusedOptimizerBlock:
def get_all_fused_optimizer_blocks( def get_all_fused_optimizer_blocks(
gm: IterGraphModule, optim_ops: Union[Tuple[str, ...], str] gm: IterGraphModule, optim_ops: Union[Tuple[str, ...], str]
) -> List[FusedOptimizerBlock]: ) -> List[FusedOptimizerBlock]:
""" """Find all the FusedOptimizerBlock that the optimizer operators are in `optim_ops`."""
Find all the FusedOptimizerBlock that the optimizer operators are in
`optim_ops`.
"""
return [ return [
get_fused_optimizer_block(node) get_fused_optimizer_block(node)
for node in gm.graph.nodes for node in gm.graph.nodes
@ -799,9 +785,9 @@ def _split_fused_adam(
orig_optim_block: FusedOptimizerBlock, orig_optim_block: FusedOptimizerBlock,
split_gradients: Set[fx.Node], split_gradients: Set[fx.Node],
) -> Tuple[FusedOptimizerBlock, FusedOptimizerBlock]: ) -> Tuple[FusedOptimizerBlock, FusedOptimizerBlock]:
""" """Split the `orig_optim_block` into two FusedOptimizerBlock.
Split the `orig_optim_block` into two FusedOptimizerBlock. The first one
will be the optimizer that optimize `split_gradients`. The second one is The first one will be the optimizer that optimize `split_gradients`. The second one is
used to optimize the remaining gradients. used to optimize the remaining gradients.
An assert will be raised if one of the optimizer optimize zero gradients. An assert will be raised if one of the optimizer optimize zero gradients.
""" """
@ -949,7 +935,8 @@ def iter_move_grads_and_optimizers(
target_comm_node: str, target_comm_node: str,
target_dest_node: str, target_dest_node: str,
) -> None: ) -> None:
"""Function to extract a comm block and split out a new optimizer and step for it. """Extract a comm block and split out a new optimizer and step for it.
This subgraph is then moved to the forward graph. This subgraph is then moved to the forward graph.
""" """
for comm_block in get_all_comm_blocks(gm, "all_reduce"): for comm_block in get_all_comm_blocks(gm, "all_reduce"):
@ -982,8 +969,7 @@ def find_all_descendants(
gm: IterGraphModule, gm: IterGraphModule,
parent_nodes: List[fx.Node], parent_nodes: List[fx.Node],
) -> List[fx.Node]: ) -> List[fx.Node]:
"""identifying list of nodes to move during FX graph transformation""" """Identify the list of nodes to move during FX graph transformation."""
assert len(parent_nodes) > 0, "No parent nodes are given." assert len(parent_nodes) > 0, "No parent nodes are given."
output = get_output(gm.graph) output = get_output(gm.graph)

View File

@ -41,9 +41,9 @@ def get_node_tensor_metadata(node: fx.Node, is_required: bool = True) -> TensorM
def get_output(graph: fx.Graph) -> fx.Node: def get_output(graph: fx.Graph) -> fx.Node:
""" """Take a graphmodule and return the graph output node.
Take a graphmodule and returns the graph output node. We traverse in reverse
to expedite it, with the idea that last node should be output We traverse in reverse to expedite it, with the idea that last node should be output
""" """
for node in reversed(graph.nodes): for node in reversed(graph.nodes):
if node.op == OP.OUTPUT: if node.op == OP.OUTPUT:
@ -54,10 +54,7 @@ def get_output(graph: fx.Graph) -> fx.Node:
def find_node( def find_node(
graph: fx.Graph, predicate: Callable, reverse_order: bool = False graph: fx.Graph, predicate: Callable, reverse_order: bool = False
) -> List[fx.Node]: ) -> List[fx.Node]:
""" """Take a predicate and return all the nodes in the `graph` where the predicate holds."""
Take a predicate and return all the nodes in the `graph` where the predicate
holds.
"""
nodes = cast(Iterable[fx.Node], graph.nodes) nodes = cast(Iterable[fx.Node], graph.nodes)
if reverse_order: if reverse_order:
nodes = cast(Iterable[fx.Node], iter(reversed(nodes))) # type: ignore[call-overload] nodes = cast(Iterable[fx.Node], iter(reversed(nodes))) # type: ignore[call-overload]
@ -65,8 +62,8 @@ def find_node(
def is_leaf_subgraph(graph: fx.Graph, subgraph: List[fx.Node]) -> bool: def is_leaf_subgraph(graph: fx.Graph, subgraph: List[fx.Node]) -> bool:
""" """Ensure nodes in ``subgraph`` satisfy one of the following rules.
This function ensures nodes in ``subgraph`` satisfy one of the rules:
1. The user of the node is in ``subgraph``. 1. The user of the node is in ``subgraph``.
2. The user of the node is output. 2. The user of the node is output.
3. There are no users -- the node is a side-effect node. 3. There are no users -- the node is a side-effect node.
@ -85,11 +82,10 @@ def is_leaf_subgraph(graph: fx.Graph, subgraph: List[fx.Node]) -> bool:
def clone_subgraph( def clone_subgraph(
graph: fx.Graph, subgraph: List[fx.Node], target: fx.Node graph: fx.Graph, subgraph: List[fx.Node], target: fx.Node
) -> List[fx.Node]: ) -> List[fx.Node]:
""" """Clone the given subgraph and insert it before ``target``.
Clone the given subgraph and insert it before ``target``.
This API currently does not support inserting after ``target``. This API currently does not support inserting after ``target``.
""" """
all_nodes = set(subgraph) all_nodes = set(subgraph)
mapping: Dict[fx.Node, fx.Node] = dict() mapping: Dict[fx.Node, fx.Node] = dict()
cloned_subgraph = [] cloned_subgraph = []
@ -125,12 +121,11 @@ def clone_subgraph(
def rebuild_graph(gm: fx.GraphModule, remove_dead_code: bool = True) -> None: def rebuild_graph(gm: fx.GraphModule, remove_dead_code: bool = True) -> None:
""" """Run the required steps to ensure production-ready graph.
Runs the required steps to ensure production-ready graph.
note - per the fx docs, eliminate dead code is not very precise. Note - per the fx docs, elimination of dead code is not very precise.
Hence, the flag to make this step optional. Hence, the flag to make this step optional.
""" """
gm.graph.lint() gm.graph.lint()
if remove_dead_code: if remove_dead_code:
gm.graph.eliminate_dead_code() gm.graph.eliminate_dead_code()

View File

@ -22,9 +22,9 @@ logger: logging.Logger = logging.getLogger("IterGraphModule")
class IterGraph(fx.Graph): class IterGraph(fx.Graph):
""" """``IterGraph`` is used to perform cross-iteration optimization.
``IterGraph`` is used to perform cross-iteration optimization. ``IterGraph``
keeps track of the 3 graphs, self (the original graph), setup graph, and ``IterGraph`` keeps track of the 3 graphs, self (the original graph), setup graph, and
cleanup graph. The 3 graphs should be identical copies of a ``fx.Graph``. cleanup graph. The 3 graphs should be identical copies of a ``fx.Graph``.
IterGraph subclass fx.Graph to override the necessary APIs that will be used IterGraph subclass fx.Graph to override the necessary APIs that will be used
@ -127,11 +127,10 @@ class IterGraph(fx.Graph):
def _forward_subgraph_inputs( def _forward_subgraph_inputs(
self, subgraph: List[fx.Node], graph: fx.Graph, erase_node: bool self, subgraph: List[fx.Node], graph: fx.Graph, erase_node: bool
) -> int: ) -> int:
""" """Turn the inputs of a subgraph into the extra output of the entire graph.
This function turns the inputs of a subgraph into the extra output
of the entire graph. If ``erase_node`` is True, the subgraph will be If ``erase_node`` is True, the subgraph will be erased from the graph -- essentially forward the inputs
erased from the graph -- essentially forward the inputs of the subgraph of the subgraph to the output of the graph.
to the output of the graph.
""" """
output = get_output(graph) output = get_output(graph)
inputs = [] inputs = []
@ -219,10 +218,10 @@ class IterGraph(fx.Graph):
def _forward_inputs_to_subgraph( def _forward_inputs_to_subgraph(
self, subgraph: List[fx.Node], graph: fx.Graph, extra_input: int self, subgraph: List[fx.Node], graph: fx.Graph, extra_input: int
) -> None: ) -> None:
""" """Create extra input nodes and forward the input nodes to the ``subgraph``.
This function creates extra input nodes and forward the input nodes to
the ``subgraph``. The external input nodes of ``subgraph`` (nodes that The external input nodes of ``subgraph`` (nodes that are not in ``subgraph``) will replaced by the newly
are not in ``subgraph``) will replaced by the newly created input nodes. created input nodes.
""" """
placeholders = [node for node in graph.nodes if str(node.op) == "placeholder"] placeholders = [node for node in graph.nodes if str(node.op) == "placeholder"]
assert placeholders, "No placeholders are found" assert placeholders, "No placeholders are found"
@ -275,8 +274,8 @@ class IterGraph(fx.Graph):
def move_to_next_iter_before( def move_to_next_iter_before(
self, subgraph: List[fx.Node], target_node: fx.Node self, subgraph: List[fx.Node], target_node: fx.Node
) -> None: ) -> None:
""" """Move the ``subgraph`` to the next iteration before ``target_node``.
Move the ``subgraph`` to the next iteration before ``target_node``.
The ``subgraph`` is a list of fx.Node and must satisfy the following The ``subgraph`` is a list of fx.Node and must satisfy the following
restrictions: restrictions:
1. The order of the nodes in ``subgraph`` must obey the topological 1. The order of the nodes in ``subgraph`` must obey the topological
@ -633,8 +632,8 @@ class IterGraph(fx.Graph):
class IterGraphModule(nn.Module): class IterGraphModule(nn.Module):
""" """``IterGraphModule`` provides the ability to do cross-iteration optimization.
``IterGraphModule`` provides the ability to do cross-iteration optimization.
Given a ``fx.GraphModule``, main_gm, ``IterGraphModule`` internally Given a ``fx.GraphModule``, main_gm, ``IterGraphModule`` internally
duplicate it to 3 copies and redirect the ``forward`` request to a different duplicate it to 3 copies and redirect the ``forward`` request to a different
``fx.GraphModule`` based on the iteration count. This allows users to do ``fx.GraphModule`` based on the iteration count. This allows users to do
@ -674,10 +673,9 @@ class IterGraphModule(nn.Module):
self._enable_inductor = enable_inductor self._enable_inductor = enable_inductor
def finalize_setup(self) -> None: def finalize_setup(self) -> None:
""" """Set up the internal states and also get the signal from users that what is the maximum iteration count.
Must be called before the forward() is called. This method setups
the internal states and also get the signal from users that what This method must be called before the forward() is called.
is the maximum iteration count.
""" """
if not self._is_frozen: if not self._is_frozen:
self.graph.freeze_cross_iter_movement() self.graph.freeze_cross_iter_movement()

View File

@ -112,14 +112,16 @@ FLAT_PARAM = "_flat_param"
class OptimStateKeyType(Enum): class OptimStateKeyType(Enum):
"""Represents the type of key in an optimizer state-dict."""
PARAM_NAME = auto() PARAM_NAME = auto()
PARAM_ID = auto() PARAM_ID = auto()
class FullyShardedDataParallel(nn.Module, _FSDPState): class FullyShardedDataParallel(nn.Module, _FSDPState):
""" """A wrapper for sharding module parameters across data parallel workers.
A wrapper for sharding module parameters across data parallel workers. This
is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_. This is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
FullyShardedDataParallel is commonly shortened to FSDP. FullyShardedDataParallel is commonly shortened to FSDP.
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336 .. _`Xu et al.`: https://arxiv.org/abs/2004.13336
@ -515,9 +517,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
@property @property
def module(self) -> nn.Module: def module(self) -> nn.Module:
""" """Return the wrapped module."""
Returns the wrapped module (like :class:`DistributedDataParallel`).
"""
# FSDP's `.module` must refer to the innermost wrapped module when # FSDP's `.module` must refer to the innermost wrapped module when
# composing with other module wrappers in order for state dict to work # composing with other module wrappers in order for state dict to work
if isinstance(self._fsdp_wrapped_module, ActivationWrapper): if isinstance(self._fsdp_wrapped_module, ActivationWrapper):
@ -547,6 +547,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
return super().__getitem__(key) return super().__getitem__(key)
def check_is_root(self) -> bool: def check_is_root(self) -> bool:
"""Check if this instance is a root FSDP module."""
return _is_fsdp_root(self, self) return _is_fsdp_root(self, self)
@staticmethod @staticmethod
@ -554,9 +555,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
module: nn.Module, module: nn.Module,
root_only: bool = False, root_only: bool = False,
) -> List["FullyShardedDataParallel"]: ) -> List["FullyShardedDataParallel"]:
""" """Return all nested FSDP instances.
Returns all nested FSDP instances, possibly including ``module`` itself
and only including FSDP root modules if ``root_only=True``. This possibly includes ``module`` itself and only includes FSDP root modules if ``root_only=True``.
Args: Args:
module (torch.nn.Module): Root module, which may or may not be an module (torch.nn.Module): Root module, which may or may not be an
@ -573,9 +574,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
return traversal_utils._get_fsdp_states(module) return traversal_utils._get_fsdp_states(module)
def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel": def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``) r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
as well as self. Typical use includes initializing the parameters of a model
(see also :ref:`nn-init-doc`). Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`).
Compared to ``torch.nn.Module.apply``, this version additionally gathers Compared to ``torch.nn.Module.apply``, this version additionally gathers
the full parameters before applying ``fn``. It should not be called from the full parameters before applying ``fn``. It should not be called from
@ -614,8 +615,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
return ret return ret
def _mixed_precision_enabled_for_buffers(self) -> bool: def _mixed_precision_enabled_for_buffers(self) -> bool:
""" """Return whether the user explicitly enabled buffer mixed precision.
Returns if the user explicitly enabled buffer mixed precision.
NOTE: Unlike parameters and gradient reduction, buffer mixed precision NOTE: Unlike parameters and gradient reduction, buffer mixed precision
is applied at the FSDP instance level, not the ``FlatParameter`` level, is applied at the FSDP instance level, not the ``FlatParameter`` level,
@ -624,15 +624,11 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
return self.mixed_precision.buffer_dtype is not None return self.mixed_precision.buffer_dtype is not None
def _low_precision_hook_enabled(self) -> bool: def _low_precision_hook_enabled(self) -> bool:
""" """Whether a low precision hook is registered or not."""
Whether a low precision hook is registered or not.
"""
return self._comm_hook is not None and self._comm_hook in LOW_PRECISION_HOOKS return self._comm_hook is not None and self._comm_hook in LOW_PRECISION_HOOKS
def _reset_lazy_init(self) -> None: def _reset_lazy_init(self) -> None:
""" """Reset instance so :func:`_lazy_init` will run on the next forward."""
Reset instance so :func:`_lazy_init` will run on the next forward.
"""
self._is_root: Optional[bool] = None self._is_root: Optional[bool] = None
@staticmethod @staticmethod
@ -642,9 +638,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
state_dict_config: Optional[StateDictConfig] = None, state_dict_config: Optional[StateDictConfig] = None,
optim_state_dict_config: Optional[OptimStateDictConfig] = None, optim_state_dict_config: Optional[OptimStateDictConfig] = None,
) -> StateDictSettings: ) -> StateDictSettings:
""" """Set the ``state_dict_type`` of all the descendant FSDP modules of the target module.
Set the ``state_dict_type`` and the corresponding (optional)
configurations of all the descendant FSDP modules of the target module. Also takes (optional) configuration for the model's and optimizer's state dict.
The target module does not have to be a FSDP module. If the target The target module does not have to be a FSDP module. If the target
module is a FSDP module, its ``state_dict_type`` will also be changed. module is a FSDP module, its ``state_dict_type`` will also be changed.
@ -747,10 +743,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
@staticmethod @staticmethod
def get_state_dict_type(module: nn.Module) -> StateDictSettings: def get_state_dict_type(module: nn.Module) -> StateDictSettings:
""" """Get the state_dict_type and the corresponding configurations for the FSDP modules rooted at ``module``.
Get the state_dict_type and the corresponding configurations
for the FSDP modules rooted at ``module``. The target module The target module does not have to be an FSDP module.
does not have to be an FSDP module.
Returns: Returns:
A ``StateDictSettings`` containing the state_dict_type and A ``StateDictSettings`` containing the state_dict_type and
@ -797,10 +792,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
state_dict_config: Optional[StateDictConfig] = None, state_dict_config: Optional[StateDictConfig] = None,
optim_state_dict_config: Optional[OptimStateDictConfig] = None, optim_state_dict_config: Optional[OptimStateDictConfig] = None,
) -> Generator: ) -> Generator:
""" """Set the ``state_dict_type`` of all the descendant FSDP modules of the target module.
A context manager to set the ``state_dict_type`` of all the descendant
FSDP modules of the target module. This context manager has the same This context manager has the same functions as :meth:`set_state_dict_type`. Read the document of
functions as :meth:`set_state_dict_type`. Read the document of
:meth:`set_state_dict_type` for the detail. :meth:`set_state_dict_type` for the detail.
Example:: Example::
@ -836,10 +830,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
) )
def forward(self, *args: Any, **kwargs: Any) -> Any: def forward(self, *args: Any, **kwargs: Any) -> Any:
""" """Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic."""
Runs the forward pass for the wrapped module, inserting FSDP-specific
pre- and post-forward sharding logic.
"""
handle = self._handle handle = self._handle
with torch.autograd.profiler.record_function( with torch.autograd.profiler.record_function(
"FullyShardedDataParallel.forward" "FullyShardedDataParallel.forward"
@ -875,7 +866,8 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
offload_to_cpu: bool = False, offload_to_cpu: bool = False,
with_grads: bool = False, with_grads: bool = False,
) -> Generator: ) -> Generator:
r"""A context manager to expose full params for FSDP instances. r"""Expose full params for FSDP instances with this context manager.
Can be useful *after* forward/backward for a model to get Can be useful *after* forward/backward for a model to get
the params for additional processing or checking. It can take a non-FSDP the params for additional processing or checking. It can take a non-FSDP
module and will summon full params for all contained FSDP modules as module and will summon full params for all contained FSDP modules as
@ -941,9 +933,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
@contextlib.contextmanager @contextlib.contextmanager
def _deregister_orig_params_ctx(self): def _deregister_orig_params_ctx(self):
""" """Deregister the original parameters and expose the :class:`FlatParameter`.
This deregisters the original parameters and exposes the
:class:`FlatParameter` s. If a :class:`FlatParameter` is sharded, then If a :class:`FlatParameter` is sharded, then
this refreshes the sharded views before exiting. This method should this refreshes the sharded views before exiting. This method should
only be called when using the original parameters. only be called when using the original parameters.
""" """
@ -961,11 +953,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
_register_orig_params(fsdp_module, fsdp_module) _register_orig_params(fsdp_module, fsdp_module)
def _apply(self, *args, **kwargs): def _apply(self, *args, **kwargs):
""" """Deregister the original parameters and expose the :class:`FlatParameter` s before calling ``_apply()``."""
When using the original parameters, this deregisters the original
parameters and exposes the :class:`FlatParameter` s before calling
``_apply()``.
"""
# When using the original parameters: Since (1) the `FlatParameter`s # When using the original parameters: Since (1) the `FlatParameter`s
# own the storage and (2) `_apply()` is the subroutine underlying the # own the storage and (2) `_apply()` is the subroutine underlying the
# most common storage-changing ops like `to()` and `cuda()`, we # most common storage-changing ops like `to()` and `cuda()`, we
@ -985,9 +973,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
*args, *args,
**kwargs, **kwargs,
) -> Iterator[Tuple[str, torch.Tensor]]: ) -> Iterator[Tuple[str, torch.Tensor]]:
""" """Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself.
Overrides :meth:`named_buffers()` to intercept buffer names and
remove all occurrences of the FSDP-specific flattened buffer prefix Intercepts buffer names and removes all occurrences of the FSDP-specific flattened buffer prefix
when inside the :meth:`summon_full_params` context manager. when inside the :meth:`summon_full_params` context manager.
""" """
should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
@ -1003,9 +991,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
*args, *args,
**kwargs, **kwargs,
) -> Iterator[Tuple[str, torch.nn.Parameter]]: ) -> Iterator[Tuple[str, torch.nn.Parameter]]:
""" """Return an iterator over module parameters, yielding both the name of the parameter and the parameter itself.
Overrides :meth:`named_parameters()` to intercept parameter names and
remove all occurrences of the FSDP-specific flattened parameter prefix Intercepts parameter names and removes all occurrences of the FSDP-specific flattened parameter prefix
when inside the :meth:`summon_full_params` context manager. when inside the :meth:`summon_full_params` context manager.
""" """
should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
@ -1038,9 +1026,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
@contextmanager @contextmanager
def no_sync(self) -> Generator: def no_sync(self) -> Generator:
""" """Disable gradient synchronizations across FSDP instances.
A context manager to disable gradient synchronizations across FSDP
instances. Within this context, gradients will be accumulated in module Within this context, gradients will be accumulated in module
variables, which will later be synchronized in the first variables, which will later be synchronized in the first
forward-backward pass after exiting the context. This should only be forward-backward pass after exiting the context. This should only be
used on the root FSDP instance and will recursively apply to all used on the root FSDP instance and will recursively apply to all
@ -1079,9 +1067,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
def clip_grad_norm_( def clip_grad_norm_(
self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0 self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0
) -> torch.Tensor: ) -> torch.Tensor:
""" """Clip the gradient norm of all parameters.
Clips the gradient norm of all parameters. The norm is computed over
all parameters' gradients as viewed as a single vector, and the The norm is computed over all parameters' gradients as viewed as a single vector, and the
gradients are modified in-place. gradients are modified in-place.
Args: Args:
@ -1245,8 +1233,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
group: Optional[dist.ProcessGroup] = None, group: Optional[dist.ProcessGroup] = None,
cpu_offload: bool = True, cpu_offload: bool = True,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """Transform the state-dict of an optimizer corresponding to a sharded model.
The internal API that is used by all the optim_state_dict implementations.
This is the internal API that is used by all the optim_state_dict implementations.
Given model, optim, the original optim_state_dict, this API removes the Given model, optim, the original optim_state_dict, this API removes the
FSDP internal information and internal sharding from the optim_state_dict. FSDP internal information and internal sharding from the optim_state_dict.
""" """
@ -1298,7 +1287,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
group: Optional[dist.ProcessGroup] = None, group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
The internal API that is used by all the load optim_state_dict implementations. Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
This is the internal API that is used by all the load optim_state_dict implementations.
Given model, optim, and the saved optim_state_dict, this API adds the FSDP Given model, optim, and the saved optim_state_dict, this API adds the FSDP
internal information and internal sharding to the optim_state_dict. internal information and internal sharding to the optim_state_dict.
""" """
@ -1352,7 +1343,8 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
rank0_only: bool = True, rank0_only: bool = True,
group: Optional[dist.ProcessGroup] = None, group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """Return the full optimizer state-dict.
Consolidates the full optimizer state on rank 0 and returns it Consolidates the full optimizer state on rank 0 and returns it
as a :class:`dict` following the convention of as a :class:`dict` following the convention of
:meth:`torch.optim.Optimizer.state_dict`, i.e. with keys ``"state"`` :meth:`torch.optim.Optimizer.state_dict`, i.e. with keys ``"state"``
@ -1417,7 +1409,8 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
group: Optional[dist.ProcessGroup] = None, group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """Return the optimizer state-dict in its sharded form.
The API is similar to :meth:`full_optim_state_dict` but this API chunks The API is similar to :meth:`full_optim_state_dict` but this API chunks
all non-zero-dimension states to :class:`ShardedTensor` to save memory. all non-zero-dimension states to :class:`ShardedTensor` to save memory.
This API should only be used when the model ``state_dict`` is derived This API should only be used when the model ``state_dict`` is derived
@ -1453,12 +1446,11 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
] = None, ] = None,
optim: Optional[torch.optim.Optimizer] = None, optim: Optional[torch.optim.Optimizer] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """Shard a full optimizer state-dict.
Shards the full optimizer state dict ``full_optim_state_dict`` by
remapping the state to flattened parameters instead of unflattened Remaps the state in ``full_optim_state_dict`` to flattened parameters instead of unflattened
parameters and restricting to only this rank's part of the optimizer parameters and restricts to only this rank's part of the optimizer state.
state. The first argument should be the return value of The first argument should be the return value of :meth:`full_optim_state_dict`.
:meth:`full_optim_state_dict`.
Example:: Example::
@ -1525,7 +1517,8 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
model: torch.nn.Module, model: torch.nn.Module,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """Flatten a sharded optimizer state-dict.
The API is similar to :meth:`shard_full_optim_state_dict`. The only The API is similar to :meth:`shard_full_optim_state_dict`. The only
difference is that the input ``sharded_optim_state_dict`` should be difference is that the input ``sharded_optim_state_dict`` should be
returned from :meth:`sharded_optim_state_dict`. Therefore, there will returned from :meth:`sharded_optim_state_dict`. Therefore, there will
@ -1568,10 +1561,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
optim: Optional[torch.optim.Optimizer] = None, optim: Optional[torch.optim.Optimizer] = None,
group: Optional[Any] = None, group: Optional[Any] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """Scatter the full optimizer state dict from rank 0 to all other ranks.
Scatters the full optimizer state dict from rank 0 to all other ranks,
returning the sharded optimizer state dict on each rank. The return Returns the sharded optimizer state dict on each rank.
value is the same as :meth:`shard_full_optim_state_dict`, and on rank The return value is the same as :meth:`shard_full_optim_state_dict`, and on rank
0, the first argument should be the return value of 0, the first argument should be the return value of
:meth:`full_optim_state_dict`. :meth:`full_optim_state_dict`.
@ -1650,10 +1643,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
] = None, ] = None,
optim: Optional[torch.optim.Optimizer] = None, optim: Optional[torch.optim.Optimizer] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """Re-keys the optimizer state dict ``optim_state_dict`` to use the key type ``optim_state_key_type``.
Re-keys the optimizer state dict ``optim_state_dict`` to use the key
type ``optim_state_key_type``. This can be used to achieve This can be used to achieve compatibility between optimizer state dicts from models with FSDP
compatibility between optimizer state dicts from models with FSDP
instances and ones without. instances and ones without.
To re-key an FSDP full optimizer state dict (i.e. from To re-key an FSDP full optimizer state dict (i.e. from
@ -1771,9 +1763,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
group: Optional[dist.ProcessGroup] = None, group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Transforms the state_dict of ``optim`` for the ``model`` that is sharded Transform the state-dict of an optimizer corresponding to a sharded model.
by FSDP to one of the three types: 1) full optimizer state_dict, 2)
sharded optimizer state_dict, 3) local optimizer state_dict. The given state-dict can be transformed to one of three types:
1) full optimizer state_dict, 2) sharded optimizer state_dict, 3) local optimizer state_dict.
For full optimizer state_dict, all states are unflattened and not sharded. For full optimizer state_dict, all states are unflattened and not sharded.
Rank0 only and CPU only can be specified via :meth:`state_dict_type` to Rank0 only and CPU only can be specified via :meth:`state_dict_type` to
@ -1867,10 +1860,12 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
group: Optional[dist.ProcessGroup] = None, group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
Given a ``optim_state_dict`` that is transformed through Given a ``optim_state_dict`` that is transformed through
:meth:`optim_state_dict`, converts it to the flattened optimizer :meth:`optim_state_dict`, it gets converted to the flattened optimizer
state_dict that can be loaded to ``optim`` which is the optimizer for state_dict that can be loaded to ``optim`` which is the optimizer for
``model``. ``model`` must be sharded by FullyShardedDataParallel. ``model``. ``model`` must be sharded by FullyShardedDataParallel.
>>> # xdoctest: +SKIP("undefined variables") >>> # xdoctest: +SKIP("undefined variables")
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
@ -1946,10 +1941,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
return result return result
def register_comm_hook(self, state: object, hook: callable): def register_comm_hook(self, state: object, hook: callable):
""" """Register a communication hook.
Registers a communication hook which is an enhancement that provides a
flexible hook to users where they can specify how FSDP aggregates gradients This is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates
across multiple workers. gradients across multiple workers.
This hook can be used to implement several algorithms like This hook can be used to implement several algorithms like
`GossipGrad <https://arxiv.org/abs/1803.05880>`_ and gradient compression `GossipGrad <https://arxiv.org/abs/1803.05880>`_ and gradient compression
which involve different communication strategies for which involve different communication strategies for
@ -2008,9 +2003,9 @@ def _get_grad_norm(
norm_type: float, norm_type: float,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Returns the gradient norm of parameters ``param`` s, where the gradients Return the gradient norm of parameters ``param`` s, where the gradients are viewed as a single vector.
are viewed as a single vector. The returned norm is in FP32 even if
parameters/gradients are in a low precision. This is because the downstream The returned norm is in FP32 even if parameters/gradients are in a low precision. This is because the downstream
use of this return value is a reduction across ranks. use of this return value is a reduction across ranks.
""" """
params_with_grad = [param for param in params if param.grad is not None] params_with_grad = [param for param in params if param.grad is not None]
@ -2041,8 +2036,9 @@ def _get_param_to_fqn(
model: torch.nn.Module, model: torch.nn.Module,
) -> Dict[torch.nn.Parameter, str]: ) -> Dict[torch.nn.Parameter, str]:
""" """
Constructs a mapping from parameters to their parameter names. ``model`` Construct a mapping from parameters to their parameter names.
should not contain any :class:`FullyShardedDataParallel` instances, which
The ``model`` should not contain any :class:`FullyShardedDataParallel` instances, which
means that none of the parameters should be ``FlatParameter`` s. As a means that none of the parameters should be ``FlatParameter`` s. As a
result, compared to :meth:`_get_param_to_fqns`, the mapped result, compared to :meth:`_get_param_to_fqns`, the mapped
values may be flattened from singleton :class:`list` s to the contained values may be flattened from singleton :class:`list` s to the contained
@ -2071,6 +2067,6 @@ def _get_param_to_fqn(
def _get_fqn_to_param( def _get_fqn_to_param(
model: torch.nn.Module, model: torch.nn.Module,
) -> Dict[str, torch.nn.Parameter]: ) -> Dict[str, torch.nn.Parameter]:
"""Constructs the inverse mapping of :meth:`_get_param_to_fqn`.""" """Construct the inverse mapping of :meth:`_get_param_to_fqn`."""
param_to_param_name = _get_param_to_fqn(model) param_to_param_name = _get_param_to_fqn(model)
return dict(zip(param_to_param_name.values(), param_to_param_name.keys())) return dict(zip(param_to_param_name.values(), param_to_param_name.keys()))