mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CP][BE][2/2] Refactor the code structure (#166501)
Our CP codebase now contains several files and we are adding more. This PR refactors the code to consolidate the files into a context_parallel folder but keep the import so that the existing users of CP won't be affected. Unfortunately, we have to split this PR into two PRs as the PyTorch infra cannot accept a PR with 3000+ LoC change and git cannot recognize that _context_parallel/_attention.py is moved from _attention.py because we want to keep BC. This is the second PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166501 Approved by: https://github.com/Skylion007 ghstack dependencies: #166456
This commit is contained in:
parent
45c3f02d69
commit
7e3b9d105e
File diff suppressed because it is too large
Load Diff
|
|
@ -29,7 +29,19 @@ from ._cp_custom_ops import flex_cp_allgather
|
|||
from ._load_balancer import _create_default_load_balancer, _LoadBalancer
|
||||
|
||||
|
||||
__all__ = ["context_parallel", "set_rotate_method"]
|
||||
__all__ = [
|
||||
"_CausalBehavior",
|
||||
"_context_parallel_shard",
|
||||
"_ContextParallel",
|
||||
"_cp_options",
|
||||
"_disable_context_parallel_dispatcher",
|
||||
"_enable_context_parallel_dispatcher",
|
||||
"_is_causal_behavior",
|
||||
"_RotateMethod",
|
||||
"context_parallel",
|
||||
"context_parallel_unshard",
|
||||
"set_rotate_method",
|
||||
]
|
||||
|
||||
|
||||
class _CausalBehavior(Enum):
|
||||
|
|
|
|||
|
|
@ -479,7 +479,7 @@ class _PTRRLoadBalancer(_LoadBalancer):
|
|||
def _create_default_load_balancer(
|
||||
seq_length: int, world_size: int, device: str | torch.device
|
||||
) -> Optional[_LoadBalancer]:
|
||||
from .._attention import _cp_options
|
||||
from ._attention import _cp_options
|
||||
|
||||
if _cp_options.enable_load_balance:
|
||||
return _HeadTailLoadBalancer(seq_length, world_size, device)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user