[CP][BE][2/2] Refactor the code structure (#166501)

Our CP codebase now contains several files and we are adding more. This
PR refactors the code to consolidate the files into a context_parallel
folder but keep the import so that the existing users of CP won't be
affected.

Unfortunately, we have to split this PR into two PRs as the PyTorch
infra cannot accept a PR with 3000+ LoC change and git cannot recognize
that _context_parallel/_attention.py is moved from _attention.py because
we want to keep BC.

This is the second PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166501
Approved by: https://github.com/Skylion007
ghstack dependencies: #166456
This commit is contained in:
Chien-Chin Huang 2025-10-30 09:43:01 -07:00 committed by PyTorch MergeBot
parent 45c3f02d69
commit 7e3b9d105e
3 changed files with 30 additions and 1646 deletions

File diff suppressed because it is too large Load Diff

View File

@ -29,7 +29,19 @@ from ._cp_custom_ops import flex_cp_allgather
from ._load_balancer import _create_default_load_balancer, _LoadBalancer
__all__ = ["context_parallel", "set_rotate_method"]
__all__ = [
"_CausalBehavior",
"_context_parallel_shard",
"_ContextParallel",
"_cp_options",
"_disable_context_parallel_dispatcher",
"_enable_context_parallel_dispatcher",
"_is_causal_behavior",
"_RotateMethod",
"context_parallel",
"context_parallel_unshard",
"set_rotate_method",
]
class _CausalBehavior(Enum):

View File

@ -479,7 +479,7 @@ class _PTRRLoadBalancer(_LoadBalancer):
def _create_default_load_balancer(
seq_length: int, world_size: int, device: str | torch.device
) -> Optional[_LoadBalancer]:
from .._attention import _cp_options
from ._attention import _cp_options
if _cp_options.enable_load_balance:
return _HeadTailLoadBalancer(seq_length, world_size, device)