From d6d6fa26f540c10c57ac80547a9475e9f4c201f2 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 29 Oct 2025 15:10:40 +0000 Subject: [PATCH] Revert "bwd pass (#164504)" This reverts commit f36f372acc28062e0988d84699c62689b0d89a6e. Reverted https://github.com/pytorch/pytorch/pull/164504 on behalf of https://github.com/jeffdaily due to CI had been clean for both cuda and rocm before merge, broke post merge? ([comment](https://github.com/pytorch/pytorch/pull/164504#issuecomment-3462116676)) --- test/test_varlen_attention.py | 233 ++------------------------------- torch/nn/attention/__init__.py | 3 + torch/nn/attention/varlen.py | 160 ++-------------------- 3 files changed, 26 insertions(+), 370 deletions(-) diff --git a/test/test_varlen_attention.py b/test/test_varlen_attention.py index b8399dd8c04..f249adf21a5 100644 --- a/test/test_varlen_attention.py +++ b/test/test_varlen_attention.py @@ -5,12 +5,11 @@ from collections import namedtuple import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn.attention.varlen import varlen_attn +from torch.nn.attention import varlen_attn from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_nn import NNTestCase from torch.testing._internal.common_utils import parametrize, run_tests -from torch.utils._python_dispatch import TorchDispatchMode VarlenShape = namedtuple( @@ -24,18 +23,6 @@ default_tolerances = { } -class OpLoggingMode(TorchDispatchMode): - """Logging mode that captures all dispatched operations""" - - def __init__(self): - self.called_ops = [] - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - op_name = str(func) - self.called_ops.append(op_name) - return func(*args, **(kwargs or {})) - - class AttentionBlock(nn.Module): def __init__( self, embed_dim: int, num_heads: int, device: torch.device, dtype: torch.dtype @@ -52,9 +39,12 @@ class AttentionBlock(nn.Module): embed_dim, embed_dim, bias=False, device=device, dtype=dtype ) - def get_varlen_qkv( + def forward_varlen( self, x_packed: torch.Tensor, + cu_seq: torch.Tensor, + max_len: int, + is_causal: bool = False, ): qkv = self.qkv_proj(x_packed) q, k, v = qkv.chunk(3, dim=-1) @@ -63,51 +53,24 @@ class AttentionBlock(nn.Module): k = k.view(-1, self.num_heads, self.head_dim) v = v.view(-1, self.num_heads, self.head_dim) - return q, k, v - - def forward_varlen( - self, - x_packed: torch.Tensor, - cu_seq: torch.Tensor, - max_len: int, - is_causal: bool = False, - ): - q, k, v = self.get_varlen_qkv(x_packed) - - attn_out = varlen_attn(q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal) + attn_out = varlen_attn( + q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal=is_causal + ) attn_out = attn_out.view(-1, self.embed_dim) return self.out_proj(attn_out) - def forward_sdpa( - self, - x_padded: torch.Tensor, - seq_lengths: torch.Tensor, - dtype: torch.dtype, - is_causal: bool = False, - ): + def forward_sdpa(self, x_padded: torch.Tensor, is_causal: bool = False): batch_size, seq_len, _ = x_padded.shape qkv = self.qkv_proj(x_padded) q, k, v = qkv.chunk(3, dim=-1) - mask = ( - torch.arange(seq_len, device=x_padded.device)[None, :] - < seq_lengths[:, None] - ) - - attn_mask = mask[:, None, None, :].expand( - batch_size, self.num_heads, seq_len, seq_len - ) - q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - attn_out = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, is_causal=is_causal - ) - + attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal) attn_out = ( attn_out.transpose(1, 2) .contiguous() @@ -128,9 +91,7 @@ def create_variable_length_batch( seq_lengths = torch.tensor(seq_lengths, device=device) total_tokens = seq_lengths.sum().item() - x_packed = torch.randn( - total_tokens, shape.embed_dim, device=device, dtype=dtype, requires_grad=True - ) + x_packed = torch.randn(total_tokens, shape.embed_dim, device=device, dtype=dtype) cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32) cu_seq[1:] = seq_lengths.cumsum(0) @@ -145,7 +106,6 @@ def create_variable_length_batch( end_idx = start_idx + seq_len x_padded[i, :seq_len] = x_packed[start_idx:end_idx] start_idx = end_idx - x_padded = x_padded.clone().detach().requires_grad_() return { "seq_lengths": seq_lengths, @@ -173,11 +133,7 @@ class TestVarlenAttention(NNTestCase): total_tokens = shape.batch_size * shape.max_seq_len x_packed = torch.randn( - total_tokens, - shape.embed_dim, - device=device, - dtype=dtype, - requires_grad=True, + total_tokens, shape.embed_dim, device=device, dtype=dtype ) cu_seq = torch.tensor( [0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32 @@ -191,128 +147,6 @@ class TestVarlenAttention(NNTestCase): self.assertEqual(output.device, torch.device(device)) self.assertEqual(output.dtype, dtype) - varlen_grad_out = torch.ones_like(output) - - varlen_grad = torch.autograd.grad( - outputs=output, - inputs=x_packed, - grad_outputs=varlen_grad_out, - retain_graph=True, - create_graph=False, - allow_unused=False, - )[0] - - self.assertIsNotNone(varlen_grad) - self.assertEqual(varlen_grad.shape, x_packed.shape) - self.assertEqual(varlen_grad.dtype, x_packed.dtype) - - @unittest.skipIf( - not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" - ) - @parametrize("dtype", [torch.bfloat16, torch.float16]) - def test_custom_op_compliance(self, device, dtype): - torch.manual_seed(42) - - shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16) - - attention_block = AttentionBlock( - shape.embed_dim, shape.num_heads, device, dtype - ) - - total_tokens = shape.batch_size * shape.max_seq_len - x_packed = torch.randn( - total_tokens, - shape.embed_dim, - device=device, - dtype=dtype, - ) - cu_seq = torch.tensor( - [0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32 - ) - - q, k, v = attention_block.get_varlen_qkv(x_packed) - - torch.library.opcheck( - torch.ops.torch_attn._varlen_attn, - (q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False), - ) - - out, lse, rng_state = torch.ops.torch_attn._varlen_attn( - q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False - ) - grad_out = torch.randn_like(out) - - # we don't support double backward - # skipping test_autograd_registration, test_aot_dispatch_dynamic, test_aot_dispatch_static - torch.library.opcheck( - torch.ops.torch_attn._varlen_attn_backward, - ( - grad_out, - q, - k, - v, - out, - lse, - cu_seq, - cu_seq, - shape.max_seq_len, - shape.max_seq_len, - False, - rng_state, - ), - test_utils=["test_schema", "test_faketensor"], - ) - - @unittest.skipIf( - not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" - ) - @parametrize("dtype", [torch.bfloat16, torch.float16]) - def test_custom_op_registration(self, device, dtype): - torch.manual_seed(42) - - shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16) - - attention_block = AttentionBlock( - shape.embed_dim, shape.num_heads, device, dtype - ) - - total_tokens = shape.batch_size * shape.max_seq_len - x_packed = torch.randn( - total_tokens, - shape.embed_dim, - device=device, - dtype=dtype, - requires_grad=True, - ) - cu_seq = torch.tensor( - [0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32 - ) - - compiled_forward = torch.compile( - attention_block.forward_varlen, backend="eager", fullgraph=True - ) - with OpLoggingMode() as mode: - output = compiled_forward( - x_packed, cu_seq, shape.max_seq_len, is_causal=False - ) - - varlen_grad_out = torch.ones_like(output) - _ = torch.autograd.grad( - outputs=output, - inputs=x_packed, - grad_outputs=varlen_grad_out, - retain_graph=True, - create_graph=False, - allow_unused=False, - )[0] - - called_ops = mode.called_ops - - custom_ops_called = any( - "torch_attn._varlen_attn" in op for op in called_ops - ) and any("torch_attn._varlen_attn_backward" in op for op in called_ops) - assert custom_ops_called - @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" ) @@ -338,10 +172,7 @@ class TestVarlenAttention(NNTestCase): is_causal=is_causal, ) sdpa_output = attention_block.forward_sdpa( - variable_length_batch_data["x_padded"], - variable_length_batch_data["seq_lengths"], - dtype=dtype, - is_causal=is_causal, + variable_length_batch_data["x_padded"], is_causal=is_causal ) tolerances = default_tolerances[dtype] @@ -355,44 +186,6 @@ class TestVarlenAttention(NNTestCase): torch.testing.assert_close(varlen_seq, sdpa_seq, **tolerances) start_idx = end_idx - varlen_grad_out = torch.ones_like(varlen_output) - - sdpa_grad_out = torch.zeros_like(sdpa_output) - - start_idx = 0 - for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]): - end_idx = start_idx + seq_len - sdpa_grad_out[i, :seq_len] = varlen_grad_out[start_idx:end_idx] - start_idx = end_idx - - varlen_grad = torch.autograd.grad( - outputs=varlen_output, - inputs=variable_length_batch_data["x_packed"], - grad_outputs=varlen_grad_out, - retain_graph=True, - create_graph=False, - allow_unused=False, - )[0] - - sdpa_grad = torch.autograd.grad( - outputs=sdpa_output, - inputs=variable_length_batch_data["x_padded"], - grad_outputs=sdpa_grad_out, - retain_graph=True, - create_graph=False, - allow_unused=False, - )[0] - - start_idx = 0 - for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]): - end_idx = start_idx + seq_len - - varlen_grad_seq = varlen_grad[start_idx:end_idx] - sdpa_grad_seq = sdpa_grad[i, :seq_len] - - torch.testing.assert_close(varlen_grad_seq, sdpa_grad_seq, **tolerances) - start_idx = end_idx - device_types = ("cuda",) diff --git a/torch/nn/attention/__init__.py b/torch/nn/attention/__init__.py index 5e6e0fa5fae..9113fd7e379 100644 --- a/torch/nn/attention/__init__.py +++ b/torch/nn/attention/__init__.py @@ -14,11 +14,14 @@ from torch.backends.cuda import ( SDPAParams, ) +from .varlen import varlen_attn + __all__: list[str] = [ "SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS", + "varlen_attn", ] # Note: [SDPA warnings] diff --git a/torch/nn/attention/varlen.py b/torch/nn/attention/varlen.py index 78613364963..7234dd5e791 100644 --- a/torch/nn/attention/varlen.py +++ b/torch/nn/attention/varlen.py @@ -7,7 +7,7 @@ that calls into the optimized Flash Attention kernels. import logging from functools import lru_cache -from typing import Any, NamedTuple, Optional, Union +from typing import NamedTuple, Optional, Union import torch @@ -33,7 +33,8 @@ class AuxRequest(NamedTuple): lse: bool = False -@torch.library.custom_op("torch_attn::_varlen_attn", mutates_args={}) +# import failures when I try to register as custom op +# @torch.library.custom_op("torch_nn_attention::_varlen_attn", mutates_args={}) def _varlen_attn( query: torch.Tensor, key: torch.Tensor, @@ -43,7 +44,7 @@ def _varlen_attn( max_q: int, max_k: int, is_causal: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """ Private custom op for variable-length attention. @@ -69,7 +70,7 @@ def _varlen_attn( False, # return_debug_mask ) # cuDNN returns: (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask) - output, softmax_lse, rng_state = result[0], result[1], result[6] + output, softmax_lse = result[0], result[1] else: log.info("Using Flash Attention backend for varlen_attn") output, softmax_lse, rng_state, _, _ = torch.ops.aten._flash_attention_forward( @@ -85,13 +86,10 @@ def _varlen_attn( return_debug_mask=False, ) - rng_state_ = torch.zeros( - (2,), dtype=torch.uint64, device=query.device - ) # hardcoded since dropout is hardcoded to 0 - return output, softmax_lse, rng_state_ + return output, softmax_lse -@_varlen_attn.register_fake +# @_varlen_attn.register_fake def _varlen_attn_fake( query: torch.Tensor, key: torch.Tensor, @@ -101,7 +99,7 @@ def _varlen_attn_fake( max_q: int, max_k: int, is_causal: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """ Fake implementation for meta tensor computation and tracing. @@ -119,9 +117,7 @@ def _varlen_attn_fake( (num_heads, total_q), dtype=torch.float, device=query.device ) - rng_state = torch.empty((2,), dtype=torch.uint64, device=query.device) - - return output, logsumexp, rng_state + return output, logsumexp def varlen_attn( @@ -195,145 +191,9 @@ def varlen_attn( ... query, key, value, cu_seq, cu_seq, max_len, max_len, is_causal=False ... ) """ - out, lse, _ = torch.ops.torch_attn._varlen_attn( + out, lse = _varlen_attn( query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal ) if return_aux is not None and return_aux.lse: return out, lse return out - - -def _setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None: - query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal = inputs - out, lse, rng_state = output - ctx.query = query - ctx.key = key - ctx.value = value - ctx.cu_seq_q = cu_seq_q - ctx.cu_seq_k = cu_seq_k - ctx.max_q = max_q - ctx.max_k = max_k - ctx.is_causal = is_causal - ctx.output = out - ctx.lse = lse - ctx.rng_state = rng_state - - -@torch.library.custom_op("torch_attn::_varlen_attn_backward", mutates_args={}) -def _varlen_attn_backward( - grad_out: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - out: torch.Tensor, - lse: torch.Tensor, - cu_seq_q: torch.Tensor, - cu_seq_k: torch.Tensor, - max_q: int, - max_k: int, - is_causal: bool, - rng_state: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - unused = torch.empty(0, device=query.device) - - use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index) - if use_cudnn: - log.info("Using cuDNN backend for varlen_attn") - dq, dk, dv = torch.ops.aten._cudnn_attention_backward( - grad_out, - query, - key, - value, - out, - lse, - cu_seq_q, - cu_seq_k, - max_q, - max_k, - 0.0, - is_causal, - rng_state, - unused, - ) - else: - log.info("Using Flash Attention backend for varlen_attn") - dq, dk, dv = torch.ops.aten._flash_attention_backward( - grad_out, - query, - key, - value, - out, - lse, - cu_seq_q, - cu_seq_k, - max_q, - max_k, - 0.0, - is_causal, - rng_state, - unused, - ) - return dq, dk, dv - - -@_varlen_attn_backward.register_fake -def _varlen_attn_backward_fake( - grad_out: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - out: torch.Tensor, - lse: torch.Tensor, - cu_seq_q: torch.Tensor, - cu_seq_k: torch.Tensor, - max_q: int, - max_k: int, - is_causal: bool, - rng_state: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Fake implementation for meta tensor computation and tracing. - """ - - grad_query = torch.empty_like(query) - grad_key = torch.empty_like(key) - grad_value = torch.empty_like(value) - - return grad_query, grad_key, grad_value - - -def _backward( - ctx: Any, grad_out: torch.Tensor, grad_lse: torch.Tensor, grad_rng: torch.Tensor -) -> tuple[Optional[torch.Tensor], ...]: - query = ctx.query - key = ctx.key - value = ctx.value - cu_seq_q = ctx.cu_seq_q - cu_seq_k = ctx.cu_seq_k - max_q = ctx.max_q - max_k = ctx.max_k - is_causal = ctx.is_causal - out = ctx.output - lse = ctx.lse - rng_state = ctx.rng_state - - # rng_state = torch.empty(2, device=query.device) - - dq, dk, dv = torch.ops.torch_attn._varlen_attn_backward( - grad_out, - query, - key, - value, - out, - lse, - cu_seq_q, - cu_seq_k, - max_q, - max_k, - is_causal, - rng_state, - ) - return dq, dk, dv, None, None, None, None, None, None - - -_varlen_attn.register_autograd(_backward, setup_context=_setup_context)