[CP] Rewrite ring attention backward algorithm and enablement APIs (#131351)

**What does this PR achieve**
1. This PR rewrite ring attention backward algorithm to fuse the alltoall and overlap the gradient communication with computation.

2. Enables memory efficient attention with CP by templating the ring attention backward to verify the accuracy as fp32 gives us higher confident about the implementation correctness.

3. Provides some experimental APIs to enable context parallelism.

4. Ensures CP work with torch.compiler. The combination of causal masking and torch.compiler has not
yet worked.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131351
Approved by: https://github.com/wanchaol
This commit is contained in:
Chien-Chin Huang 2024-08-14 22:42:44 -07:00 committed by PyTorch MergeBot
parent 7470ae85e4
commit 3434a54fba
4 changed files with 700 additions and 547 deletions

View File

@ -3,17 +3,17 @@
import unittest
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Shard
from torch.distributed._tensor import DeviceMesh
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed._tensor.experimental.attention import (
_AttentionContextParallel,
_CausalBehavior,
_context_parallel_buffers,
_is_causal_behavior,
_scaled_dot_product_chunk_flash_attention,
_scaled_dot_product_ring_efficient_attention,
_scaled_dot_product_ring_flash_attention,
attention_context_parallel,
AttentionContextParallel,
context_parallel,
)
from torch.distributed.tensor.parallel import parallelize_module
from torch.nn.attention import sdpa_kernel, SDPBackend
@ -21,7 +21,6 @@ from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_FUSED_ATTENTION,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
TEST_CUDA,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import (
@ -29,7 +28,6 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
skipIfRocm,
TEST_WITH_ROCM,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
@ -40,137 +38,161 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
c10d_functional = torch.ops.c10d_functional
backends = []
if PLATFORM_SUPPORTS_FLASH_ATTENTION:
backends.append(SDPBackend.FLASH_ATTENTION)
if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
backends.append(SDPBackend.EFFICIENT_ATTENTION)
class RingAttentionTest(DTensorTestBase):
@property
def world_size(self) -> int:
return 2
return torch.cuda.device_count()
@skip_if_lt_x_gpu(2)
@skipIfRocm # Missing _c10d_functional_autograd::all_to_all_single
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
not PLATFORM_SUPPORTS_FUSED_ATTENTION,
"Does not support flash nor efficient attention",
)
@with_comms
@parametrize("is_causal", [True, False])
def test_ring_attention_sdpa(self, is_causal: bool) -> None:
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, self.world_size),
)
@parametrize("compiled", [True, False])
@parametrize("backend", backends)
def test_ring_attention_sdpa(
self, is_causal: bool, compiled: bool, backend: SDPBackend
) -> None:
device_mesh = DeviceMesh(self.device_type, torch.arange(0, self.world_size))
dtype = torch.bfloat16
bs = 8
query_tokens = 8
context_tokens = query_tokens if is_causal else 8
query_tokens = 64
context_tokens = 64
dim = 32
nheads = 8
query = torch.rand(
torch.manual_seed(10)
dtype = (
torch.bfloat16 if backend == SDPBackend.FLASH_ATTENTION else torch.float32
)
if is_causal and compiled and self.world_size > 2:
# TODO: Fix this after we move `wait_tensor` to use `with_effect`.
return
q = torch.rand(
(bs, nheads, self.world_size * query_tokens, dim),
device=self.device_type,
dtype=dtype,
requires_grad=True,
)
key = torch.rand(
k = torch.rand(
(bs, nheads, self.world_size * context_tokens, dim),
device=self.device_type,
dtype=dtype,
requires_grad=True,
)
value = torch.rand(
v = torch.rand(
(bs, nheads, self.world_size * context_tokens, dim),
device=self.device_type,
dtype=dtype,
requires_grad=True,
)
query_placement = [Shard(2)]
dquery = distribute_tensor(query, device_mesh, query_placement)
self.assertEqual(query.shape, (bs, nheads, self.world_size * query_tokens, dim))
# Ensure all ranks have the same initialization data.
with torch.no_grad():
dist.broadcast(q, src=0)
dist.broadcast(k, src=0)
dist.broadcast(v, src=0)
context_placement = [Shard(2)]
dkey = distribute_tensor(key, device_mesh, context_placement)
dvalue = distribute_tensor(value, device_mesh, context_placement)
for t in [dkey, dvalue]:
self.assertEqual(
t.shape, (bs, nheads, context_tokens * self.world_size, dim)
with sdpa_kernel(backend):
out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
out.sum().backward()
local_out, local_dq, local_dk, local_dv = _context_parallel_buffers(
device_mesh,
buffers=(out, q.grad, k.grad, v.grad),
buffer_seq_dims=(2, 2, 2, 2),
)
cp_q = q.clone().detach()
cp_k = k.clone().detach()
cp_v = v.clone().detach()
# Theoretically, context_parallel() should not be used to shard
# parameters because when require_grad is True, resize_ is not
# allowed. But requires_grad of cp_q, cp_k, and cp_v are False
# now. So we can just use context_parallel() to shard q, k, v.
# In reality, context_paralle() should be used to shard the input.
with context_parallel(
device_mesh, buffers=(cp_q, cp_k, cp_v), buffer_seq_dims=(2, 2, 2)
):
cp_q.requires_grad = True
cp_k.requires_grad = True
cp_v.requires_grad = True
with CommDebugMode() as comm_mode:
with sdpa_kernel(backend):
if compiled:
fn = torch.compile(
F.scaled_dot_product_attention,
fullgraph=True,
backend="aot_eager",
)
else:
fn = F.scaled_dot_product_attention
cp_out = fn(cp_q, cp_k, cp_v, is_causal=is_causal)
cp_out.sum().backward()
if not compiled:
# Compiler and CommDebugMode do not work well together.
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.all_to_all_single: self.world_size * 3
- 2
},
)
# Due to numerical error, we need to choose different atol for different
# attention kernels
atol = (
1e-08
if backend == SDPBackend.EFFICIENT_ATTENTION
else 1e-3 * self.world_size
)
self.assertEqual(t.to_local().shape, (bs, nheads, context_tokens, dim))
self.assertTrue(torch.allclose(local_out, cp_out, atol=atol))
# local tensors
out, logsumexp, *others = torch.ops.aten._scaled_dot_product_flash_attention(
query, key, value, is_causal=is_causal
)
self.assertEqual(out.shape, (bs, nheads, self.world_size * query_tokens, dim))
out.sum().backward()
out_grad = query.grad
query.grad = None
self.assertIsNotNone(out_grad)
# compute chunked version to compare distributed to chunked implementations
# chunked isn't numerically identical to single operator version
(
out_chunk,
logsumexp_chunk,
*others,
) = _scaled_dot_product_chunk_flash_attention(
query,
key,
value,
size=self.world_size,
is_causal=is_causal,
)
out_chunk.sum().backward()
self.assertEqual(
out_chunk.shape, (bs, nheads, self.world_size * query_tokens, dim)
)
self.assertEqual(logsumexp_chunk, logsumexp)
self.assertEqual(out_chunk, out)
out_chunk_grad = query.grad
query.grad = None
# gradient doesn't match due to numerical issues with chunk size > 1
# self.assertEqual(out_chunk_grad, out_grad)
# parallel behavior
with attention_context_parallel(), CommDebugMode() as comm_mode:
(
out_parallel,
logsumexp_parallel,
*others,
) = torch.ops.aten._scaled_dot_product_flash_attention(
dquery, dkey, dvalue, is_causal=is_causal
atol = (
2e-06
if backend == SDPBackend.EFFICIENT_ATTENTION
else 8e-3 * self.world_size
)
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.all_to_all_single: self.world_size - 1,
},
)
self.assertEqual(out_parallel.placements, (Shard(2),))
self.assertEqual(
out_parallel._local_tensor.shape, (bs, nheads, query_tokens, dim)
)
self.assertEqual(
out_parallel.shape, (bs, nheads, self.world_size * query_tokens, dim)
)
out_parallel_tensor = out_parallel.full_tensor()
self.assertEqual(out_parallel_tensor, out)
logsumexp_parallel_tensor = logsumexp_parallel.full_tensor()
self.assertEqual(logsumexp_parallel_tensor, logsumexp)
self.assertTrue(torch.allclose(local_dq, cp_q.grad, atol=atol))
self.assertTrue(torch.allclose(local_dk, cp_k.grad, atol=atol))
self.assertTrue(torch.allclose(local_dv, cp_v.grad, atol=atol))
self.assertIsNone(dquery.grad)
with attention_context_parallel(), CommDebugMode() as comm_mode:
out_parallel.sum().backward()
cp_q.grad = None
cp_k.grad = None
cp_v.grad = None
cp_q.requires_grad = False
cp_k.requires_grad = False
cp_v.requires_grad = False
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.all_to_all_single: (self.world_size - 1) * 2,
},
def test_is_causal_behavior(self) -> None:
self.assertEqual(
_is_causal_behavior(rank=0, world_size=4, i=0, is_causal=False),
_CausalBehavior.NOT_IS_CAUSAL,
)
out_parallel_grad = dquery.grad.full_tensor()
dquery.grad = None
self.assertEqual(out_parallel_grad, out_chunk_grad)
ranks = [
[_CausalBehavior.IS_CAUSAL, _CausalBehavior.SKIP],
[_CausalBehavior.IS_CAUSAL, _CausalBehavior.NOT_IS_CAUSAL],
]
for rank, iters in enumerate(ranks):
for i, behavior in enumerate(iters):
self.assertEqual(
_is_causal_behavior(rank=rank, world_size=2, i=i, is_causal=True),
behavior,
)
@skip_if_lt_x_gpu(2)
@unittest.skipIf(
@ -201,7 +223,7 @@ class RingAttentionTest(DTensorTestBase):
module=encoder_layer,
device_mesh=device_mesh,
parallelize_plan={
"self_attn": AttentionContextParallel(),
"self_attn": _AttentionContextParallel(),
},
)
model = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
@ -230,30 +252,11 @@ class RingAttentionTest(DTensorTestBase):
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.all_to_all_single: (self.world_size - 1)
* 2
c10d_functional.all_to_all_single: (self.world_size * 2 - 1)
* num_layers,
},
)
def test_is_causal_behavior(self) -> None:
# not causal
self.assertEqual(
_is_causal_behavior(rank=0, world_size=4, i=0, is_causal=False),
_CausalBehavior.NOT_IS_CAUSAL,
)
ranks = [
[_CausalBehavior.IS_CAUSAL, _CausalBehavior.SKIP],
[_CausalBehavior.IS_CAUSAL, _CausalBehavior.NOT_IS_CAUSAL],
]
for rank, iters in enumerate(ranks):
for i, behavior in enumerate(iters):
self.assertEqual(
_is_causal_behavior(rank=rank, world_size=2, i=i, is_causal=True),
behavior,
)
@skip_if_lt_x_gpu(2)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
@ -275,7 +278,7 @@ class RingAttentionTest(DTensorTestBase):
module=model,
device_mesh=device_mesh,
parallelize_plan={
f"layers.{i}.attention": AttentionContextParallel()
f"layers.{i}.attention": _AttentionContextParallel()
for i in range(args.n_layers)
},
)
@ -299,105 +302,14 @@ class RingAttentionTest(DTensorTestBase):
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.all_to_all_single: (self.world_size - 1)
* 2
c10d_functional.all_to_all_single: (self.world_size * 2 - 1)
* args.n_layers,
},
)
@skip_if_lt_x_gpu(2)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FUSED_ATTENTION,
"Does not support flash nor efficient attention",
)
@unittest.skipIf(
TEST_CUDA and not TEST_WITH_ROCM and not PLATFORM_SUPPORTS_FLASH_ATTENTION,
"Does not support flash attention",
) # On CUDA (not ROCM) platform, the UT is skipped if no FA support (even if ME may get supported)
@with_comms
@parametrize(
"attention_fn",
[
_scaled_dot_product_ring_flash_attention
if PLATFORM_SUPPORTS_FLASH_ATTENTION
else None,
_scaled_dot_product_ring_efficient_attention
if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION
else None,
# _scaled_dot_product_ring_cudnn_attention, # TODO: not built by default
],
)
def test_ring_attention_compile(self, attention_fn: object) -> None:
if attention_fn is None:
self.skipTest("Unsupported on current platform")
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, self.world_size),
)
dtype = torch.bfloat16
bs = 8
query_tokens = 8
context_tokens = 24
dim = 32
nheads = 8
query = torch.rand(
(bs, nheads, self.world_size * query_tokens, dim),
device=self.device_type,
dtype=dtype,
requires_grad=True,
)
key = torch.rand(
(bs, nheads, self.world_size * context_tokens, dim),
device=self.device_type,
dtype=dtype,
)
value = torch.rand(
(bs, nheads, self.world_size * context_tokens, dim),
device=self.device_type,
dtype=dtype,
)
query_placement = [Shard(2)]
dquery = distribute_tensor(query, device_mesh, query_placement)
self.assertEqual(query.shape, (bs, nheads, self.world_size * query_tokens, dim))
context_placement = [Shard(2)]
dkey = distribute_tensor(key, device_mesh, context_placement)
dvalue = distribute_tensor(value, device_mesh, context_placement)
# compiled = attention_fn
compiled = torch.compile(attention_fn, fullgraph=True, backend="aot_eager")
out, lse, *args = compiled(
device_mesh.get_group(),
dquery.to_local(),
dkey.to_local(),
dvalue.to_local(),
)
self.assertEqual(out.shape, (bs, nheads, query_tokens, dim))
self.assertIsInstance(lse, torch.Tensor)
(
out_chunk,
*others,
) = _scaled_dot_product_chunk_flash_attention(
query,
key,
value,
size=self.world_size,
is_causal=False,
)
self.assertEqual(
out,
out_chunk[
:, :, self.rank * query_tokens : (self.rank + 1) * query_tokens, :
],
)
out.sum().backward()
instantiate_parametrized_tests(RingAttentionTest)
if backends:
instantiate_parametrized_tests(RingAttentionTest)
if __name__ == "__main__":
run_tests()

File diff suppressed because it is too large Load Diff

View File

@ -2,6 +2,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor
from typing import List
import torch
from torch.distributed._tensor._op_schema import (
OpSchema,
@ -344,7 +346,7 @@ def scaled_dot_product_efficient_attention_strategy(
has_attn_bias = op_schema.args_schema[3] is not None
compute_log_sumexp = op_schema.args_schema[4]
single_mesh_dim_strategies = []
single_mesh_dim_strategies: List[PlacementList] = []
# placement list stores placements of [outputs, inputs]
# in the spda case, we have 2 valid tensor outputs and 3 or 4 tensor inputs
@ -360,6 +362,20 @@ def scaled_dot_product_efficient_attention_strategy(
]
if has_attn_bias:
all_replicate.append(Replicate()) # attn bias
# Context Parallelism: shards on the sequence dim
single_mesh_dim_strategies.append(
[
Shard(2), # output
Shard(2), # logsumexp
None, # philox_seed
None, # philox_offset
Shard(2), # q
Shard(2), # k
Shard(2), # v
]
)
single_mesh_dim_strategies.append(all_replicate)
# second we can accept the sharding pattern of tensor parallelism, which
@ -453,6 +469,27 @@ def scaled_dot_product_efficient_attention_backward_strategy(
num_heads_dim_sharding.extend([Replicate(), Replicate()])
single_mesh_dim_strategies.append(num_heads_dim_sharding)
# Context Parallelism: shards on the sequence dim
seq_dim_sharding: PlacementList = [
Shard(2), # grad_q
Shard(2), # grad_k
Shard(2), # grad_v
Shard(1) if has_attn_bias else None, # grad_bias
Shard(2), # grad_output
Shard(2), # q
Shard(2), # k
Shard(2), # v
Shard(2), # output
Shard(2), # logsumexp
]
# accept replicate on the rest tensor inputs, potentially
# cum_seq_q, cum_seq_k, philox_seed, philox_offset
# at indices 6, 7, 12, 13, respectively
if has_attn_bias:
num_heads_dim_sharding.insert(8, Shard(1))
seq_dim_sharding.extend([Replicate(), Replicate()])
single_mesh_dim_strategies.append(seq_dim_sharding)
return expand_to_full_mesh_op_strategy(
mesh,
op_schema,

View File

@ -4928,7 +4928,7 @@ def _find_or_create_pg_by_ranks_and_tag(
my_ranks = rank_set
assert my_ranks is not None, "rankset doesn't include the current node"
my_ranks.sort()
my_ranks = sorted(my_ranks)
pg = _find_pg_by_ranks_and_tag(tag, my_ranks)
if pg is not None: