pytorch/torch/distributed/fsdp/__init__.py
Wei Feng 6918f17114 [FSDP2] provide public API to share cuda streams across roots (#165024)
for pipeline parallel, we can have multiple FSDP roots (chunks)
```
model = nn.Sequential([chunk0, chunk1])
fully_shard(model.chunk0)
fully_shard(model.chunk1)
```

we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation
```
from torch.distributed.fsdp import share_comm_ctx
share_comm_ctx([model.chunk0, model.chunk1])
```

unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context`

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165024
Approved by: https://github.com/mori360
2025-10-14 17:50:46 +00:00

70 lines
1.8 KiB
Python

from ._flat_param import FlatParameter as FlatParameter
from ._fully_shard import (
CPUOffloadPolicy,
FSDPModule,
fully_shard,
MixedPrecisionPolicy,
OffloadPolicy,
register_fsdp_forward_method,
share_comm_ctx,
UnshardHandle,
)
from .fully_sharded_data_parallel import (
BackwardPrefetch,
CPUOffload,
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel,
LocalOptimStateDictConfig,
LocalStateDictConfig,
MixedPrecision,
OptimStateDictConfig,
OptimStateKeyType,
ShardedOptimStateDictConfig,
ShardedStateDictConfig,
ShardingStrategy,
StateDictConfig,
StateDictSettings,
StateDictType,
)
__all__ = [
# FSDP1
"BackwardPrefetch",
"CPUOffload",
"FullOptimStateDictConfig",
"FullStateDictConfig",
"FullyShardedDataParallel",
"LocalOptimStateDictConfig",
"LocalStateDictConfig",
"MixedPrecision",
"OptimStateDictConfig",
"OptimStateKeyType",
"ShardedOptimStateDictConfig",
"ShardedStateDictConfig",
"ShardingStrategy",
"StateDictConfig",
"StateDictSettings",
"StateDictType",
# FSDP2
"CPUOffloadPolicy",
"FSDPModule",
"fully_shard",
"MixedPrecisionPolicy",
"OffloadPolicy",
"register_fsdp_forward_method",
"UnshardHandle",
"share_comm_ctx",
]
# Set namespace for exposed private names
CPUOffloadPolicy.__module__ = "torch.distributed.fsdp"
FSDPModule.__module__ = "torch.distributed.fsdp"
fully_shard.__module__ = "torch.distributed.fsdp"
MixedPrecisionPolicy.__module__ = "torch.distributed.fsdp"
OffloadPolicy.__module__ = "torch.distributed.fsdp"
register_fsdp_forward_method.__module__ = "torch.distributed.fsdp"
UnshardHandle.__module__ = "torch.distributed.fsdp"
share_comm_ctx.__module__ = "torch.distributed.fsdp"