mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
make fsdp folder to be public (#72084)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72084
make fsdp folder to be public
ghstack-source-id: 148173447
Test Plan: unit tests
Reviewed By: mrshenli
Differential Revision: D33903417
fbshipit-source-id: 7852a2adc4af09af48a5ffa52ebf210489f834d5
(cherry picked from commit bd06513cfe)
This commit is contained in:
parent
ed435e903f
commit
2336571cb7
7
docs/source/fsdp.rst
Normal file
7
docs/source/fsdp.rst
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
FullyShardedDataParallel
|
||||
========================
|
||||
|
||||
.. automodule:: torch.distributed.fsdp
|
||||
|
||||
.. autoclass:: torch.distributed.fsdp.FullyShardedDataParallel
|
||||
:members:
|
||||
|
|
@ -61,6 +61,7 @@ Features described in this documentation are classified by release status:
|
|||
torch.distributed <distributed>
|
||||
torch.distributed.algorithms.join <distributed.algorithms.join>
|
||||
torch.distributed.elastic <distributed.elastic>
|
||||
torch.distributed.fsdp <fsdp>
|
||||
torch.distributed.optim <distributed.optim>
|
||||
torch.distributions <distributions>
|
||||
torch.fft <fft>
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import unittest
|
|||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch.distributed._fsdp.flatten_params_wrapper import FlattenParamsWrapper
|
||||
from torch.distributed.fsdp.flatten_params_wrapper import FlattenParamsWrapper
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from functools import partial
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed._fsdp.fully_sharded_data_parallel import (
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
FullyShardedDataParallel as FSDP,
|
||||
CPUOffload,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -26,8 +26,8 @@ from torch.testing._internal.common_utils import (
|
|||
run_tests,
|
||||
)
|
||||
|
||||
from torch.distributed._fsdp import CPUOffload
|
||||
from torch.distributed._fsdp.fully_sharded_data_parallel import BackwardPrefetch_
|
||||
from torch.distributed.fsdp import CPUOffload
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch
|
||||
|
||||
|
||||
if not dist.is_available():
|
||||
|
|
@ -69,7 +69,7 @@ class TestParityWithDDP(FSDPTest):
|
|||
)
|
||||
@parametrize(
|
||||
"backward_prefetch",
|
||||
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
|
||||
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
|
||||
)
|
||||
def test_nested_wrapped_model(self, cpu_offload, backward_prefetch):
|
||||
init_modes = self._get_init_modes_for_test(cpu_offload)
|
||||
|
|
@ -89,7 +89,7 @@ class TestParityWithDDP(FSDPTest):
|
|||
)
|
||||
@parametrize(
|
||||
"backward_prefetch",
|
||||
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
|
||||
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
|
||||
)
|
||||
def test_nested_all_wrapped_model(self, cpu_offload, backward_prefetch):
|
||||
init_modes = self._get_init_modes_for_test(cpu_offload)
|
||||
|
|
@ -110,7 +110,7 @@ class TestParityWithDDP(FSDPTest):
|
|||
)
|
||||
@parametrize(
|
||||
"backward_prefetch",
|
||||
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
|
||||
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
|
||||
)
|
||||
def test_transformer_parameterized(self, cpu_offload, backward_prefetch):
|
||||
init_modes = self._get_init_modes_for_test(cpu_offload)
|
||||
|
|
@ -130,7 +130,7 @@ class TestParityWithDDP(FSDPTest):
|
|||
)
|
||||
@parametrize(
|
||||
"backward_prefetch",
|
||||
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
|
||||
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
|
||||
)
|
||||
def test_delayed_optim_step(self, cpu_offload, backward_prefetch):
|
||||
# We use a model with a long CUDA delay right before the optimizer step.
|
||||
|
|
@ -156,7 +156,7 @@ class TestParityWithDDP(FSDPTest):
|
|||
)
|
||||
@parametrize(
|
||||
"backward_prefetch",
|
||||
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
|
||||
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
|
||||
)
|
||||
def test_delayed_reduce_scatter(self, cpu_offload, backward_prefetch):
|
||||
# We insert a delay in the torch.distributed._reduce_scatter_base op, so that
|
||||
|
|
@ -186,7 +186,7 @@ class TestParityWithDDP(FSDPTest):
|
|||
)
|
||||
@parametrize(
|
||||
"backward_prefetch",
|
||||
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
|
||||
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
|
||||
)
|
||||
def test_mixture_of_experts(self, cpu_offload, backward_prefetch):
|
||||
init_modes = self._get_init_modes_for_test(cpu_offload)
|
||||
|
|
@ -209,7 +209,7 @@ class TestParityWithDDP(FSDPTest):
|
|||
)
|
||||
@parametrize(
|
||||
"backward_prefetch",
|
||||
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
|
||||
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
|
||||
)
|
||||
def test_mixture_of_experts_with_delay_before_free(self, cpu_offload, backward_prefetch):
|
||||
init_modes = self._get_init_modes_for_test(cpu_offload)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch import distributed as dist
|
||||
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import sys
|
|||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.nn import Linear, Module
|
||||
from torch.optim import SGD
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch import distributed as dist
|
||||
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
FSDPTest,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import sys
|
|||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.nn import Linear, Module
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.optim import SGD
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import sys
|
|||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.nn import Linear, Module, Sequential
|
||||
from torch.optim import SGD
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch import distributed as dist
|
||||
from torch.cuda import Event
|
||||
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
FSDPTest,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import sys
|
|||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP, CPUOffload
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, CPUOffload
|
||||
from torch.nn import Linear, Module
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.optim import SGD
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import sys
|
|||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.nn import Linear
|
||||
from torch.optim import SGD
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import unittest
|
|||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch.distributed._fsdp.utils import (
|
||||
from torch.distributed.fsdp.utils import (
|
||||
_apply_to_tensors,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
|
|
|
|||
|
|
@ -9,12 +9,12 @@ import unittest
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._fsdp.fully_sharded_data_parallel import (
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
FullyShardedDataParallel as FSDP,
|
||||
CPUOffload,
|
||||
BackwardPrefetch_,
|
||||
BackwardPrefetch,
|
||||
)
|
||||
from torch.distributed._fsdp.wrap import (
|
||||
from torch.distributed.fsdp.wrap import (
|
||||
default_auto_wrap_policy,
|
||||
enable_wrap,
|
||||
wrap,
|
||||
|
|
@ -132,7 +132,7 @@ class TestFSDPWrap(FSDPTest):
|
|||
)
|
||||
@parametrize(
|
||||
"backward_prefetch",
|
||||
[BackwardPrefetch_.BACKWARD_POST, BackwardPrefetch_.BACKWARD_PRE]
|
||||
[BackwardPrefetch.BACKWARD_POST, BackwardPrefetch.BACKWARD_PRE]
|
||||
)
|
||||
@parametrize(
|
||||
"fsdp_init_mode",
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from abc import ABC
|
|||
import inspect
|
||||
from typing import Dict, Type
|
||||
|
||||
from torch.distributed._fsdp import FullyShardedDataParallel
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.optim import Optimizer
|
||||
from torch.distributed.optim import as_functional_optim
|
||||
|
|
|
|||
|
|
@ -36,11 +36,19 @@ if TYPE_CHECKING:
|
|||
|
||||
@dataclass
|
||||
class CPUOffload:
|
||||
"""
|
||||
CPU offlaoding config. Currently, only parameter and gradient CPU
|
||||
offload are supported.
|
||||
offload_params: Offloading parameters to CPUs when these parameters are
|
||||
not used for computation on GPUs. This implicitly enables
|
||||
gradient offloading to CPUs in order for parameters and
|
||||
gradients to be on the same device to work with optimizer.
|
||||
"""
|
||||
offload_params: bool = False
|
||||
# TODO: state dict offloading, activation offloading
|
||||
# TODO: state dict offloading
|
||||
# https://github.com/pytorch/pytorch/issues/67224
|
||||
|
||||
class BackwardPrefetch_(Enum):
|
||||
class BackwardPrefetch(Enum):
|
||||
"""
|
||||
Specify where to prefetch next layer's full parameters
|
||||
during backward pass.
|
||||
|
|
@ -88,28 +96,34 @@ class FullyShardedDataParallel(nn.Module):
|
|||
"""
|
||||
A wrapper for sharding Module parameters across data parallel workers. This
|
||||
is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
|
||||
``FullyShardedDataParallel`` is commonly shorten to FSDP.
|
||||
FullyShardedDataParallel is commonly shorten to FSDP.
|
||||
|
||||
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336
|
||||
.. _DeepSpeed: https://www.deepspeed.ai/
|
||||
|
||||
Example::
|
||||
import torch
|
||||
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
|
||||
torch.cuda.set_device(device_id)
|
||||
sharded_module = FSDP(my_module)
|
||||
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
|
||||
x = sharded_module(x, y=3, z=torch.Tensor([1]))
|
||||
loss = x.sum()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
>>> import torch
|
||||
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
>>> torch.cuda.set_device(device_id)
|
||||
>>> sharded_module = FSDP(my_module)
|
||||
>>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
|
||||
>>> x = sharded_module(x, y=3, z=torch.Tensor([1]))
|
||||
>>> loss = x.sum()
|
||||
>>> loss.backward()
|
||||
>>> optim.step()
|
||||
|
||||
.. warning::
|
||||
The optimizer must be initialized *after* the module has been wrapped,
|
||||
since FSDP will shard parameters in-place and this will break any
|
||||
previously initialized optimizers.
|
||||
|
||||
.. warning:
|
||||
Module should be already placed on the destination device or
|
||||
device is set properly using torch.cuda.set_device(device_id).
|
||||
FSDP will get compute device from module first, if module device
|
||||
is CPU, FSDP will then get compute device from current device.
|
||||
|
||||
Args:
|
||||
module (nn.Module):
|
||||
module to be wrapped with FSDP.
|
||||
|
|
@ -128,11 +142,30 @@ class FullyShardedDataParallel(nn.Module):
|
|||
Note that this policy currently will only apply to child modules of
|
||||
the passed in module. The remainder modules are always wrapped in
|
||||
the returned FSDP root instance.
|
||||
backward_prefetch: (Optional[BackwardPrefetch_]):
|
||||
``default_auto_wrap_policy`` written in ``torch.distributed.fsdp.wrap`` is
|
||||
an example of ``fsdp_auto_wrap_policy`` callable, this policy wraps layers
|
||||
with parameter sizes larger than 100M. Users can supply the customized
|
||||
``fsdp_auto_wrap_policy`` callable that should accept following arguments:
|
||||
``module: nn.Module``, ``recurse: bool``, ``unwrapped_params: int``,
|
||||
extra customized arguments could be added to the customized
|
||||
``fsdp_auto_wrap_policy`` callable as well.
|
||||
|
||||
Example::
|
||||
|
||||
>>> def custom_auto_wrap_policy(
|
||||
>>> module: nn.Module,
|
||||
>>> recurse: bool,
|
||||
>>> unwrapped_params: int,
|
||||
>>> # These are customizable for this policy function.
|
||||
>>> min_num_params: int = int(1e8),
|
||||
>>> ) -> bool:
|
||||
>>> return unwrapped_params >= min_num_params
|
||||
|
||||
backward_prefetch: (Optional[BackwardPrefetch]):
|
||||
This is an experimental feature that is subject to change in the
|
||||
the near future. It allows users to enable two different backward_prefetch
|
||||
algorithms to help backward communication and computation overlapping.
|
||||
Pros and cons of each algorithm is explained in the class ``BackwardPrefetch_``.
|
||||
Pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -141,7 +174,7 @@ class FullyShardedDataParallel(nn.Module):
|
|||
process_group: Optional[ProcessGroup] = None,
|
||||
cpu_offload: Optional[CPUOffload] = None,
|
||||
fsdp_auto_wrap_policy: Optional[Callable] = None,
|
||||
backward_prefetch: Optional[BackwardPrefetch_] = None,
|
||||
backward_prefetch: Optional[BackwardPrefetch] = None,
|
||||
):
|
||||
torch._C._log_api_usage_once("torch.distributed.fsdp")
|
||||
super().__init__()
|
||||
|
|
@ -566,7 +599,7 @@ class FullyShardedDataParallel(nn.Module):
|
|||
|
||||
def _need_prefetch_pre_backward_hook(self) -> bool:
|
||||
if (
|
||||
self.backward_prefetch == BackwardPrefetch_.BACKWARD_PRE
|
||||
self.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
|
||||
and self._fsdp_graph_order is not None
|
||||
and self._my_fsdp_idx_in_graph is not None and self._my_fsdp_idx_in_graph > 0
|
||||
and self._fsdp_graph_order[self._my_fsdp_idx_in_graph - 1].training_state != TrainingState_.BACKWARD_POST
|
||||
|
|
@ -577,7 +610,7 @@ class FullyShardedDataParallel(nn.Module):
|
|||
|
||||
def _need_prefetch_post_backward_hook(self) -> bool:
|
||||
if (
|
||||
self.backward_prefetch == BackwardPrefetch_.BACKWARD_POST
|
||||
self.backward_prefetch == BackwardPrefetch.BACKWARD_POST
|
||||
and self._fsdp_graph_order is not None
|
||||
and self._my_fsdp_idx_in_graph is not None and self._my_fsdp_idx_in_graph > 0
|
||||
and self._fsdp_graph_order[self._my_fsdp_idx_in_graph - 1].training_state != TrainingState_.BACKWARD_POST
|
||||
|
|
@ -9,8 +9,8 @@ from unittest import mock
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed._fsdp import FullyShardedDataParallel, CPUOffload
|
||||
from torch.distributed._fsdp.fully_sharded_data_parallel import (
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel, CPUOffload
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
TrainingState_,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import (
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user