pytorch/test/distributed/tensor/test_attention.py
Chien-Chin Huang 56838bad5f [CP][BE][1/2] Refactor the code structure (#166456)
Our CP codebase now contains several files and we are adding more. This PR refactors the code to consolidate the files into a context_parallel folder but keep the import so that the existing users of CP won't be affected.

Unfortunately, we have to split this PR into two PRs as the PyTorch infra cannot accept a PR with 3000+ LoC change and git cannot recognize that _context_parallel/_attention.py is moved from _attention.py because we want to keep BC.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166456
Approved by: https://github.com/Skylion007
2025-10-30 19:46:49 +00:00

849 lines
28 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import itertools
import random
import unittest
from typing import Any, Callable, ClassVar, Optional
import torch
import torch.distributed as dist
import torch.distributed.distributed_c10d as c10d
import torch.nn.functional as F
from torch import Tensor
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DeviceMesh
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.experimental._attention import (
_CausalBehavior,
_context_parallel_shard,
_ContextParallel,
_cp_options,
_disable_context_parallel_dispatcher,
_enable_context_parallel_dispatcher,
_HeadTailLoadBalancer,
_is_causal_behavior,
_LoadBalancer,
_PerDocumentHeadTailLoadBalancer,
_PTRRLoadBalancer,
_RotateMethod,
context_parallel,
context_parallel_unshard,
set_rotate_method,
)
from torch.distributed.tensor.experimental._context_parallel._cp_custom_ops import (
flex_cp_allgather,
)
from torch.distributed.tensor.parallel import parallelize_module
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.attention.flex_attention import (
_mask_mod_signature,
AuxOutput,
AuxRequest,
BlockMask,
create_block_mask,
flex_attention,
)
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_FUSED_ATTENTION,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
DTensorTestBase,
map_local_tensor_for_rank,
with_comms,
)
c10d_functional = torch.ops.c10d_functional
backends = []
if PLATFORM_SUPPORTS_FLASH_ATTENTION:
backends.append(SDPBackend.FLASH_ATTENTION)
if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
backends.append(SDPBackend.EFFICIENT_ATTENTION)
if PLATFORM_SUPPORTS_CUDNN_ATTENTION:
backends.append(SDPBackend.CUDNN_ATTENTION)
rotater_enum_to_str = {
_RotateMethod.ALL_GATHER: "allgather",
_RotateMethod.ALL_TO_ALL: "alltoall",
} # mapping from _RotateMethod enum to string
class SDPAWrapper(torch.nn.Module):
def __init__(self, compiled: bool, backend: SDPBackend) -> None:
super().__init__()
if compiled:
self.sdpa = torch.compile(
F.scaled_dot_product_attention,
fullgraph=True,
backend="aot_eager",
)
else:
self.sdpa = F.scaled_dot_product_attention
self.backend = backend
def forward(self, *args: object, **kwargs: object) -> torch.Tensor:
with sdpa_kernel(self.backend):
return self.sdpa(*args, **kwargs)
class RingAttentionTest(DTensorTestBase):
@property
def world_size(self) -> int:
return torch.cuda.device_count()
@property
def destroy_pg_upon_exit(self) -> bool:
return False
@skip_if_lt_x_gpu(2)
@skipIfRocm # Missing _c10d_functional_autograd::all_to_all_single
@unittest.skipIf(
not PLATFORM_SUPPORTS_FUSED_ATTENTION,
"Does not support flash nor efficient attention",
)
@with_comms
def test_ring_attention_sdpa(self) -> None:
self.run_subtests(
{
"is_causal": [True, False],
"compiled": [True, False],
"backend": backends,
"load_balance": [True, False],
"rotater": [_RotateMethod.ALL_TO_ALL, _RotateMethod.ALL_GATHER],
"test_forward_only": [True, False],
"use_context": [True, False],
},
self._test_ring_attention_sdpa,
)
def _ring_attention_sdpa(
self,
cp_q: torch.Tensor,
cp_k: torch.Tensor,
cp_v: torch.Tensor,
*,
fn_eval: Callable,
mesh: DeviceMesh,
seq_dim: int,
is_causal: bool,
compiled: bool,
backend: SDPBackend,
rotater: _RotateMethod,
test_forward_only: bool,
load_balance: bool,
use_context: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if not use_context:
cp_plan = _ContextParallel(
seq_dim=seq_dim,
attention_type=_ContextParallel.AttentionType.SDPA,
)
attention = SDPAWrapper(compiled=compiled, backend=backend)
attention = parallelize_module(attention, mesh, cp_plan)
if load_balance:
seq_len = cp_q.size(seq_dim)
load_balancer = _HeadTailLoadBalancer(seq_len, mesh.size(), cp_q.device)
else:
load_balancer = None
cp_q, cp_k, cp_v = _context_parallel_shard(
mesh, (cp_q, cp_k, cp_v), (seq_dim,) * 3, load_balancer=load_balancer
)
_enable_context_parallel_dispatcher()
else:
# Theoretically, context_parallel() should not be used to shard
# parameters because when require_grad is True, resize_ is not
# allowed. But requires_grad of cp_q, cp_k, and cp_v are False
# now. So we can just use context_parallel() to shard q, k, v.
# In reality, context_parallel() should be used to shard the input.
# In reality, context_parallel() should only be used to shard
# the model inputs (batch).
_cp_options.enable_load_balance = load_balance
cp_context = context_parallel(
mesh, buffers=(cp_q, cp_k, cp_v), buffer_seq_dims=(seq_dim,) * 3
)
cp_context.__enter__()
# NOTE: This demonstrates that monkey patching is not fully reliable.
# If we use SDPAWrapper directly, the monkey patching dispatch mode
# does not function correctly. To ensure proper behavior,
# F.scaled_dot_product_attention must be referenced within the
# context_parallel() scope.
attention = F.scaled_dot_product_attention
if compiled:
attention = torch.compile(
attention, fullgraph=True, backend="aot_eager"
)
for target in [cp_q, cp_k, cp_v]:
target.requires_grad = True
with CommDebugMode() as comm_mode:
with sdpa_kernel(backend):
cp_out = fn_eval(
attention,
cp_q,
cp_k,
cp_v,
is_causal=is_causal,
)
if not compiled and rotater == _RotateMethod.ALL_TO_ALL:
# Compiler and CommDebugMode do not work well together.
expect_all2all_count = (
self.world_size - 1
if test_forward_only
else self.world_size * 3 - 2
)
self.assertDictEqual(
comm_mode.get_comm_counts(),
{c10d_functional.all_to_all_single: expect_all2all_count},
)
cp_dq, cp_dk, cp_dv = cp_q.grad, cp_k.grad, cp_v.grad
for target in [cp_q, cp_k, cp_v]:
target.requires_grad = False
if not use_context:
_disable_context_parallel_dispatcher()
else:
cp_context.__exit__(None, None, None)
return cp_out, cp_dq, cp_dk, cp_dv
def _test_ring_attention_sdpa(
self,
is_causal: bool,
compiled: bool,
backend: SDPBackend,
load_balance: bool,
rotater: _RotateMethod,
test_forward_only: bool,
use_context: bool,
) -> None:
def fn_eval(fn, *args, **kwargs):
if test_forward_only:
with torch.no_grad():
return fn(*args, **kwargs)
else:
out = fn(*args, **kwargs)
out.sum().backward()
return out
if load_balance and not is_causal:
return
set_rotate_method(rotater_enum_to_str[rotater])
self.assertEqual(_cp_options.rotate_method, rotater)
device_mesh = DeviceMesh(self.device_type, torch.arange(0, self.world_size))
dtype = torch.bfloat16
bs = 8
seq_length = 1024
seq_dim = 2
dim = 32
nheads = 8
torch.manual_seed(10)
dtype = (
torch.bfloat16
if backend == SDPBackend.FLASH_ATTENTION
or backend == SDPBackend.CUDNN_ATTENTION
else torch.float32
)
q, k, v = [
torch.rand(
(bs, nheads, seq_length * self.world_size, dim),
device=self.device_type,
dtype=dtype,
requires_grad=True,
)
for _ in range(3)
]
# Ensure all ranks have the same initialization data.
with torch.no_grad():
dist.broadcast(q, src=0)
dist.broadcast(k, src=0)
dist.broadcast(v, src=0)
with sdpa_kernel(backend):
out = fn_eval(F.scaled_dot_product_attention, q, k, v, is_causal=is_causal)
cp_q, cp_k, cp_v = [target.detach().clone() for target in [q, k, v]]
cp_out, cp_dq, cp_dk, cp_dv = self._ring_attention_sdpa(
cp_q,
cp_k,
cp_v,
fn_eval=fn_eval,
mesh=device_mesh,
seq_dim=seq_dim,
is_causal=is_causal,
compiled=compiled,
backend=backend,
rotater=rotater,
test_forward_only=test_forward_only,
load_balance=load_balance,
use_context=use_context,
)
# Due to numerical error, we need to choose different atol for different
# attention kernels
(cp_out,) = context_parallel_unshard(device_mesh, [cp_out], [seq_dim])
atol = (
2e-06
if backend == SDPBackend.EFFICIENT_ATTENTION
else 8e-3 * self.world_size
)
rtol = (
1e-05
if backend == SDPBackend.EFFICIENT_ATTENTION
else 1e-3 * self.world_size
)
torch.testing.assert_close(out, cp_out, atol=atol, rtol=rtol)
if test_forward_only:
return
cp_dq, cp_dk, cp_dv = context_parallel_unshard(
device_mesh,
[cp_dq, cp_dk, cp_dv],
[seq_dim] * 3,
)
torch.testing.assert_close(q.grad, cp_dq, atol=atol, rtol=rtol)
torch.testing.assert_close(k.grad, cp_dk, atol=atol, rtol=rtol)
torch.testing.assert_close(v.grad, cp_dv, atol=atol, rtol=rtol)
def test_is_causal_behavior(self) -> None:
_cp_options.enable_load_balance = False
self.assertEqual(
_is_causal_behavior(rank=0, world_size=4, i=0, is_causal=False),
_CausalBehavior.NOT_IS_CAUSAL,
)
ranks = [
[_CausalBehavior.IS_CAUSAL, _CausalBehavior.SKIP],
[_CausalBehavior.IS_CAUSAL, _CausalBehavior.NOT_IS_CAUSAL],
]
for rank, iters in enumerate(ranks):
for i, behavior in enumerate(iters):
self.assertEqual(
_is_causal_behavior(rank=rank, world_size=2, i=i, is_causal=True),
behavior,
)
_cp_options.enable_load_balance = True
ranks = [
[_CausalBehavior.IS_CAUSAL, _CausalBehavior.NOT_IS_CAUSAL],
[_CausalBehavior.IS_CAUSAL, _CausalBehavior.NOT_IS_CAUSAL],
]
for rank, iters in enumerate(ranks):
for i, behavior in enumerate(iters):
self.assertEqual(
_is_causal_behavior(rank=rank, world_size=2, i=i, is_causal=True),
behavior,
)
# Compile the flex_attention function
compiled_flex_attention = torch.compile(flex_attention, dynamic=False, fullgraph=True)
compiled_create_block_mask = torch.compile(
create_block_mask, dynamic=False, fullgraph=True
)
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
# copied from https://github.com/meta-pytorch/attention-gym/blob/main/attn_gym/masks/document_mask.py
def generate_random_lengths(total_length, num_documents) -> list[int]:
# Initialize all lengths to 1 to ensure each document has at least one token
lengths = [1] * num_documents
remaining_length = total_length - num_documents
# Randomly distribute the remaining length
for _ in range(remaining_length):
index = random.randint(0, num_documents - 1)
lengths[index] += 1
return lengths
def generate_random_lengths_in_chunks(
total_length, num_documents, chunk_size
) -> list[int]:
# Generate a list of random document lengths so that each document contains
# some number of chunks of size `chunk_size`. This means each document's length
# must be a multiple of `chunk_size`. Besides, the lengths of all the documents
# sum up to `total_length`.
num_chunks = total_length // chunk_size
assert total_length % chunk_size == 0 and num_chunks >= num_documents
num_chunks_per_document = [1] * num_documents
remaining_chunks = num_chunks - num_documents
# Randomly distribute the remaining chunks
for _ in range(remaining_chunks):
index = random.randint(0, num_documents - 1) # document_id
num_chunks_per_document[index] += 1
return [num_chunks * chunk_size for num_chunks in num_chunks_per_document]
def length_to_offsets(lengths: list[list[int]], device: str | torch.device) -> Tensor:
"""Converts a list of lengths to a list of offsets.
Args:
lengths: A list of lengths.
"""
offsets = [[0] + lengths_in_batch for lengths_in_batch in lengths]
offsets = torch.tensor(offsets, device=device, dtype=torch.int32)
offsets = torch.cumsum(offsets, dim=-1)
return offsets
def _offsets_to_doc_ids_tensor(offsets):
doc_ids = []
device = offsets.device
for batch_idx in range(offsets.size(0)):
counts = offsets[batch_idx][1:] - offsets[batch_idx][:-1]
doc_id = torch.repeat_interleave(
torch.arange(len(counts), device=device, dtype=torch.int32), counts
)
doc_ids.append(doc_id)
return torch.stack(doc_ids)
def generate_doc_mask_mod(
mask_mod: _mask_mod_signature, offsets: Tensor
) -> _mask_mod_signature:
"""Generates mask mods that apply to inputs to flex attention in the sequence stacked
format.
Args:
mask_mod: The mask mod to apply to the documents
offsets: This tensor should be of shape(num_documents + 1)
this should contain the cumulative counts of document tokens.
e.g. if you have 3 documents of length 2, 4, 3 then
offsets = [0, 2, 6, 9]
Note:
What is the sequence stacked format? When assembling batches of inputs, we
take multiple sequences and stack them together to form 1 large sequence. We then
use masking to ensure that the attention scores are only applied to tokens within
the same document.
"""
document_id = _offsets_to_doc_ids_tensor(offsets)
def doc_mask_mod(b, h, q_idx, kv_idx):
same_doc = document_id[b][q_idx] == document_id[b][kv_idx]
q_logical = q_idx - offsets[b, document_id[b, q_idx]]
kv_logical = kv_idx - offsets[b, document_id[b, kv_idx]]
inner_mask = mask_mod(b, h, q_logical, kv_logical)
return same_doc & inner_mask
return doc_mask_mod
class FlexAttentionWrapper(torch.nn.Module):
_flex_attn: ClassVar[Callable] = torch.compile(flex_attention)
def __init__(self) -> None:
super().__init__()
def forward(
self, *args: object, **kwargs: object
) -> [
torch.Tensor | tuple[torch.Tensor, torch.Tensor],
tuple[torch.Tensor, AuxOutput],
]:
return FlexAttentionWrapper._flex_attn(*args, **kwargs)
class CPFlexAttentionTest(DTensorTestBase):
@property
def world_size(self) -> int:
return 2
def _test_cp_flex_attention(
self,
*,
qkv_size: int,
B: int = 1,
block_mask,
lb_type: str,
document_lengths: Optional[list[list[int]]] = None,
) -> None:
torch.use_deterministic_algorithms(True)
torch.cuda.manual_seed(1234)
dtype = torch.float32
bs = B if B > 1 else 8
dim = 32
nheads = 8
seq_dim = 2
lb = self._get_load_balancer(
lb_type,
{
"seq_length": qkv_size,
"document_lengths": document_lengths,
"block_mask": block_mask,
},
)
qkv = [
torch.rand(
(bs, nheads, qkv_size, dim),
device=self.device_type,
dtype=dtype,
requires_grad=True,
)
for _ in range(3)
]
expect_out, expect_aux = compiled_flex_attention(
*qkv, block_mask=block_mask, return_aux=AuxRequest(lse=True)
)
expect_out.sum().backward()
# Prepare the required global vars for CP+Flex:
device_mesh = init_device_mesh(
device_type=self.device_type,
mesh_shape=(self.world_size,),
mesh_dim_names=("cp",),
)
flex_attention_wrapper_module = FlexAttentionWrapper()
cp_plan = _ContextParallel(
seq_dim=seq_dim,
attention_type=_ContextParallel.AttentionType.FLEX,
)
parallelize_module(
flex_attention_wrapper_module,
device_mesh,
cp_plan,
)
*cp_qkv, cp_block_mask = _context_parallel_shard(
device_mesh,
[t.detach().clone() for t in qkv] + [block_mask],
[seq_dim] * 4,
load_balancer=lb,
)
for t in cp_qkv:
t.requires_grad = True
cp_out, cp_aux = flex_attention_wrapper_module(
*cp_qkv,
block_mask=cp_block_mask,
return_aux=AuxRequest(lse=True),
)
# backward run
cp_out.sum().backward()
atol = 2e-06
rtol = 1e-05
# unshard the output
cp_out, cp_lse = context_parallel_unshard(
device_mesh,
buffers=[cp_out, cp_aux.lse],
seq_dims=[seq_dim] * 2,
load_balancer=lb,
)
torch.testing.assert_close(cp_out, expect_out, atol=atol, rtol=rtol)
torch.testing.assert_close(cp_lse, expect_aux.lse, atol=atol, rtol=rtol)
# unshard the gradient
cp_qkv_grad = context_parallel_unshard(
device_mesh,
buffers=[t.grad for t in cp_qkv],
seq_dims=[seq_dim] * 3,
load_balancer=lb,
)
qkv_grad = [t.grad for t in qkv]
for grad, cp_grad in zip(qkv_grad, cp_qkv_grad):
torch.testing.assert_close(grad, cp_grad, atol=atol, rtol=rtol)
def _get_load_balancer(
self, lb_type: str, kwargs: dict[str, Any]
) -> Optional[_LoadBalancer]:
seq_length = kwargs["seq_length"]
document_lengths = kwargs["document_lengths"]
block_mask = kwargs["block_mask"]
# generate load balancer
if lb_type == "None":
load_balancer = None # no load-balance
elif lb_type == "_HeadTailLoadBalancer":
assert isinstance(seq_length, int)
load_balancer = _HeadTailLoadBalancer(
seq_length, self.world_size, torch.device(self.device_type)
)
elif lb_type == "_PerDocumentHeadTailLoadBalancer":
assert isinstance(document_lengths, list)
load_balancer = _PerDocumentHeadTailLoadBalancer(
document_lengths, self.world_size, torch.device(self.device_type)
)
elif lb_type == "_PTRRLoadBalancer":
assert isinstance(block_mask, BlockMask)
load_balancer = _PTRRLoadBalancer(
block_mask,
self.world_size,
)
else:
raise ValueError(f"load_balancer type {lb_type} is not supported!")
return load_balancer
@skip_if_lt_x_gpu(2)
@with_comms
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
)
def test_cp_flex_attention_causal_mask(self) -> None:
seq_length_list = [256 * self.world_size, 2048]
load_balance_type_list = [
"None",
"_HeadTailLoadBalancer",
"_PTRRLoadBalancer",
]
# NOTE: Each (seq_len, load_balance_type) tuple introduces 2
# create_block_mask compilations: 1 for single-rank flex_attention and 1 for
# CP flex_attention. In order to avoid the "exceeds_recompile_limit" error,
# we need to increase the cache_size_limit to 2 * num_of_sub_test_runs which
# will be the total number of compilations in our test case.
torch._dynamo.config.cache_size_limit = (len(seq_length_list) + 1) * (
1 + len(load_balance_type_list)
)
for qkv_size, lb_type in itertools.product(
seq_length_list, load_balance_type_list
):
block_mask = compiled_create_block_mask(
causal_mask,
B=1,
H=1,
Q_LEN=qkv_size,
KV_LEN=qkv_size,
device=self.device_type,
)
self._test_cp_flex_attention(
qkv_size=qkv_size, block_mask=block_mask, lb_type=lb_type
)
# NOTE: Context Parallel should not be used for small attentions (block_size < 128)
qkv_size = 64 * self.world_size
block_mask = compiled_create_block_mask(
causal_mask,
B=1,
H=1,
Q_LEN=qkv_size,
KV_LEN=qkv_size,
device=self.device_type,
)
for lb_type in ["None", "_HeadTailLoadBalancer"]:
with self.assertRaisesRegex(
NotImplementedError,
f"Q_LEN {qkv_size} is not divisible",
):
self._test_cp_flex_attention(
qkv_size=qkv_size, block_mask=block_mask, lb_type=lb_type
)
for lb_type in ["_PTRRLoadBalancer"]:
with self.assertRaisesRegex(
NotImplementedError,
"must be divisible by group_size",
):
self._test_cp_flex_attention(
qkv_size=qkv_size, block_mask=block_mask, lb_type=lb_type
)
# TODO: merge with the above test
@skip_if_lt_x_gpu(2)
@with_comms
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
)
def test_cp_flex_attention_document_mask(self) -> None:
random.seed(10)
# parameters for testing
doc_count = 28
batch_size_list = [2, 4, 8]
max_seq_len_list = [
256 * self.world_size,
2048,
# 128 * self.world_size # NOTE: Mismatched elements: 8 / 131072 (0.0%),
]
load_balance_type = [
"None",
"_HeadTailLoadBalancer",
"_PerDocumentHeadTailLoadBalancer",
"_PTRRLoadBalancer",
]
# NOTE: Each (batch_size, seq_len, load_balance_type) tuple introduces 2
# create_block_mask compilations: 1 for single-rank flex_attention and 1 for
# CP flex_attention. In order to avoid the "exceeds_recompile_limit" error,
# we need to increase the cache_size_limit to 2 * num_of_sub_test_runs which
# will be the total number of compilations in our test case.
torch._dynamo.config.cache_size_limit = (
2 * len(batch_size_list) * len(max_seq_len_list) * len(load_balance_type)
)
# TODO: change this for-loop to run_subtests
# Use a for-loop instead of run_subtests because we need to initialize the mask
# for each subtest. This can be baked into self._test_cp_flex_attention as
# a str argument denoting mask type.
for batch_size, max_seq_len, lb_type in itertools.product(
batch_size_list,
max_seq_len_list,
load_balance_type,
):
# initialize document mask
lengths = [
(
generate_random_lengths_in_chunks(
max_seq_len, doc_count, chunk_size=2 * self.world_size
)
if lb_type == "_PerDocumentHeadTailLoadBalancer"
else generate_random_lengths(max_seq_len, doc_count)
)
for _ in range(batch_size)
]
offsets = length_to_offsets(lengths, self.device_type)
document_causal_mask = generate_doc_mask_mod(causal_mask, offsets)
block_mask = compiled_create_block_mask(
document_causal_mask,
B=batch_size,
H=1,
Q_LEN=max_seq_len,
KV_LEN=max_seq_len,
device=self.device_type,
)
self._test_cp_flex_attention(
qkv_size=max_seq_len,
B=batch_size,
lb_type=lb_type,
block_mask=block_mask,
document_lengths=lengths,
)
class TestCPCustomOps(DTensorTestBase):
@property
def world_size(self) -> int:
return 2
@skip_if_lt_x_gpu(2)
@with_comms
def test_flex_cp_custom_op(self) -> None:
mesh = init_device_mesh(
device_type=self.device_type,
mesh_shape=(self.world_size,),
mesh_dim_names=("cp",),
)
examples_k_v = [
(
torch.randn(8, 8, 8, 8, device=self.device_type),
torch.randn(8, 8, 8, 8, device=self.device_type),
2,
c10d._get_process_group_name(mesh.get_group()),
),
(
torch.randn(8, 8, 8, 8, device=self.device_type, requires_grad=True),
torch.randn(8, 8, 8, 8, device=self.device_type, requires_grad=True),
2,
c10d._get_process_group_name(mesh.get_group()),
),
]
for example in examples_k_v:
torch.library.opcheck(flex_cp_allgather, example)
class TestSharding(DTensorTestBase):
@property
def world_size(self) -> int:
return 2
@skip_if_lt_x_gpu(2)
@with_comms
def test_context_parallel_shard(self) -> None:
B = 4
seq_len = 32
device_mesh = init_device_mesh(
mesh_shape=(2,), mesh_dim_names=("cp",), device_type=self.device_type
)
freqs_cis = torch.arange(0, seq_len, device=self.device_type)
q = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len)
k = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len)
v = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len)
load_balancer = _HeadTailLoadBalancer(
seq_len, self.world_size, torch.device(self.device_type)
)
freqs_cis_shard, q_shard, k_shard, v_shard = _context_parallel_shard(
device_mesh, [freqs_cis, q, k, v], [0, 1, 1, 1], load_balancer=load_balancer
)
self.assertEqual(freqs_cis_shard.size(), (seq_len // 2,))
chunks = freqs_cis.chunk(self.world_size * 2)
self.assertEqual(
freqs_cis_shard,
map_local_tensor_for_rank(
chunks,
self.rank,
lambda chunks, rank: torch.cat(
[chunks[rank], chunks[self.world_size * 2 - rank - 1]],
dim=0,
),
),
)
RingAttentionTestWithLocalTensor = create_local_tensor_test_class(
RingAttentionTest,
skipped_tests=[
# Need to make attention implementation local tensor friendly, e.g.
# rewrite "rank local" logic
"test_ring_attention_sdpa",
],
)
CPFlexAttentionTestWithLocalTensor = create_local_tensor_test_class(
CPFlexAttentionTest,
skipped_tests=[
# Missing support for batched tensors
"test_cp_flex_attention_causal_mask",
"test_cp_flex_attention_document_mask",
],
)
TestCPCustomOpsWithLocalTensor = create_local_tensor_test_class(
TestCPCustomOps,
skipped_tests=[
# Missing support for fake tensors
"test_flex_cp_custom_op",
],
)
TestShardingWithLocalTensor = create_local_tensor_test_class(
TestSharding,
)
if __name__ == "__main__":
run_tests()