mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[FSDP2] Move to public torch.distributed.fsdp (#141868)"
This reverts commit 45583a5df9.
Reverted https://github.com/pytorch/pytorch/pull/141868 on behalf of https://github.com/atalman due to failing internally ([comment](https://github.com/pytorch/pytorch/pull/141868#issuecomment-2523925180))
This commit is contained in:
parent
4af7aa5e64
commit
bab15df40a
|
|
@ -1,85 +0,0 @@
|
|||
torch.distributed.fsdp.fully_shard
|
||||
==================================
|
||||
|
||||
PyTorch FSDP2 (``fully_shard``)
|
||||
-------------------------------
|
||||
|
||||
PyTorch FSDP2 provides a fully sharded data parallelism (FSDP) implementation
|
||||
targeting performant eager-mode while using per-parameter sharding for improved
|
||||
usability.
|
||||
|
||||
- If you are new to FSDP, we recommend that you start with FSDP2 due to improved
|
||||
usability.
|
||||
- If you are currently using FSDP1, consider evaluating the following
|
||||
differences to see if you should switch to FSDP2:
|
||||
|
||||
Compared to PyTorch FSDP1 (``FullyShardedDataParallel``):
|
||||
|
||||
- FSDP2 uses ``DTensor``-based dim-0 per-parameter sharding for a simpler
|
||||
sharding representation compared to FSDP1's flat-parameter sharding, while
|
||||
preserving similar throughput performance. More specifically, FSDP2 chunks
|
||||
each parameter on dim-0 across the data parallel workers (using
|
||||
``torch.chunk(dim=0)``), whereas FSDP1 flattens, concatenates, and chunks a
|
||||
group of tensors together, making reasoning about what data is present on
|
||||
each worker and resharding to different parallelisms complex. Per-parameter
|
||||
sharding provides a more intuitive user experience, relaxes constraints
|
||||
around frozen parameters, and allows for communication-free (sharded) state
|
||||
dicts, which otherwise require all-gathers in FSDP1.
|
||||
- FSDP2 implements a different memory management approach to handle the
|
||||
multi-stream usages that avoids ``torch.Tensor.record_stream``. This ensures
|
||||
deterministic and expected memory usage and does not require blocking the CPU
|
||||
like in FSDP1's ``limit_all_gathers=True``.
|
||||
- FSDP2 exposes APIs for manual control over prefetching and collective
|
||||
scheduling, allowing power users more customization. See the methods on
|
||||
``FSDPModule`` below for details.
|
||||
- FSDP2 simplifies some of the API surface: e.g. FSDP2 does not directly
|
||||
support full state dicts. Instead, users can reshard the sharded state dicts
|
||||
containing ``DTensor`` s to full state dicts themselves using ``DTensor``
|
||||
APIs like ``DTensor.full_tensor()`` or by using higher-level APIs like
|
||||
`PyTorch Distributed Checkpoint <https://pytorch.org/docs/stable/distributed.checkpoint.html>`_ 's
|
||||
distributed state dict APIs. Also, some other args have been removed; see
|
||||
`here <https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md>`_ for
|
||||
details.
|
||||
|
||||
If you are onboarding FSDP for the first time or if any of the above appeals to
|
||||
your use case, we recommend that you consider using FSDP2.
|
||||
|
||||
See `this RFC <https://github.com/pytorch/pytorch/issues/114299>`_ for details
|
||||
on system design and implementation.
|
||||
|
||||
.. note::
|
||||
``torch.distributed.fsdp.fully_shard`` is currently in prototype state and
|
||||
under development. The core API will likely not change, but we may make some
|
||||
API changes if necessary.
|
||||
|
||||
.. currentmodule:: torch.distributed.fsdp
|
||||
|
||||
The frontend API is ``fully_shard`` that can be called on a ``module``:
|
||||
|
||||
.. autofunction:: fully_shard
|
||||
|
||||
Calling ``fully_shard(module)`` dynamically constructs a new class that
|
||||
subclasses ``type(module)`` and an FSDP class ``FSDPModule``. For example, if
|
||||
we call ``fully_shard(linear)`` on a module ``linear: nn.Linear``, then FSDP
|
||||
constructs a new class ``FSDPLinear`` and changes ``linear`` 's type to this.
|
||||
Otherwise, ``fully_shard`` does not change the module structure and parameter
|
||||
fully-qualified names. The class ``FSDPModule`` allows providing some
|
||||
FSDP-specific methods on the module.
|
||||
|
||||
.. autoclass:: FSDPModule
|
||||
:members:
|
||||
:member-order: bysource
|
||||
|
||||
.. autoclass:: UnshardHandle
|
||||
:members:
|
||||
|
||||
.. autofunction:: register_fsdp_forward_method
|
||||
|
||||
.. autoclass:: MixedPrecisionPolicy
|
||||
:members:
|
||||
|
||||
.. autoclass:: OffloadPolicy
|
||||
:members:
|
||||
|
||||
.. autoclass:: CPUOffloadPolicy
|
||||
:members:
|
||||
|
|
@ -79,7 +79,6 @@ Features described in this documentation are classified by release status:
|
|||
torch.distributed.algorithms.join <distributed.algorithms.join>
|
||||
torch.distributed.elastic <distributed.elastic>
|
||||
torch.distributed.fsdp <fsdp>
|
||||
torch.distributed.fsdp.fully_shard <distributed.fsdp.fully_shard>
|
||||
torch.distributed.tensor.parallel <distributed.tensor.parallel>
|
||||
torch.distributed.optim <distributed.optim>
|
||||
torch.distributed.pipelining <distributed.pipelining>
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from typing import Any, List, Optional, Type, Union
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.nn.parallel.scatter_gather import _is_namedtuple
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from typing import Optional, Union
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable import replicate
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, MLPStack
|
||||
|
|
|
|||
|
|
@ -11,30 +11,30 @@ import torch.distributed as dist
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._composable import checkpoint, replicate
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
from torch.distributed.fsdp import (
|
||||
from torch.distributed._composable.fsdp import (
|
||||
FSDPModule,
|
||||
fully_shard,
|
||||
MixedPrecisionPolicy,
|
||||
OffloadPolicy,
|
||||
)
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
|
||||
from torch.distributed._composable.fsdp._fsdp_collectives import (
|
||||
_div_if_needed,
|
||||
_get_gradient_divide_factors,
|
||||
foreach_all_gather,
|
||||
foreach_all_gather_copy_out,
|
||||
foreach_reduce,
|
||||
)
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_common import FSDPMeshInfo, TrainingState
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_init import (
|
||||
from torch.distributed._composable.fsdp._fsdp_common import FSDPMeshInfo, TrainingState
|
||||
from torch.distributed._composable.fsdp._fsdp_init import (
|
||||
_get_post_forward_mesh_info,
|
||||
_init_default_fully_shard_mesh,
|
||||
)
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param import ShardedState
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.distributed._composable.fsdp._fsdp_param import ShardedState
|
||||
from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed._tensor.experimental import implicit_replication
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.experimental import implicit_replication
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
|
|
|
|||
|
|
@ -12,19 +12,17 @@ from unittest import mock
|
|||
|
||||
import torch
|
||||
import torch._dynamo.testing
|
||||
import torch.distributed._composable.fsdp._fsdp_param
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import comms
|
||||
from torch._inductor.utils import is_fallback_op, run_and_get_code
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed._composable.fsdp._fsdp_common import TrainingState
|
||||
from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup
|
||||
from torch.distributed._tensor import init_device_mesh
|
||||
from torch.distributed.fsdp import (
|
||||
fully_shard,
|
||||
FullyShardedDataParallel as FSDP,
|
||||
ShardingStrategy,
|
||||
)
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_distributed import (
|
||||
at_least_x_gpu,
|
||||
|
|
@ -85,7 +83,7 @@ class TestFullyShardCompileCompute(FSDPTest):
|
|||
):
|
||||
torch._dynamo.reset()
|
||||
trace_rules_check_count = 0
|
||||
HOOKS_FILE_NAME = "torch/distributed/fsdp/_fully_shard/_fsdp_state.py"
|
||||
HOOKS_FILE_NAME = "torch/distributed/_composable/fsdp/_fsdp_state.py"
|
||||
HOOK_WRAPPER_NAME = "fsdp_hook_wrapper"
|
||||
|
||||
def patched_trace_rules_check(*args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@ import torch.distributed as dist
|
|||
import torch.nn as nn
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.autograd.grad_mode import _unsafe_preserve_version_counter
|
||||
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ import torch.distributed as dist
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._composable import checkpoint, replicate
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param_group import (
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed._composable.fsdp._fsdp_param_group import (
|
||||
RegisterPostBackwardFunction,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ import copy
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.amp.grad_scaler import GradScaler, OptState
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed._tensor import init_device_mesh
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,13 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable import replicate
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed._composable.fsdp._fsdp_init import (
|
||||
_get_managed_modules,
|
||||
_get_managed_states,
|
||||
)
|
||||
from torch.distributed._composable.fsdp._fsdp_param import ParamModuleInfo
|
||||
from torch.distributed._composable.fsdp._fsdp_param_group import _get_param_module_infos
|
||||
from torch.distributed._tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
|
|
@ -17,15 +24,6 @@ from torch.distributed._tensor import (
|
|||
Shard,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_init import (
|
||||
_get_managed_modules,
|
||||
_get_managed_states,
|
||||
)
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param import ParamModuleInfo
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param_group import (
|
||||
_get_param_module_infos,
|
||||
)
|
||||
from torch.distributed.fsdp._init_utils import (
|
||||
_init_inter_node_process_group,
|
||||
_init_intra_node_process_group,
|
||||
|
|
@ -1158,26 +1156,5 @@ class TestFullyShardShardPlacementFn(FSDPTestMultiThread):
|
|||
fully_shard(model, shard_placement_fn=shard_placement_fn)
|
||||
|
||||
|
||||
# TODO: Remove this test class once we remove the old import path:
|
||||
# torch/distributed/_composable/fsdp
|
||||
class TestFullyShardOldImport(FSDPTestMultiThread):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
def test_old_import_training(self):
|
||||
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
|
||||
|
||||
model = nn.Sequential(nn.Linear(16, 16), nn.Linear(16, 16))
|
||||
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
|
||||
fully_shard(model[0], mp_policy=mp_policy)
|
||||
fully_shard(model[1], mp_policy=mp_policy)
|
||||
fully_shard(model, mp_policy=mp_policy)
|
||||
|
||||
inp = torch.randn((8, 16), device="cuda")
|
||||
model(inp).sum().backward()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ import logging
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
logger = logging.getLogger("torch.distributed._composable.fsdp")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
device = "cuda"
|
||||
|
|
|
|||
|
|
@ -4,7 +4,11 @@ import functools
|
|||
import gc
|
||||
|
||||
import torch
|
||||
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, OffloadPolicy
|
||||
from torch.distributed._composable.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
fully_shard,
|
||||
OffloadPolicy,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
import torch.nn as nn
|
||||
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
|
||||
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
|
||||
from torch.distributed._composable.fsdp._fsdp_collectives import (
|
||||
_get_gradient_divide_factors,
|
||||
)
|
||||
from torch.distributed.tensor import Shard
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from typing import Callable
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed._tensor.experimental import implicit_replication
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
FSDPTest,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import copy
|
|||
import unittest
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.distributed.fsdp import FSDPModule, fully_shard
|
||||
from torch.distributed._composable.fsdp import FSDPModule, fully_shard
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_fsdp import FSDPTestMultiThread, MLP
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from typing import Dict, Optional
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard
|
||||
from torch.distributed.tensor import distribute_tensor, DTensor, Shard
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
|
|
|
|||
|
|
@ -12,18 +12,18 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable import checkpoint, replicate
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
_CHECKPOINT_PREFIX,
|
||||
apply_activation_checkpointing,
|
||||
)
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.fsdp import (
|
||||
from torch.distributed._composable.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
FSDPModule,
|
||||
fully_shard,
|
||||
OffloadPolicy,
|
||||
register_fsdp_forward_method,
|
||||
)
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
_CHECKPOINT_PREFIX,
|
||||
apply_activation_checkpointing,
|
||||
)
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor import DTensor, init_device_mesh, Shard
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
|
|
@ -671,7 +671,7 @@ class TestFullyShard1DTrainingCompose(FSDPTest):
|
|||
module_grouping: str,
|
||||
):
|
||||
assert checkpoint_impl in ("composable", "utils", "wrapper")
|
||||
testing_compile = fully_shard != torch.distributed.fsdp.fully_shard
|
||||
testing_compile = fully_shard != torch.distributed._composable.fsdp.fully_shard
|
||||
if testing_compile and checkpoint_impl == "composable":
|
||||
return
|
||||
torch.manual_seed(42)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import torch.distributed.checkpoint as dcp
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._composable import replicate
|
||||
from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard
|
||||
from torch.distributed._tensor import DTensor, init_device_mesh, Replicate, Shard
|
||||
from torch.distributed.checkpoint.state_dict import (
|
||||
get_model_state_dict,
|
||||
|
|
@ -21,11 +22,7 @@ from torch.distributed.checkpoint.state_dict import (
|
|||
StateDictOptions,
|
||||
)
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
fully_shard,
|
||||
FullyShardedDataParallel as FSDP,
|
||||
)
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp._common_utils import (
|
||||
_get_module_fsdp_state,
|
||||
clean_tensor_name,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,10 @@ from typing import TYPE_CHECKING
|
|||
import torch
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable.fsdp.fully_shard import (
|
||||
fully_shard,
|
||||
MixedPrecisionPolicy,
|
||||
)
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed.checkpoint import FileSystemReader
|
||||
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
|
||||
|
|
@ -13,7 +17,6 @@ from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_di
|
|||
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
|
||||
from torch.distributed.checkpoint.stateful import Stateful
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
|
||||
from torch.distributed.pipelining import PipelineStage
|
||||
from torch.distributed.pipelining.schedules import (
|
||||
PipelineScheduleSingle,
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed._composable.replicate import replicate
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
skip_if_lt_x_gpu,
|
||||
|
|
|
|||
|
|
@ -6,12 +6,12 @@ import itertools
|
|||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
import torch.distributed.tensor._random as random
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh
|
||||
from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
|
||||
from torch.distributed._tensor.api import distribute_tensor
|
||||
from torch.distributed._tensor.placement_types import Replicate, Shard
|
||||
from torch.distributed.distributed_c10d import broadcast_object_list
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from torch.distributed.tensor._random import (
|
||||
is_rng_supported_mesh,
|
||||
manual_seed,
|
||||
|
|
|
|||
|
|
@ -6,18 +6,18 @@ from typing import Union
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable import checkpoint
|
||||
from torch.distributed._composable.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
fully_shard,
|
||||
MixedPrecisionPolicy,
|
||||
OffloadPolicy,
|
||||
)
|
||||
from torch.distributed._tensor import init_device_mesh
|
||||
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
apply_activation_checkpointing,
|
||||
CheckpointWrapper,
|
||||
)
|
||||
from torch.distributed.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
fully_shard,
|
||||
MixedPrecisionPolicy,
|
||||
OffloadPolicy,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, MLP
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import copy
|
|||
import torch
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed._tensor import DTensor, init_device_mesh
|
||||
from torch.distributed._tensor.experimental import implicit_replication
|
||||
from torch.distributed.checkpoint.state_dict import (
|
||||
|
|
@ -13,11 +14,7 @@ from torch.distributed.checkpoint.state_dict import (
|
|||
get_optimizer_state_dict,
|
||||
StateDictOptions,
|
||||
)
|
||||
from torch.distributed.fsdp import (
|
||||
fully_shard,
|
||||
FullyShardedDataParallel as FSDP,
|
||||
StateDictType,
|
||||
)
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
|
||||
from torch.distributed.fsdp.wrap import always_wrap_policy
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable import replicate
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
from torch.distributed._tensor import DTensor, init_device_mesh
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
|
|
@ -27,7 +28,6 @@ from torch.distributed.checkpoint.state_dict import (
|
|||
StateDictOptions,
|
||||
)
|
||||
from torch.distributed.fsdp import (
|
||||
fully_shard,
|
||||
FullyShardedDataParallel as FSDP,
|
||||
ShardingStrategy,
|
||||
StateDictType,
|
||||
|
|
|
|||
|
|
@ -3263,7 +3263,7 @@ if torch.distributed.is_available():
|
|||
"torch.distributed._composable.replicate",
|
||||
}
|
||||
if not torch._dynamo.config.skip_fsdp_hooks:
|
||||
LEGACY_MOD_INLINELIST.add("torch.distributed.fsdp._fully_shard")
|
||||
LEGACY_MOD_INLINELIST.add("torch.distributed._composable.fsdp")
|
||||
|
||||
|
||||
# Force inline functions under these modules, even they are in *_SKIPLIST.
|
||||
|
|
@ -3323,7 +3323,7 @@ MOD_INLINELIST = set(MOD_INLINELIST)
|
|||
if torch.distributed.is_available():
|
||||
MOD_INLINELIST.add("torch.distributed")
|
||||
if not torch._dynamo.config.skip_fsdp_hooks:
|
||||
MOD_INLINELIST.add("torch.distributed.fsdp._fully_shard")
|
||||
MOD_INLINELIST.add("torch.distributed._composable.fsdp")
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
|
|
|
|||
|
|
@ -994,7 +994,7 @@ class FSDPParamGroupUseTrainingStateVariable(ContextWrappingVariable):
|
|||
self.param_group_var.value._training_state = value
|
||||
|
||||
def module_name(self):
|
||||
return "torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup"
|
||||
return "torch.distributed._composable.fsdp._fsdp_param_group.FSDPParamGroup"
|
||||
|
||||
def fn_name(self):
|
||||
return "use_training_state"
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from .constant import ConstantVariable
|
|||
|
||||
|
||||
try:
|
||||
from torch.distributed.fsdp._fully_shard import _fsdp_param_group
|
||||
from torch.distributed._composable.fsdp import _fsdp_param_group
|
||||
except ModuleNotFoundError:
|
||||
_fsdp_param_group = None
|
||||
|
||||
|
|
@ -305,7 +305,7 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
|||
and not tx.output.current_tracer.allow_side_effects_under_checkpoint
|
||||
):
|
||||
try:
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState
|
||||
from torch.distributed._composable.fsdp._fsdp_state import FSDPState
|
||||
except Exception:
|
||||
FSDPState = None
|
||||
if FSDPState is not None and self.fn in [
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ except ModuleNotFoundError:
|
|||
np = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
from torch.distributed.fsdp._fully_shard import _fsdp_param_group
|
||||
from torch.distributed._composable.fsdp import _fsdp_param_group
|
||||
except ModuleNotFoundError:
|
||||
_fsdp_param_group = None # type: ignore[assignment]
|
||||
|
||||
|
|
|
|||
|
|
@ -536,7 +536,7 @@ Graph: {graph}
|
|||
|
||||
def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None:
|
||||
try:
|
||||
import torch.distributed.fsdp._fully_shard._fsdp_collectives
|
||||
import torch.distributed._composable.fsdp._fsdp_collectives
|
||||
|
||||
assert torch.distributed.is_available()
|
||||
# Assert existence of these ops
|
||||
|
|
|
|||
|
|
@ -1,8 +1,2 @@
|
|||
from torch.distributed.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
FSDPModule,
|
||||
fully_shard,
|
||||
MixedPrecisionPolicy,
|
||||
OffloadPolicy,
|
||||
register_fsdp_forward_method,
|
||||
)
|
||||
from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
|
||||
from .fully_shard import FSDPModule, fully_shard, register_fsdp_forward_method
|
||||
|
|
|
|||
|
|
@ -57,10 +57,7 @@ class MixedPrecisionPolicy:
|
|||
|
||||
@dataclass
|
||||
class OffloadPolicy:
|
||||
"""
|
||||
This base class represents the policy of no offloading and is only used as
|
||||
the default value for the ``offload_policy`` arg.
|
||||
"""
|
||||
"""This base class represents the policy of no offloading."""
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -74,10 +71,10 @@ class CPUOffloadPolicy(OffloadPolicy):
|
|||
|
||||
Attributes:
|
||||
pin_memory (bool): Whether to pin sharded parameter and gradient
|
||||
memory. Pinning memory allows both more efficient H2D/D2H copies
|
||||
and for the copies to overlap with compute. However, the pinned
|
||||
memory cannot be used by other processes. Set this to ``False`` if
|
||||
you have insufficient CPU memory. (Default: ``True``)
|
||||
memory. Pinning memory allows H2D/D2H copying without blocking the
|
||||
CPU and in turn, overlap with compute, but pinned memory cannot be
|
||||
used by other processes. Set this to ``False`` if you have
|
||||
insufficient CPU memory. (Default: ``True``)
|
||||
"""
|
||||
|
||||
pin_memory: bool = True
|
||||
|
|
@ -29,7 +29,7 @@ from ._fsdp_common import (
|
|||
from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState
|
||||
|
||||
|
||||
logger = logging.getLogger("torch.distributed.fsdp.fully_shard")
|
||||
logger = logging.getLogger("torch.distributed._composable.fsdp")
|
||||
|
||||
_ModuleToHandleDict = Dict[nn.Module, RemovableHandle] # for state dict
|
||||
|
||||
|
|
@ -42,7 +42,7 @@ if TYPE_CHECKING:
|
|||
from ._fsdp_param import FSDPParam
|
||||
|
||||
|
||||
logger = logging.getLogger("torch.distributed.fsdp.fully_shard")
|
||||
logger = logging.getLogger("torch.distributed._composable.fsdp")
|
||||
|
||||
|
||||
class FSDPStateContext:
|
||||
|
|
@ -34,14 +34,6 @@ from ._fsdp_param_group import FSDPParamGroup
|
|||
from ._fsdp_state import _get_module_fsdp_state, FSDPState
|
||||
|
||||
|
||||
__all__ = [
|
||||
"fully_shard",
|
||||
"FSDPModule",
|
||||
"UnshardHandle",
|
||||
"register_fsdp_forward_method",
|
||||
]
|
||||
|
||||
|
||||
cls_to_fsdp_cls: Dict[Type, Type] = {}
|
||||
|
||||
|
||||
|
|
@ -58,80 +50,68 @@ def fully_shard(
|
|||
offload_policy: OffloadPolicy = OffloadPolicy(),
|
||||
):
|
||||
"""
|
||||
Apply fully sharded data parallelism (FSDP) to ``module``, where FSDP
|
||||
shards module parameters, gradients, and optimizer states across data
|
||||
parallel workers to save memory at the cost of communication.
|
||||
Shard module parameters across data parallel workers.
|
||||
|
||||
At initialization, FSDP shards the module's parameters across the data
|
||||
parallel workers given by ``mesh``. Before forward, FSDP all-gathers the
|
||||
sharded parameters across the data-parallel workers to get the unsharded
|
||||
parameters for forward computation. If ``reshard_after_forward`` is
|
||||
``True``, then FSDP frees the unsharded parameters after forward and
|
||||
re-all-gathers them in backward before gradient computation. After gradient
|
||||
computation, FSDP frees the unsharded parameters and reduce-scatters the
|
||||
unsharded gradients across data-parallel workers.
|
||||
This function applies fully sharded data parallelism (FSDP) or a variant to
|
||||
``module``, a technique for memory savings at the cost of communication.
|
||||
Parameters are sharded across ``mesh``, and in turn, so are their gradients
|
||||
and optimizer states.
|
||||
|
||||
This implementation represents the sharded parameters as :class:`DTensor` s
|
||||
sharded on dim-0, while the unsharded parameters will be like the original
|
||||
parameters on ``module`` (e.g. :class:`torch.Tensor` if originally
|
||||
:class:`torch.Tensor`). A module
|
||||
`forward pre-hook <https://pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.register_forward_pre_hook>`_
|
||||
on ``module`` all-gathers the parameters, and a module
|
||||
`forward hook <https://pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook>`_
|
||||
on ``module`` frees them (if needed). Similar backward hooks all-gather
|
||||
parameters and later free parameters and reduce-scatter gradients.
|
||||
The sharded parameters are all-gathered to construct the unsharded
|
||||
parameters for forward or backward computation. The unsharded parameters
|
||||
are freed after computation to save memory. The gradients are reduced
|
||||
across the mesh and divided by the mesh size for data parallelism. The
|
||||
optimizer step runs on the sharded parameters.
|
||||
|
||||
Since grouping multiple tensors together for one collective is critical for
|
||||
communication efficiency, this implementation makes this grouping first
|
||||
class. Calling :meth:`fully_shard` on ``module`` constructs one group that
|
||||
Each call to ``fully_shard`` constructs one communication group that
|
||||
includes the parameters in ``module.parameters()`` except those already
|
||||
assigned to a group from an earlier call on a submodule. This means that
|
||||
:meth:`fully_shard` should be called bottom-up on your model. Each group's
|
||||
parameters are all-gathered in one collective, and its gradients are
|
||||
reduce-scattered in one collective. Partitioning the model into multiple
|
||||
groups ("layer by layer") allows for peak memory savings and communication/computation
|
||||
overlap. Users generally should *not* call :meth:`fully_shard` only on the
|
||||
topmost root module.
|
||||
assigned to a group from a nested call. Each group's parameters and its
|
||||
gradients are communicated together in one collective, respectively.
|
||||
Constructing multiple groups across the model (e.g. "layer by layer")
|
||||
allows for peak memory savings and communication/computation overlap.
|
||||
|
||||
Implementation-wise, the sharded parameters are represented as
|
||||
:class:`DTensor` s, sharded on dim-0, and the unsharded parameters are
|
||||
represented as :class:`Tensor` s. A module forward pre-hook all-gathers the
|
||||
parameters, and a module forward hook frees them. Similar backward hooks
|
||||
gather parameters and later free parameters/reduce gradients.
|
||||
|
||||
Args:
|
||||
module (Union[nn.Module, List[nn.Module]): The module or modules to
|
||||
shard with FSDP and group together for communication.
|
||||
mesh (Optional[DeviceMesh]): This data parallel mesh defines the
|
||||
sharding and device. If 1D, then parameters are fully sharded
|
||||
across the 1D mesh (FSDP) with ``(Shard(0),)`` placement. If 2D,
|
||||
then parameters are sharded across the 1st dim and replicated
|
||||
across the 0th dim (HSDP) with ``(Replicate(), Shard(0))``
|
||||
placement. The mesh's device type gives the device type used for
|
||||
communication; if a CUDA or CUDA-like device type, then we use the
|
||||
current device.
|
||||
across the 1D mesh (FSDP). If 2D, then parameters are sharded
|
||||
across the 0th dim and replicated across the 1st dim (HSDP). The
|
||||
mesh's device type gives the device type used for communication;
|
||||
if a CUDA or CUDA-like device type, then we use the current device.
|
||||
reshard_after_forward (Union[bool, int]): This controls the parameter
|
||||
behavior after forward and can trade off memory and communication:
|
||||
|
||||
- If ``True``, then this reshards parameters after forward and
|
||||
re-all-gathers in backward.
|
||||
all-gathers in backward.
|
||||
- If ``False``, then this keeps the unsharded parameters in memory
|
||||
after forward and avoids the all-gather in backward.
|
||||
after forward and avoids the all-gather in backward.
|
||||
- If an ``int``, then this represents the world size to reshard to
|
||||
after forward. It should be a non-trivial divisor of the ``mesh``
|
||||
shard dim size (i.e. excluding 1 and the dim size itself). A
|
||||
choice may be the intra-node size (e.g. ``torch.cuda.device_count()``).
|
||||
This allows the all-gather in backward to be over a smaller world
|
||||
size at the cost of higher memory usage than setting to ``True``.
|
||||
after forward. It should be a non-trivial divisor of the ``mesh``
|
||||
shard dim size (i.e. excluding 1 and the dim size itself). A choice
|
||||
may be the intra-node size (e.g. ``torch.cuda.device_count()``).
|
||||
This allows the all-gather in backward to be over a smaller world
|
||||
size at the cost of higher memory usage than setting to ``True``.
|
||||
- The root FSDP state has its value specially set to ``False`` as a
|
||||
heuristic since its parameters would typically be immediately
|
||||
all-gathered for backward.
|
||||
heuristic since its parameters would typically be immediately
|
||||
all-gathered for backward.
|
||||
- After forward, the parameters registered to the module depend on
|
||||
to this: The registered parameters are the sharded parameters if
|
||||
``True``; unsharded parameters if ``False``; and the paramters
|
||||
resharded to the smaller mesh otherwise. To modify the parameters
|
||||
between forward and backward, the registered parameters must be
|
||||
the sharded parameters. For ``False`` or an ``int``, this can be
|
||||
done by manually resharding via :meth:`reshard`.
|
||||
to this: The registered parameters are the sharded parameters if
|
||||
``True``; unsharded parameters if ``False``; and the paramters
|
||||
resharded to the smaller mesh otherwise. To modify the parameters
|
||||
between forward and backward, the registered parameters must be the
|
||||
sharded parameters. For ``False`` or an ``int``, this can be done
|
||||
by manually resharding via :meth:`reshard`.
|
||||
shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]):
|
||||
This callable can be used to override the sharding placement for a
|
||||
parameter to shard a parameter on a dimension other than dim-0. If
|
||||
this callable returns a :class:`Shard` placement (not ``None``),
|
||||
then FSDP will shard according to that placement (e.g. ``Shard(1)``).
|
||||
this callable returns a ``Shard`` placement (not ``None``), then
|
||||
FSDP will shard according to that placement (e.g. ``Shard(1)``).
|
||||
If sharding on a nonzero dim, we currently require even sharding,
|
||||
i.e. the tensor dim size on that dim must be divisible by the FSDP
|
||||
shard mesh size.
|
||||
|
|
@ -194,14 +174,14 @@ def fully_shard(
|
|||
cls = module.__class__
|
||||
new_cls = cls_to_fsdp_cls.get(cls, None)
|
||||
if not new_cls:
|
||||
dct = {"__deepcopy__": _unimplemented_deepcopy}
|
||||
dct = {"__deepcopy__": unimplemented_deepcopy}
|
||||
new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)
|
||||
cls_to_fsdp_cls[cls] = new_cls
|
||||
module.__class__ = new_cls
|
||||
return arg_module
|
||||
|
||||
|
||||
def _unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
|
||||
def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
|
||||
raise AssertionError(
|
||||
"FSDP does not support deepcopy. Please use state dict for serialization."
|
||||
)
|
||||
|
|
@ -222,9 +202,9 @@ class FSDPModule:
|
|||
|
||||
def reshard(self) -> None:
|
||||
"""
|
||||
Reshards the module's parameters, freeing the unsharded parameters if
|
||||
they are allocated and registering the sharded parameters to the
|
||||
module. This method is *not* recursive.
|
||||
Reshards the module's parameters, registering the sharded parameters
|
||||
to the module and freeing the unsharded parameters if needed. This
|
||||
method is *not* recursive.
|
||||
"""
|
||||
state = self._get_fsdp_state()
|
||||
if fsdp_param_group := state._fsdp_param_group:
|
||||
|
|
@ -233,9 +213,7 @@ class FSDPModule:
|
|||
def unshard(self, async_op: bool = False) -> Optional["UnshardHandle"]:
|
||||
"""
|
||||
Unshards the module's parameters by allocating memory and all-gathering
|
||||
the parameters. This method is *not* recursive. The unshard follows the
|
||||
:class:`MixedPrecisionPolicy`, so it will all-gather following
|
||||
``param_dtype`` if set.
|
||||
the parameters. This method is *not* recursive.
|
||||
|
||||
Args:
|
||||
async_op (bool): If ``True``, then returns a :class:`UnshardHandle`
|
||||
|
|
@ -243,17 +221,19 @@ class FSDPModule:
|
|||
``False``, then returns ``None`` and waits on the handle inside
|
||||
this function.
|
||||
|
||||
.. note:: If ``async_op=True``, then FSDP will wait on the pending
|
||||
unshard in the module's pre-forward for the user. The user only
|
||||
needs to call :meth:`wait` explicitly if the wait should happen
|
||||
before pre-forward.
|
||||
.. warning:: This method is experimental and subject to change.
|
||||
|
||||
.. note:: If ``async_op=True``, then the user does not have to call
|
||||
:meth:`wait` on the returned handle if waiting on the unshard op
|
||||
in the module's pre-forward is tolerable. FSDP will wait on the
|
||||
pending unshard op in the pre-forward automatically.
|
||||
"""
|
||||
state = self._get_fsdp_state()
|
||||
fsdp_param_group = state._fsdp_param_group
|
||||
if fsdp_param_group is not None:
|
||||
fsdp_param_group.lazy_init()
|
||||
fsdp_param_group.unshard(async_op=async_op)
|
||||
handle = _UnshardHandleImpl(fsdp_param_group)
|
||||
handle = UnshardHandle(fsdp_param_group)
|
||||
if async_op:
|
||||
return handle
|
||||
handle.wait()
|
||||
|
|
@ -261,10 +241,9 @@ class FSDPModule:
|
|||
|
||||
def set_is_last_backward(self, is_last_backward: bool) -> None:
|
||||
"""
|
||||
Sets whether the next backward is the last one. On the last backward,
|
||||
FSDP waits on pending gradient reduction and clears internal data
|
||||
data structures for backward prefetching. This can be useful for
|
||||
microbatching.
|
||||
Sets whether the next backward is the last one, meaning that FSDP
|
||||
should wait for gradient reduction to finish and clear internal data
|
||||
structures used for explicit prefetching.
|
||||
"""
|
||||
state = self._get_fsdp_state()
|
||||
state._state_ctx.is_last_backward = is_last_backward
|
||||
|
|
@ -274,13 +253,13 @@ class FSDPModule:
|
|||
) -> None:
|
||||
"""
|
||||
Sets if the module should sync gradients. This can be used to implement
|
||||
gradient accumulation *without communication*. For HSDP, this controls
|
||||
gradient accumulation without communication. For HSDP, this controls
|
||||
both reduce-scatter and all-reduce together.
|
||||
|
||||
Args:
|
||||
requires_gradient_sync (bool): Whether to reduce gradients for the
|
||||
module's parameters.
|
||||
recurse (bool): Whether to set for all FSDP submodules or just the
|
||||
recurse (bool): Whether to set for all submodules or just the
|
||||
passed-in module.
|
||||
"""
|
||||
self_module = cast(nn.Module, self)
|
||||
|
|
@ -314,13 +293,12 @@ class FSDPModule:
|
|||
"""
|
||||
Sets if the module should reshard parameters after backward. This can
|
||||
be used during gradient accumulation to trade off higher memory for
|
||||
reduced communication since the unsharded parameters do not need to be
|
||||
re-all-gathered before the next forward.
|
||||
reduced communication.
|
||||
|
||||
Args:
|
||||
reshard_after_backward (bool): Whether to reshard parameters after
|
||||
backward.
|
||||
recurse (bool): Whether to set for all FSDP submodules or just the
|
||||
recurse (bool): Whether to set for all submodules or just the
|
||||
passed-in module.
|
||||
"""
|
||||
self_module = cast(nn.Module, self)
|
||||
|
|
@ -457,22 +435,24 @@ class FSDPModule:
|
|||
|
||||
class UnshardHandle:
|
||||
"""
|
||||
A handle to wait on a :meth:`FSDPModule.unshard` op.
|
||||
A handle to wait on the unshard op.
|
||||
|
||||
Args:
|
||||
fsdp_param_group (FSDPParamGroup, optional): FSDP parameter group to
|
||||
unshard. This should be ``None`` iff the FSDP module does not
|
||||
manage any parameters, meaning the unshard is a no-op.
|
||||
"""
|
||||
|
||||
def wait(self) -> None:
|
||||
"""
|
||||
Waits on the unshard op. This ensures that the current stream can use
|
||||
the unsharded parameters, which are now registered to the module.
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
class _UnshardHandleImpl(UnshardHandle):
|
||||
def __init__(self, fsdp_param_group: Optional[FSDPParamGroup]):
|
||||
self._fsdp_param_group = fsdp_param_group
|
||||
|
||||
def wait(self):
|
||||
"""
|
||||
Waits on the unshard op.
|
||||
|
||||
This ensures that the current stream can use the unsharded parameters,
|
||||
which are now registered to the module.
|
||||
"""
|
||||
if self._fsdp_param_group is not None:
|
||||
self._fsdp_param_group.wait_for_unshard()
|
||||
# Avoid keeping a reference
|
||||
|
|
@ -481,15 +461,13 @@ class _UnshardHandleImpl(UnshardHandle):
|
|||
|
||||
def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None:
|
||||
"""
|
||||
Registers a method on ``module`` to be considered a forward method for
|
||||
FSDP.
|
||||
Registers a method on ``module`` to be a forward method for FSDP.
|
||||
|
||||
FSDP all-gathers parameters pre-forward and optionally frees parameters
|
||||
post-forward (depending on ``reshard_after_forward``). FSDP only knows to
|
||||
do this for :meth:`nn.Module.forward` by default. This function patches a
|
||||
user-specified method to run the pre/post-forward hooks before/after the
|
||||
method, respectively. If ``module`` is not an :class:`FSDPModule`, then
|
||||
this is a no-op.
|
||||
FSDP only knows to run its pre-forward and post-forward hooks on the
|
||||
default :meth:`nn.Module.forward` method. This function patches a user
|
||||
specified method to run the pre/post-forward hooks before/after the method,
|
||||
respectively. If ``module`` is not an :class:`FSDPModule`, then this is a
|
||||
no-op.
|
||||
|
||||
Args:
|
||||
module (nn.Module): Module to register the forward method on.
|
||||
|
|
@ -7,6 +7,8 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch import nn, optim
|
||||
from torch._guards import active_fake_mode
|
||||
from torch.distributed._composable.fsdp import FSDPModule
|
||||
from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup
|
||||
from torch.distributed._tools.mem_tracker import _RefType, _State, MemTracker
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_IllegalWork,
|
||||
|
|
@ -14,8 +16,6 @@ from torch.distributed.distributed_c10d import (
|
|||
ReduceOp,
|
||||
Work,
|
||||
)
|
||||
from torch.distributed.fsdp import FSDPModule
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
|
||||
from torch.futures import Future
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
|
|
|||
|
|
@ -1,13 +1,4 @@
|
|||
from ._flat_param import FlatParameter as FlatParameter
|
||||
from ._fully_shard import (
|
||||
CPUOffloadPolicy,
|
||||
FSDPModule,
|
||||
fully_shard,
|
||||
MixedPrecisionPolicy,
|
||||
OffloadPolicy,
|
||||
register_fsdp_forward_method,
|
||||
UnshardHandle,
|
||||
)
|
||||
from .fully_sharded_data_parallel import (
|
||||
BackwardPrefetch,
|
||||
CPUOffload,
|
||||
|
|
@ -29,7 +20,6 @@ from .fully_sharded_data_parallel import (
|
|||
|
||||
|
||||
__all__ = [
|
||||
# FSDP1
|
||||
"BackwardPrefetch",
|
||||
"CPUOffload",
|
||||
"FullOptimStateDictConfig",
|
||||
|
|
@ -46,21 +36,4 @@ __all__ = [
|
|||
"StateDictConfig",
|
||||
"StateDictSettings",
|
||||
"StateDictType",
|
||||
# FSDP2
|
||||
"CPUOffloadPolicy",
|
||||
"FSDPModule",
|
||||
"fully_shard",
|
||||
"MixedPrecisionPolicy",
|
||||
"OffloadPolicy",
|
||||
"register_fsdp_forward_method",
|
||||
"UnshardHandle",
|
||||
]
|
||||
|
||||
# Set namespace for exposed private names
|
||||
CPUOffloadPolicy.__module__ = "torch.distributed.fsdp"
|
||||
FSDPModule.__module__ = "torch.distributed.fsdp"
|
||||
fully_shard.__module__ = "torch.distributed.fsdp"
|
||||
MixedPrecisionPolicy.__module__ = "torch.distributed.fsdp"
|
||||
OffloadPolicy.__module__ = "torch.distributed.fsdp"
|
||||
register_fsdp_forward_method.__module__ = "torch.distributed.fsdp"
|
||||
UnshardHandle.__module__ = "torch.distributed.fsdp"
|
||||
|
|
|
|||
|
|
@ -1,18 +0,0 @@
|
|||
from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
|
||||
from ._fully_shard import (
|
||||
FSDPModule,
|
||||
fully_shard,
|
||||
register_fsdp_forward_method,
|
||||
UnshardHandle,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CPUOffloadPolicy",
|
||||
"FSDPModule",
|
||||
"fully_shard",
|
||||
"MixedPrecisionPolicy",
|
||||
"OffloadPolicy",
|
||||
"register_fsdp_forward_method",
|
||||
"UnshardHandle",
|
||||
]
|
||||
|
|
@ -24,7 +24,7 @@ from typing import (
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.fsdp import FSDPModule, UnshardHandle
|
||||
from torch.distributed._composable.fsdp.fully_shard import FSDPModule, UnshardHandle
|
||||
from torch.profiler import record_function
|
||||
|
||||
from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import torch.distributed as dist
|
|||
import torch.fx as fx
|
||||
import torch.nn as nn
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.distributed.fsdp import FSDPModule, fully_shard
|
||||
from torch.distributed._composable.fsdp.fully_shard import FSDPModule, fully_shard
|
||||
from torch.fx.node import map_aggregate
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
|
|
|||
|
|
@ -31,17 +31,14 @@ import torch.distributed as dist
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._composable import checkpoint
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.fsdp import (
|
||||
CPUOffload,
|
||||
fully_shard,
|
||||
FullyShardedDataParallel as FSDP,
|
||||
)
|
||||
from torch.distributed.fsdp._common_utils import TrainingState
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param_group import (
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed._composable.fsdp._fsdp_param_group import (
|
||||
FSDPParamGroup,
|
||||
RegisterPostBackwardFunction,
|
||||
)
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp._common_utils import TrainingState
|
||||
from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
BackwardPrefetch,
|
||||
|
|
@ -1487,7 +1484,7 @@ class FSDPTest(MultiProcessTestCase):
|
|||
|
||||
def test_compiled_fsdp(compile_compute_on_module: Optional[type] = None):
|
||||
def fully_shard_with_compiled_compute(*args, **kwargs):
|
||||
torch.distributed.fsdp.fully_shard(*args, **kwargs) # type: ignore[operator]
|
||||
torch.distributed._composable.fsdp.fully_shard(*args, **kwargs) # type: ignore[operator]
|
||||
if compile_compute_on_module is None or isinstance(
|
||||
args[0], compile_compute_on_module
|
||||
):
|
||||
|
|
@ -1500,7 +1497,7 @@ def test_compiled_fsdp(compile_compute_on_module: Optional[type] = None):
|
|||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
original_fully_shard = torch.distributed.fsdp.fully_shard
|
||||
original_fully_shard = torch.distributed._composable.fsdp.fully_shard
|
||||
for mode in FullyShardMode:
|
||||
if mode != FullyShardMode.EAGER and not has_triton():
|
||||
warnings.warn("Inductor on GPU needs Triton and recent GPU arch")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user