[CP][BE][1/2] Refactor the code structure (#166456)

Our CP codebase now contains several files and we are adding more. This PR refactors the code to consolidate the files into a context_parallel folder but keep the import so that the existing users of CP won't be affected.

Unfortunately, we have to split this PR into two PRs as the PyTorch infra cannot accept a PR with 3000+ LoC change and git cannot recognize that _context_parallel/_attention.py is moved from _attention.py because we want to keep BC.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166456
Approved by: https://github.com/Skylion007
This commit is contained in:
Chien-Chin Huang 2025-10-30 09:43:01 -07:00 committed by PyTorch MergeBot
parent ad3a56ab98
commit 56838bad5f
6 changed files with 1722 additions and 10 deletions

View File

@ -20,18 +20,18 @@ from torch.distributed.tensor.experimental._attention import (
_cp_options,
_disable_context_parallel_dispatcher,
_enable_context_parallel_dispatcher,
_HeadTailLoadBalancer,
_is_causal_behavior,
_LoadBalancer,
_PerDocumentHeadTailLoadBalancer,
_PTRRLoadBalancer,
_RotateMethod,
context_parallel,
context_parallel_unshard,
set_rotate_method,
)
from torch.distributed.tensor.experimental._cp_custom_ops import flex_cp_allgather
from torch.distributed.tensor.experimental._load_balancer import (
_HeadTailLoadBalancer,
_LoadBalancer,
_PerDocumentHeadTailLoadBalancer,
_PTRRLoadBalancer,
from torch.distributed.tensor.experimental._context_parallel._cp_custom_ops import (
flex_cp_allgather,
)
from torch.distributed.tensor.parallel import parallelize_module
from torch.nn.attention import sdpa_kernel, SDPBackend

View File

@ -17,9 +17,12 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import distribute_tensor, DTensor, Shard
from torch.distributed.tensor.experimental._load_balancer import (
from torch.distributed.tensor.experimental._context_parallel._load_balancer import (
_create_default_load_balancer,
_HeadTailLoadBalancer,
_LoadBalancer,
_PerDocumentHeadTailLoadBalancer,
_PTRRLoadBalancer,
)
from torch.distributed.tensor.parallel import ParallelStyle
from torch.nn.attention.flex_attention import (
@ -29,10 +32,26 @@ from torch.nn.attention.flex_attention import (
)
from torch.utils._pytree import tree_flatten, tree_unflatten
from ._cp_custom_ops import flex_cp_allgather
from ._context_parallel._cp_custom_ops import flex_cp_allgather
__all__ = ["context_parallel", "set_rotate_method"]
__all__ = [
"_CausalBehavior",
"_context_parallel_shard",
"_ContextParallel",
"_cp_options",
"_disable_context_parallel_dispatcher",
"_enable_context_parallel_dispatcher",
"_is_causal_behavior",
"_RotateMethod",
"context_parallel",
"context_parallel_unshard",
"set_rotate_method",
"_HeadTailLoadBalancer",
"_LoadBalancer",
"_PerDocumentHeadTailLoadBalancer",
"_PTRRLoadBalancer",
]
class _CausalBehavior(Enum):

View File

@ -0,0 +1,46 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Context Parallel components
from ._attention import (
_CausalBehavior,
_context_parallel_shard,
_ContextParallel,
_cp_options,
_disable_context_parallel_dispatcher,
_enable_context_parallel_dispatcher,
_is_causal_behavior,
_RotateMethod,
context_parallel,
context_parallel_unshard,
set_rotate_method,
)
from ._cp_custom_ops import flex_cp_allgather
from ._load_balancer import (
_HeadTailLoadBalancer,
_LoadBalancer,
_PerDocumentHeadTailLoadBalancer,
_PTRRLoadBalancer,
)
__all__ = [
# From _attention
"_CausalBehavior",
"_context_parallel_shard",
"_ContextParallel",
"_cp_options",
"_disable_context_parallel_dispatcher",
"_enable_context_parallel_dispatcher",
"_is_causal_behavior",
"_RotateMethod",
"context_parallel",
"context_parallel_unshard",
"set_rotate_method",
# From _cp_custom_ops
"flex_cp_allgather",
# From _load_balancer
"_HeadTailLoadBalancer",
"_LoadBalancer",
"_PerDocumentHeadTailLoadBalancer",
"_PTRRLoadBalancer",
]

File diff suppressed because it is too large Load Diff

View File

@ -479,7 +479,7 @@ class _PTRRLoadBalancer(_LoadBalancer):
def _create_default_load_balancer(
seq_length: int, world_size: int, device: str | torch.device
) -> Optional[_LoadBalancer]:
from torch.distributed.tensor.experimental._attention import _cp_options
from .._attention import _cp_options
if _cp_options.enable_load_balance:
return _HeadTailLoadBalancer(seq_length, world_size, device)