mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
for pipeline parallel, we can have multiple FSDP roots (chunks) ``` model = nn.Sequential([chunk0, chunk1]) fully_shard(model.chunk0) fully_shard(model.chunk1) ``` we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation ``` from torch.distributed.fsdp import share_comm_ctx share_comm_ctx([model.chunk0, model.chunk1]) ``` unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context` Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/165024 Approved by: https://github.com/mori360
70 lines
1.8 KiB
Python
70 lines
1.8 KiB
Python
from ._flat_param import FlatParameter as FlatParameter
|
|
from ._fully_shard import (
|
|
CPUOffloadPolicy,
|
|
FSDPModule,
|
|
fully_shard,
|
|
MixedPrecisionPolicy,
|
|
OffloadPolicy,
|
|
register_fsdp_forward_method,
|
|
share_comm_ctx,
|
|
UnshardHandle,
|
|
)
|
|
from .fully_sharded_data_parallel import (
|
|
BackwardPrefetch,
|
|
CPUOffload,
|
|
FullOptimStateDictConfig,
|
|
FullStateDictConfig,
|
|
FullyShardedDataParallel,
|
|
LocalOptimStateDictConfig,
|
|
LocalStateDictConfig,
|
|
MixedPrecision,
|
|
OptimStateDictConfig,
|
|
OptimStateKeyType,
|
|
ShardedOptimStateDictConfig,
|
|
ShardedStateDictConfig,
|
|
ShardingStrategy,
|
|
StateDictConfig,
|
|
StateDictSettings,
|
|
StateDictType,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
# FSDP1
|
|
"BackwardPrefetch",
|
|
"CPUOffload",
|
|
"FullOptimStateDictConfig",
|
|
"FullStateDictConfig",
|
|
"FullyShardedDataParallel",
|
|
"LocalOptimStateDictConfig",
|
|
"LocalStateDictConfig",
|
|
"MixedPrecision",
|
|
"OptimStateDictConfig",
|
|
"OptimStateKeyType",
|
|
"ShardedOptimStateDictConfig",
|
|
"ShardedStateDictConfig",
|
|
"ShardingStrategy",
|
|
"StateDictConfig",
|
|
"StateDictSettings",
|
|
"StateDictType",
|
|
# FSDP2
|
|
"CPUOffloadPolicy",
|
|
"FSDPModule",
|
|
"fully_shard",
|
|
"MixedPrecisionPolicy",
|
|
"OffloadPolicy",
|
|
"register_fsdp_forward_method",
|
|
"UnshardHandle",
|
|
"share_comm_ctx",
|
|
]
|
|
|
|
# Set namespace for exposed private names
|
|
CPUOffloadPolicy.__module__ = "torch.distributed.fsdp"
|
|
FSDPModule.__module__ = "torch.distributed.fsdp"
|
|
fully_shard.__module__ = "torch.distributed.fsdp"
|
|
MixedPrecisionPolicy.__module__ = "torch.distributed.fsdp"
|
|
OffloadPolicy.__module__ = "torch.distributed.fsdp"
|
|
register_fsdp_forward_method.__module__ = "torch.distributed.fsdp"
|
|
UnshardHandle.__module__ = "torch.distributed.fsdp"
|
|
share_comm_ctx.__module__ = "torch.distributed.fsdp"
|