mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
make fsdp folder to be public (#72084)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72084
make fsdp folder to be public
ghstack-source-id: 148173447
Test Plan: unit tests
Reviewed By: mrshenli
Differential Revision: D33903417
fbshipit-source-id: 7852a2adc4af09af48a5ffa52ebf210489f834d5
(cherry picked from commit bd06513cfe)
This commit is contained in:
parent
ed435e903f
commit
2336571cb7
7
docs/source/fsdp.rst
Normal file
7
docs/source/fsdp.rst
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
FullyShardedDataParallel
|
||||||
|
========================
|
||||||
|
|
||||||
|
.. automodule:: torch.distributed.fsdp
|
||||||
|
|
||||||
|
.. autoclass:: torch.distributed.fsdp.FullyShardedDataParallel
|
||||||
|
:members:
|
||||||
|
|
@ -61,6 +61,7 @@ Features described in this documentation are classified by release status:
|
||||||
torch.distributed <distributed>
|
torch.distributed <distributed>
|
||||||
torch.distributed.algorithms.join <distributed.algorithms.join>
|
torch.distributed.algorithms.join <distributed.algorithms.join>
|
||||||
torch.distributed.elastic <distributed.elastic>
|
torch.distributed.elastic <distributed.elastic>
|
||||||
|
torch.distributed.fsdp <fsdp>
|
||||||
torch.distributed.optim <distributed.optim>
|
torch.distributed.optim <distributed.optim>
|
||||||
torch.distributions <distributions>
|
torch.distributions <distributions>
|
||||||
torch.fft <fft>
|
torch.fft <fft>
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
from torch.distributed._fsdp.flatten_params_wrapper import FlattenParamsWrapper
|
from torch.distributed.fsdp.flatten_params_wrapper import FlattenParamsWrapper
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._fsdp.fully_sharded_data_parallel import (
|
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||||
FullyShardedDataParallel as FSDP,
|
FullyShardedDataParallel as FSDP,
|
||||||
CPUOffload,
|
CPUOffload,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -26,8 +26,8 @@ from torch.testing._internal.common_utils import (
|
||||||
run_tests,
|
run_tests,
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch.distributed._fsdp import CPUOffload
|
from torch.distributed.fsdp import CPUOffload
|
||||||
from torch.distributed._fsdp.fully_sharded_data_parallel import BackwardPrefetch_
|
from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch
|
||||||
|
|
||||||
|
|
||||||
if not dist.is_available():
|
if not dist.is_available():
|
||||||
|
|
@ -69,7 +69,7 @@ class TestParityWithDDP(FSDPTest):
|
||||||
)
|
)
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"backward_prefetch",
|
"backward_prefetch",
|
||||||
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
|
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
|
||||||
)
|
)
|
||||||
def test_nested_wrapped_model(self, cpu_offload, backward_prefetch):
|
def test_nested_wrapped_model(self, cpu_offload, backward_prefetch):
|
||||||
init_modes = self._get_init_modes_for_test(cpu_offload)
|
init_modes = self._get_init_modes_for_test(cpu_offload)
|
||||||
|
|
@ -89,7 +89,7 @@ class TestParityWithDDP(FSDPTest):
|
||||||
)
|
)
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"backward_prefetch",
|
"backward_prefetch",
|
||||||
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
|
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
|
||||||
)
|
)
|
||||||
def test_nested_all_wrapped_model(self, cpu_offload, backward_prefetch):
|
def test_nested_all_wrapped_model(self, cpu_offload, backward_prefetch):
|
||||||
init_modes = self._get_init_modes_for_test(cpu_offload)
|
init_modes = self._get_init_modes_for_test(cpu_offload)
|
||||||
|
|
@ -110,7 +110,7 @@ class TestParityWithDDP(FSDPTest):
|
||||||
)
|
)
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"backward_prefetch",
|
"backward_prefetch",
|
||||||
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
|
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
|
||||||
)
|
)
|
||||||
def test_transformer_parameterized(self, cpu_offload, backward_prefetch):
|
def test_transformer_parameterized(self, cpu_offload, backward_prefetch):
|
||||||
init_modes = self._get_init_modes_for_test(cpu_offload)
|
init_modes = self._get_init_modes_for_test(cpu_offload)
|
||||||
|
|
@ -130,7 +130,7 @@ class TestParityWithDDP(FSDPTest):
|
||||||
)
|
)
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"backward_prefetch",
|
"backward_prefetch",
|
||||||
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
|
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
|
||||||
)
|
)
|
||||||
def test_delayed_optim_step(self, cpu_offload, backward_prefetch):
|
def test_delayed_optim_step(self, cpu_offload, backward_prefetch):
|
||||||
# We use a model with a long CUDA delay right before the optimizer step.
|
# We use a model with a long CUDA delay right before the optimizer step.
|
||||||
|
|
@ -156,7 +156,7 @@ class TestParityWithDDP(FSDPTest):
|
||||||
)
|
)
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"backward_prefetch",
|
"backward_prefetch",
|
||||||
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
|
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
|
||||||
)
|
)
|
||||||
def test_delayed_reduce_scatter(self, cpu_offload, backward_prefetch):
|
def test_delayed_reduce_scatter(self, cpu_offload, backward_prefetch):
|
||||||
# We insert a delay in the torch.distributed._reduce_scatter_base op, so that
|
# We insert a delay in the torch.distributed._reduce_scatter_base op, so that
|
||||||
|
|
@ -186,7 +186,7 @@ class TestParityWithDDP(FSDPTest):
|
||||||
)
|
)
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"backward_prefetch",
|
"backward_prefetch",
|
||||||
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
|
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
|
||||||
)
|
)
|
||||||
def test_mixture_of_experts(self, cpu_offload, backward_prefetch):
|
def test_mixture_of_experts(self, cpu_offload, backward_prefetch):
|
||||||
init_modes = self._get_init_modes_for_test(cpu_offload)
|
init_modes = self._get_init_modes_for_test(cpu_offload)
|
||||||
|
|
@ -209,7 +209,7 @@ class TestParityWithDDP(FSDPTest):
|
||||||
)
|
)
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"backward_prefetch",
|
"backward_prefetch",
|
||||||
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
|
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
|
||||||
)
|
)
|
||||||
def test_mixture_of_experts_with_delay_before_free(self, cpu_offload, backward_prefetch):
|
def test_mixture_of_experts_with_delay_before_free(self, cpu_offload, backward_prefetch):
|
||||||
init_modes = self._get_init_modes_for_test(cpu_offload)
|
init_modes = self._get_init_modes_for_test(cpu_offload)
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||||
from torch.testing._internal.common_fsdp import (
|
from torch.testing._internal.common_fsdp import (
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.nn import Linear, Module
|
from torch.nn import Linear, Module
|
||||||
from torch.optim import SGD
|
from torch.optim import SGD
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||||
from torch.testing._internal.common_fsdp import (
|
from torch.testing._internal.common_fsdp import (
|
||||||
FSDPTest,
|
FSDPTest,
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.nn import Linear, Module
|
from torch.nn import Linear, Module
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
from torch.optim import SGD
|
from torch.optim import SGD
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.nn import Linear, Module, Sequential
|
from torch.nn import Linear, Module, Sequential
|
||||||
from torch.optim import SGD
|
from torch.optim import SGD
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
from torch.cuda import Event
|
from torch.cuda import Event
|
||||||
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||||
from torch.testing._internal.common_fsdp import (
|
from torch.testing._internal.common_fsdp import (
|
||||||
FSDPTest,
|
FSDPTest,
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP, CPUOffload
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, CPUOffload
|
||||||
from torch.nn import Linear, Module
|
from torch.nn import Linear, Module
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
from torch.optim import SGD
|
from torch.optim import SGD
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.nn import Linear
|
from torch.nn import Linear
|
||||||
from torch.optim import SGD
|
from torch.optim import SGD
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
from torch.distributed._fsdp.utils import (
|
from torch.distributed.fsdp.utils import (
|
||||||
_apply_to_tensors,
|
_apply_to_tensors,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
|
|
|
||||||
|
|
@ -9,12 +9,12 @@ import unittest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.distributed._fsdp.fully_sharded_data_parallel import (
|
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||||
FullyShardedDataParallel as FSDP,
|
FullyShardedDataParallel as FSDP,
|
||||||
CPUOffload,
|
CPUOffload,
|
||||||
BackwardPrefetch_,
|
BackwardPrefetch,
|
||||||
)
|
)
|
||||||
from torch.distributed._fsdp.wrap import (
|
from torch.distributed.fsdp.wrap import (
|
||||||
default_auto_wrap_policy,
|
default_auto_wrap_policy,
|
||||||
enable_wrap,
|
enable_wrap,
|
||||||
wrap,
|
wrap,
|
||||||
|
|
@ -132,7 +132,7 @@ class TestFSDPWrap(FSDPTest):
|
||||||
)
|
)
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"backward_prefetch",
|
"backward_prefetch",
|
||||||
[BackwardPrefetch_.BACKWARD_POST, BackwardPrefetch_.BACKWARD_PRE]
|
[BackwardPrefetch.BACKWARD_POST, BackwardPrefetch.BACKWARD_PRE]
|
||||||
)
|
)
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"fsdp_init_mode",
|
"fsdp_init_mode",
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from abc import ABC
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Dict, Type
|
from typing import Dict, Type
|
||||||
|
|
||||||
from torch.distributed._fsdp import FullyShardedDataParallel
|
from torch.distributed.fsdp import FullyShardedDataParallel
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.distributed.optim import as_functional_optim
|
from torch.distributed.optim import as_functional_optim
|
||||||
|
|
|
||||||
|
|
@ -36,11 +36,19 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CPUOffload:
|
class CPUOffload:
|
||||||
|
"""
|
||||||
|
CPU offlaoding config. Currently, only parameter and gradient CPU
|
||||||
|
offload are supported.
|
||||||
|
offload_params: Offloading parameters to CPUs when these parameters are
|
||||||
|
not used for computation on GPUs. This implicitly enables
|
||||||
|
gradient offloading to CPUs in order for parameters and
|
||||||
|
gradients to be on the same device to work with optimizer.
|
||||||
|
"""
|
||||||
offload_params: bool = False
|
offload_params: bool = False
|
||||||
# TODO: state dict offloading, activation offloading
|
# TODO: state dict offloading
|
||||||
# https://github.com/pytorch/pytorch/issues/67224
|
# https://github.com/pytorch/pytorch/issues/67224
|
||||||
|
|
||||||
class BackwardPrefetch_(Enum):
|
class BackwardPrefetch(Enum):
|
||||||
"""
|
"""
|
||||||
Specify where to prefetch next layer's full parameters
|
Specify where to prefetch next layer's full parameters
|
||||||
during backward pass.
|
during backward pass.
|
||||||
|
|
@ -88,28 +96,34 @@ class FullyShardedDataParallel(nn.Module):
|
||||||
"""
|
"""
|
||||||
A wrapper for sharding Module parameters across data parallel workers. This
|
A wrapper for sharding Module parameters across data parallel workers. This
|
||||||
is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
|
is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
|
||||||
``FullyShardedDataParallel`` is commonly shorten to FSDP.
|
FullyShardedDataParallel is commonly shorten to FSDP.
|
||||||
|
|
||||||
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336
|
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336
|
||||||
.. _DeepSpeed: https://www.deepspeed.ai/
|
.. _DeepSpeed: https://www.deepspeed.ai/
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
import torch
|
|
||||||
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP
|
>>> import torch
|
||||||
torch.cuda.set_device(device_id)
|
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
sharded_module = FSDP(my_module)
|
>>> torch.cuda.set_device(device_id)
|
||||||
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
|
>>> sharded_module = FSDP(my_module)
|
||||||
x = sharded_module(x, y=3, z=torch.Tensor([1]))
|
>>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
|
||||||
loss = x.sum()
|
>>> x = sharded_module(x, y=3, z=torch.Tensor([1]))
|
||||||
loss.backward()
|
>>> loss = x.sum()
|
||||||
optim.step()
|
>>> loss.backward()
|
||||||
|
>>> optim.step()
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
The optimizer must be initialized *after* the module has been wrapped,
|
The optimizer must be initialized *after* the module has been wrapped,
|
||||||
since FSDP will shard parameters in-place and this will break any
|
since FSDP will shard parameters in-place and this will break any
|
||||||
previously initialized optimizers.
|
previously initialized optimizers.
|
||||||
|
|
||||||
.. warning:
|
.. warning:
|
||||||
Module should be already placed on the destination device or
|
Module should be already placed on the destination device or
|
||||||
device is set properly using torch.cuda.set_device(device_id).
|
device is set properly using torch.cuda.set_device(device_id).
|
||||||
FSDP will get compute device from module first, if module device
|
FSDP will get compute device from module first, if module device
|
||||||
is CPU, FSDP will then get compute device from current device.
|
is CPU, FSDP will then get compute device from current device.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module (nn.Module):
|
module (nn.Module):
|
||||||
module to be wrapped with FSDP.
|
module to be wrapped with FSDP.
|
||||||
|
|
@ -128,11 +142,30 @@ class FullyShardedDataParallel(nn.Module):
|
||||||
Note that this policy currently will only apply to child modules of
|
Note that this policy currently will only apply to child modules of
|
||||||
the passed in module. The remainder modules are always wrapped in
|
the passed in module. The remainder modules are always wrapped in
|
||||||
the returned FSDP root instance.
|
the returned FSDP root instance.
|
||||||
backward_prefetch: (Optional[BackwardPrefetch_]):
|
``default_auto_wrap_policy`` written in ``torch.distributed.fsdp.wrap`` is
|
||||||
|
an example of ``fsdp_auto_wrap_policy`` callable, this policy wraps layers
|
||||||
|
with parameter sizes larger than 100M. Users can supply the customized
|
||||||
|
``fsdp_auto_wrap_policy`` callable that should accept following arguments:
|
||||||
|
``module: nn.Module``, ``recurse: bool``, ``unwrapped_params: int``,
|
||||||
|
extra customized arguments could be added to the customized
|
||||||
|
``fsdp_auto_wrap_policy`` callable as well.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> def custom_auto_wrap_policy(
|
||||||
|
>>> module: nn.Module,
|
||||||
|
>>> recurse: bool,
|
||||||
|
>>> unwrapped_params: int,
|
||||||
|
>>> # These are customizable for this policy function.
|
||||||
|
>>> min_num_params: int = int(1e8),
|
||||||
|
>>> ) -> bool:
|
||||||
|
>>> return unwrapped_params >= min_num_params
|
||||||
|
|
||||||
|
backward_prefetch: (Optional[BackwardPrefetch]):
|
||||||
This is an experimental feature that is subject to change in the
|
This is an experimental feature that is subject to change in the
|
||||||
the near future. It allows users to enable two different backward_prefetch
|
the near future. It allows users to enable two different backward_prefetch
|
||||||
algorithms to help backward communication and computation overlapping.
|
algorithms to help backward communication and computation overlapping.
|
||||||
Pros and cons of each algorithm is explained in the class ``BackwardPrefetch_``.
|
Pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -141,7 +174,7 @@ class FullyShardedDataParallel(nn.Module):
|
||||||
process_group: Optional[ProcessGroup] = None,
|
process_group: Optional[ProcessGroup] = None,
|
||||||
cpu_offload: Optional[CPUOffload] = None,
|
cpu_offload: Optional[CPUOffload] = None,
|
||||||
fsdp_auto_wrap_policy: Optional[Callable] = None,
|
fsdp_auto_wrap_policy: Optional[Callable] = None,
|
||||||
backward_prefetch: Optional[BackwardPrefetch_] = None,
|
backward_prefetch: Optional[BackwardPrefetch] = None,
|
||||||
):
|
):
|
||||||
torch._C._log_api_usage_once("torch.distributed.fsdp")
|
torch._C._log_api_usage_once("torch.distributed.fsdp")
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -566,7 +599,7 @@ class FullyShardedDataParallel(nn.Module):
|
||||||
|
|
||||||
def _need_prefetch_pre_backward_hook(self) -> bool:
|
def _need_prefetch_pre_backward_hook(self) -> bool:
|
||||||
if (
|
if (
|
||||||
self.backward_prefetch == BackwardPrefetch_.BACKWARD_PRE
|
self.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
|
||||||
and self._fsdp_graph_order is not None
|
and self._fsdp_graph_order is not None
|
||||||
and self._my_fsdp_idx_in_graph is not None and self._my_fsdp_idx_in_graph > 0
|
and self._my_fsdp_idx_in_graph is not None and self._my_fsdp_idx_in_graph > 0
|
||||||
and self._fsdp_graph_order[self._my_fsdp_idx_in_graph - 1].training_state != TrainingState_.BACKWARD_POST
|
and self._fsdp_graph_order[self._my_fsdp_idx_in_graph - 1].training_state != TrainingState_.BACKWARD_POST
|
||||||
|
|
@ -577,7 +610,7 @@ class FullyShardedDataParallel(nn.Module):
|
||||||
|
|
||||||
def _need_prefetch_post_backward_hook(self) -> bool:
|
def _need_prefetch_post_backward_hook(self) -> bool:
|
||||||
if (
|
if (
|
||||||
self.backward_prefetch == BackwardPrefetch_.BACKWARD_POST
|
self.backward_prefetch == BackwardPrefetch.BACKWARD_POST
|
||||||
and self._fsdp_graph_order is not None
|
and self._fsdp_graph_order is not None
|
||||||
and self._my_fsdp_idx_in_graph is not None and self._my_fsdp_idx_in_graph > 0
|
and self._my_fsdp_idx_in_graph is not None and self._my_fsdp_idx_in_graph > 0
|
||||||
and self._fsdp_graph_order[self._my_fsdp_idx_in_graph - 1].training_state != TrainingState_.BACKWARD_POST
|
and self._fsdp_graph_order[self._my_fsdp_idx_in_graph - 1].training_state != TrainingState_.BACKWARD_POST
|
||||||
|
|
@ -9,8 +9,8 @@ from unittest import mock
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._fsdp import FullyShardedDataParallel, CPUOffload
|
from torch.distributed.fsdp import FullyShardedDataParallel, CPUOffload
|
||||||
from torch.distributed._fsdp.fully_sharded_data_parallel import (
|
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||||
TrainingState_,
|
TrainingState_,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_distributed import (
|
from torch.testing._internal.common_distributed import (
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user