make fsdp folder to be public (#72084)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72084

make fsdp folder to be public
ghstack-source-id: 148173447

Test Plan: unit tests

Reviewed By: mrshenli

Differential Revision: D33903417

fbshipit-source-id: 7852a2adc4af09af48a5ffa52ebf210489f834d5
(cherry picked from commit bd06513cfe)
This commit is contained in:
Yanli Zhao 2022-02-02 07:18:49 -08:00 committed by PyTorch MergeBot
parent ed435e903f
commit 2336571cb7
22 changed files with 85 additions and 44 deletions

7
docs/source/fsdp.rst Normal file
View File

@ -0,0 +1,7 @@
FullyShardedDataParallel
========================
.. automodule:: torch.distributed.fsdp
.. autoclass:: torch.distributed.fsdp.FullyShardedDataParallel
:members:

View File

@ -61,6 +61,7 @@ Features described in this documentation are classified by release status:
torch.distributed <distributed>
torch.distributed.algorithms.join <distributed.algorithms.join>
torch.distributed.elastic <distributed.elastic>
torch.distributed.fsdp <fsdp>
torch.distributed.optim <distributed.optim>
torch.distributions <distributions>
torch.fft <fft>

View File

@ -5,7 +5,7 @@ import unittest
import torch
from torch import distributed as dist
from torch.distributed._fsdp.flatten_params_wrapper import FlattenParamsWrapper
from torch.distributed.fsdp.flatten_params_wrapper import FlattenParamsWrapper
from torch.testing._internal.common_utils import run_tests, TestCase

View File

@ -6,7 +6,7 @@ from functools import partial
import torch
import torch.nn as nn
from torch.distributed._fsdp.fully_sharded_data_parallel import (
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
CPUOffload,
)

View File

@ -26,8 +26,8 @@ from torch.testing._internal.common_utils import (
run_tests,
)
from torch.distributed._fsdp import CPUOffload
from torch.distributed._fsdp.fully_sharded_data_parallel import BackwardPrefetch_
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch
if not dist.is_available():
@ -69,7 +69,7 @@ class TestParityWithDDP(FSDPTest):
)
@parametrize(
"backward_prefetch",
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
)
def test_nested_wrapped_model(self, cpu_offload, backward_prefetch):
init_modes = self._get_init_modes_for_test(cpu_offload)
@ -89,7 +89,7 @@ class TestParityWithDDP(FSDPTest):
)
@parametrize(
"backward_prefetch",
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
)
def test_nested_all_wrapped_model(self, cpu_offload, backward_prefetch):
init_modes = self._get_init_modes_for_test(cpu_offload)
@ -110,7 +110,7 @@ class TestParityWithDDP(FSDPTest):
)
@parametrize(
"backward_prefetch",
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
)
def test_transformer_parameterized(self, cpu_offload, backward_prefetch):
init_modes = self._get_init_modes_for_test(cpu_offload)
@ -130,7 +130,7 @@ class TestParityWithDDP(FSDPTest):
)
@parametrize(
"backward_prefetch",
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
)
def test_delayed_optim_step(self, cpu_offload, backward_prefetch):
# We use a model with a long CUDA delay right before the optimizer step.
@ -156,7 +156,7 @@ class TestParityWithDDP(FSDPTest):
)
@parametrize(
"backward_prefetch",
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
)
def test_delayed_reduce_scatter(self, cpu_offload, backward_prefetch):
# We insert a delay in the torch.distributed._reduce_scatter_base op, so that
@ -186,7 +186,7 @@ class TestParityWithDDP(FSDPTest):
)
@parametrize(
"backward_prefetch",
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
)
def test_mixture_of_experts(self, cpu_offload, backward_prefetch):
init_modes = self._get_init_modes_for_test(cpu_offload)
@ -209,7 +209,7 @@ class TestParityWithDDP(FSDPTest):
)
@parametrize(
"backward_prefetch",
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
)
def test_mixture_of_experts_with_delay_before_free(self, cpu_offload, backward_prefetch):
init_modes = self._get_init_modes_for_test(cpu_offload)

View File

@ -7,7 +7,7 @@ import torch
import torch.nn as nn
import torch.optim as optim
from torch import distributed as dist
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (

View File

@ -4,7 +4,7 @@ import sys
import torch
from torch import distributed as dist
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn import Linear, Module
from torch.optim import SGD
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu

View File

@ -6,7 +6,7 @@ import torch
import torch.nn as nn
import torch.optim as optim
from torch import distributed as dist
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
FSDPTest,

View File

@ -4,7 +4,7 @@ import sys
import torch
from torch import distributed as dist
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn import Linear, Module
from torch.nn.parallel import DistributedDataParallel
from torch.optim import SGD

View File

@ -4,7 +4,7 @@ import sys
import torch
from torch import distributed as dist
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn import Linear, Module, Sequential
from torch.optim import SGD
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu

View File

@ -9,7 +9,7 @@ import torch
import torch.nn as nn
from torch import distributed as dist
from torch.cuda import Event
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
FSDPTest,

View File

@ -4,7 +4,7 @@ import sys
import torch
from torch import distributed as dist
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP, CPUOffload
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, CPUOffload
from torch.nn import Linear, Module
from torch.nn.parallel import DistributedDataParallel
from torch.optim import SGD

View File

@ -4,7 +4,7 @@ import sys
import torch
from torch import distributed as dist
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn import Linear
from torch.optim import SGD
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu

View File

@ -6,7 +6,7 @@ import unittest
import torch
from torch import distributed as dist
from torch.distributed._fsdp.utils import (
from torch.distributed.fsdp.utils import (
_apply_to_tensors,
)
from torch.testing._internal.common_utils import (

View File

@ -9,12 +9,12 @@ import unittest
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._fsdp.fully_sharded_data_parallel import (
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
CPUOffload,
BackwardPrefetch_,
BackwardPrefetch,
)
from torch.distributed._fsdp.wrap import (
from torch.distributed.fsdp.wrap import (
default_auto_wrap_policy,
enable_wrap,
wrap,
@ -132,7 +132,7 @@ class TestFSDPWrap(FSDPTest):
)
@parametrize(
"backward_prefetch",
[BackwardPrefetch_.BACKWARD_POST, BackwardPrefetch_.BACKWARD_PRE]
[BackwardPrefetch.BACKWARD_POST, BackwardPrefetch.BACKWARD_PRE]
)
@parametrize(
"fsdp_init_mode",

View File

@ -2,7 +2,7 @@ from abc import ABC
import inspect
from typing import Dict, Type
from torch.distributed._fsdp import FullyShardedDataParallel
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer
from torch.distributed.optim import as_functional_optim

View File

@ -36,11 +36,19 @@ if TYPE_CHECKING:
@dataclass
class CPUOffload:
"""
CPU offlaoding config. Currently, only parameter and gradient CPU
offload are supported.
offload_params: Offloading parameters to CPUs when these parameters are
not used for computation on GPUs. This implicitly enables
gradient offloading to CPUs in order for parameters and
gradients to be on the same device to work with optimizer.
"""
offload_params: bool = False
# TODO: state dict offloading, activation offloading
# TODO: state dict offloading
# https://github.com/pytorch/pytorch/issues/67224
class BackwardPrefetch_(Enum):
class BackwardPrefetch(Enum):
"""
Specify where to prefetch next layer's full parameters
during backward pass.
@ -88,28 +96,34 @@ class FullyShardedDataParallel(nn.Module):
"""
A wrapper for sharding Module parameters across data parallel workers. This
is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
``FullyShardedDataParallel`` is commonly shorten to FSDP.
FullyShardedDataParallel is commonly shorten to FSDP.
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336
.. _DeepSpeed: https://www.deepspeed.ai/
Example::
import torch
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
torch.cuda.set_device(device_id)
sharded_module = FSDP(my_module)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
x = sharded_module(x, y=3, z=torch.Tensor([1]))
loss = x.sum()
loss.backward()
optim.step()
>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> torch.cuda.set_device(device_id)
>>> sharded_module = FSDP(my_module)
>>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
>>> x = sharded_module(x, y=3, z=torch.Tensor([1]))
>>> loss = x.sum()
>>> loss.backward()
>>> optim.step()
.. warning::
The optimizer must be initialized *after* the module has been wrapped,
since FSDP will shard parameters in-place and this will break any
previously initialized optimizers.
.. warning:
Module should be already placed on the destination device or
device is set properly using torch.cuda.set_device(device_id).
FSDP will get compute device from module first, if module device
is CPU, FSDP will then get compute device from current device.
Args:
module (nn.Module):
module to be wrapped with FSDP.
@ -128,11 +142,30 @@ class FullyShardedDataParallel(nn.Module):
Note that this policy currently will only apply to child modules of
the passed in module. The remainder modules are always wrapped in
the returned FSDP root instance.
backward_prefetch: (Optional[BackwardPrefetch_]):
``default_auto_wrap_policy`` written in ``torch.distributed.fsdp.wrap`` is
an example of ``fsdp_auto_wrap_policy`` callable, this policy wraps layers
with parameter sizes larger than 100M. Users can supply the customized
``fsdp_auto_wrap_policy`` callable that should accept following arguments:
``module: nn.Module``, ``recurse: bool``, ``unwrapped_params: int``,
extra customized arguments could be added to the customized
``fsdp_auto_wrap_policy`` callable as well.
Example::
>>> def custom_auto_wrap_policy(
>>> module: nn.Module,
>>> recurse: bool,
>>> unwrapped_params: int,
>>> # These are customizable for this policy function.
>>> min_num_params: int = int(1e8),
>>> ) -> bool:
>>> return unwrapped_params >= min_num_params
backward_prefetch: (Optional[BackwardPrefetch]):
This is an experimental feature that is subject to change in the
the near future. It allows users to enable two different backward_prefetch
algorithms to help backward communication and computation overlapping.
Pros and cons of each algorithm is explained in the class ``BackwardPrefetch_``.
Pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
"""
def __init__(
@ -141,7 +174,7 @@ class FullyShardedDataParallel(nn.Module):
process_group: Optional[ProcessGroup] = None,
cpu_offload: Optional[CPUOffload] = None,
fsdp_auto_wrap_policy: Optional[Callable] = None,
backward_prefetch: Optional[BackwardPrefetch_] = None,
backward_prefetch: Optional[BackwardPrefetch] = None,
):
torch._C._log_api_usage_once("torch.distributed.fsdp")
super().__init__()
@ -566,7 +599,7 @@ class FullyShardedDataParallel(nn.Module):
def _need_prefetch_pre_backward_hook(self) -> bool:
if (
self.backward_prefetch == BackwardPrefetch_.BACKWARD_PRE
self.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
and self._fsdp_graph_order is not None
and self._my_fsdp_idx_in_graph is not None and self._my_fsdp_idx_in_graph > 0
and self._fsdp_graph_order[self._my_fsdp_idx_in_graph - 1].training_state != TrainingState_.BACKWARD_POST
@ -577,7 +610,7 @@ class FullyShardedDataParallel(nn.Module):
def _need_prefetch_post_backward_hook(self) -> bool:
if (
self.backward_prefetch == BackwardPrefetch_.BACKWARD_POST
self.backward_prefetch == BackwardPrefetch.BACKWARD_POST
and self._fsdp_graph_order is not None
and self._my_fsdp_idx_in_graph is not None and self._my_fsdp_idx_in_graph > 0
and self._fsdp_graph_order[self._my_fsdp_idx_in_graph - 1].training_state != TrainingState_.BACKWARD_POST

View File

@ -9,8 +9,8 @@ from unittest import mock
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._fsdp import FullyShardedDataParallel, CPUOffload
from torch.distributed._fsdp.fully_sharded_data_parallel import (
from torch.distributed.fsdp import FullyShardedDataParallel, CPUOffload
from torch.distributed.fsdp.fully_sharded_data_parallel import (
TrainingState_,
)
from torch.testing._internal.common_distributed import (