[Docs] Fix docstring format (#99396)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99396
Approved by: https://github.com/awgu
This commit is contained in:
zhouzaida 2023-04-28 01:10:03 +00:00 committed by PyTorch MergeBot
parent 64efd88845
commit b51f92ebda
2 changed files with 11 additions and 11 deletions

View File

@ -52,13 +52,13 @@ class ShardingStrategy(Enum):
synchronizes them (via all-reduce) after the backward computation. The
unsharded optimizer states are updated locally per rank.
- ``HYBRID_SHARD``: Apply ``FULL_SHARD`` within a node, and replicate parameters across
nodes. This results in reduced communication volume as expensive all-gathers and
reduce-scatters are only done within a node, which can be more performant for medium
-sized models.
nodes. This results in reduced communication volume as expensive all-gathers and
reduce-scatters are only done within a node, which can be more performant for medium
-sized models.
- ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across
nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput
since the unsharded parameters are not freed after the forward pass, saving the
all-gathers in the pre-backward.
nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput
since the unsharded parameters are not freed after the forward pass, saving the
all-gathers in the pre-backward.
"""
FULL_SHARD = auto()

View File

@ -216,7 +216,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
Args:
module (nn.Module):
This is the module to be wrapped with FSDP.
process_group: Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]
process_group (Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]):
This is the process group used for collective communications and
the one over which the model is sharded. For hybrid sharding strategies such as
``ShardingStrategy.HYBRID_SHARD`` users can
@ -1458,9 +1458,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
corresponding to the unflattened parameters and holding the
sharded optimizer state.
model (torch.nn.Module):
Refer to :meth:``shard_full_optim_state_dict``.
Refer to :meth:`shard_full_optim_state_dict`.
optim (torch.optim.Optimizer): Optimizer for ``model`` 's
parameters.
parameters.
Returns:
Refer to :meth:`shard_full_optim_state_dict`.
@ -1785,7 +1785,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
) -> Dict[str, Any]:
"""
This hook is intended be used by ``torch.distributed.NamedOptimizer``.
The functionality is identical to ``:meth:optim_state_dict`` except
The functionality is identical to :meth:`optim_state_dict` except
for the different arguments.
Args:
@ -1916,7 +1916,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
) -> Dict[str, Any]:
"""
This hook is intended be used by ``torch.distributed.NamedOptimizer``.
The functionality is identical to ``:meth:optim_state_dict_to_load``
The functionality is identical to :meth:`optim_state_dict_to_load`
except for the different arguments.
Args: