diff --git a/test/distributed/_tensor/test_attention.py b/test/distributed/_tensor/test_attention.py index 3a34af11d9f..db5a26d4385 100644 --- a/test/distributed/_tensor/test_attention.py +++ b/test/distributed/_tensor/test_attention.py @@ -10,6 +10,8 @@ from torch.distributed._tensor.experimental.attention import ( _CausalBehavior, _is_causal_behavior, _scaled_dot_product_chunk_flash_attention, + _scaled_dot_product_ring_efficient_attention, + _scaled_dot_product_ring_flash_attention, attention_context_parallel, AttentionContextParallel, ) @@ -295,6 +297,86 @@ class RingAttentionTest(DTensorTestBase): }, ) + @skip_if_lt_x_gpu(2) + @unittest.skipIf( + not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention" + ) + @with_comms + @parametrize( + "attention_fn", + [ + _scaled_dot_product_ring_flash_attention, + _scaled_dot_product_ring_efficient_attention, + # _scaled_dot_product_ring_cudnn_attention, # TODO: not built by default + ], + ) + def test_ring_attention_compile(self, attention_fn: object) -> None: + device_mesh = DeviceMesh( + self.device_type, + torch.arange(0, self.world_size), + ) + dtype = torch.bfloat16 + bs = 8 + query_tokens = 8 + context_tokens = 24 + dim = 32 + nheads = 8 + query = torch.rand( + (bs, nheads, self.world_size * query_tokens, dim), + device=self.device_type, + dtype=dtype, + requires_grad=True, + ) + key = torch.rand( + (bs, nheads, self.world_size * context_tokens, dim), + device=self.device_type, + dtype=dtype, + ) + value = torch.rand( + (bs, nheads, self.world_size * context_tokens, dim), + device=self.device_type, + dtype=dtype, + ) + + query_placement = [Shard(2)] + dquery = distribute_tensor(query, device_mesh, query_placement) + self.assertEqual(query.shape, (bs, nheads, self.world_size * query_tokens, dim)) + + context_placement = [Shard(2)] + dkey = distribute_tensor(key, device_mesh, context_placement) + dvalue = distribute_tensor(value, device_mesh, context_placement) + + # compiled = attention_fn + compiled = torch.compile(attention_fn, fullgraph=True, backend="aot_eager") + + out, lse, *args = compiled( + device_mesh.get_group(), + dquery.to_local(), + dkey.to_local(), + dvalue.to_local(), + ) + self.assertEqual(out.shape, (bs, nheads, query_tokens, dim)) + self.assertIsInstance(lse, torch.Tensor) + + ( + out_chunk, + *others, + ) = _scaled_dot_product_chunk_flash_attention( + query, + key, + value, + size=self.world_size, + is_causal=False, + ) + self.assertEqual( + out, + out_chunk[ + :, :, self.rank * query_tokens : (self.rank + 1) * query_tokens, : + ], + ) + + out.sum().backward() + instantiate_parametrized_tests(RingAttentionTest) diff --git a/test/distributed/test_functional_api.py b/test/distributed/test_functional_api.py index d26dcf970a9..f2255637a69 100644 --- a/test/distributed/test_functional_api.py +++ b/test/distributed/test_functional_api.py @@ -634,14 +634,14 @@ class TestFunctionalAutograd(MultiThreadedTestCase): def test_all_to_all_single(self, compile: bool = True) -> None: group = dist.group.WORLD.group_name - t = torch.rand((self.world_size, 2), requires_grad=True) + t = torch.ones((self.world_size, 2), requires_grad=True) def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor: sizes = [1] * world_size - t = t * 10 + t = t * 2 assert t.requires_grad out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group) - out = out + 2 + out = out + 0 return out if compile: @@ -650,11 +650,13 @@ class TestFunctionalAutograd(MultiThreadedTestCase): compiled = my_func out = compiled(t, self.world_size) + self.assertEqual(out.shape, t.shape) + self.assertEqual(out, torch.full_like(t, 2.0)) self.assertIsNotNone(out.grad_fn) self.assertTrue(out.requires_grad) loss = out.sum() loss.backward() - self.assertIsNotNone(t.grad) + self.assertEqual(t.grad, torch.full_like(t, 2.0)) def test_all_to_all_single_inductor(self) -> None: group = dist.group.WORLD.group_name @@ -752,5 +754,61 @@ class TestFunctionalAutograd(MultiThreadedTestCase): self.assertEqual(input_tensor.grad, torch.full(output_size, fill_value=1.0)) +class TestFunctionalAutogradWithNCCL(MultiProcessTestCase): + def setUp(self): + super().setUp() + os.environ["WORLD_SIZE"] = str(self.world_size) + os.environ["BACKEND"] = dist.Backend.NCCL + self._spawn_processes() + + @property + def device(self): + return torch.device(self.rank) + + @property + def world_size(self): + return 2 + + @property + def process_group(self): + return dist.group.WORLD + + def dist_init(self): + dist.init_process_group( + backend=BACKEND, + world_size=self.world_size, + rank=self.rank, + init_method=f"file://{self.file_name}", + ) + + # set device for nccl pg for collectives + if BACKEND == "nccl": + torch.cuda.set_device(self.rank) + + def destroy_comms(self): + # Wait for all ranks to reach here before starting shutdown. + dist.barrier() + dist.destroy_process_group() + + @requires_nccl() + @with_comms() + def test_all_to_all_single(self) -> None: + group = self.process_group.group_name + + t = torch.ones((self.world_size, 2), requires_grad=True, device=self.device) + + sizes = [1] * self.world_size + assert t.requires_grad + out = ft_c.all_to_all_single_autograd(t * 2, sizes, sizes, group) + 0 + + self.assertEqual(out.shape, t.shape) + self.assertEqual(out, torch.full_like(t, 2.0)) + self.assertIsNotNone(out.grad_fn) + self.assertTrue(out.requires_grad) + loss = out.sum() + loss.backward() + self.assertEqual(t.grad, torch.full_like(t, 2.0)) + + if __name__ == "__main__": run_tests() diff --git a/torch/csrc/distributed/c10d/Functional.cpp b/torch/csrc/distributed/c10d/Functional.cpp index 942ae7358d3..5728774f748 100644 --- a/torch/csrc/distributed/c10d/Functional.cpp +++ b/torch/csrc/distributed/c10d/Functional.cpp @@ -409,7 +409,7 @@ class AllToAllSingle : public torch::autograd::Function { const std::string& group_name = ctx->saved_data["group_name"].toStringRef(); DCHECK(grad_out_list.size() == 1); - auto grad_out = grad_out_list[0]; + auto grad_out = grad_out_list[0].contiguous(); auto out = c10::Dispatcher::singleton() @@ -434,7 +434,7 @@ at::Tensor all_to_all_single_autograd( const std::vector& input_split_sizes, const std::string& group_name) { return AllToAllSingle::apply( - input, output_split_sizes, input_split_sizes, group_name)[0]; + input, output_split_sizes, input_split_sizes, group_name); } class ReduceScatterTensor diff --git a/torch/distributed/_tensor/debug/comm_mode.py b/torch/distributed/_tensor/debug/comm_mode.py index 43def0b9d64..b195b30154f 100644 --- a/torch/distributed/_tensor/debug/comm_mode.py +++ b/torch/distributed/_tensor/debug/comm_mode.py @@ -8,6 +8,7 @@ from torch.utils._python_dispatch import TorchDispatchMode funcol_native = torch.ops._c10d_functional funcol_py = torch.ops.c10d_functional +funcol_autograd = torch.ops._c10d_functional_autograd NATIVE_TO_PY_MAPPING = { funcol_native.all_gather_into_tensor: funcol_py.all_gather_into_tensor, @@ -17,6 +18,8 @@ NATIVE_TO_PY_MAPPING = { funcol_native.broadcast: funcol_py.broadcast, funcol_native.reduce_scatter_tensor: funcol_py.reduce_scatter_tensor, funcol_native.reduce_scatter_tensor_coalesced: funcol_py.reduce_scatter_tensor_coalesced, + # functional ops + funcol_autograd.all_to_all_single: funcol_py.all_to_all_single, } diff --git a/torch/distributed/_tensor/experimental/attention.py b/torch/distributed/_tensor/experimental/attention.py index 195a94fed8a..eb7703a96ba 100644 --- a/torch/distributed/_tensor/experimental/attention.py +++ b/torch/distributed/_tensor/experimental/attention.py @@ -1,7 +1,7 @@ import contextlib import weakref from enum import Enum -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, List, Optional, Protocol, Tuple, Union import torch import torch.distributed as dist @@ -54,6 +54,10 @@ def _merge_sdpa( """ assert len(chunks) == len(logsumexps) + # LSE may be padded in the sequence dimension such as with memory efficient attention. + seq_len = chunks[0].size(2) + logsumexps = [lse[:, :, :seq_len] for lse in logsumexps] + softmax_lse = torch.stack([lse.exp() for lse in logsumexps]).sum(dim=0).log_() out = [] @@ -80,19 +84,148 @@ def _scaled_dot_product_ring_flash_attention( if return_debug_mask: raise NotImplementedError("return_debug_mask is not supported yet") + return _templated_ring_attention( + mesh, + torch.ops.aten._scaled_dot_product_flash_attention, + query=query, + key=key, + value=value, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + ) + + +def _scaled_dot_product_ring_efficient_attention( + mesh: DeviceMesh, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + compute_log_sumexp: bool = True, + *, + scale: Optional[float] = None, +) -> Tuple[torch.Tensor, ...]: + if attn_bias is not None: + raise NotImplementedError("attn_bias is not supported yet") + if not compute_log_sumexp: + raise NotImplementedError("compute_log_sumexp must be set") + + return _templated_ring_attention( + mesh, + torch.ops.aten._scaled_dot_product_efficient_attention, + query=query, + key=key, + value=value, + attn_bias=attn_bias, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + compute_log_sumexp=compute_log_sumexp, + ) + + +def _scaled_dot_product_ring_cudnn_attention( + mesh: DeviceMesh, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = True, + *, + scale: Optional[float] = None, +) -> Tuple[torch.Tensor, ...]: + if not return_debug_mask: + raise NotImplementedError("return_debug_mask must be set") + + return _templated_ring_attention( + mesh, + torch.ops.aten._scaled_dot_product_cudnn_attention, + query=query, + key=key, + value=value, + dropout_p=dropout_p, + is_causal=is_causal, + return_debug_mask=return_debug_mask, + scale=scale, + ) + + +def _ring_rotate(block: torch.Tensor, pg: dist.ProcessGroup) -> torch.Tensor: + rank = dist.get_rank(pg) + size = dist.get_world_size(pg) + + # rank 0 sends to rank 1, rank 1 sends to rank 2, ..., rank n-1 sends to rank 0 + input_split_sizes = [0] * size + input_split_sizes[(rank + 1) % size] = len(block) + output_split_sizes = [0] * size + output_split_sizes[(rank - 1) % size] = len(block) + + out = ft_c.all_to_all_single_autograd( + block, input_split_sizes, output_split_sizes, pg + ) + return out + + +class AttentionOp(Protocol): + def __call__( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + *args: object, + is_causal: bool = False, + **kwargs: object, + ) -> Tuple[torch.Tensor, ...]: + ... + + +def _templated_ring_attention( + mesh: DeviceMesh, + op: AttentionOp, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + *args: object, + is_causal: bool = False, + **kwargs: object, +) -> Tuple[torch.Tensor, ...]: + """ + This is a generalized ring attention implementation that can support multiple attention ops. + + Parameters + ---------- + op: + The attention op to use + *args: + additional args are passed to the op + **kwargs: + additional kwargs are passed to the op + + Returns + ------- + out: + The merged attention output + softmax_lse: + The logsumexp of the merged attention output + """ if is_causal and (query.size(2) != key.size(2)): raise NotImplementedError( "is_causal requires the same query and context sequence lengths" ) - pg = mesh.get_group() - assert isinstance(pg, dist.ProcessGroup), "must be single dimension" + if isinstance(mesh, dist.ProcessGroup): + pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = mesh + else: + pg = mesh.get_group() + assert isinstance(pg, dist.ProcessGroup), "process group must be single dimension" rank = dist.get_rank(pg) size = dist.get_world_size(pg) - # rank 0 sends to rank 1, rank 1 sends to rank 2, ..., rank n-1 sends to rank 0 - right_dsts = list(range(1, size)) + [0] - next_kv = None chunks = [] @@ -106,20 +239,20 @@ def _scaled_dot_product_ring_flash_attention( if i < (size - 1): next_kv = torch.cat([key.flatten(), value.flatten()]) - next_kv = ft_c.permute_tensor(next_kv, right_dsts, pg) + next_kv = _ring_rotate(next_kv, pg) is_causal_behavior = _is_causal_behavior( rank=rank, world_size=size, i=i, is_causal=is_causal ) if is_causal_behavior != _CausalBehavior.SKIP: - local_results = torch.ops.aten._scaled_dot_product_flash_attention( + local_results = op( query, key, value, - dropout_p=dropout_p, + *args, is_causal=is_causal_behavior.value, - scale=scale, + **kwargs, ) chunks.append(local_results[0]) logsumexps.append(local_results[1])