mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
distributed: templated ring attention (#124215)
This adds a templated version of the ring attention forwards function as well as tests it with memory efficient attention. This doesn't add support for memory efficient attention in DTensor. That will be added in a follow up PR. This templating is also a POC of how to support other attention ops such as Jagged/nested tensor and as well how to implement striped attention in a scalable way. Misc changes: * Fixes all_to_all_single autograd implementation with CUDA + adds NCCL test * Adds compile support to the ring attention implementations (required some tweaks to process groups) Test plan: ``` pytest test/distributed/_tensor/test_attention.py pytest test/distributed/test_functional_api.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/124215 Approved by: https://github.com/wanchaol
This commit is contained in:
parent
4946638f06
commit
ddd0ed1b43
|
|
@ -10,6 +10,8 @@ from torch.distributed._tensor.experimental.attention import (
|
|||
_CausalBehavior,
|
||||
_is_causal_behavior,
|
||||
_scaled_dot_product_chunk_flash_attention,
|
||||
_scaled_dot_product_ring_efficient_attention,
|
||||
_scaled_dot_product_ring_flash_attention,
|
||||
attention_context_parallel,
|
||||
AttentionContextParallel,
|
||||
)
|
||||
|
|
@ -295,6 +297,86 @@ class RingAttentionTest(DTensorTestBase):
|
|||
},
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
|
||||
)
|
||||
@with_comms
|
||||
@parametrize(
|
||||
"attention_fn",
|
||||
[
|
||||
_scaled_dot_product_ring_flash_attention,
|
||||
_scaled_dot_product_ring_efficient_attention,
|
||||
# _scaled_dot_product_ring_cudnn_attention, # TODO: not built by default
|
||||
],
|
||||
)
|
||||
def test_ring_attention_compile(self, attention_fn: object) -> None:
|
||||
device_mesh = DeviceMesh(
|
||||
self.device_type,
|
||||
torch.arange(0, self.world_size),
|
||||
)
|
||||
dtype = torch.bfloat16
|
||||
bs = 8
|
||||
query_tokens = 8
|
||||
context_tokens = 24
|
||||
dim = 32
|
||||
nheads = 8
|
||||
query = torch.rand(
|
||||
(bs, nheads, self.world_size * query_tokens, dim),
|
||||
device=self.device_type,
|
||||
dtype=dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
key = torch.rand(
|
||||
(bs, nheads, self.world_size * context_tokens, dim),
|
||||
device=self.device_type,
|
||||
dtype=dtype,
|
||||
)
|
||||
value = torch.rand(
|
||||
(bs, nheads, self.world_size * context_tokens, dim),
|
||||
device=self.device_type,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
query_placement = [Shard(2)]
|
||||
dquery = distribute_tensor(query, device_mesh, query_placement)
|
||||
self.assertEqual(query.shape, (bs, nheads, self.world_size * query_tokens, dim))
|
||||
|
||||
context_placement = [Shard(2)]
|
||||
dkey = distribute_tensor(key, device_mesh, context_placement)
|
||||
dvalue = distribute_tensor(value, device_mesh, context_placement)
|
||||
|
||||
# compiled = attention_fn
|
||||
compiled = torch.compile(attention_fn, fullgraph=True, backend="aot_eager")
|
||||
|
||||
out, lse, *args = compiled(
|
||||
device_mesh.get_group(),
|
||||
dquery.to_local(),
|
||||
dkey.to_local(),
|
||||
dvalue.to_local(),
|
||||
)
|
||||
self.assertEqual(out.shape, (bs, nheads, query_tokens, dim))
|
||||
self.assertIsInstance(lse, torch.Tensor)
|
||||
|
||||
(
|
||||
out_chunk,
|
||||
*others,
|
||||
) = _scaled_dot_product_chunk_flash_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
size=self.world_size,
|
||||
is_causal=False,
|
||||
)
|
||||
self.assertEqual(
|
||||
out,
|
||||
out_chunk[
|
||||
:, :, self.rank * query_tokens : (self.rank + 1) * query_tokens, :
|
||||
],
|
||||
)
|
||||
|
||||
out.sum().backward()
|
||||
|
||||
|
||||
instantiate_parametrized_tests(RingAttentionTest)
|
||||
|
||||
|
|
|
|||
|
|
@ -634,14 +634,14 @@ class TestFunctionalAutograd(MultiThreadedTestCase):
|
|||
def test_all_to_all_single(self, compile: bool = True) -> None:
|
||||
group = dist.group.WORLD.group_name
|
||||
|
||||
t = torch.rand((self.world_size, 2), requires_grad=True)
|
||||
t = torch.ones((self.world_size, 2), requires_grad=True)
|
||||
|
||||
def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
|
||||
sizes = [1] * world_size
|
||||
t = t * 10
|
||||
t = t * 2
|
||||
assert t.requires_grad
|
||||
out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group)
|
||||
out = out + 2
|
||||
out = out + 0
|
||||
return out
|
||||
|
||||
if compile:
|
||||
|
|
@ -650,11 +650,13 @@ class TestFunctionalAutograd(MultiThreadedTestCase):
|
|||
compiled = my_func
|
||||
|
||||
out = compiled(t, self.world_size)
|
||||
self.assertEqual(out.shape, t.shape)
|
||||
self.assertEqual(out, torch.full_like(t, 2.0))
|
||||
self.assertIsNotNone(out.grad_fn)
|
||||
self.assertTrue(out.requires_grad)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
self.assertIsNotNone(t.grad)
|
||||
self.assertEqual(t.grad, torch.full_like(t, 2.0))
|
||||
|
||||
def test_all_to_all_single_inductor(self) -> None:
|
||||
group = dist.group.WORLD.group_name
|
||||
|
|
@ -752,5 +754,61 @@ class TestFunctionalAutograd(MultiThreadedTestCase):
|
|||
self.assertEqual(input_tensor.grad, torch.full(output_size, fill_value=1.0))
|
||||
|
||||
|
||||
class TestFunctionalAutogradWithNCCL(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
os.environ["WORLD_SIZE"] = str(self.world_size)
|
||||
os.environ["BACKEND"] = dist.Backend.NCCL
|
||||
self._spawn_processes()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return torch.device(self.rank)
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return 2
|
||||
|
||||
@property
|
||||
def process_group(self):
|
||||
return dist.group.WORLD
|
||||
|
||||
def dist_init(self):
|
||||
dist.init_process_group(
|
||||
backend=BACKEND,
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
init_method=f"file://{self.file_name}",
|
||||
)
|
||||
|
||||
# set device for nccl pg for collectives
|
||||
if BACKEND == "nccl":
|
||||
torch.cuda.set_device(self.rank)
|
||||
|
||||
def destroy_comms(self):
|
||||
# Wait for all ranks to reach here before starting shutdown.
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@with_comms()
|
||||
def test_all_to_all_single(self) -> None:
|
||||
group = self.process_group.group_name
|
||||
|
||||
t = torch.ones((self.world_size, 2), requires_grad=True, device=self.device)
|
||||
|
||||
sizes = [1] * self.world_size
|
||||
assert t.requires_grad
|
||||
out = ft_c.all_to_all_single_autograd(t * 2, sizes, sizes, group) + 0
|
||||
|
||||
self.assertEqual(out.shape, t.shape)
|
||||
self.assertEqual(out, torch.full_like(t, 2.0))
|
||||
self.assertIsNotNone(out.grad_fn)
|
||||
self.assertTrue(out.requires_grad)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
self.assertEqual(t.grad, torch.full_like(t, 2.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -409,7 +409,7 @@ class AllToAllSingle : public torch::autograd::Function<AllToAllSingle> {
|
|||
const std::string& group_name = ctx->saved_data["group_name"].toStringRef();
|
||||
|
||||
DCHECK(grad_out_list.size() == 1);
|
||||
auto grad_out = grad_out_list[0];
|
||||
auto grad_out = grad_out_list[0].contiguous();
|
||||
|
||||
auto out =
|
||||
c10::Dispatcher::singleton()
|
||||
|
|
@ -434,7 +434,7 @@ at::Tensor all_to_all_single_autograd(
|
|||
const std::vector<int64_t>& input_split_sizes,
|
||||
const std::string& group_name) {
|
||||
return AllToAllSingle::apply(
|
||||
input, output_split_sizes, input_split_sizes, group_name)[0];
|
||||
input, output_split_sizes, input_split_sizes, group_name);
|
||||
}
|
||||
|
||||
class ReduceScatterTensor
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from torch.utils._python_dispatch import TorchDispatchMode
|
|||
|
||||
funcol_native = torch.ops._c10d_functional
|
||||
funcol_py = torch.ops.c10d_functional
|
||||
funcol_autograd = torch.ops._c10d_functional_autograd
|
||||
|
||||
NATIVE_TO_PY_MAPPING = {
|
||||
funcol_native.all_gather_into_tensor: funcol_py.all_gather_into_tensor,
|
||||
|
|
@ -17,6 +18,8 @@ NATIVE_TO_PY_MAPPING = {
|
|||
funcol_native.broadcast: funcol_py.broadcast,
|
||||
funcol_native.reduce_scatter_tensor: funcol_py.reduce_scatter_tensor,
|
||||
funcol_native.reduce_scatter_tensor_coalesced: funcol_py.reduce_scatter_tensor_coalesced,
|
||||
# functional ops
|
||||
funcol_autograd.all_to_all_single: funcol_py.all_to_all_single,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import contextlib
|
||||
import weakref
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Generator, List, Optional, Protocol, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -54,6 +54,10 @@ def _merge_sdpa(
|
|||
"""
|
||||
assert len(chunks) == len(logsumexps)
|
||||
|
||||
# LSE may be padded in the sequence dimension such as with memory efficient attention.
|
||||
seq_len = chunks[0].size(2)
|
||||
logsumexps = [lse[:, :, :seq_len] for lse in logsumexps]
|
||||
|
||||
softmax_lse = torch.stack([lse.exp() for lse in logsumexps]).sum(dim=0).log_()
|
||||
|
||||
out = []
|
||||
|
|
@ -80,19 +84,148 @@ def _scaled_dot_product_ring_flash_attention(
|
|||
if return_debug_mask:
|
||||
raise NotImplementedError("return_debug_mask is not supported yet")
|
||||
|
||||
return _templated_ring_attention(
|
||||
mesh,
|
||||
torch.ops.aten._scaled_dot_product_flash_attention,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
|
||||
def _scaled_dot_product_ring_efficient_attention(
|
||||
mesh: DeviceMesh,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
compute_log_sumexp: bool = True,
|
||||
*,
|
||||
scale: Optional[float] = None,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
if attn_bias is not None:
|
||||
raise NotImplementedError("attn_bias is not supported yet")
|
||||
if not compute_log_sumexp:
|
||||
raise NotImplementedError("compute_log_sumexp must be set")
|
||||
|
||||
return _templated_ring_attention(
|
||||
mesh,
|
||||
torch.ops.aten._scaled_dot_product_efficient_attention,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_bias=attn_bias,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
compute_log_sumexp=compute_log_sumexp,
|
||||
)
|
||||
|
||||
|
||||
def _scaled_dot_product_ring_cudnn_attention(
|
||||
mesh: DeviceMesh,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
return_debug_mask: bool = True,
|
||||
*,
|
||||
scale: Optional[float] = None,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
if not return_debug_mask:
|
||||
raise NotImplementedError("return_debug_mask must be set")
|
||||
|
||||
return _templated_ring_attention(
|
||||
mesh,
|
||||
torch.ops.aten._scaled_dot_product_cudnn_attention,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
return_debug_mask=return_debug_mask,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
|
||||
def _ring_rotate(block: torch.Tensor, pg: dist.ProcessGroup) -> torch.Tensor:
|
||||
rank = dist.get_rank(pg)
|
||||
size = dist.get_world_size(pg)
|
||||
|
||||
# rank 0 sends to rank 1, rank 1 sends to rank 2, ..., rank n-1 sends to rank 0
|
||||
input_split_sizes = [0] * size
|
||||
input_split_sizes[(rank + 1) % size] = len(block)
|
||||
output_split_sizes = [0] * size
|
||||
output_split_sizes[(rank - 1) % size] = len(block)
|
||||
|
||||
out = ft_c.all_to_all_single_autograd(
|
||||
block, input_split_sizes, output_split_sizes, pg
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class AttentionOp(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
*args: object,
|
||||
is_causal: bool = False,
|
||||
**kwargs: object,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
...
|
||||
|
||||
|
||||
def _templated_ring_attention(
|
||||
mesh: DeviceMesh,
|
||||
op: AttentionOp,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
*args: object,
|
||||
is_causal: bool = False,
|
||||
**kwargs: object,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
This is a generalized ring attention implementation that can support multiple attention ops.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
op:
|
||||
The attention op to use
|
||||
*args:
|
||||
additional args are passed to the op
|
||||
**kwargs:
|
||||
additional kwargs are passed to the op
|
||||
|
||||
Returns
|
||||
-------
|
||||
out:
|
||||
The merged attention output
|
||||
softmax_lse:
|
||||
The logsumexp of the merged attention output
|
||||
"""
|
||||
if is_causal and (query.size(2) != key.size(2)):
|
||||
raise NotImplementedError(
|
||||
"is_causal requires the same query and context sequence lengths"
|
||||
)
|
||||
|
||||
pg = mesh.get_group()
|
||||
assert isinstance(pg, dist.ProcessGroup), "must be single dimension"
|
||||
if isinstance(mesh, dist.ProcessGroup):
|
||||
pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = mesh
|
||||
else:
|
||||
pg = mesh.get_group()
|
||||
assert isinstance(pg, dist.ProcessGroup), "process group must be single dimension"
|
||||
rank = dist.get_rank(pg)
|
||||
size = dist.get_world_size(pg)
|
||||
|
||||
# rank 0 sends to rank 1, rank 1 sends to rank 2, ..., rank n-1 sends to rank 0
|
||||
right_dsts = list(range(1, size)) + [0]
|
||||
|
||||
next_kv = None
|
||||
|
||||
chunks = []
|
||||
|
|
@ -106,20 +239,20 @@ def _scaled_dot_product_ring_flash_attention(
|
|||
|
||||
if i < (size - 1):
|
||||
next_kv = torch.cat([key.flatten(), value.flatten()])
|
||||
next_kv = ft_c.permute_tensor(next_kv, right_dsts, pg)
|
||||
next_kv = _ring_rotate(next_kv, pg)
|
||||
|
||||
is_causal_behavior = _is_causal_behavior(
|
||||
rank=rank, world_size=size, i=i, is_causal=is_causal
|
||||
)
|
||||
|
||||
if is_causal_behavior != _CausalBehavior.SKIP:
|
||||
local_results = torch.ops.aten._scaled_dot_product_flash_attention(
|
||||
local_results = op(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p=dropout_p,
|
||||
*args,
|
||||
is_causal=is_causal_behavior.value,
|
||||
scale=scale,
|
||||
**kwargs,
|
||||
)
|
||||
chunks.append(local_results[0])
|
||||
logsumexps.append(local_results[1])
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user