mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CP][BE][1/2] Refactor the code structure (#166456)
Our CP codebase now contains several files and we are adding more. This PR refactors the code to consolidate the files into a context_parallel folder but keep the import so that the existing users of CP won't be affected. Unfortunately, we have to split this PR into two PRs as the PyTorch infra cannot accept a PR with 3000+ LoC change and git cannot recognize that _context_parallel/_attention.py is moved from _attention.py because we want to keep BC. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166456 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
ad3a56ab98
commit
56838bad5f
|
|
@ -20,18 +20,18 @@ from torch.distributed.tensor.experimental._attention import (
|
|||
_cp_options,
|
||||
_disable_context_parallel_dispatcher,
|
||||
_enable_context_parallel_dispatcher,
|
||||
_HeadTailLoadBalancer,
|
||||
_is_causal_behavior,
|
||||
_LoadBalancer,
|
||||
_PerDocumentHeadTailLoadBalancer,
|
||||
_PTRRLoadBalancer,
|
||||
_RotateMethod,
|
||||
context_parallel,
|
||||
context_parallel_unshard,
|
||||
set_rotate_method,
|
||||
)
|
||||
from torch.distributed.tensor.experimental._cp_custom_ops import flex_cp_allgather
|
||||
from torch.distributed.tensor.experimental._load_balancer import (
|
||||
_HeadTailLoadBalancer,
|
||||
_LoadBalancer,
|
||||
_PerDocumentHeadTailLoadBalancer,
|
||||
_PTRRLoadBalancer,
|
||||
from torch.distributed.tensor.experimental._context_parallel._cp_custom_ops import (
|
||||
flex_cp_allgather,
|
||||
)
|
||||
from torch.distributed.tensor.parallel import parallelize_module
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
|
|
|
|||
|
|
@ -17,9 +17,12 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor import distribute_tensor, DTensor, Shard
|
||||
from torch.distributed.tensor.experimental._load_balancer import (
|
||||
from torch.distributed.tensor.experimental._context_parallel._load_balancer import (
|
||||
_create_default_load_balancer,
|
||||
_HeadTailLoadBalancer,
|
||||
_LoadBalancer,
|
||||
_PerDocumentHeadTailLoadBalancer,
|
||||
_PTRRLoadBalancer,
|
||||
)
|
||||
from torch.distributed.tensor.parallel import ParallelStyle
|
||||
from torch.nn.attention.flex_attention import (
|
||||
|
|
@ -29,10 +32,26 @@ from torch.nn.attention.flex_attention import (
|
|||
)
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
from ._cp_custom_ops import flex_cp_allgather
|
||||
from ._context_parallel._cp_custom_ops import flex_cp_allgather
|
||||
|
||||
|
||||
__all__ = ["context_parallel", "set_rotate_method"]
|
||||
__all__ = [
|
||||
"_CausalBehavior",
|
||||
"_context_parallel_shard",
|
||||
"_ContextParallel",
|
||||
"_cp_options",
|
||||
"_disable_context_parallel_dispatcher",
|
||||
"_enable_context_parallel_dispatcher",
|
||||
"_is_causal_behavior",
|
||||
"_RotateMethod",
|
||||
"context_parallel",
|
||||
"context_parallel_unshard",
|
||||
"set_rotate_method",
|
||||
"_HeadTailLoadBalancer",
|
||||
"_LoadBalancer",
|
||||
"_PerDocumentHeadTailLoadBalancer",
|
||||
"_PTRRLoadBalancer",
|
||||
]
|
||||
|
||||
|
||||
class _CausalBehavior(Enum):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Context Parallel components
|
||||
|
||||
from ._attention import (
|
||||
_CausalBehavior,
|
||||
_context_parallel_shard,
|
||||
_ContextParallel,
|
||||
_cp_options,
|
||||
_disable_context_parallel_dispatcher,
|
||||
_enable_context_parallel_dispatcher,
|
||||
_is_causal_behavior,
|
||||
_RotateMethod,
|
||||
context_parallel,
|
||||
context_parallel_unshard,
|
||||
set_rotate_method,
|
||||
)
|
||||
from ._cp_custom_ops import flex_cp_allgather
|
||||
from ._load_balancer import (
|
||||
_HeadTailLoadBalancer,
|
||||
_LoadBalancer,
|
||||
_PerDocumentHeadTailLoadBalancer,
|
||||
_PTRRLoadBalancer,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
# From _attention
|
||||
"_CausalBehavior",
|
||||
"_context_parallel_shard",
|
||||
"_ContextParallel",
|
||||
"_cp_options",
|
||||
"_disable_context_parallel_dispatcher",
|
||||
"_enable_context_parallel_dispatcher",
|
||||
"_is_causal_behavior",
|
||||
"_RotateMethod",
|
||||
"context_parallel",
|
||||
"context_parallel_unshard",
|
||||
"set_rotate_method",
|
||||
# From _cp_custom_ops
|
||||
"flex_cp_allgather",
|
||||
# From _load_balancer
|
||||
"_HeadTailLoadBalancer",
|
||||
"_LoadBalancer",
|
||||
"_PerDocumentHeadTailLoadBalancer",
|
||||
"_PTRRLoadBalancer",
|
||||
]
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -479,7 +479,7 @@ class _PTRRLoadBalancer(_LoadBalancer):
|
|||
def _create_default_load_balancer(
|
||||
seq_length: int, world_size: int, device: str | torch.device
|
||||
) -> Optional[_LoadBalancer]:
|
||||
from torch.distributed.tensor.experimental._attention import _cp_options
|
||||
from .._attention import _cp_options
|
||||
|
||||
if _cp_options.enable_load_balance:
|
||||
return _HeadTailLoadBalancer(seq_length, world_size, device)
|
||||
Loading…
Reference in New Issue
Block a user