mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix pydocstyle errors in fully_sharded_data_parallel.py, api.py, graph_utils.py, distribute.py, iter_graph_module.py, comm_tensor.py, experimental_ops.py, batch_dim_utils.py, data_parallel.py, graph_optimization.py (#113216)
Fixes #113191 ``` pydocstyle torch/distributed/fsdp/fully_sharded_data_parallel.py --count ``` On master: 80 After my changes on this PR: 3 ``` pydocstyle torch/distributed/_spmd/comm_tensor.py --count ``` On master: 5 After my changes on this PR: 3 ``` pydocstyle torch/distributed/_spmd/experimental_ops.py --count ``` On master: 3 After my changes on this PR: 1 ``` pydocstyle torch/distributed/_spmd/iter_graph_module.py --count ``` On master: 39 After my changes on this PR: 27 ``` pydocstyle torch/distributed/_spmd/graph_utils.py --count ``` On master: 16 After my changes on this PR: 4 ``` pydocstyle torch/distributed/_spmd/distribute.py --count ``` On master: 19 After my changes on this PR: 10 ``` pydocstyle torch/distributed/_spmd/api.py --count ``` On master: 10 After my changes on this PR: 3 ``` pydocstyle torch/distributed/_spmd/batch_dim_utils.py --count ``` On master: 14 After my changes on this PR: 3 ``` pydocstyle torch/distributed/_spmd/data_parallel.py --count ``` On master: 34 After my changes on this PR: 2 ``` pydocstyle torch/distributed/_spmd/graph_optimization.py --count ``` On master: 35 After my changes on this PR: 13 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113216 Approved by: https://github.com/ezyang
This commit is contained in:
parent
773b1cbe4f
commit
866457e746
|
|
@ -32,8 +32,8 @@ from torch.nn.utils._named_member_accessor import NamedMemberAccessor
|
|||
|
||||
|
||||
class Override(ABC):
|
||||
r"""
|
||||
Override the tracing and transformation behavior of :meth:`~torch.distributed._spmd.compile`.
|
||||
r"""Override the tracing and transformation behavior of :meth:`~torch.distributed._spmd.compile`.
|
||||
|
||||
This is useful when any part of the model is not traceable or if you prefer
|
||||
to not trace it due to any reason. More specifically, users can implement
|
||||
:meth:`torch.distributed._spmd.Override.replacement` to replace an original
|
||||
|
|
@ -47,10 +47,10 @@ class Override(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def replacement(self, fqn: str, orig_submodule: torch.nn.Module) -> torch.nn.Module:
|
||||
r"""
|
||||
Implement this method to return a new :class:`nn.Module` instance to
|
||||
replace the ``orig_submodule`` argument in the model. This helps if
|
||||
``orig_submodule`` is not traceable or should not be traced.
|
||||
r"""Implement this method to return a new :class:`nn.Module` instance to replace the ``orig_submodule``
|
||||
argument in the model.
|
||||
|
||||
This helps if ``orig_submodule`` is not traceable or should not be traced.
|
||||
|
||||
Args:
|
||||
fqn (str): fully quantified name of the submodule.
|
||||
|
|
@ -58,6 +58,7 @@ class Override(ABC):
|
|||
|
||||
Returns:
|
||||
A new :class:`nn.Module` instance to replace the original one.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
@ -83,6 +84,7 @@ class Override(ABC):
|
|||
|
||||
Returns:
|
||||
The :class:`fx.Graph` after transformation.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
@ -98,8 +100,7 @@ class _PyTreeCodeGenOutputsOnly(_PyTreeCodeGen):
|
|||
|
||||
|
||||
def _to_caller_flattened_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
"""Move the responsibility of flattening the input arguments from the
|
||||
graph module to the caller.
|
||||
"""Move the responsibility of flattening the input arguments from the graph module to the caller.
|
||||
|
||||
Example:
|
||||
|
||||
|
|
@ -108,6 +109,7 @@ def _to_caller_flattened_graph_module(gm: torch.fx.GraphModule) -> torch.fx.Grap
|
|||
gm = gm(to_caller_flattened_graph_module)
|
||||
|
||||
output = gm(*pytree.flatten(my_struct)[0])
|
||||
|
||||
"""
|
||||
# pyre-ignore[16]
|
||||
gm._graph._codegen = _PyTreeCodeGenOutputsOnly(
|
||||
|
|
@ -500,9 +502,9 @@ def compile(
|
|||
gm_transformation: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
|
||||
parallel_mode: Optional[ParallelMode] = None,
|
||||
):
|
||||
r"""
|
||||
Compile and optimize a callable, which can be a train step within a training
|
||||
loop. This method will extract :class:`nn.Module` and :class:`torch.optim.Optimizer`
|
||||
r"""Compile and optimize a callable, which can be a train step within a training loop.
|
||||
|
||||
This method will extract :class:`nn.Module` and :class:`torch.optim.Optimizer`
|
||||
instances from the input arguments and trace operations applied to their
|
||||
parameters and states.
|
||||
|
||||
|
|
@ -519,6 +521,7 @@ def compile(
|
|||
that specifies how to parallelize the callable. Each ParallelMode
|
||||
would have its own strategy to partition the model and the captured
|
||||
graph (Default: ``None``)
|
||||
|
||||
"""
|
||||
|
||||
def inner(func: Callable):
|
||||
|
|
|
|||
|
|
@ -20,9 +20,9 @@ aten = torch.ops.aten
|
|||
|
||||
|
||||
class BatchDimAnalyzer:
|
||||
"""
|
||||
This class is used to analyze the batch dimension of each tensor/node in the
|
||||
graph. We need to know the batch dimension of each tensor/node so that we know
|
||||
"""This class is used to analyze the batch dimension of each tensor/node in the graph.
|
||||
|
||||
We need to know the batch dimension of each tensor/node so that we know
|
||||
exactly the sharding layout of intermediate tensors.
|
||||
|
||||
We possibly should evaluate using symbolic shapes to track the batch dimension.
|
||||
|
|
@ -54,9 +54,7 @@ class BatchDimAnalyzer:
|
|||
}
|
||||
|
||||
def init_batch_dim_size(self, batch_dim_size: int) -> None:
|
||||
"""
|
||||
initialize batch dim size base on the first input batch size
|
||||
"""
|
||||
"""Initialize batch dim size base on the first input batch size."""
|
||||
if self.batch_dim_size != -1 and self.batch_dim_size != batch_dim_size:
|
||||
raise RuntimeError(
|
||||
f"batch dim size is already initialized! "
|
||||
|
|
@ -74,9 +72,7 @@ class BatchDimAnalyzer:
|
|||
return self.batch_dim_map[node]
|
||||
|
||||
def compute_batch_dim(self, node: fx.Node, full_reduction=False) -> int:
|
||||
"""
|
||||
compute the batch dimension for the `node`
|
||||
"""
|
||||
"""Compute the batch dimension for the `node`."""
|
||||
assert self.batch_dim_size != -1, "batch dim size is not initialized!"
|
||||
|
||||
if node in self.batch_dim_map:
|
||||
|
|
@ -168,10 +164,7 @@ class BatchDimAnalyzer:
|
|||
return -2
|
||||
|
||||
def compute_act_spec(self, node: fx.Node, mesh: DeviceMesh) -> DTensorSpec:
|
||||
"""
|
||||
This function first compute the batch dimension for the current node,
|
||||
then generate the sharding spec that shards on the batch dimension.
|
||||
"""
|
||||
"""Compute the batch dimension for the current node, then generate the sharding spec that shards on the batch dimension."""
|
||||
node_batch_dim = self.compute_batch_dim(node)
|
||||
if node_batch_dim == -1:
|
||||
# indicate this activation is replicated
|
||||
|
|
|
|||
|
|
@ -56,9 +56,9 @@ def _get_tracer() -> Optional[torch.fx.Tracer]:
|
|||
|
||||
class CommTensor(torch.Tensor):
|
||||
r"""
|
||||
A Tensor subclass to wrap input tensors for collective communications. This
|
||||
Tensor subclass works for both eager and tracing mode.
|
||||
A Tensor subclass to wrap input tensors for collective communications.
|
||||
|
||||
This Tensor subclass works for both eager and tracing mode.
|
||||
In eager mode, it will record whether the inplace collective communication
|
||||
has been launched using this Tensor and remember the corresponding work
|
||||
handle. If yes, it will explicitly call wait() in the ``__torch_dispatch__``
|
||||
|
|
|
|||
|
|
@ -41,7 +41,8 @@ _spmd_lib_impl.impl("tag_grad", lambda x: x, "CompositeExplicitAutograd")
|
|||
|
||||
|
||||
class DataParallelStyle(Enum):
|
||||
"""
|
||||
"""This enum represents the style of the data-parallel operation.
|
||||
|
||||
We have three types of Data Parallel style:
|
||||
1. DEFAULT: the default data parallel style, which is to represent a mixed
|
||||
replicate and fully shard behavior. For each parameter that is able
|
||||
|
|
@ -64,8 +65,8 @@ class DataParallelStyle(Enum):
|
|||
|
||||
|
||||
class NodeType(Enum):
|
||||
"""
|
||||
NodeType is a enum that records the type of the tensors in the graph.
|
||||
"""NodeType is an enum that records the type of the tensors in the graph.
|
||||
|
||||
This is used to determine the data parallel strategy.
|
||||
"""
|
||||
|
||||
|
|
@ -77,9 +78,8 @@ class NodeType(Enum):
|
|||
|
||||
|
||||
class DataParallelStrategy(OpStrategy):
|
||||
"""
|
||||
DataParallelStrategy is a special case of OpStrategy that only records
|
||||
the "data parallel style" placement strategy for each fx Node.
|
||||
"""DataParallelStrategy is a special case of OpStrategy that only records the "data parallel style" placement
|
||||
strategy for each fx Node.
|
||||
|
||||
It takes a list of PlacementStrategy, where each PlacementStrategy describes
|
||||
one way to distribute the tensor and computation. In the DataParallel case,
|
||||
|
|
@ -113,13 +113,10 @@ class DataParallelStrategy(OpStrategy):
|
|||
|
||||
@contextmanager
|
||||
def gradients_tagging(params: Dict[str, torch.Tensor]):
|
||||
"""
|
||||
This is a helper function that tags the gradient of the parameters
|
||||
with a special tag, so that we can identify them during SPMD expansion.
|
||||
"""Tag the gradient of the parameters with a special tag, so that we can identify them during SPMD expansion.
|
||||
|
||||
It's safe to trace those hooks and we would remove those nodes later.
|
||||
"""
|
||||
|
||||
tagging_hooks = []
|
||||
try:
|
||||
for p in params.values():
|
||||
|
|
@ -135,9 +132,7 @@ def gradients_tagging(params: Dict[str, torch.Tensor]):
|
|||
def _gen_shard_strategy(
|
||||
mesh: DeviceMesh, shard_dim: int, input_specs: Optional[List[DTensorSpec]] = None
|
||||
) -> PlacementStrategy:
|
||||
"""
|
||||
util function to generate a shard strategy on shard_dim
|
||||
"""
|
||||
"""Util function to generate a shard strategy on shard_dim."""
|
||||
return PlacementStrategy(
|
||||
output_spec=DTensorSpec(mesh=mesh, placements=(Shard(shard_dim),)),
|
||||
input_specs=input_specs,
|
||||
|
|
@ -147,9 +142,7 @@ def _gen_shard_strategy(
|
|||
def _gen_replicate_strategy(
|
||||
mesh: DeviceMesh, input_specs: Optional[List[DTensorSpec]] = None
|
||||
) -> PlacementStrategy:
|
||||
"""
|
||||
util function to generate a replicate strategy
|
||||
"""
|
||||
"""Util function to generate a replicate strategy."""
|
||||
return PlacementStrategy(
|
||||
output_spec=DTensorSpec(mesh=mesh, placements=(Replicate(),)),
|
||||
input_specs=input_specs,
|
||||
|
|
@ -157,9 +150,7 @@ def _gen_replicate_strategy(
|
|||
|
||||
|
||||
def _gen_partial_strategy(mesh: DeviceMesh) -> PlacementStrategy:
|
||||
"""
|
||||
util function to generate a partial strategy
|
||||
"""
|
||||
"""Util function to generate a partial strategy."""
|
||||
# NOTE: we use AVG by default, avg reduction is needed depending on
|
||||
# the loss function, for most loss function it should do
|
||||
# gradient averaging. There might be certain cases it should
|
||||
|
|
@ -180,10 +171,7 @@ def build_data_parallel_strategies(
|
|||
mesh: DeviceMesh,
|
||||
batch_dim: int = 0,
|
||||
) -> Dict[fx.Node, StrategyType]:
|
||||
"""
|
||||
This function loop through the train step graph and build the
|
||||
data parallel strategy for each fx Node
|
||||
"""
|
||||
"""Loop through the train step graph and build the data parallel strategy for each fx Node."""
|
||||
activation_idx = num_params + num_states
|
||||
non_compute_ops = [
|
||||
aten.clone.default,
|
||||
|
|
@ -518,9 +506,7 @@ def mark_data_parallel_shardings(
|
|||
dp_strategy_map: Dict[fx.Node, StrategyType],
|
||||
parallel_mode: DataParallelStyle = DataParallelStyle.FULLY_SHARD,
|
||||
) -> None:
|
||||
"""
|
||||
This function marks the sharding for the nodes in the train_step_graph
|
||||
"""
|
||||
"""Mark the sharding for the nodes in the train_step_graph."""
|
||||
activation_idx = num_parameters + num_states
|
||||
placeholder_idx = 0
|
||||
for node in train_step_graph.graph.nodes:
|
||||
|
|
@ -601,9 +587,7 @@ def mark_data_parallel_shardings(
|
|||
|
||||
|
||||
def _partition_val(val: Any, spec: DTensorSpec) -> Any:
|
||||
"""
|
||||
util function to convert a full tensor val to its local component
|
||||
"""
|
||||
"""Util function to convert a full tensor val to its local component."""
|
||||
if isinstance(val, torch.Tensor):
|
||||
local_shard = val
|
||||
if val.ndim == 0:
|
||||
|
|
@ -629,10 +613,7 @@ def _partition_val(val: Any, spec: DTensorSpec) -> Any:
|
|||
|
||||
|
||||
def partitioner(graph: GraphModule) -> GraphModule:
|
||||
"""
|
||||
Graph partitioner that partitions the single device graph
|
||||
to distributed graph
|
||||
"""
|
||||
"""Graph partitioner that partitions the single device graph to distributed graph."""
|
||||
shape_adjustment_ops = {
|
||||
aten._unsafe_view.default: 1,
|
||||
aten.expand.default: 1,
|
||||
|
|
@ -761,10 +742,9 @@ def partition_data_parallel(
|
|||
parallel_style: DataParallelStyle,
|
||||
input_batch_dim: int,
|
||||
) -> GraphModule:
|
||||
"""
|
||||
The entry point function to partition the graph to data parallel
|
||||
graph, it also shard/replicate the model parameters and optimizer
|
||||
states to DTensors.
|
||||
"""Partition the graph to into a data parallel graph.
|
||||
|
||||
This function also shards/replicates the model parameters and optimizer states to DTensors.
|
||||
"""
|
||||
num_params_buffers = len(params_buffers)
|
||||
flattened_states = pytree.tree_leaves(named_states)
|
||||
|
|
|
|||
|
|
@ -48,9 +48,9 @@ class Schema:
|
|||
|
||||
@dataclass
|
||||
class DSymInt:
|
||||
"""
|
||||
DSymInt represents a value retrieved by a SymInt op from a DTensor. DSymInt
|
||||
helps View and Factory ops to determine the placement and shape of the
|
||||
"""DSymInt represents a value retrieved by a SymInt op from a DTensor.
|
||||
|
||||
DSymInt helps View and Factory ops to determine the placement and shape of the
|
||||
output tensor, as those operators either do not have an input DTensor or
|
||||
the input DTensor is insufficient to determine the output tensor's placement.
|
||||
"""
|
||||
|
|
@ -90,7 +90,7 @@ class DSymInt:
|
|||
|
||||
|
||||
def _is_partial_dtensor(obj: Any) -> bool:
|
||||
"""check if object is 1) DTensor and 2) with any placement of _Partial"""
|
||||
"""Check if object is 1) DTensor and 2) with any placement of _Partial."""
|
||||
if not isinstance(obj, DTensor):
|
||||
return False
|
||||
|
||||
|
|
@ -475,8 +475,8 @@ def _get_dtensor_dispatch_graph(
|
|||
def _build_dummy_add_graph(
|
||||
dt: DTensor, node_to_obj: Dict[fx.Node, Any]
|
||||
) -> Tuple[fx.GraphModule, Any]:
|
||||
"""
|
||||
Creates a graph for a dummy add function from a partial DTensor.
|
||||
"""Create a graph for a dummy add function from a partial DTensor.
|
||||
|
||||
This dummy add is used for triggering all_reduce on a Partial DTensor
|
||||
during the DTensor expansion of the traced graph.
|
||||
Also returns the actual DTensor after resharding.
|
||||
|
|
@ -703,10 +703,12 @@ def _convert_to_distributed(
|
|||
default_mesh: Optional[DeviceMesh] = None,
|
||||
_allow_partial: bool = False,
|
||||
) -> Tuple[fx.GraphModule, Dict[str, Schema]]:
|
||||
"""
|
||||
"""Transform a graph module to a distributed graph module.
|
||||
|
||||
Returns:
|
||||
- transformed graph module
|
||||
- map from output name to DTensorSpec
|
||||
|
||||
"""
|
||||
global logger
|
||||
logger = get_logger("spmd_exp")
|
||||
|
|
|
|||
|
|
@ -342,10 +342,7 @@ def _prop_native_layer_norm_backward(op_schema: OpSchema) -> OutputSharding:
|
|||
def _refine_sharding(
|
||||
op_schema: OpSchema, active_dim: Optional[int]
|
||||
) -> Sequence[Placement]:
|
||||
"""
|
||||
Considers 2 first inputs of op_schema as having same shape,
|
||||
and returns suggested placement for a pointwise operation.
|
||||
"""
|
||||
"""Considers 2 first inputs of op_schema as having same shape, and returns suggested placement for a pointwise operation."""
|
||||
# consider the operating dimension as a singleton to prevent sharding on it
|
||||
# however, if active_dim is None, this means the input and output shapes are equal and
|
||||
# we'll apply exactly the pointwise rule.
|
||||
|
|
|
|||
|
|
@ -64,10 +64,9 @@ def graph_optimization_pass(
|
|||
prerequisites: Iterable[Callable],
|
||||
apply_after: Iterable[Callable],
|
||||
) -> Callable:
|
||||
"""
|
||||
The contract of graph optimization pass. All the passes should be wrapped
|
||||
with this decorator.
|
||||
"""Define the contract of a graph optimization pass.
|
||||
|
||||
All the passes should be wrapped with this decorator.
|
||||
`prerequisites` is used to annotate the prerequisite passes of the this pass.
|
||||
`apply_after` means that this wrapped pass must be applied after the passes
|
||||
in `apply_after`. The difference between `prerequisites` and `apply_after`
|
||||
|
|
@ -158,12 +157,11 @@ class CommBlock:
|
|||
|
||||
|
||||
def get_comm_block(comm_node: fx.Node) -> CommBlock:
|
||||
"""
|
||||
Given a collective node (e.g., allreduce), find out all the nodes belong to
|
||||
this communcation.
|
||||
"""Find out all the nodes belong to this communcation given a collective node (e.g., allreduce).
|
||||
|
||||
Args:
|
||||
comm_node(fx.Node): The target communication/collective node.
|
||||
|
||||
Returns:
|
||||
The CommBlock that encapsulates the related nodes (e.g., wait_node) of
|
||||
the given comm_node.
|
||||
|
|
@ -306,10 +304,7 @@ def _scatter_wait_result(
|
|||
comm_blocks: List[CommBlock],
|
||||
node_indices: Dict[fx.Node, int],
|
||||
) -> None:
|
||||
"""
|
||||
Scatters the result of the fused communication node to the original users --
|
||||
splitting the output and reshape each subitem.
|
||||
"""
|
||||
"""Scatter the result of the fused communication node to the original users -- splitting the output and reshape each subitem."""
|
||||
last_wait_node_idx = 0
|
||||
for node in gm.graph.nodes:
|
||||
if node == fused_comm_block.comm_node:
|
||||
|
|
@ -371,9 +366,7 @@ def _fuse_with_cat(
|
|||
comm_blocks: List[CommBlock],
|
||||
node_indices: Dict[fx.Node, int],
|
||||
) -> CommBlock:
|
||||
"""
|
||||
Given a list of CommBlock (only allreduce), fuse the CommBlocks using concat.
|
||||
"""
|
||||
"""Fuse the CommBlocks using concat given a list of CommBlock (only allreduce)."""
|
||||
# Find the last input node.
|
||||
last_input_node = comm_blocks[0].inputs[0]
|
||||
last_input_index = -1
|
||||
|
|
@ -474,8 +467,8 @@ def comm_fusion_with_concat(
|
|||
gm: IterGraphModule,
|
||||
bucket_size_mb: int,
|
||||
) -> None:
|
||||
"""
|
||||
Run fuse communication with concat.
|
||||
"""Run fuse communication with concat.
|
||||
|
||||
This implementation uses concat to concat the bucketed gradients.
|
||||
"""
|
||||
comm_blocks = get_all_comm_blocks(gm, (CommType.ALLREDUCE, "all_reduce"))
|
||||
|
|
@ -508,9 +501,7 @@ def comm_fusion_with_concat(
|
|||
apply_after=[],
|
||||
)
|
||||
def schedule_comm_wait(gm: IterGraphModule) -> None:
|
||||
"""
|
||||
Delay the execution of wait tensors of allreduce until its first user.
|
||||
"""
|
||||
"""Delay the execution of wait tensors of allreduce until its first user."""
|
||||
comm_blocks = get_all_comm_blocks(gm, (CommType.ALLREDUCE, "all_reduce"))
|
||||
|
||||
# Find all the end users.
|
||||
|
|
@ -549,8 +540,8 @@ def schedule_comm_wait(gm: IterGraphModule) -> None:
|
|||
apply_after=[],
|
||||
)
|
||||
def remove_copy_from_optimizer(gm: IterGraphModule) -> None:
|
||||
"""
|
||||
Erase the orphant copy_ that generated when tracing optimizer.
|
||||
"""Erase the orphant copy_ that generated when tracing optimizer.
|
||||
|
||||
Two reasons why we could not simply use the DCE of fx.Graph.
|
||||
1. fx.Graph treats copy_ as a side-effect node and does not erase it.
|
||||
2. Users may want to preserve some orphan `copy_` that is not from the
|
||||
|
|
@ -742,9 +733,7 @@ class FusedOptimizerBlock:
|
|||
|
||||
|
||||
def get_fused_optimizer_block(optim_node: fx.Node) -> FusedOptimizerBlock:
|
||||
"""
|
||||
Given a fused optimizer node and return the FusedOptimizerBlock.
|
||||
"""
|
||||
"""Given a fused optimizer node and return the FusedOptimizerBlock."""
|
||||
MAX_STEP_DISTANCE = 5
|
||||
# Find the step (foreach_add)
|
||||
nodes = collections.deque([optim_node, None])
|
||||
|
|
@ -783,10 +772,7 @@ def get_fused_optimizer_block(optim_node: fx.Node) -> FusedOptimizerBlock:
|
|||
def get_all_fused_optimizer_blocks(
|
||||
gm: IterGraphModule, optim_ops: Union[Tuple[str, ...], str]
|
||||
) -> List[FusedOptimizerBlock]:
|
||||
"""
|
||||
Find all the FusedOptimizerBlock that the optimizer operators are in
|
||||
`optim_ops`.
|
||||
"""
|
||||
"""Find all the FusedOptimizerBlock that the optimizer operators are in `optim_ops`."""
|
||||
return [
|
||||
get_fused_optimizer_block(node)
|
||||
for node in gm.graph.nodes
|
||||
|
|
@ -799,9 +785,9 @@ def _split_fused_adam(
|
|||
orig_optim_block: FusedOptimizerBlock,
|
||||
split_gradients: Set[fx.Node],
|
||||
) -> Tuple[FusedOptimizerBlock, FusedOptimizerBlock]:
|
||||
"""
|
||||
Split the `orig_optim_block` into two FusedOptimizerBlock. The first one
|
||||
will be the optimizer that optimize `split_gradients`. The second one is
|
||||
"""Split the `orig_optim_block` into two FusedOptimizerBlock.
|
||||
|
||||
The first one will be the optimizer that optimize `split_gradients`. The second one is
|
||||
used to optimize the remaining gradients.
|
||||
An assert will be raised if one of the optimizer optimize zero gradients.
|
||||
"""
|
||||
|
|
@ -949,7 +935,8 @@ def iter_move_grads_and_optimizers(
|
|||
target_comm_node: str,
|
||||
target_dest_node: str,
|
||||
) -> None:
|
||||
"""Function to extract a comm block and split out a new optimizer and step for it.
|
||||
"""Extract a comm block and split out a new optimizer and step for it.
|
||||
|
||||
This subgraph is then moved to the forward graph.
|
||||
"""
|
||||
for comm_block in get_all_comm_blocks(gm, "all_reduce"):
|
||||
|
|
@ -982,8 +969,7 @@ def find_all_descendants(
|
|||
gm: IterGraphModule,
|
||||
parent_nodes: List[fx.Node],
|
||||
) -> List[fx.Node]:
|
||||
"""identifying list of nodes to move during FX graph transformation"""
|
||||
|
||||
"""Identify the list of nodes to move during FX graph transformation."""
|
||||
assert len(parent_nodes) > 0, "No parent nodes are given."
|
||||
|
||||
output = get_output(gm.graph)
|
||||
|
|
|
|||
|
|
@ -41,9 +41,9 @@ def get_node_tensor_metadata(node: fx.Node, is_required: bool = True) -> TensorM
|
|||
|
||||
|
||||
def get_output(graph: fx.Graph) -> fx.Node:
|
||||
"""
|
||||
Take a graphmodule and returns the graph output node. We traverse in reverse
|
||||
to expedite it, with the idea that last node should be output
|
||||
"""Take a graphmodule and return the graph output node.
|
||||
|
||||
We traverse in reverse to expedite it, with the idea that last node should be output
|
||||
"""
|
||||
for node in reversed(graph.nodes):
|
||||
if node.op == OP.OUTPUT:
|
||||
|
|
@ -54,10 +54,7 @@ def get_output(graph: fx.Graph) -> fx.Node:
|
|||
def find_node(
|
||||
graph: fx.Graph, predicate: Callable, reverse_order: bool = False
|
||||
) -> List[fx.Node]:
|
||||
"""
|
||||
Take a predicate and return all the nodes in the `graph` where the predicate
|
||||
holds.
|
||||
"""
|
||||
"""Take a predicate and return all the nodes in the `graph` where the predicate holds."""
|
||||
nodes = cast(Iterable[fx.Node], graph.nodes)
|
||||
if reverse_order:
|
||||
nodes = cast(Iterable[fx.Node], iter(reversed(nodes))) # type: ignore[call-overload]
|
||||
|
|
@ -65,8 +62,8 @@ def find_node(
|
|||
|
||||
|
||||
def is_leaf_subgraph(graph: fx.Graph, subgraph: List[fx.Node]) -> bool:
|
||||
"""
|
||||
This function ensures nodes in ``subgraph`` satisfy one of the rules:
|
||||
"""Ensure nodes in ``subgraph`` satisfy one of the following rules.
|
||||
|
||||
1. The user of the node is in ``subgraph``.
|
||||
2. The user of the node is output.
|
||||
3. There are no users -- the node is a side-effect node.
|
||||
|
|
@ -85,11 +82,10 @@ def is_leaf_subgraph(graph: fx.Graph, subgraph: List[fx.Node]) -> bool:
|
|||
def clone_subgraph(
|
||||
graph: fx.Graph, subgraph: List[fx.Node], target: fx.Node
|
||||
) -> List[fx.Node]:
|
||||
"""
|
||||
Clone the given subgraph and insert it before ``target``.
|
||||
"""Clone the given subgraph and insert it before ``target``.
|
||||
|
||||
This API currently does not support inserting after ``target``.
|
||||
"""
|
||||
|
||||
all_nodes = set(subgraph)
|
||||
mapping: Dict[fx.Node, fx.Node] = dict()
|
||||
cloned_subgraph = []
|
||||
|
|
@ -125,12 +121,11 @@ def clone_subgraph(
|
|||
|
||||
|
||||
def rebuild_graph(gm: fx.GraphModule, remove_dead_code: bool = True) -> None:
|
||||
"""
|
||||
Runs the required steps to ensure production-ready graph.
|
||||
note - per the fx docs, eliminate dead code is not very precise.
|
||||
"""Run the required steps to ensure production-ready graph.
|
||||
|
||||
Note - per the fx docs, elimination of dead code is not very precise.
|
||||
Hence, the flag to make this step optional.
|
||||
"""
|
||||
|
||||
gm.graph.lint()
|
||||
if remove_dead_code:
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
|
|
|||
|
|
@ -22,9 +22,9 @@ logger: logging.Logger = logging.getLogger("IterGraphModule")
|
|||
|
||||
|
||||
class IterGraph(fx.Graph):
|
||||
"""
|
||||
``IterGraph`` is used to perform cross-iteration optimization. ``IterGraph``
|
||||
keeps track of the 3 graphs, self (the original graph), setup graph, and
|
||||
"""``IterGraph`` is used to perform cross-iteration optimization.
|
||||
|
||||
``IterGraph`` keeps track of the 3 graphs, self (the original graph), setup graph, and
|
||||
cleanup graph. The 3 graphs should be identical copies of a ``fx.Graph``.
|
||||
|
||||
IterGraph subclass fx.Graph to override the necessary APIs that will be used
|
||||
|
|
@ -127,11 +127,10 @@ class IterGraph(fx.Graph):
|
|||
def _forward_subgraph_inputs(
|
||||
self, subgraph: List[fx.Node], graph: fx.Graph, erase_node: bool
|
||||
) -> int:
|
||||
"""
|
||||
This function turns the inputs of a subgraph into the extra output
|
||||
of the entire graph. If ``erase_node`` is True, the subgraph will be
|
||||
erased from the graph -- essentially forward the inputs of the subgraph
|
||||
to the output of the graph.
|
||||
"""Turn the inputs of a subgraph into the extra output of the entire graph.
|
||||
|
||||
If ``erase_node`` is True, the subgraph will be erased from the graph -- essentially forward the inputs
|
||||
of the subgraph to the output of the graph.
|
||||
"""
|
||||
output = get_output(graph)
|
||||
inputs = []
|
||||
|
|
@ -219,10 +218,10 @@ class IterGraph(fx.Graph):
|
|||
def _forward_inputs_to_subgraph(
|
||||
self, subgraph: List[fx.Node], graph: fx.Graph, extra_input: int
|
||||
) -> None:
|
||||
"""
|
||||
This function creates extra input nodes and forward the input nodes to
|
||||
the ``subgraph``. The external input nodes of ``subgraph`` (nodes that
|
||||
are not in ``subgraph``) will replaced by the newly created input nodes.
|
||||
"""Create extra input nodes and forward the input nodes to the ``subgraph``.
|
||||
|
||||
The external input nodes of ``subgraph`` (nodes that are not in ``subgraph``) will replaced by the newly
|
||||
created input nodes.
|
||||
"""
|
||||
placeholders = [node for node in graph.nodes if str(node.op) == "placeholder"]
|
||||
assert placeholders, "No placeholders are found"
|
||||
|
|
@ -275,8 +274,8 @@ class IterGraph(fx.Graph):
|
|||
def move_to_next_iter_before(
|
||||
self, subgraph: List[fx.Node], target_node: fx.Node
|
||||
) -> None:
|
||||
"""
|
||||
Move the ``subgraph`` to the next iteration before ``target_node``.
|
||||
"""Move the ``subgraph`` to the next iteration before ``target_node``.
|
||||
|
||||
The ``subgraph`` is a list of fx.Node and must satisfy the following
|
||||
restrictions:
|
||||
1. The order of the nodes in ``subgraph`` must obey the topological
|
||||
|
|
@ -633,8 +632,8 @@ class IterGraph(fx.Graph):
|
|||
|
||||
|
||||
class IterGraphModule(nn.Module):
|
||||
"""
|
||||
``IterGraphModule`` provides the ability to do cross-iteration optimization.
|
||||
"""``IterGraphModule`` provides the ability to do cross-iteration optimization.
|
||||
|
||||
Given a ``fx.GraphModule``, main_gm, ``IterGraphModule`` internally
|
||||
duplicate it to 3 copies and redirect the ``forward`` request to a different
|
||||
``fx.GraphModule`` based on the iteration count. This allows users to do
|
||||
|
|
@ -674,10 +673,9 @@ class IterGraphModule(nn.Module):
|
|||
self._enable_inductor = enable_inductor
|
||||
|
||||
def finalize_setup(self) -> None:
|
||||
"""
|
||||
Must be called before the forward() is called. This method setups
|
||||
the internal states and also get the signal from users that what
|
||||
is the maximum iteration count.
|
||||
"""Set up the internal states and also get the signal from users that what is the maximum iteration count.
|
||||
|
||||
This method must be called before the forward() is called.
|
||||
"""
|
||||
if not self._is_frozen:
|
||||
self.graph.freeze_cross_iter_movement()
|
||||
|
|
|
|||
|
|
@ -112,14 +112,16 @@ FLAT_PARAM = "_flat_param"
|
|||
|
||||
|
||||
class OptimStateKeyType(Enum):
|
||||
"""Represents the type of key in an optimizer state-dict."""
|
||||
|
||||
PARAM_NAME = auto()
|
||||
PARAM_ID = auto()
|
||||
|
||||
|
||||
class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||
"""
|
||||
A wrapper for sharding module parameters across data parallel workers. This
|
||||
is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
|
||||
"""A wrapper for sharding module parameters across data parallel workers.
|
||||
|
||||
This is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
|
||||
FullyShardedDataParallel is commonly shortened to FSDP.
|
||||
|
||||
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336
|
||||
|
|
@ -515,9 +517,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
|
||||
@property
|
||||
def module(self) -> nn.Module:
|
||||
"""
|
||||
Returns the wrapped module (like :class:`DistributedDataParallel`).
|
||||
"""
|
||||
"""Return the wrapped module."""
|
||||
# FSDP's `.module` must refer to the innermost wrapped module when
|
||||
# composing with other module wrappers in order for state dict to work
|
||||
if isinstance(self._fsdp_wrapped_module, ActivationWrapper):
|
||||
|
|
@ -547,6 +547,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
return super().__getitem__(key)
|
||||
|
||||
def check_is_root(self) -> bool:
|
||||
"""Check if this instance is a root FSDP module."""
|
||||
return _is_fsdp_root(self, self)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -554,9 +555,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
module: nn.Module,
|
||||
root_only: bool = False,
|
||||
) -> List["FullyShardedDataParallel"]:
|
||||
"""
|
||||
Returns all nested FSDP instances, possibly including ``module`` itself
|
||||
and only including FSDP root modules if ``root_only=True``.
|
||||
"""Return all nested FSDP instances.
|
||||
|
||||
This possibly includes ``module`` itself and only includes FSDP root modules if ``root_only=True``.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): Root module, which may or may not be an
|
||||
|
|
@ -573,9 +574,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
return traversal_utils._get_fsdp_states(module)
|
||||
|
||||
def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
|
||||
r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``)
|
||||
as well as self. Typical use includes initializing the parameters of a model
|
||||
(see also :ref:`nn-init-doc`).
|
||||
r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
|
||||
|
||||
Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`).
|
||||
|
||||
Compared to ``torch.nn.Module.apply``, this version additionally gathers
|
||||
the full parameters before applying ``fn``. It should not be called from
|
||||
|
|
@ -614,8 +615,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
return ret
|
||||
|
||||
def _mixed_precision_enabled_for_buffers(self) -> bool:
|
||||
"""
|
||||
Returns if the user explicitly enabled buffer mixed precision.
|
||||
"""Return whether the user explicitly enabled buffer mixed precision.
|
||||
|
||||
NOTE: Unlike parameters and gradient reduction, buffer mixed precision
|
||||
is applied at the FSDP instance level, not the ``FlatParameter`` level,
|
||||
|
|
@ -624,15 +624,11 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
return self.mixed_precision.buffer_dtype is not None
|
||||
|
||||
def _low_precision_hook_enabled(self) -> bool:
|
||||
"""
|
||||
Whether a low precision hook is registered or not.
|
||||
"""
|
||||
"""Whether a low precision hook is registered or not."""
|
||||
return self._comm_hook is not None and self._comm_hook in LOW_PRECISION_HOOKS
|
||||
|
||||
def _reset_lazy_init(self) -> None:
|
||||
"""
|
||||
Reset instance so :func:`_lazy_init` will run on the next forward.
|
||||
"""
|
||||
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
|
||||
self._is_root: Optional[bool] = None
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -642,9 +638,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
state_dict_config: Optional[StateDictConfig] = None,
|
||||
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
|
||||
) -> StateDictSettings:
|
||||
"""
|
||||
Set the ``state_dict_type`` and the corresponding (optional)
|
||||
configurations of all the descendant FSDP modules of the target module.
|
||||
"""Set the ``state_dict_type`` of all the descendant FSDP modules of the target module.
|
||||
|
||||
Also takes (optional) configuration for the model's and optimizer's state dict.
|
||||
The target module does not have to be a FSDP module. If the target
|
||||
module is a FSDP module, its ``state_dict_type`` will also be changed.
|
||||
|
||||
|
|
@ -747,10 +743,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
|
||||
@staticmethod
|
||||
def get_state_dict_type(module: nn.Module) -> StateDictSettings:
|
||||
"""
|
||||
Get the state_dict_type and the corresponding configurations
|
||||
for the FSDP modules rooted at ``module``. The target module
|
||||
does not have to be an FSDP module.
|
||||
"""Get the state_dict_type and the corresponding configurations for the FSDP modules rooted at ``module``.
|
||||
|
||||
The target module does not have to be an FSDP module.
|
||||
|
||||
Returns:
|
||||
A ``StateDictSettings`` containing the state_dict_type and
|
||||
|
|
@ -797,10 +792,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
state_dict_config: Optional[StateDictConfig] = None,
|
||||
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
|
||||
) -> Generator:
|
||||
"""
|
||||
A context manager to set the ``state_dict_type`` of all the descendant
|
||||
FSDP modules of the target module. This context manager has the same
|
||||
functions as :meth:`set_state_dict_type`. Read the document of
|
||||
"""Set the ``state_dict_type`` of all the descendant FSDP modules of the target module.
|
||||
|
||||
This context manager has the same functions as :meth:`set_state_dict_type`. Read the document of
|
||||
:meth:`set_state_dict_type` for the detail.
|
||||
|
||||
Example::
|
||||
|
|
@ -836,10 +830,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
)
|
||||
|
||||
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Runs the forward pass for the wrapped module, inserting FSDP-specific
|
||||
pre- and post-forward sharding logic.
|
||||
"""
|
||||
"""Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic."""
|
||||
handle = self._handle
|
||||
with torch.autograd.profiler.record_function(
|
||||
"FullyShardedDataParallel.forward"
|
||||
|
|
@ -875,7 +866,8 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
offload_to_cpu: bool = False,
|
||||
with_grads: bool = False,
|
||||
) -> Generator:
|
||||
r"""A context manager to expose full params for FSDP instances.
|
||||
r"""Expose full params for FSDP instances with this context manager.
|
||||
|
||||
Can be useful *after* forward/backward for a model to get
|
||||
the params for additional processing or checking. It can take a non-FSDP
|
||||
module and will summon full params for all contained FSDP modules as
|
||||
|
|
@ -941,9 +933,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
|
||||
@contextlib.contextmanager
|
||||
def _deregister_orig_params_ctx(self):
|
||||
"""
|
||||
This deregisters the original parameters and exposes the
|
||||
:class:`FlatParameter` s. If a :class:`FlatParameter` is sharded, then
|
||||
"""Deregister the original parameters and expose the :class:`FlatParameter`.
|
||||
|
||||
If a :class:`FlatParameter` is sharded, then
|
||||
this refreshes the sharded views before exiting. This method should
|
||||
only be called when using the original parameters.
|
||||
"""
|
||||
|
|
@ -961,11 +953,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
_register_orig_params(fsdp_module, fsdp_module)
|
||||
|
||||
def _apply(self, *args, **kwargs):
|
||||
"""
|
||||
When using the original parameters, this deregisters the original
|
||||
parameters and exposes the :class:`FlatParameter` s before calling
|
||||
``_apply()``.
|
||||
"""
|
||||
"""Deregister the original parameters and expose the :class:`FlatParameter` s before calling ``_apply()``."""
|
||||
# When using the original parameters: Since (1) the `FlatParameter`s
|
||||
# own the storage and (2) `_apply()` is the subroutine underlying the
|
||||
# most common storage-changing ops like `to()` and `cuda()`, we
|
||||
|
|
@ -985,9 +973,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
*args,
|
||||
**kwargs,
|
||||
) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||
"""
|
||||
Overrides :meth:`named_buffers()` to intercept buffer names and
|
||||
remove all occurrences of the FSDP-specific flattened buffer prefix
|
||||
"""Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself.
|
||||
|
||||
Intercepts buffer names and removes all occurrences of the FSDP-specific flattened buffer prefix
|
||||
when inside the :meth:`summon_full_params` context manager.
|
||||
"""
|
||||
should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
|
||||
|
|
@ -1003,9 +991,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
*args,
|
||||
**kwargs,
|
||||
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
|
||||
"""
|
||||
Overrides :meth:`named_parameters()` to intercept parameter names and
|
||||
remove all occurrences of the FSDP-specific flattened parameter prefix
|
||||
"""Return an iterator over module parameters, yielding both the name of the parameter and the parameter itself.
|
||||
|
||||
Intercepts parameter names and removes all occurrences of the FSDP-specific flattened parameter prefix
|
||||
when inside the :meth:`summon_full_params` context manager.
|
||||
"""
|
||||
should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
|
||||
|
|
@ -1038,9 +1026,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
|
||||
@contextmanager
|
||||
def no_sync(self) -> Generator:
|
||||
"""
|
||||
A context manager to disable gradient synchronizations across FSDP
|
||||
instances. Within this context, gradients will be accumulated in module
|
||||
"""Disable gradient synchronizations across FSDP instances.
|
||||
|
||||
Within this context, gradients will be accumulated in module
|
||||
variables, which will later be synchronized in the first
|
||||
forward-backward pass after exiting the context. This should only be
|
||||
used on the root FSDP instance and will recursively apply to all
|
||||
|
|
@ -1079,9 +1067,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
def clip_grad_norm_(
|
||||
self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Clips the gradient norm of all parameters. The norm is computed over
|
||||
all parameters' gradients as viewed as a single vector, and the
|
||||
"""Clip the gradient norm of all parameters.
|
||||
|
||||
The norm is computed over all parameters' gradients as viewed as a single vector, and the
|
||||
gradients are modified in-place.
|
||||
|
||||
Args:
|
||||
|
|
@ -1245,8 +1233,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
group: Optional[dist.ProcessGroup] = None,
|
||||
cpu_offload: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
The internal API that is used by all the optim_state_dict implementations.
|
||||
"""Transform the state-dict of an optimizer corresponding to a sharded model.
|
||||
|
||||
This is the internal API that is used by all the optim_state_dict implementations.
|
||||
Given model, optim, the original optim_state_dict, this API removes the
|
||||
FSDP internal information and internal sharding from the optim_state_dict.
|
||||
"""
|
||||
|
|
@ -1298,7 +1287,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
The internal API that is used by all the load optim_state_dict implementations.
|
||||
Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
|
||||
|
||||
This is the internal API that is used by all the load optim_state_dict implementations.
|
||||
Given model, optim, and the saved optim_state_dict, this API adds the FSDP
|
||||
internal information and internal sharding to the optim_state_dict.
|
||||
"""
|
||||
|
|
@ -1352,7 +1343,8 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
rank0_only: bool = True,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
"""Return the full optimizer state-dict.
|
||||
|
||||
Consolidates the full optimizer state on rank 0 and returns it
|
||||
as a :class:`dict` following the convention of
|
||||
:meth:`torch.optim.Optimizer.state_dict`, i.e. with keys ``"state"``
|
||||
|
|
@ -1417,7 +1409,8 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
optim: torch.optim.Optimizer,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
"""Return the optimizer state-dict in its sharded form.
|
||||
|
||||
The API is similar to :meth:`full_optim_state_dict` but this API chunks
|
||||
all non-zero-dimension states to :class:`ShardedTensor` to save memory.
|
||||
This API should only be used when the model ``state_dict`` is derived
|
||||
|
|
@ -1453,12 +1446,11 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
] = None,
|
||||
optim: Optional[torch.optim.Optimizer] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Shards the full optimizer state dict ``full_optim_state_dict`` by
|
||||
remapping the state to flattened parameters instead of unflattened
|
||||
parameters and restricting to only this rank's part of the optimizer
|
||||
state. The first argument should be the return value of
|
||||
:meth:`full_optim_state_dict`.
|
||||
"""Shard a full optimizer state-dict.
|
||||
|
||||
Remaps the state in ``full_optim_state_dict`` to flattened parameters instead of unflattened
|
||||
parameters and restricts to only this rank's part of the optimizer state.
|
||||
The first argument should be the return value of :meth:`full_optim_state_dict`.
|
||||
|
||||
Example::
|
||||
|
||||
|
|
@ -1525,7 +1517,8 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
"""Flatten a sharded optimizer state-dict.
|
||||
|
||||
The API is similar to :meth:`shard_full_optim_state_dict`. The only
|
||||
difference is that the input ``sharded_optim_state_dict`` should be
|
||||
returned from :meth:`sharded_optim_state_dict`. Therefore, there will
|
||||
|
|
@ -1568,10 +1561,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
optim: Optional[torch.optim.Optimizer] = None,
|
||||
group: Optional[Any] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Scatters the full optimizer state dict from rank 0 to all other ranks,
|
||||
returning the sharded optimizer state dict on each rank. The return
|
||||
value is the same as :meth:`shard_full_optim_state_dict`, and on rank
|
||||
"""Scatter the full optimizer state dict from rank 0 to all other ranks.
|
||||
|
||||
Returns the sharded optimizer state dict on each rank.
|
||||
The return value is the same as :meth:`shard_full_optim_state_dict`, and on rank
|
||||
0, the first argument should be the return value of
|
||||
:meth:`full_optim_state_dict`.
|
||||
|
||||
|
|
@ -1650,10 +1643,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
] = None,
|
||||
optim: Optional[torch.optim.Optimizer] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Re-keys the optimizer state dict ``optim_state_dict`` to use the key
|
||||
type ``optim_state_key_type``. This can be used to achieve
|
||||
compatibility between optimizer state dicts from models with FSDP
|
||||
"""Re-keys the optimizer state dict ``optim_state_dict`` to use the key type ``optim_state_key_type``.
|
||||
|
||||
This can be used to achieve compatibility between optimizer state dicts from models with FSDP
|
||||
instances and ones without.
|
||||
|
||||
To re-key an FSDP full optimizer state dict (i.e. from
|
||||
|
|
@ -1771,9 +1763,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transforms the state_dict of ``optim`` for the ``model`` that is sharded
|
||||
by FSDP to one of the three types: 1) full optimizer state_dict, 2)
|
||||
sharded optimizer state_dict, 3) local optimizer state_dict.
|
||||
Transform the state-dict of an optimizer corresponding to a sharded model.
|
||||
|
||||
The given state-dict can be transformed to one of three types:
|
||||
1) full optimizer state_dict, 2) sharded optimizer state_dict, 3) local optimizer state_dict.
|
||||
|
||||
For full optimizer state_dict, all states are unflattened and not sharded.
|
||||
Rank0 only and CPU only can be specified via :meth:`state_dict_type` to
|
||||
|
|
@ -1867,8 +1860,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
|
||||
|
||||
Given a ``optim_state_dict`` that is transformed through
|
||||
:meth:`optim_state_dict`, converts it to the flattened optimizer
|
||||
:meth:`optim_state_dict`, it gets converted to the flattened optimizer
|
||||
state_dict that can be loaded to ``optim`` which is the optimizer for
|
||||
``model``. ``model`` must be sharded by FullyShardedDataParallel.
|
||||
|
||||
|
|
@ -1946,10 +1941,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
return result
|
||||
|
||||
def register_comm_hook(self, state: object, hook: callable):
|
||||
"""
|
||||
Registers a communication hook which is an enhancement that provides a
|
||||
flexible hook to users where they can specify how FSDP aggregates gradients
|
||||
across multiple workers.
|
||||
"""Register a communication hook.
|
||||
|
||||
This is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates
|
||||
gradients across multiple workers.
|
||||
This hook can be used to implement several algorithms like
|
||||
`GossipGrad <https://arxiv.org/abs/1803.05880>`_ and gradient compression
|
||||
which involve different communication strategies for
|
||||
|
|
@ -2008,9 +2003,9 @@ def _get_grad_norm(
|
|||
norm_type: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns the gradient norm of parameters ``param`` s, where the gradients
|
||||
are viewed as a single vector. The returned norm is in FP32 even if
|
||||
parameters/gradients are in a low precision. This is because the downstream
|
||||
Return the gradient norm of parameters ``param`` s, where the gradients are viewed as a single vector.
|
||||
|
||||
The returned norm is in FP32 even if parameters/gradients are in a low precision. This is because the downstream
|
||||
use of this return value is a reduction across ranks.
|
||||
"""
|
||||
params_with_grad = [param for param in params if param.grad is not None]
|
||||
|
|
@ -2041,8 +2036,9 @@ def _get_param_to_fqn(
|
|||
model: torch.nn.Module,
|
||||
) -> Dict[torch.nn.Parameter, str]:
|
||||
"""
|
||||
Constructs a mapping from parameters to their parameter names. ``model``
|
||||
should not contain any :class:`FullyShardedDataParallel` instances, which
|
||||
Construct a mapping from parameters to their parameter names.
|
||||
|
||||
The ``model`` should not contain any :class:`FullyShardedDataParallel` instances, which
|
||||
means that none of the parameters should be ``FlatParameter`` s. As a
|
||||
result, compared to :meth:`_get_param_to_fqns`, the mapped
|
||||
values may be flattened from singleton :class:`list` s to the contained
|
||||
|
|
@ -2071,6 +2067,6 @@ def _get_param_to_fqn(
|
|||
def _get_fqn_to_param(
|
||||
model: torch.nn.Module,
|
||||
) -> Dict[str, torch.nn.Parameter]:
|
||||
"""Constructs the inverse mapping of :meth:`_get_param_to_fqn`."""
|
||||
"""Construct the inverse mapping of :meth:`_get_param_to_fqn`."""
|
||||
param_to_param_name = _get_param_to_fqn(model)
|
||||
return dict(zip(param_to_param_name.values(), param_to_param_name.keys()))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user