mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix pydocstyle errors in fully_sharded_data_parallel.py, api.py, graph_utils.py, distribute.py, iter_graph_module.py, comm_tensor.py, experimental_ops.py, batch_dim_utils.py, data_parallel.py, graph_optimization.py (#113216)
Fixes #113191 ``` pydocstyle torch/distributed/fsdp/fully_sharded_data_parallel.py --count ``` On master: 80 After my changes on this PR: 3 ``` pydocstyle torch/distributed/_spmd/comm_tensor.py --count ``` On master: 5 After my changes on this PR: 3 ``` pydocstyle torch/distributed/_spmd/experimental_ops.py --count ``` On master: 3 After my changes on this PR: 1 ``` pydocstyle torch/distributed/_spmd/iter_graph_module.py --count ``` On master: 39 After my changes on this PR: 27 ``` pydocstyle torch/distributed/_spmd/graph_utils.py --count ``` On master: 16 After my changes on this PR: 4 ``` pydocstyle torch/distributed/_spmd/distribute.py --count ``` On master: 19 After my changes on this PR: 10 ``` pydocstyle torch/distributed/_spmd/api.py --count ``` On master: 10 After my changes on this PR: 3 ``` pydocstyle torch/distributed/_spmd/batch_dim_utils.py --count ``` On master: 14 After my changes on this PR: 3 ``` pydocstyle torch/distributed/_spmd/data_parallel.py --count ``` On master: 34 After my changes on this PR: 2 ``` pydocstyle torch/distributed/_spmd/graph_optimization.py --count ``` On master: 35 After my changes on this PR: 13 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113216 Approved by: https://github.com/ezyang
This commit is contained in:
parent
773b1cbe4f
commit
866457e746
|
|
@ -32,8 +32,8 @@ from torch.nn.utils._named_member_accessor import NamedMemberAccessor
|
||||||
|
|
||||||
|
|
||||||
class Override(ABC):
|
class Override(ABC):
|
||||||
r"""
|
r"""Override the tracing and transformation behavior of :meth:`~torch.distributed._spmd.compile`.
|
||||||
Override the tracing and transformation behavior of :meth:`~torch.distributed._spmd.compile`.
|
|
||||||
This is useful when any part of the model is not traceable or if you prefer
|
This is useful when any part of the model is not traceable or if you prefer
|
||||||
to not trace it due to any reason. More specifically, users can implement
|
to not trace it due to any reason. More specifically, users can implement
|
||||||
:meth:`torch.distributed._spmd.Override.replacement` to replace an original
|
:meth:`torch.distributed._spmd.Override.replacement` to replace an original
|
||||||
|
|
@ -47,10 +47,10 @@ class Override(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def replacement(self, fqn: str, orig_submodule: torch.nn.Module) -> torch.nn.Module:
|
def replacement(self, fqn: str, orig_submodule: torch.nn.Module) -> torch.nn.Module:
|
||||||
r"""
|
r"""Implement this method to return a new :class:`nn.Module` instance to replace the ``orig_submodule``
|
||||||
Implement this method to return a new :class:`nn.Module` instance to
|
argument in the model.
|
||||||
replace the ``orig_submodule`` argument in the model. This helps if
|
|
||||||
``orig_submodule`` is not traceable or should not be traced.
|
This helps if ``orig_submodule`` is not traceable or should not be traced.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fqn (str): fully quantified name of the submodule.
|
fqn (str): fully quantified name of the submodule.
|
||||||
|
|
@ -58,6 +58,7 @@ class Override(ABC):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A new :class:`nn.Module` instance to replace the original one.
|
A new :class:`nn.Module` instance to replace the original one.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -83,6 +84,7 @@ class Override(ABC):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The :class:`fx.Graph` after transformation.
|
The :class:`fx.Graph` after transformation.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -98,8 +100,7 @@ class _PyTreeCodeGenOutputsOnly(_PyTreeCodeGen):
|
||||||
|
|
||||||
|
|
||||||
def _to_caller_flattened_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
def _to_caller_flattened_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||||
"""Move the responsibility of flattening the input arguments from the
|
"""Move the responsibility of flattening the input arguments from the graph module to the caller.
|
||||||
graph module to the caller.
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
|
|
@ -108,6 +109,7 @@ def _to_caller_flattened_graph_module(gm: torch.fx.GraphModule) -> torch.fx.Grap
|
||||||
gm = gm(to_caller_flattened_graph_module)
|
gm = gm(to_caller_flattened_graph_module)
|
||||||
|
|
||||||
output = gm(*pytree.flatten(my_struct)[0])
|
output = gm(*pytree.flatten(my_struct)[0])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
gm._graph._codegen = _PyTreeCodeGenOutputsOnly(
|
gm._graph._codegen = _PyTreeCodeGenOutputsOnly(
|
||||||
|
|
@ -500,9 +502,9 @@ def compile(
|
||||||
gm_transformation: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
|
gm_transformation: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
|
||||||
parallel_mode: Optional[ParallelMode] = None,
|
parallel_mode: Optional[ParallelMode] = None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""Compile and optimize a callable, which can be a train step within a training loop.
|
||||||
Compile and optimize a callable, which can be a train step within a training
|
|
||||||
loop. This method will extract :class:`nn.Module` and :class:`torch.optim.Optimizer`
|
This method will extract :class:`nn.Module` and :class:`torch.optim.Optimizer`
|
||||||
instances from the input arguments and trace operations applied to their
|
instances from the input arguments and trace operations applied to their
|
||||||
parameters and states.
|
parameters and states.
|
||||||
|
|
||||||
|
|
@ -519,6 +521,7 @@ def compile(
|
||||||
that specifies how to parallelize the callable. Each ParallelMode
|
that specifies how to parallelize the callable. Each ParallelMode
|
||||||
would have its own strategy to partition the model and the captured
|
would have its own strategy to partition the model and the captured
|
||||||
graph (Default: ``None``)
|
graph (Default: ``None``)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def inner(func: Callable):
|
def inner(func: Callable):
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,9 @@ aten = torch.ops.aten
|
||||||
|
|
||||||
|
|
||||||
class BatchDimAnalyzer:
|
class BatchDimAnalyzer:
|
||||||
"""
|
"""This class is used to analyze the batch dimension of each tensor/node in the graph.
|
||||||
This class is used to analyze the batch dimension of each tensor/node in the
|
|
||||||
graph. We need to know the batch dimension of each tensor/node so that we know
|
We need to know the batch dimension of each tensor/node so that we know
|
||||||
exactly the sharding layout of intermediate tensors.
|
exactly the sharding layout of intermediate tensors.
|
||||||
|
|
||||||
We possibly should evaluate using symbolic shapes to track the batch dimension.
|
We possibly should evaluate using symbolic shapes to track the batch dimension.
|
||||||
|
|
@ -54,9 +54,7 @@ class BatchDimAnalyzer:
|
||||||
}
|
}
|
||||||
|
|
||||||
def init_batch_dim_size(self, batch_dim_size: int) -> None:
|
def init_batch_dim_size(self, batch_dim_size: int) -> None:
|
||||||
"""
|
"""Initialize batch dim size base on the first input batch size."""
|
||||||
initialize batch dim size base on the first input batch size
|
|
||||||
"""
|
|
||||||
if self.batch_dim_size != -1 and self.batch_dim_size != batch_dim_size:
|
if self.batch_dim_size != -1 and self.batch_dim_size != batch_dim_size:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"batch dim size is already initialized! "
|
f"batch dim size is already initialized! "
|
||||||
|
|
@ -74,9 +72,7 @@ class BatchDimAnalyzer:
|
||||||
return self.batch_dim_map[node]
|
return self.batch_dim_map[node]
|
||||||
|
|
||||||
def compute_batch_dim(self, node: fx.Node, full_reduction=False) -> int:
|
def compute_batch_dim(self, node: fx.Node, full_reduction=False) -> int:
|
||||||
"""
|
"""Compute the batch dimension for the `node`."""
|
||||||
compute the batch dimension for the `node`
|
|
||||||
"""
|
|
||||||
assert self.batch_dim_size != -1, "batch dim size is not initialized!"
|
assert self.batch_dim_size != -1, "batch dim size is not initialized!"
|
||||||
|
|
||||||
if node in self.batch_dim_map:
|
if node in self.batch_dim_map:
|
||||||
|
|
@ -168,10 +164,7 @@ class BatchDimAnalyzer:
|
||||||
return -2
|
return -2
|
||||||
|
|
||||||
def compute_act_spec(self, node: fx.Node, mesh: DeviceMesh) -> DTensorSpec:
|
def compute_act_spec(self, node: fx.Node, mesh: DeviceMesh) -> DTensorSpec:
|
||||||
"""
|
"""Compute the batch dimension for the current node, then generate the sharding spec that shards on the batch dimension."""
|
||||||
This function first compute the batch dimension for the current node,
|
|
||||||
then generate the sharding spec that shards on the batch dimension.
|
|
||||||
"""
|
|
||||||
node_batch_dim = self.compute_batch_dim(node)
|
node_batch_dim = self.compute_batch_dim(node)
|
||||||
if node_batch_dim == -1:
|
if node_batch_dim == -1:
|
||||||
# indicate this activation is replicated
|
# indicate this activation is replicated
|
||||||
|
|
|
||||||
|
|
@ -56,9 +56,9 @@ def _get_tracer() -> Optional[torch.fx.Tracer]:
|
||||||
|
|
||||||
class CommTensor(torch.Tensor):
|
class CommTensor(torch.Tensor):
|
||||||
r"""
|
r"""
|
||||||
A Tensor subclass to wrap input tensors for collective communications. This
|
A Tensor subclass to wrap input tensors for collective communications.
|
||||||
Tensor subclass works for both eager and tracing mode.
|
|
||||||
|
|
||||||
|
This Tensor subclass works for both eager and tracing mode.
|
||||||
In eager mode, it will record whether the inplace collective communication
|
In eager mode, it will record whether the inplace collective communication
|
||||||
has been launched using this Tensor and remember the corresponding work
|
has been launched using this Tensor and remember the corresponding work
|
||||||
handle. If yes, it will explicitly call wait() in the ``__torch_dispatch__``
|
handle. If yes, it will explicitly call wait() in the ``__torch_dispatch__``
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,8 @@ _spmd_lib_impl.impl("tag_grad", lambda x: x, "CompositeExplicitAutograd")
|
||||||
|
|
||||||
|
|
||||||
class DataParallelStyle(Enum):
|
class DataParallelStyle(Enum):
|
||||||
"""
|
"""This enum represents the style of the data-parallel operation.
|
||||||
|
|
||||||
We have three types of Data Parallel style:
|
We have three types of Data Parallel style:
|
||||||
1. DEFAULT: the default data parallel style, which is to represent a mixed
|
1. DEFAULT: the default data parallel style, which is to represent a mixed
|
||||||
replicate and fully shard behavior. For each parameter that is able
|
replicate and fully shard behavior. For each parameter that is able
|
||||||
|
|
@ -64,8 +65,8 @@ class DataParallelStyle(Enum):
|
||||||
|
|
||||||
|
|
||||||
class NodeType(Enum):
|
class NodeType(Enum):
|
||||||
"""
|
"""NodeType is an enum that records the type of the tensors in the graph.
|
||||||
NodeType is a enum that records the type of the tensors in the graph.
|
|
||||||
This is used to determine the data parallel strategy.
|
This is used to determine the data parallel strategy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -77,9 +78,8 @@ class NodeType(Enum):
|
||||||
|
|
||||||
|
|
||||||
class DataParallelStrategy(OpStrategy):
|
class DataParallelStrategy(OpStrategy):
|
||||||
"""
|
"""DataParallelStrategy is a special case of OpStrategy that only records the "data parallel style" placement
|
||||||
DataParallelStrategy is a special case of OpStrategy that only records
|
strategy for each fx Node.
|
||||||
the "data parallel style" placement strategy for each fx Node.
|
|
||||||
|
|
||||||
It takes a list of PlacementStrategy, where each PlacementStrategy describes
|
It takes a list of PlacementStrategy, where each PlacementStrategy describes
|
||||||
one way to distribute the tensor and computation. In the DataParallel case,
|
one way to distribute the tensor and computation. In the DataParallel case,
|
||||||
|
|
@ -113,13 +113,10 @@ class DataParallelStrategy(OpStrategy):
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def gradients_tagging(params: Dict[str, torch.Tensor]):
|
def gradients_tagging(params: Dict[str, torch.Tensor]):
|
||||||
"""
|
"""Tag the gradient of the parameters with a special tag, so that we can identify them during SPMD expansion.
|
||||||
This is a helper function that tags the gradient of the parameters
|
|
||||||
with a special tag, so that we can identify them during SPMD expansion.
|
|
||||||
|
|
||||||
It's safe to trace those hooks and we would remove those nodes later.
|
It's safe to trace those hooks and we would remove those nodes later.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tagging_hooks = []
|
tagging_hooks = []
|
||||||
try:
|
try:
|
||||||
for p in params.values():
|
for p in params.values():
|
||||||
|
|
@ -135,9 +132,7 @@ def gradients_tagging(params: Dict[str, torch.Tensor]):
|
||||||
def _gen_shard_strategy(
|
def _gen_shard_strategy(
|
||||||
mesh: DeviceMesh, shard_dim: int, input_specs: Optional[List[DTensorSpec]] = None
|
mesh: DeviceMesh, shard_dim: int, input_specs: Optional[List[DTensorSpec]] = None
|
||||||
) -> PlacementStrategy:
|
) -> PlacementStrategy:
|
||||||
"""
|
"""Util function to generate a shard strategy on shard_dim."""
|
||||||
util function to generate a shard strategy on shard_dim
|
|
||||||
"""
|
|
||||||
return PlacementStrategy(
|
return PlacementStrategy(
|
||||||
output_spec=DTensorSpec(mesh=mesh, placements=(Shard(shard_dim),)),
|
output_spec=DTensorSpec(mesh=mesh, placements=(Shard(shard_dim),)),
|
||||||
input_specs=input_specs,
|
input_specs=input_specs,
|
||||||
|
|
@ -147,9 +142,7 @@ def _gen_shard_strategy(
|
||||||
def _gen_replicate_strategy(
|
def _gen_replicate_strategy(
|
||||||
mesh: DeviceMesh, input_specs: Optional[List[DTensorSpec]] = None
|
mesh: DeviceMesh, input_specs: Optional[List[DTensorSpec]] = None
|
||||||
) -> PlacementStrategy:
|
) -> PlacementStrategy:
|
||||||
"""
|
"""Util function to generate a replicate strategy."""
|
||||||
util function to generate a replicate strategy
|
|
||||||
"""
|
|
||||||
return PlacementStrategy(
|
return PlacementStrategy(
|
||||||
output_spec=DTensorSpec(mesh=mesh, placements=(Replicate(),)),
|
output_spec=DTensorSpec(mesh=mesh, placements=(Replicate(),)),
|
||||||
input_specs=input_specs,
|
input_specs=input_specs,
|
||||||
|
|
@ -157,9 +150,7 @@ def _gen_replicate_strategy(
|
||||||
|
|
||||||
|
|
||||||
def _gen_partial_strategy(mesh: DeviceMesh) -> PlacementStrategy:
|
def _gen_partial_strategy(mesh: DeviceMesh) -> PlacementStrategy:
|
||||||
"""
|
"""Util function to generate a partial strategy."""
|
||||||
util function to generate a partial strategy
|
|
||||||
"""
|
|
||||||
# NOTE: we use AVG by default, avg reduction is needed depending on
|
# NOTE: we use AVG by default, avg reduction is needed depending on
|
||||||
# the loss function, for most loss function it should do
|
# the loss function, for most loss function it should do
|
||||||
# gradient averaging. There might be certain cases it should
|
# gradient averaging. There might be certain cases it should
|
||||||
|
|
@ -180,10 +171,7 @@ def build_data_parallel_strategies(
|
||||||
mesh: DeviceMesh,
|
mesh: DeviceMesh,
|
||||||
batch_dim: int = 0,
|
batch_dim: int = 0,
|
||||||
) -> Dict[fx.Node, StrategyType]:
|
) -> Dict[fx.Node, StrategyType]:
|
||||||
"""
|
"""Loop through the train step graph and build the data parallel strategy for each fx Node."""
|
||||||
This function loop through the train step graph and build the
|
|
||||||
data parallel strategy for each fx Node
|
|
||||||
"""
|
|
||||||
activation_idx = num_params + num_states
|
activation_idx = num_params + num_states
|
||||||
non_compute_ops = [
|
non_compute_ops = [
|
||||||
aten.clone.default,
|
aten.clone.default,
|
||||||
|
|
@ -518,9 +506,7 @@ def mark_data_parallel_shardings(
|
||||||
dp_strategy_map: Dict[fx.Node, StrategyType],
|
dp_strategy_map: Dict[fx.Node, StrategyType],
|
||||||
parallel_mode: DataParallelStyle = DataParallelStyle.FULLY_SHARD,
|
parallel_mode: DataParallelStyle = DataParallelStyle.FULLY_SHARD,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""Mark the sharding for the nodes in the train_step_graph."""
|
||||||
This function marks the sharding for the nodes in the train_step_graph
|
|
||||||
"""
|
|
||||||
activation_idx = num_parameters + num_states
|
activation_idx = num_parameters + num_states
|
||||||
placeholder_idx = 0
|
placeholder_idx = 0
|
||||||
for node in train_step_graph.graph.nodes:
|
for node in train_step_graph.graph.nodes:
|
||||||
|
|
@ -601,9 +587,7 @@ def mark_data_parallel_shardings(
|
||||||
|
|
||||||
|
|
||||||
def _partition_val(val: Any, spec: DTensorSpec) -> Any:
|
def _partition_val(val: Any, spec: DTensorSpec) -> Any:
|
||||||
"""
|
"""Util function to convert a full tensor val to its local component."""
|
||||||
util function to convert a full tensor val to its local component
|
|
||||||
"""
|
|
||||||
if isinstance(val, torch.Tensor):
|
if isinstance(val, torch.Tensor):
|
||||||
local_shard = val
|
local_shard = val
|
||||||
if val.ndim == 0:
|
if val.ndim == 0:
|
||||||
|
|
@ -629,10 +613,7 @@ def _partition_val(val: Any, spec: DTensorSpec) -> Any:
|
||||||
|
|
||||||
|
|
||||||
def partitioner(graph: GraphModule) -> GraphModule:
|
def partitioner(graph: GraphModule) -> GraphModule:
|
||||||
"""
|
"""Graph partitioner that partitions the single device graph to distributed graph."""
|
||||||
Graph partitioner that partitions the single device graph
|
|
||||||
to distributed graph
|
|
||||||
"""
|
|
||||||
shape_adjustment_ops = {
|
shape_adjustment_ops = {
|
||||||
aten._unsafe_view.default: 1,
|
aten._unsafe_view.default: 1,
|
||||||
aten.expand.default: 1,
|
aten.expand.default: 1,
|
||||||
|
|
@ -761,10 +742,9 @@ def partition_data_parallel(
|
||||||
parallel_style: DataParallelStyle,
|
parallel_style: DataParallelStyle,
|
||||||
input_batch_dim: int,
|
input_batch_dim: int,
|
||||||
) -> GraphModule:
|
) -> GraphModule:
|
||||||
"""
|
"""Partition the graph to into a data parallel graph.
|
||||||
The entry point function to partition the graph to data parallel
|
|
||||||
graph, it also shard/replicate the model parameters and optimizer
|
This function also shards/replicates the model parameters and optimizer states to DTensors.
|
||||||
states to DTensors.
|
|
||||||
"""
|
"""
|
||||||
num_params_buffers = len(params_buffers)
|
num_params_buffers = len(params_buffers)
|
||||||
flattened_states = pytree.tree_leaves(named_states)
|
flattened_states = pytree.tree_leaves(named_states)
|
||||||
|
|
|
||||||
|
|
@ -48,9 +48,9 @@ class Schema:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DSymInt:
|
class DSymInt:
|
||||||
"""
|
"""DSymInt represents a value retrieved by a SymInt op from a DTensor.
|
||||||
DSymInt represents a value retrieved by a SymInt op from a DTensor. DSymInt
|
|
||||||
helps View and Factory ops to determine the placement and shape of the
|
DSymInt helps View and Factory ops to determine the placement and shape of the
|
||||||
output tensor, as those operators either do not have an input DTensor or
|
output tensor, as those operators either do not have an input DTensor or
|
||||||
the input DTensor is insufficient to determine the output tensor's placement.
|
the input DTensor is insufficient to determine the output tensor's placement.
|
||||||
"""
|
"""
|
||||||
|
|
@ -90,7 +90,7 @@ class DSymInt:
|
||||||
|
|
||||||
|
|
||||||
def _is_partial_dtensor(obj: Any) -> bool:
|
def _is_partial_dtensor(obj: Any) -> bool:
|
||||||
"""check if object is 1) DTensor and 2) with any placement of _Partial"""
|
"""Check if object is 1) DTensor and 2) with any placement of _Partial."""
|
||||||
if not isinstance(obj, DTensor):
|
if not isinstance(obj, DTensor):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -475,8 +475,8 @@ def _get_dtensor_dispatch_graph(
|
||||||
def _build_dummy_add_graph(
|
def _build_dummy_add_graph(
|
||||||
dt: DTensor, node_to_obj: Dict[fx.Node, Any]
|
dt: DTensor, node_to_obj: Dict[fx.Node, Any]
|
||||||
) -> Tuple[fx.GraphModule, Any]:
|
) -> Tuple[fx.GraphModule, Any]:
|
||||||
"""
|
"""Create a graph for a dummy add function from a partial DTensor.
|
||||||
Creates a graph for a dummy add function from a partial DTensor.
|
|
||||||
This dummy add is used for triggering all_reduce on a Partial DTensor
|
This dummy add is used for triggering all_reduce on a Partial DTensor
|
||||||
during the DTensor expansion of the traced graph.
|
during the DTensor expansion of the traced graph.
|
||||||
Also returns the actual DTensor after resharding.
|
Also returns the actual DTensor after resharding.
|
||||||
|
|
@ -703,10 +703,12 @@ def _convert_to_distributed(
|
||||||
default_mesh: Optional[DeviceMesh] = None,
|
default_mesh: Optional[DeviceMesh] = None,
|
||||||
_allow_partial: bool = False,
|
_allow_partial: bool = False,
|
||||||
) -> Tuple[fx.GraphModule, Dict[str, Schema]]:
|
) -> Tuple[fx.GraphModule, Dict[str, Schema]]:
|
||||||
"""
|
"""Transform a graph module to a distributed graph module.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- transformed graph module
|
- transformed graph module
|
||||||
- map from output name to DTensorSpec
|
- map from output name to DTensorSpec
|
||||||
|
|
||||||
"""
|
"""
|
||||||
global logger
|
global logger
|
||||||
logger = get_logger("spmd_exp")
|
logger = get_logger("spmd_exp")
|
||||||
|
|
|
||||||
|
|
@ -342,10 +342,7 @@ def _prop_native_layer_norm_backward(op_schema: OpSchema) -> OutputSharding:
|
||||||
def _refine_sharding(
|
def _refine_sharding(
|
||||||
op_schema: OpSchema, active_dim: Optional[int]
|
op_schema: OpSchema, active_dim: Optional[int]
|
||||||
) -> Sequence[Placement]:
|
) -> Sequence[Placement]:
|
||||||
"""
|
"""Considers 2 first inputs of op_schema as having same shape, and returns suggested placement for a pointwise operation."""
|
||||||
Considers 2 first inputs of op_schema as having same shape,
|
|
||||||
and returns suggested placement for a pointwise operation.
|
|
||||||
"""
|
|
||||||
# consider the operating dimension as a singleton to prevent sharding on it
|
# consider the operating dimension as a singleton to prevent sharding on it
|
||||||
# however, if active_dim is None, this means the input and output shapes are equal and
|
# however, if active_dim is None, this means the input and output shapes are equal and
|
||||||
# we'll apply exactly the pointwise rule.
|
# we'll apply exactly the pointwise rule.
|
||||||
|
|
|
||||||
|
|
@ -64,10 +64,9 @@ def graph_optimization_pass(
|
||||||
prerequisites: Iterable[Callable],
|
prerequisites: Iterable[Callable],
|
||||||
apply_after: Iterable[Callable],
|
apply_after: Iterable[Callable],
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
"""
|
"""Define the contract of a graph optimization pass.
|
||||||
The contract of graph optimization pass. All the passes should be wrapped
|
|
||||||
with this decorator.
|
|
||||||
|
|
||||||
|
All the passes should be wrapped with this decorator.
|
||||||
`prerequisites` is used to annotate the prerequisite passes of the this pass.
|
`prerequisites` is used to annotate the prerequisite passes of the this pass.
|
||||||
`apply_after` means that this wrapped pass must be applied after the passes
|
`apply_after` means that this wrapped pass must be applied after the passes
|
||||||
in `apply_after`. The difference between `prerequisites` and `apply_after`
|
in `apply_after`. The difference between `prerequisites` and `apply_after`
|
||||||
|
|
@ -158,12 +157,11 @@ class CommBlock:
|
||||||
|
|
||||||
|
|
||||||
def get_comm_block(comm_node: fx.Node) -> CommBlock:
|
def get_comm_block(comm_node: fx.Node) -> CommBlock:
|
||||||
"""
|
"""Find out all the nodes belong to this communcation given a collective node (e.g., allreduce).
|
||||||
Given a collective node (e.g., allreduce), find out all the nodes belong to
|
|
||||||
this communcation.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
comm_node(fx.Node): The target communication/collective node.
|
comm_node(fx.Node): The target communication/collective node.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The CommBlock that encapsulates the related nodes (e.g., wait_node) of
|
The CommBlock that encapsulates the related nodes (e.g., wait_node) of
|
||||||
the given comm_node.
|
the given comm_node.
|
||||||
|
|
@ -306,10 +304,7 @@ def _scatter_wait_result(
|
||||||
comm_blocks: List[CommBlock],
|
comm_blocks: List[CommBlock],
|
||||||
node_indices: Dict[fx.Node, int],
|
node_indices: Dict[fx.Node, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""Scatter the result of the fused communication node to the original users -- splitting the output and reshape each subitem."""
|
||||||
Scatters the result of the fused communication node to the original users --
|
|
||||||
splitting the output and reshape each subitem.
|
|
||||||
"""
|
|
||||||
last_wait_node_idx = 0
|
last_wait_node_idx = 0
|
||||||
for node in gm.graph.nodes:
|
for node in gm.graph.nodes:
|
||||||
if node == fused_comm_block.comm_node:
|
if node == fused_comm_block.comm_node:
|
||||||
|
|
@ -371,9 +366,7 @@ def _fuse_with_cat(
|
||||||
comm_blocks: List[CommBlock],
|
comm_blocks: List[CommBlock],
|
||||||
node_indices: Dict[fx.Node, int],
|
node_indices: Dict[fx.Node, int],
|
||||||
) -> CommBlock:
|
) -> CommBlock:
|
||||||
"""
|
"""Fuse the CommBlocks using concat given a list of CommBlock (only allreduce)."""
|
||||||
Given a list of CommBlock (only allreduce), fuse the CommBlocks using concat.
|
|
||||||
"""
|
|
||||||
# Find the last input node.
|
# Find the last input node.
|
||||||
last_input_node = comm_blocks[0].inputs[0]
|
last_input_node = comm_blocks[0].inputs[0]
|
||||||
last_input_index = -1
|
last_input_index = -1
|
||||||
|
|
@ -474,8 +467,8 @@ def comm_fusion_with_concat(
|
||||||
gm: IterGraphModule,
|
gm: IterGraphModule,
|
||||||
bucket_size_mb: int,
|
bucket_size_mb: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""Run fuse communication with concat.
|
||||||
Run fuse communication with concat.
|
|
||||||
This implementation uses concat to concat the bucketed gradients.
|
This implementation uses concat to concat the bucketed gradients.
|
||||||
"""
|
"""
|
||||||
comm_blocks = get_all_comm_blocks(gm, (CommType.ALLREDUCE, "all_reduce"))
|
comm_blocks = get_all_comm_blocks(gm, (CommType.ALLREDUCE, "all_reduce"))
|
||||||
|
|
@ -508,9 +501,7 @@ def comm_fusion_with_concat(
|
||||||
apply_after=[],
|
apply_after=[],
|
||||||
)
|
)
|
||||||
def schedule_comm_wait(gm: IterGraphModule) -> None:
|
def schedule_comm_wait(gm: IterGraphModule) -> None:
|
||||||
"""
|
"""Delay the execution of wait tensors of allreduce until its first user."""
|
||||||
Delay the execution of wait tensors of allreduce until its first user.
|
|
||||||
"""
|
|
||||||
comm_blocks = get_all_comm_blocks(gm, (CommType.ALLREDUCE, "all_reduce"))
|
comm_blocks = get_all_comm_blocks(gm, (CommType.ALLREDUCE, "all_reduce"))
|
||||||
|
|
||||||
# Find all the end users.
|
# Find all the end users.
|
||||||
|
|
@ -549,8 +540,8 @@ def schedule_comm_wait(gm: IterGraphModule) -> None:
|
||||||
apply_after=[],
|
apply_after=[],
|
||||||
)
|
)
|
||||||
def remove_copy_from_optimizer(gm: IterGraphModule) -> None:
|
def remove_copy_from_optimizer(gm: IterGraphModule) -> None:
|
||||||
"""
|
"""Erase the orphant copy_ that generated when tracing optimizer.
|
||||||
Erase the orphant copy_ that generated when tracing optimizer.
|
|
||||||
Two reasons why we could not simply use the DCE of fx.Graph.
|
Two reasons why we could not simply use the DCE of fx.Graph.
|
||||||
1. fx.Graph treats copy_ as a side-effect node and does not erase it.
|
1. fx.Graph treats copy_ as a side-effect node and does not erase it.
|
||||||
2. Users may want to preserve some orphan `copy_` that is not from the
|
2. Users may want to preserve some orphan `copy_` that is not from the
|
||||||
|
|
@ -742,9 +733,7 @@ class FusedOptimizerBlock:
|
||||||
|
|
||||||
|
|
||||||
def get_fused_optimizer_block(optim_node: fx.Node) -> FusedOptimizerBlock:
|
def get_fused_optimizer_block(optim_node: fx.Node) -> FusedOptimizerBlock:
|
||||||
"""
|
"""Given a fused optimizer node and return the FusedOptimizerBlock."""
|
||||||
Given a fused optimizer node and return the FusedOptimizerBlock.
|
|
||||||
"""
|
|
||||||
MAX_STEP_DISTANCE = 5
|
MAX_STEP_DISTANCE = 5
|
||||||
# Find the step (foreach_add)
|
# Find the step (foreach_add)
|
||||||
nodes = collections.deque([optim_node, None])
|
nodes = collections.deque([optim_node, None])
|
||||||
|
|
@ -783,10 +772,7 @@ def get_fused_optimizer_block(optim_node: fx.Node) -> FusedOptimizerBlock:
|
||||||
def get_all_fused_optimizer_blocks(
|
def get_all_fused_optimizer_blocks(
|
||||||
gm: IterGraphModule, optim_ops: Union[Tuple[str, ...], str]
|
gm: IterGraphModule, optim_ops: Union[Tuple[str, ...], str]
|
||||||
) -> List[FusedOptimizerBlock]:
|
) -> List[FusedOptimizerBlock]:
|
||||||
"""
|
"""Find all the FusedOptimizerBlock that the optimizer operators are in `optim_ops`."""
|
||||||
Find all the FusedOptimizerBlock that the optimizer operators are in
|
|
||||||
`optim_ops`.
|
|
||||||
"""
|
|
||||||
return [
|
return [
|
||||||
get_fused_optimizer_block(node)
|
get_fused_optimizer_block(node)
|
||||||
for node in gm.graph.nodes
|
for node in gm.graph.nodes
|
||||||
|
|
@ -799,9 +785,9 @@ def _split_fused_adam(
|
||||||
orig_optim_block: FusedOptimizerBlock,
|
orig_optim_block: FusedOptimizerBlock,
|
||||||
split_gradients: Set[fx.Node],
|
split_gradients: Set[fx.Node],
|
||||||
) -> Tuple[FusedOptimizerBlock, FusedOptimizerBlock]:
|
) -> Tuple[FusedOptimizerBlock, FusedOptimizerBlock]:
|
||||||
"""
|
"""Split the `orig_optim_block` into two FusedOptimizerBlock.
|
||||||
Split the `orig_optim_block` into two FusedOptimizerBlock. The first one
|
|
||||||
will be the optimizer that optimize `split_gradients`. The second one is
|
The first one will be the optimizer that optimize `split_gradients`. The second one is
|
||||||
used to optimize the remaining gradients.
|
used to optimize the remaining gradients.
|
||||||
An assert will be raised if one of the optimizer optimize zero gradients.
|
An assert will be raised if one of the optimizer optimize zero gradients.
|
||||||
"""
|
"""
|
||||||
|
|
@ -949,7 +935,8 @@ def iter_move_grads_and_optimizers(
|
||||||
target_comm_node: str,
|
target_comm_node: str,
|
||||||
target_dest_node: str,
|
target_dest_node: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Function to extract a comm block and split out a new optimizer and step for it.
|
"""Extract a comm block and split out a new optimizer and step for it.
|
||||||
|
|
||||||
This subgraph is then moved to the forward graph.
|
This subgraph is then moved to the forward graph.
|
||||||
"""
|
"""
|
||||||
for comm_block in get_all_comm_blocks(gm, "all_reduce"):
|
for comm_block in get_all_comm_blocks(gm, "all_reduce"):
|
||||||
|
|
@ -982,8 +969,7 @@ def find_all_descendants(
|
||||||
gm: IterGraphModule,
|
gm: IterGraphModule,
|
||||||
parent_nodes: List[fx.Node],
|
parent_nodes: List[fx.Node],
|
||||||
) -> List[fx.Node]:
|
) -> List[fx.Node]:
|
||||||
"""identifying list of nodes to move during FX graph transformation"""
|
"""Identify the list of nodes to move during FX graph transformation."""
|
||||||
|
|
||||||
assert len(parent_nodes) > 0, "No parent nodes are given."
|
assert len(parent_nodes) > 0, "No parent nodes are given."
|
||||||
|
|
||||||
output = get_output(gm.graph)
|
output = get_output(gm.graph)
|
||||||
|
|
|
||||||
|
|
@ -41,9 +41,9 @@ def get_node_tensor_metadata(node: fx.Node, is_required: bool = True) -> TensorM
|
||||||
|
|
||||||
|
|
||||||
def get_output(graph: fx.Graph) -> fx.Node:
|
def get_output(graph: fx.Graph) -> fx.Node:
|
||||||
"""
|
"""Take a graphmodule and return the graph output node.
|
||||||
Take a graphmodule and returns the graph output node. We traverse in reverse
|
|
||||||
to expedite it, with the idea that last node should be output
|
We traverse in reverse to expedite it, with the idea that last node should be output
|
||||||
"""
|
"""
|
||||||
for node in reversed(graph.nodes):
|
for node in reversed(graph.nodes):
|
||||||
if node.op == OP.OUTPUT:
|
if node.op == OP.OUTPUT:
|
||||||
|
|
@ -54,10 +54,7 @@ def get_output(graph: fx.Graph) -> fx.Node:
|
||||||
def find_node(
|
def find_node(
|
||||||
graph: fx.Graph, predicate: Callable, reverse_order: bool = False
|
graph: fx.Graph, predicate: Callable, reverse_order: bool = False
|
||||||
) -> List[fx.Node]:
|
) -> List[fx.Node]:
|
||||||
"""
|
"""Take a predicate and return all the nodes in the `graph` where the predicate holds."""
|
||||||
Take a predicate and return all the nodes in the `graph` where the predicate
|
|
||||||
holds.
|
|
||||||
"""
|
|
||||||
nodes = cast(Iterable[fx.Node], graph.nodes)
|
nodes = cast(Iterable[fx.Node], graph.nodes)
|
||||||
if reverse_order:
|
if reverse_order:
|
||||||
nodes = cast(Iterable[fx.Node], iter(reversed(nodes))) # type: ignore[call-overload]
|
nodes = cast(Iterable[fx.Node], iter(reversed(nodes))) # type: ignore[call-overload]
|
||||||
|
|
@ -65,8 +62,8 @@ def find_node(
|
||||||
|
|
||||||
|
|
||||||
def is_leaf_subgraph(graph: fx.Graph, subgraph: List[fx.Node]) -> bool:
|
def is_leaf_subgraph(graph: fx.Graph, subgraph: List[fx.Node]) -> bool:
|
||||||
"""
|
"""Ensure nodes in ``subgraph`` satisfy one of the following rules.
|
||||||
This function ensures nodes in ``subgraph`` satisfy one of the rules:
|
|
||||||
1. The user of the node is in ``subgraph``.
|
1. The user of the node is in ``subgraph``.
|
||||||
2. The user of the node is output.
|
2. The user of the node is output.
|
||||||
3. There are no users -- the node is a side-effect node.
|
3. There are no users -- the node is a side-effect node.
|
||||||
|
|
@ -85,11 +82,10 @@ def is_leaf_subgraph(graph: fx.Graph, subgraph: List[fx.Node]) -> bool:
|
||||||
def clone_subgraph(
|
def clone_subgraph(
|
||||||
graph: fx.Graph, subgraph: List[fx.Node], target: fx.Node
|
graph: fx.Graph, subgraph: List[fx.Node], target: fx.Node
|
||||||
) -> List[fx.Node]:
|
) -> List[fx.Node]:
|
||||||
"""
|
"""Clone the given subgraph and insert it before ``target``.
|
||||||
Clone the given subgraph and insert it before ``target``.
|
|
||||||
This API currently does not support inserting after ``target``.
|
This API currently does not support inserting after ``target``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
all_nodes = set(subgraph)
|
all_nodes = set(subgraph)
|
||||||
mapping: Dict[fx.Node, fx.Node] = dict()
|
mapping: Dict[fx.Node, fx.Node] = dict()
|
||||||
cloned_subgraph = []
|
cloned_subgraph = []
|
||||||
|
|
@ -125,12 +121,11 @@ def clone_subgraph(
|
||||||
|
|
||||||
|
|
||||||
def rebuild_graph(gm: fx.GraphModule, remove_dead_code: bool = True) -> None:
|
def rebuild_graph(gm: fx.GraphModule, remove_dead_code: bool = True) -> None:
|
||||||
"""
|
"""Run the required steps to ensure production-ready graph.
|
||||||
Runs the required steps to ensure production-ready graph.
|
|
||||||
note - per the fx docs, eliminate dead code is not very precise.
|
Note - per the fx docs, elimination of dead code is not very precise.
|
||||||
Hence, the flag to make this step optional.
|
Hence, the flag to make this step optional.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
gm.graph.lint()
|
gm.graph.lint()
|
||||||
if remove_dead_code:
|
if remove_dead_code:
|
||||||
gm.graph.eliminate_dead_code()
|
gm.graph.eliminate_dead_code()
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,9 @@ logger: logging.Logger = logging.getLogger("IterGraphModule")
|
||||||
|
|
||||||
|
|
||||||
class IterGraph(fx.Graph):
|
class IterGraph(fx.Graph):
|
||||||
"""
|
"""``IterGraph`` is used to perform cross-iteration optimization.
|
||||||
``IterGraph`` is used to perform cross-iteration optimization. ``IterGraph``
|
|
||||||
keeps track of the 3 graphs, self (the original graph), setup graph, and
|
``IterGraph`` keeps track of the 3 graphs, self (the original graph), setup graph, and
|
||||||
cleanup graph. The 3 graphs should be identical copies of a ``fx.Graph``.
|
cleanup graph. The 3 graphs should be identical copies of a ``fx.Graph``.
|
||||||
|
|
||||||
IterGraph subclass fx.Graph to override the necessary APIs that will be used
|
IterGraph subclass fx.Graph to override the necessary APIs that will be used
|
||||||
|
|
@ -127,11 +127,10 @@ class IterGraph(fx.Graph):
|
||||||
def _forward_subgraph_inputs(
|
def _forward_subgraph_inputs(
|
||||||
self, subgraph: List[fx.Node], graph: fx.Graph, erase_node: bool
|
self, subgraph: List[fx.Node], graph: fx.Graph, erase_node: bool
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""Turn the inputs of a subgraph into the extra output of the entire graph.
|
||||||
This function turns the inputs of a subgraph into the extra output
|
|
||||||
of the entire graph. If ``erase_node`` is True, the subgraph will be
|
If ``erase_node`` is True, the subgraph will be erased from the graph -- essentially forward the inputs
|
||||||
erased from the graph -- essentially forward the inputs of the subgraph
|
of the subgraph to the output of the graph.
|
||||||
to the output of the graph.
|
|
||||||
"""
|
"""
|
||||||
output = get_output(graph)
|
output = get_output(graph)
|
||||||
inputs = []
|
inputs = []
|
||||||
|
|
@ -219,10 +218,10 @@ class IterGraph(fx.Graph):
|
||||||
def _forward_inputs_to_subgraph(
|
def _forward_inputs_to_subgraph(
|
||||||
self, subgraph: List[fx.Node], graph: fx.Graph, extra_input: int
|
self, subgraph: List[fx.Node], graph: fx.Graph, extra_input: int
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""Create extra input nodes and forward the input nodes to the ``subgraph``.
|
||||||
This function creates extra input nodes and forward the input nodes to
|
|
||||||
the ``subgraph``. The external input nodes of ``subgraph`` (nodes that
|
The external input nodes of ``subgraph`` (nodes that are not in ``subgraph``) will replaced by the newly
|
||||||
are not in ``subgraph``) will replaced by the newly created input nodes.
|
created input nodes.
|
||||||
"""
|
"""
|
||||||
placeholders = [node for node in graph.nodes if str(node.op) == "placeholder"]
|
placeholders = [node for node in graph.nodes if str(node.op) == "placeholder"]
|
||||||
assert placeholders, "No placeholders are found"
|
assert placeholders, "No placeholders are found"
|
||||||
|
|
@ -275,8 +274,8 @@ class IterGraph(fx.Graph):
|
||||||
def move_to_next_iter_before(
|
def move_to_next_iter_before(
|
||||||
self, subgraph: List[fx.Node], target_node: fx.Node
|
self, subgraph: List[fx.Node], target_node: fx.Node
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""Move the ``subgraph`` to the next iteration before ``target_node``.
|
||||||
Move the ``subgraph`` to the next iteration before ``target_node``.
|
|
||||||
The ``subgraph`` is a list of fx.Node and must satisfy the following
|
The ``subgraph`` is a list of fx.Node and must satisfy the following
|
||||||
restrictions:
|
restrictions:
|
||||||
1. The order of the nodes in ``subgraph`` must obey the topological
|
1. The order of the nodes in ``subgraph`` must obey the topological
|
||||||
|
|
@ -633,8 +632,8 @@ class IterGraph(fx.Graph):
|
||||||
|
|
||||||
|
|
||||||
class IterGraphModule(nn.Module):
|
class IterGraphModule(nn.Module):
|
||||||
"""
|
"""``IterGraphModule`` provides the ability to do cross-iteration optimization.
|
||||||
``IterGraphModule`` provides the ability to do cross-iteration optimization.
|
|
||||||
Given a ``fx.GraphModule``, main_gm, ``IterGraphModule`` internally
|
Given a ``fx.GraphModule``, main_gm, ``IterGraphModule`` internally
|
||||||
duplicate it to 3 copies and redirect the ``forward`` request to a different
|
duplicate it to 3 copies and redirect the ``forward`` request to a different
|
||||||
``fx.GraphModule`` based on the iteration count. This allows users to do
|
``fx.GraphModule`` based on the iteration count. This allows users to do
|
||||||
|
|
@ -674,10 +673,9 @@ class IterGraphModule(nn.Module):
|
||||||
self._enable_inductor = enable_inductor
|
self._enable_inductor = enable_inductor
|
||||||
|
|
||||||
def finalize_setup(self) -> None:
|
def finalize_setup(self) -> None:
|
||||||
"""
|
"""Set up the internal states and also get the signal from users that what is the maximum iteration count.
|
||||||
Must be called before the forward() is called. This method setups
|
|
||||||
the internal states and also get the signal from users that what
|
This method must be called before the forward() is called.
|
||||||
is the maximum iteration count.
|
|
||||||
"""
|
"""
|
||||||
if not self._is_frozen:
|
if not self._is_frozen:
|
||||||
self.graph.freeze_cross_iter_movement()
|
self.graph.freeze_cross_iter_movement()
|
||||||
|
|
|
||||||
|
|
@ -112,14 +112,16 @@ FLAT_PARAM = "_flat_param"
|
||||||
|
|
||||||
|
|
||||||
class OptimStateKeyType(Enum):
|
class OptimStateKeyType(Enum):
|
||||||
|
"""Represents the type of key in an optimizer state-dict."""
|
||||||
|
|
||||||
PARAM_NAME = auto()
|
PARAM_NAME = auto()
|
||||||
PARAM_ID = auto()
|
PARAM_ID = auto()
|
||||||
|
|
||||||
|
|
||||||
class FullyShardedDataParallel(nn.Module, _FSDPState):
|
class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
"""
|
"""A wrapper for sharding module parameters across data parallel workers.
|
||||||
A wrapper for sharding module parameters across data parallel workers. This
|
|
||||||
is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
|
This is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
|
||||||
FullyShardedDataParallel is commonly shortened to FSDP.
|
FullyShardedDataParallel is commonly shortened to FSDP.
|
||||||
|
|
||||||
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336
|
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336
|
||||||
|
|
@ -515,9 +517,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def module(self) -> nn.Module:
|
def module(self) -> nn.Module:
|
||||||
"""
|
"""Return the wrapped module."""
|
||||||
Returns the wrapped module (like :class:`DistributedDataParallel`).
|
|
||||||
"""
|
|
||||||
# FSDP's `.module` must refer to the innermost wrapped module when
|
# FSDP's `.module` must refer to the innermost wrapped module when
|
||||||
# composing with other module wrappers in order for state dict to work
|
# composing with other module wrappers in order for state dict to work
|
||||||
if isinstance(self._fsdp_wrapped_module, ActivationWrapper):
|
if isinstance(self._fsdp_wrapped_module, ActivationWrapper):
|
||||||
|
|
@ -547,6 +547,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
return super().__getitem__(key)
|
return super().__getitem__(key)
|
||||||
|
|
||||||
def check_is_root(self) -> bool:
|
def check_is_root(self) -> bool:
|
||||||
|
"""Check if this instance is a root FSDP module."""
|
||||||
return _is_fsdp_root(self, self)
|
return _is_fsdp_root(self, self)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -554,9 +555,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
module: nn.Module,
|
module: nn.Module,
|
||||||
root_only: bool = False,
|
root_only: bool = False,
|
||||||
) -> List["FullyShardedDataParallel"]:
|
) -> List["FullyShardedDataParallel"]:
|
||||||
"""
|
"""Return all nested FSDP instances.
|
||||||
Returns all nested FSDP instances, possibly including ``module`` itself
|
|
||||||
and only including FSDP root modules if ``root_only=True``.
|
This possibly includes ``module`` itself and only includes FSDP root modules if ``root_only=True``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module (torch.nn.Module): Root module, which may or may not be an
|
module (torch.nn.Module): Root module, which may or may not be an
|
||||||
|
|
@ -573,9 +574,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
return traversal_utils._get_fsdp_states(module)
|
return traversal_utils._get_fsdp_states(module)
|
||||||
|
|
||||||
def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
|
def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
|
||||||
r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``)
|
r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
|
||||||
as well as self. Typical use includes initializing the parameters of a model
|
|
||||||
(see also :ref:`nn-init-doc`).
|
Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`).
|
||||||
|
|
||||||
Compared to ``torch.nn.Module.apply``, this version additionally gathers
|
Compared to ``torch.nn.Module.apply``, this version additionally gathers
|
||||||
the full parameters before applying ``fn``. It should not be called from
|
the full parameters before applying ``fn``. It should not be called from
|
||||||
|
|
@ -614,8 +615,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def _mixed_precision_enabled_for_buffers(self) -> bool:
|
def _mixed_precision_enabled_for_buffers(self) -> bool:
|
||||||
"""
|
"""Return whether the user explicitly enabled buffer mixed precision.
|
||||||
Returns if the user explicitly enabled buffer mixed precision.
|
|
||||||
|
|
||||||
NOTE: Unlike parameters and gradient reduction, buffer mixed precision
|
NOTE: Unlike parameters and gradient reduction, buffer mixed precision
|
||||||
is applied at the FSDP instance level, not the ``FlatParameter`` level,
|
is applied at the FSDP instance level, not the ``FlatParameter`` level,
|
||||||
|
|
@ -624,15 +624,11 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
return self.mixed_precision.buffer_dtype is not None
|
return self.mixed_precision.buffer_dtype is not None
|
||||||
|
|
||||||
def _low_precision_hook_enabled(self) -> bool:
|
def _low_precision_hook_enabled(self) -> bool:
|
||||||
"""
|
"""Whether a low precision hook is registered or not."""
|
||||||
Whether a low precision hook is registered or not.
|
|
||||||
"""
|
|
||||||
return self._comm_hook is not None and self._comm_hook in LOW_PRECISION_HOOKS
|
return self._comm_hook is not None and self._comm_hook in LOW_PRECISION_HOOKS
|
||||||
|
|
||||||
def _reset_lazy_init(self) -> None:
|
def _reset_lazy_init(self) -> None:
|
||||||
"""
|
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
|
||||||
Reset instance so :func:`_lazy_init` will run on the next forward.
|
|
||||||
"""
|
|
||||||
self._is_root: Optional[bool] = None
|
self._is_root: Optional[bool] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -642,9 +638,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
state_dict_config: Optional[StateDictConfig] = None,
|
state_dict_config: Optional[StateDictConfig] = None,
|
||||||
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
|
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
|
||||||
) -> StateDictSettings:
|
) -> StateDictSettings:
|
||||||
"""
|
"""Set the ``state_dict_type`` of all the descendant FSDP modules of the target module.
|
||||||
Set the ``state_dict_type`` and the corresponding (optional)
|
|
||||||
configurations of all the descendant FSDP modules of the target module.
|
Also takes (optional) configuration for the model's and optimizer's state dict.
|
||||||
The target module does not have to be a FSDP module. If the target
|
The target module does not have to be a FSDP module. If the target
|
||||||
module is a FSDP module, its ``state_dict_type`` will also be changed.
|
module is a FSDP module, its ``state_dict_type`` will also be changed.
|
||||||
|
|
||||||
|
|
@ -747,10 +743,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_state_dict_type(module: nn.Module) -> StateDictSettings:
|
def get_state_dict_type(module: nn.Module) -> StateDictSettings:
|
||||||
"""
|
"""Get the state_dict_type and the corresponding configurations for the FSDP modules rooted at ``module``.
|
||||||
Get the state_dict_type and the corresponding configurations
|
|
||||||
for the FSDP modules rooted at ``module``. The target module
|
The target module does not have to be an FSDP module.
|
||||||
does not have to be an FSDP module.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A ``StateDictSettings`` containing the state_dict_type and
|
A ``StateDictSettings`` containing the state_dict_type and
|
||||||
|
|
@ -797,10 +792,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
state_dict_config: Optional[StateDictConfig] = None,
|
state_dict_config: Optional[StateDictConfig] = None,
|
||||||
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
|
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
"""
|
"""Set the ``state_dict_type`` of all the descendant FSDP modules of the target module.
|
||||||
A context manager to set the ``state_dict_type`` of all the descendant
|
|
||||||
FSDP modules of the target module. This context manager has the same
|
This context manager has the same functions as :meth:`set_state_dict_type`. Read the document of
|
||||||
functions as :meth:`set_state_dict_type`. Read the document of
|
|
||||||
:meth:`set_state_dict_type` for the detail.
|
:meth:`set_state_dict_type` for the detail.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
@ -836,10 +830,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
"""
|
"""Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic."""
|
||||||
Runs the forward pass for the wrapped module, inserting FSDP-specific
|
|
||||||
pre- and post-forward sharding logic.
|
|
||||||
"""
|
|
||||||
handle = self._handle
|
handle = self._handle
|
||||||
with torch.autograd.profiler.record_function(
|
with torch.autograd.profiler.record_function(
|
||||||
"FullyShardedDataParallel.forward"
|
"FullyShardedDataParallel.forward"
|
||||||
|
|
@ -875,7 +866,8 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
offload_to_cpu: bool = False,
|
offload_to_cpu: bool = False,
|
||||||
with_grads: bool = False,
|
with_grads: bool = False,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
r"""A context manager to expose full params for FSDP instances.
|
r"""Expose full params for FSDP instances with this context manager.
|
||||||
|
|
||||||
Can be useful *after* forward/backward for a model to get
|
Can be useful *after* forward/backward for a model to get
|
||||||
the params for additional processing or checking. It can take a non-FSDP
|
the params for additional processing or checking. It can take a non-FSDP
|
||||||
module and will summon full params for all contained FSDP modules as
|
module and will summon full params for all contained FSDP modules as
|
||||||
|
|
@ -941,9 +933,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _deregister_orig_params_ctx(self):
|
def _deregister_orig_params_ctx(self):
|
||||||
"""
|
"""Deregister the original parameters and expose the :class:`FlatParameter`.
|
||||||
This deregisters the original parameters and exposes the
|
|
||||||
:class:`FlatParameter` s. If a :class:`FlatParameter` is sharded, then
|
If a :class:`FlatParameter` is sharded, then
|
||||||
this refreshes the sharded views before exiting. This method should
|
this refreshes the sharded views before exiting. This method should
|
||||||
only be called when using the original parameters.
|
only be called when using the original parameters.
|
||||||
"""
|
"""
|
||||||
|
|
@ -961,11 +953,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
_register_orig_params(fsdp_module, fsdp_module)
|
_register_orig_params(fsdp_module, fsdp_module)
|
||||||
|
|
||||||
def _apply(self, *args, **kwargs):
|
def _apply(self, *args, **kwargs):
|
||||||
"""
|
"""Deregister the original parameters and expose the :class:`FlatParameter` s before calling ``_apply()``."""
|
||||||
When using the original parameters, this deregisters the original
|
|
||||||
parameters and exposes the :class:`FlatParameter` s before calling
|
|
||||||
``_apply()``.
|
|
||||||
"""
|
|
||||||
# When using the original parameters: Since (1) the `FlatParameter`s
|
# When using the original parameters: Since (1) the `FlatParameter`s
|
||||||
# own the storage and (2) `_apply()` is the subroutine underlying the
|
# own the storage and (2) `_apply()` is the subroutine underlying the
|
||||||
# most common storage-changing ops like `to()` and `cuda()`, we
|
# most common storage-changing ops like `to()` and `cuda()`, we
|
||||||
|
|
@ -985,9 +973,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Iterator[Tuple[str, torch.Tensor]]:
|
) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||||
"""
|
"""Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself.
|
||||||
Overrides :meth:`named_buffers()` to intercept buffer names and
|
|
||||||
remove all occurrences of the FSDP-specific flattened buffer prefix
|
Intercepts buffer names and removes all occurrences of the FSDP-specific flattened buffer prefix
|
||||||
when inside the :meth:`summon_full_params` context manager.
|
when inside the :meth:`summon_full_params` context manager.
|
||||||
"""
|
"""
|
||||||
should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
|
should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
|
||||||
|
|
@ -1003,9 +991,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
|
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
|
||||||
"""
|
"""Return an iterator over module parameters, yielding both the name of the parameter and the parameter itself.
|
||||||
Overrides :meth:`named_parameters()` to intercept parameter names and
|
|
||||||
remove all occurrences of the FSDP-specific flattened parameter prefix
|
Intercepts parameter names and removes all occurrences of the FSDP-specific flattened parameter prefix
|
||||||
when inside the :meth:`summon_full_params` context manager.
|
when inside the :meth:`summon_full_params` context manager.
|
||||||
"""
|
"""
|
||||||
should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
|
should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
|
||||||
|
|
@ -1038,9 +1026,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def no_sync(self) -> Generator:
|
def no_sync(self) -> Generator:
|
||||||
"""
|
"""Disable gradient synchronizations across FSDP instances.
|
||||||
A context manager to disable gradient synchronizations across FSDP
|
|
||||||
instances. Within this context, gradients will be accumulated in module
|
Within this context, gradients will be accumulated in module
|
||||||
variables, which will later be synchronized in the first
|
variables, which will later be synchronized in the first
|
||||||
forward-backward pass after exiting the context. This should only be
|
forward-backward pass after exiting the context. This should only be
|
||||||
used on the root FSDP instance and will recursively apply to all
|
used on the root FSDP instance and will recursively apply to all
|
||||||
|
|
@ -1079,9 +1067,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
def clip_grad_norm_(
|
def clip_grad_norm_(
|
||||||
self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0
|
self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""Clip the gradient norm of all parameters.
|
||||||
Clips the gradient norm of all parameters. The norm is computed over
|
|
||||||
all parameters' gradients as viewed as a single vector, and the
|
The norm is computed over all parameters' gradients as viewed as a single vector, and the
|
||||||
gradients are modified in-place.
|
gradients are modified in-place.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -1245,8 +1233,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
group: Optional[dist.ProcessGroup] = None,
|
group: Optional[dist.ProcessGroup] = None,
|
||||||
cpu_offload: bool = True,
|
cpu_offload: bool = True,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""Transform the state-dict of an optimizer corresponding to a sharded model.
|
||||||
The internal API that is used by all the optim_state_dict implementations.
|
|
||||||
|
This is the internal API that is used by all the optim_state_dict implementations.
|
||||||
Given model, optim, the original optim_state_dict, this API removes the
|
Given model, optim, the original optim_state_dict, this API removes the
|
||||||
FSDP internal information and internal sharding from the optim_state_dict.
|
FSDP internal information and internal sharding from the optim_state_dict.
|
||||||
"""
|
"""
|
||||||
|
|
@ -1298,7 +1287,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
group: Optional[dist.ProcessGroup] = None,
|
group: Optional[dist.ProcessGroup] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
The internal API that is used by all the load optim_state_dict implementations.
|
Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
|
||||||
|
|
||||||
|
This is the internal API that is used by all the load optim_state_dict implementations.
|
||||||
Given model, optim, and the saved optim_state_dict, this API adds the FSDP
|
Given model, optim, and the saved optim_state_dict, this API adds the FSDP
|
||||||
internal information and internal sharding to the optim_state_dict.
|
internal information and internal sharding to the optim_state_dict.
|
||||||
"""
|
"""
|
||||||
|
|
@ -1352,7 +1343,8 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
rank0_only: bool = True,
|
rank0_only: bool = True,
|
||||||
group: Optional[dist.ProcessGroup] = None,
|
group: Optional[dist.ProcessGroup] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""Return the full optimizer state-dict.
|
||||||
|
|
||||||
Consolidates the full optimizer state on rank 0 and returns it
|
Consolidates the full optimizer state on rank 0 and returns it
|
||||||
as a :class:`dict` following the convention of
|
as a :class:`dict` following the convention of
|
||||||
:meth:`torch.optim.Optimizer.state_dict`, i.e. with keys ``"state"``
|
:meth:`torch.optim.Optimizer.state_dict`, i.e. with keys ``"state"``
|
||||||
|
|
@ -1417,7 +1409,8 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
optim: torch.optim.Optimizer,
|
optim: torch.optim.Optimizer,
|
||||||
group: Optional[dist.ProcessGroup] = None,
|
group: Optional[dist.ProcessGroup] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""Return the optimizer state-dict in its sharded form.
|
||||||
|
|
||||||
The API is similar to :meth:`full_optim_state_dict` but this API chunks
|
The API is similar to :meth:`full_optim_state_dict` but this API chunks
|
||||||
all non-zero-dimension states to :class:`ShardedTensor` to save memory.
|
all non-zero-dimension states to :class:`ShardedTensor` to save memory.
|
||||||
This API should only be used when the model ``state_dict`` is derived
|
This API should only be used when the model ``state_dict`` is derived
|
||||||
|
|
@ -1453,12 +1446,11 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
] = None,
|
] = None,
|
||||||
optim: Optional[torch.optim.Optimizer] = None,
|
optim: Optional[torch.optim.Optimizer] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""Shard a full optimizer state-dict.
|
||||||
Shards the full optimizer state dict ``full_optim_state_dict`` by
|
|
||||||
remapping the state to flattened parameters instead of unflattened
|
Remaps the state in ``full_optim_state_dict`` to flattened parameters instead of unflattened
|
||||||
parameters and restricting to only this rank's part of the optimizer
|
parameters and restricts to only this rank's part of the optimizer state.
|
||||||
state. The first argument should be the return value of
|
The first argument should be the return value of :meth:`full_optim_state_dict`.
|
||||||
:meth:`full_optim_state_dict`.
|
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
|
|
@ -1525,7 +1517,8 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
optim: torch.optim.Optimizer,
|
optim: torch.optim.Optimizer,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""Flatten a sharded optimizer state-dict.
|
||||||
|
|
||||||
The API is similar to :meth:`shard_full_optim_state_dict`. The only
|
The API is similar to :meth:`shard_full_optim_state_dict`. The only
|
||||||
difference is that the input ``sharded_optim_state_dict`` should be
|
difference is that the input ``sharded_optim_state_dict`` should be
|
||||||
returned from :meth:`sharded_optim_state_dict`. Therefore, there will
|
returned from :meth:`sharded_optim_state_dict`. Therefore, there will
|
||||||
|
|
@ -1568,10 +1561,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
optim: Optional[torch.optim.Optimizer] = None,
|
optim: Optional[torch.optim.Optimizer] = None,
|
||||||
group: Optional[Any] = None,
|
group: Optional[Any] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""Scatter the full optimizer state dict from rank 0 to all other ranks.
|
||||||
Scatters the full optimizer state dict from rank 0 to all other ranks,
|
|
||||||
returning the sharded optimizer state dict on each rank. The return
|
Returns the sharded optimizer state dict on each rank.
|
||||||
value is the same as :meth:`shard_full_optim_state_dict`, and on rank
|
The return value is the same as :meth:`shard_full_optim_state_dict`, and on rank
|
||||||
0, the first argument should be the return value of
|
0, the first argument should be the return value of
|
||||||
:meth:`full_optim_state_dict`.
|
:meth:`full_optim_state_dict`.
|
||||||
|
|
||||||
|
|
@ -1650,10 +1643,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
] = None,
|
] = None,
|
||||||
optim: Optional[torch.optim.Optimizer] = None,
|
optim: Optional[torch.optim.Optimizer] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""Re-keys the optimizer state dict ``optim_state_dict`` to use the key type ``optim_state_key_type``.
|
||||||
Re-keys the optimizer state dict ``optim_state_dict`` to use the key
|
|
||||||
type ``optim_state_key_type``. This can be used to achieve
|
This can be used to achieve compatibility between optimizer state dicts from models with FSDP
|
||||||
compatibility between optimizer state dicts from models with FSDP
|
|
||||||
instances and ones without.
|
instances and ones without.
|
||||||
|
|
||||||
To re-key an FSDP full optimizer state dict (i.e. from
|
To re-key an FSDP full optimizer state dict (i.e. from
|
||||||
|
|
@ -1771,9 +1763,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
group: Optional[dist.ProcessGroup] = None,
|
group: Optional[dist.ProcessGroup] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Transforms the state_dict of ``optim`` for the ``model`` that is sharded
|
Transform the state-dict of an optimizer corresponding to a sharded model.
|
||||||
by FSDP to one of the three types: 1) full optimizer state_dict, 2)
|
|
||||||
sharded optimizer state_dict, 3) local optimizer state_dict.
|
The given state-dict can be transformed to one of three types:
|
||||||
|
1) full optimizer state_dict, 2) sharded optimizer state_dict, 3) local optimizer state_dict.
|
||||||
|
|
||||||
For full optimizer state_dict, all states are unflattened and not sharded.
|
For full optimizer state_dict, all states are unflattened and not sharded.
|
||||||
Rank0 only and CPU only can be specified via :meth:`state_dict_type` to
|
Rank0 only and CPU only can be specified via :meth:`state_dict_type` to
|
||||||
|
|
@ -1867,10 +1860,12 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
group: Optional[dist.ProcessGroup] = None,
|
group: Optional[dist.ProcessGroup] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
|
||||||
|
|
||||||
Given a ``optim_state_dict`` that is transformed through
|
Given a ``optim_state_dict`` that is transformed through
|
||||||
:meth:`optim_state_dict`, converts it to the flattened optimizer
|
:meth:`optim_state_dict`, it gets converted to the flattened optimizer
|
||||||
state_dict that can be loaded to ``optim`` which is the optimizer for
|
state_dict that can be loaded to ``optim`` which is the optimizer for
|
||||||
``model``. ``model`` must be sharded by FullyShardedDataParallel.
|
``model``. ``model`` must be sharded by FullyShardedDataParallel.
|
||||||
|
|
||||||
>>> # xdoctest: +SKIP("undefined variables")
|
>>> # xdoctest: +SKIP("undefined variables")
|
||||||
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
|
|
@ -1946,10 +1941,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def register_comm_hook(self, state: object, hook: callable):
|
def register_comm_hook(self, state: object, hook: callable):
|
||||||
"""
|
"""Register a communication hook.
|
||||||
Registers a communication hook which is an enhancement that provides a
|
|
||||||
flexible hook to users where they can specify how FSDP aggregates gradients
|
This is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates
|
||||||
across multiple workers.
|
gradients across multiple workers.
|
||||||
This hook can be used to implement several algorithms like
|
This hook can be used to implement several algorithms like
|
||||||
`GossipGrad <https://arxiv.org/abs/1803.05880>`_ and gradient compression
|
`GossipGrad <https://arxiv.org/abs/1803.05880>`_ and gradient compression
|
||||||
which involve different communication strategies for
|
which involve different communication strategies for
|
||||||
|
|
@ -2008,9 +2003,9 @@ def _get_grad_norm(
|
||||||
norm_type: float,
|
norm_type: float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Returns the gradient norm of parameters ``param`` s, where the gradients
|
Return the gradient norm of parameters ``param`` s, where the gradients are viewed as a single vector.
|
||||||
are viewed as a single vector. The returned norm is in FP32 even if
|
|
||||||
parameters/gradients are in a low precision. This is because the downstream
|
The returned norm is in FP32 even if parameters/gradients are in a low precision. This is because the downstream
|
||||||
use of this return value is a reduction across ranks.
|
use of this return value is a reduction across ranks.
|
||||||
"""
|
"""
|
||||||
params_with_grad = [param for param in params if param.grad is not None]
|
params_with_grad = [param for param in params if param.grad is not None]
|
||||||
|
|
@ -2041,8 +2036,9 @@ def _get_param_to_fqn(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
) -> Dict[torch.nn.Parameter, str]:
|
) -> Dict[torch.nn.Parameter, str]:
|
||||||
"""
|
"""
|
||||||
Constructs a mapping from parameters to their parameter names. ``model``
|
Construct a mapping from parameters to their parameter names.
|
||||||
should not contain any :class:`FullyShardedDataParallel` instances, which
|
|
||||||
|
The ``model`` should not contain any :class:`FullyShardedDataParallel` instances, which
|
||||||
means that none of the parameters should be ``FlatParameter`` s. As a
|
means that none of the parameters should be ``FlatParameter`` s. As a
|
||||||
result, compared to :meth:`_get_param_to_fqns`, the mapped
|
result, compared to :meth:`_get_param_to_fqns`, the mapped
|
||||||
values may be flattened from singleton :class:`list` s to the contained
|
values may be flattened from singleton :class:`list` s to the contained
|
||||||
|
|
@ -2071,6 +2067,6 @@ def _get_param_to_fqn(
|
||||||
def _get_fqn_to_param(
|
def _get_fqn_to_param(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
) -> Dict[str, torch.nn.Parameter]:
|
) -> Dict[str, torch.nn.Parameter]:
|
||||||
"""Constructs the inverse mapping of :meth:`_get_param_to_fqn`."""
|
"""Construct the inverse mapping of :meth:`_get_param_to_fqn`."""
|
||||||
param_to_param_name = _get_param_to_fqn(model)
|
param_to_param_name = _get_param_to_fqn(model)
|
||||||
return dict(zip(param_to_param_name.values(), param_to_param_name.keys()))
|
return dict(zip(param_to_param_name.values(), param_to_param_name.keys()))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user